hyperpocket 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperpocket/__init__.py +4 -4
- hyperpocket/auth/__init__.py +12 -7
- hyperpocket/auth/calendly/oauth2_handler.py +24 -17
- hyperpocket/auth/calendly/oauth2_schema.py +3 -1
- hyperpocket/auth/context.py +2 -1
- hyperpocket/auth/github/oauth2_handler.py +13 -8
- hyperpocket/auth/github/token_handler.py +27 -21
- hyperpocket/auth/google/context.py +1 -3
- hyperpocket/auth/google/oauth2_context.py +1 -1
- hyperpocket/auth/google/oauth2_handler.py +22 -17
- hyperpocket/auth/gumloop/token_context.py +1 -4
- hyperpocket/auth/gumloop/token_handler.py +48 -20
- hyperpocket/auth/gumloop/token_schema.py +2 -1
- hyperpocket/auth/handler.py +21 -6
- hyperpocket/auth/linear/token_context.py +2 -5
- hyperpocket/auth/linear/token_handler.py +45 -21
- hyperpocket/auth/notion/context.py +2 -2
- hyperpocket/auth/notion/token_context.py +2 -4
- hyperpocket/auth/notion/token_handler.py +45 -21
- hyperpocket/auth/notion/token_schema.py +0 -1
- hyperpocket/auth/reddit/oauth2_handler.py +9 -10
- hyperpocket/auth/reddit/oauth2_schema.py +0 -2
- hyperpocket/auth/schema.py +4 -1
- hyperpocket/auth/slack/oauth2_context.py +3 -1
- hyperpocket/auth/slack/oauth2_handler.py +55 -35
- hyperpocket/auth/slack/token_context.py +2 -4
- hyperpocket/auth/slack/token_handler.py +42 -19
- hyperpocket/builtin.py +4 -2
- hyperpocket/cli/__main__.py +4 -2
- hyperpocket/cli/auth.py +59 -28
- hyperpocket/cli/codegen/auth/auth_context_template.py +3 -2
- hyperpocket/cli/codegen/auth/auth_token_context_template.py +3 -2
- hyperpocket/cli/codegen/auth/auth_token_handler_template.py +6 -5
- hyperpocket/cli/codegen/auth/auth_token_schema_template.py +3 -2
- hyperpocket/cli/codegen/auth/server_auth_template.py +3 -2
- hyperpocket/cli/pull.py +5 -5
- hyperpocket/config/__init__.py +3 -8
- hyperpocket/config/auth.py +3 -1
- hyperpocket/config/logger.py +20 -15
- hyperpocket/config/session.py +4 -2
- hyperpocket/config/settings.py +19 -2
- hyperpocket/futures/__init__.py +1 -1
- hyperpocket/futures/futurestore.py +3 -2
- hyperpocket/pocket_auth.py +171 -84
- hyperpocket/pocket_core.py +51 -33
- hyperpocket/pocket_main.py +122 -93
- hyperpocket/prompts.py +2 -2
- hyperpocket/repository/__init__.py +1 -1
- hyperpocket/repository/lock.py +47 -33
- hyperpocket/repository/lockfile.py +2 -2
- hyperpocket/repository/repository.py +1 -1
- hyperpocket/server/__init__.py +1 -1
- hyperpocket/server/auth/github.py +2 -1
- hyperpocket/server/auth/linear.py +1 -3
- hyperpocket/server/auth/notion.py +2 -5
- hyperpocket/server/auth/slack.py +1 -3
- hyperpocket/server/auth/token.py +17 -11
- hyperpocket/server/proxy.py +29 -13
- hyperpocket/server/server.py +75 -31
- hyperpocket/server/tool/dto/script.py +15 -10
- hyperpocket/server/tool/wasm.py +14 -11
- hyperpocket/session/__init__.py +6 -2
- hyperpocket/session/in_memory.py +44 -24
- hyperpocket/session/interface.py +42 -24
- hyperpocket/session/redis.py +48 -31
- hyperpocket/tool/__init__.py +10 -10
- hyperpocket/tool/function/__init__.py +1 -5
- hyperpocket/tool/function/annotation.py +11 -9
- hyperpocket/tool/function/tool.py +37 -27
- hyperpocket/tool/tool.py +59 -36
- hyperpocket/tool/wasm/__init__.py +1 -1
- hyperpocket/tool/wasm/browser.py +15 -10
- hyperpocket/tool/wasm/invoker.py +16 -16
- hyperpocket/tool/wasm/script.py +27 -14
- hyperpocket/tool/wasm/templates/__init__.py +22 -15
- hyperpocket/tool/wasm/templates/node.py +2 -2
- hyperpocket/tool/wasm/templates/python.py +2 -2
- hyperpocket/tool/wasm/tool.py +27 -14
- hyperpocket/tool_like.py +3 -3
- hyperpocket/util/__init__.py +1 -1
- hyperpocket/util/extract_func_param_desc_from_docstring.py +23 -7
- hyperpocket/util/find_all_leaf_class_in_package.py +4 -3
- hyperpocket/util/find_all_subclass_in_package.py +4 -2
- hyperpocket/util/flatten_json_schema.py +10 -6
- hyperpocket/util/function_to_model.py +33 -12
- hyperpocket/util/get_objects_from_subpackage.py +1 -1
- hyperpocket/util/json_schema_to_model.py +14 -5
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/METADATA +11 -5
- hyperpocket-0.2.0.dist-info/RECORD +137 -0
- hyperpocket-0.1.10.dist-info/RECORD +0 -137
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/WHEEL +0 -0
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/entry_points.txt +0 -0
hyperpocket/server/auth/token.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from http import HTTPStatus
|
2
|
-
from urllib.parse import
|
2
|
+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
3
3
|
|
4
4
|
from fastapi import APIRouter, Form
|
5
5
|
from starlette.responses import HTMLResponse, RedirectResponse
|
@@ -29,8 +29,12 @@ async def token_form(redirect_uri: str, state: str = ""):
|
|
29
29
|
|
30
30
|
|
31
31
|
@token_router.post("/submit", response_class=RedirectResponse)
|
32
|
-
async def submit_token(
|
33
|
-
|
32
|
+
async def submit_token(
|
33
|
+
user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)
|
34
|
+
):
|
35
|
+
new_callback_url = add_query_params(
|
36
|
+
redirect_uri, {"token": user_token, "state": state}
|
37
|
+
)
|
34
38
|
return RedirectResponse(url=new_callback_url, status_code=HTTPStatus.SEE_OTHER)
|
35
39
|
|
36
40
|
|
@@ -40,12 +44,14 @@ def add_query_params(url: str, params: dict):
|
|
40
44
|
query_params.update(params)
|
41
45
|
new_query = urlencode(query_params, doseq=True)
|
42
46
|
|
43
|
-
new_url = urlunparse(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
47
|
+
new_url = urlunparse(
|
48
|
+
(
|
49
|
+
url_parts.scheme,
|
50
|
+
url_parts.netloc,
|
51
|
+
url_parts.path,
|
52
|
+
url_parts.params,
|
53
|
+
new_query,
|
54
|
+
url_parts.fragment,
|
55
|
+
)
|
56
|
+
)
|
51
57
|
return new_url
|
hyperpocket/server/proxy.py
CHANGED
@@ -9,22 +9,27 @@ async def proxy(request: Request, path: str):
|
|
9
9
|
async with httpx.AsyncClient() as client:
|
10
10
|
resp = await client.request(
|
11
11
|
method=request.method,
|
12
|
-
url=f"{config.internal_base_url}/{path}",
|
12
|
+
url=f"{config().internal_base_url}/{path}",
|
13
13
|
headers=request.headers,
|
14
14
|
content=await request.body(),
|
15
15
|
params=request.query_params,
|
16
16
|
timeout=300,
|
17
17
|
)
|
18
|
-
return HTMLResponse(
|
18
|
+
return HTMLResponse(
|
19
|
+
content=resp.text, headers=resp.headers, status_code=resp.status_code
|
20
|
+
)
|
19
21
|
|
20
22
|
|
21
23
|
def add_callback_proxy(app: FastAPI):
|
22
|
-
app.add_api_route(
|
23
|
-
|
24
|
+
app.add_api_route(
|
25
|
+
f"/{config().callback_url_rewrite_prefix}/{{path:path}}",
|
26
|
+
proxy,
|
27
|
+
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
28
|
+
)
|
24
29
|
|
25
30
|
|
26
31
|
https_proxy_app = None
|
27
|
-
if config.enable_local_callback_proxy:
|
32
|
+
if config().enable_local_callback_proxy:
|
28
33
|
https_proxy_app = FastAPI()
|
29
34
|
add_callback_proxy(https_proxy_app)
|
30
35
|
|
@@ -44,20 +49,31 @@ def _generate_ssl_certificates(ssl_keypath, ssl_certpath):
|
|
44
49
|
"/emailAddress=local@example.com"
|
45
50
|
)
|
46
51
|
command = [
|
47
|
-
"openssl",
|
48
|
-
"
|
49
|
-
"-
|
50
|
-
"-
|
51
|
-
"
|
52
|
+
"openssl",
|
53
|
+
"req",
|
54
|
+
"-x509",
|
55
|
+
"-newkey",
|
56
|
+
"rsa:4096",
|
57
|
+
"-keyout",
|
58
|
+
ssl_keypath,
|
59
|
+
"-out",
|
60
|
+
ssl_certpath,
|
61
|
+
"-days",
|
62
|
+
"1",
|
52
63
|
"-nodes",
|
53
|
-
|
64
|
+
"-subj",
|
65
|
+
subj,
|
54
66
|
"-sha256",
|
55
67
|
]
|
56
68
|
|
57
69
|
try:
|
58
70
|
# 명령 실행
|
59
|
-
subprocess.run(
|
60
|
-
|
71
|
+
subprocess.run(
|
72
|
+
command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
73
|
+
)
|
74
|
+
pocket_logger.info(
|
75
|
+
"SSL Certificates generated: callback_server.key, callback_server.crt"
|
76
|
+
)
|
61
77
|
except subprocess.CalledProcessError as e:
|
62
78
|
pocket_logger.warning(f"An error occurred while generating certificates: {e}")
|
63
79
|
raise e
|
hyperpocket/server/server.py
CHANGED
@@ -30,9 +30,11 @@ class PocketServer(object):
|
|
30
30
|
future_store: dict[str, asyncio.Future]
|
31
31
|
torn_down: bool = False
|
32
32
|
|
33
|
-
def __init__(
|
34
|
-
|
35
|
-
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
internal_server_port: int = config().internal_server_port,
|
36
|
+
proxy_port: int = config().public_server_port,
|
37
|
+
):
|
36
38
|
self.internal_server_port = internal_server_port
|
37
39
|
self.proxy_port = proxy_port
|
38
40
|
self.future_store = dict()
|
@@ -49,7 +51,9 @@ class PocketServer(object):
|
|
49
51
|
try:
|
50
52
|
await asyncio.gather(
|
51
53
|
self.main_server.serve(),
|
52
|
-
self.proxy_server.serve()
|
54
|
+
self.proxy_server.serve()
|
55
|
+
if self.proxy_server is not None
|
56
|
+
else asyncio.sleep(0),
|
53
57
|
self.poll_in_child(),
|
54
58
|
)
|
55
59
|
except Exception as e:
|
@@ -75,7 +79,9 @@ class PocketServer(object):
|
|
75
79
|
result = self.pocket_core.prepare_auth(*a, **kw)
|
76
80
|
error = None
|
77
81
|
except Exception as e:
|
78
|
-
pocket_logger.error(
|
82
|
+
pocket_logger.error(
|
83
|
+
f"failed to prepare in pocket subprocess. error: {e}"
|
84
|
+
)
|
79
85
|
result = None
|
80
86
|
error = e
|
81
87
|
|
@@ -86,7 +92,9 @@ class PocketServer(object):
|
|
86
92
|
result = await self.pocket_core.authenticate(*a, **kw)
|
87
93
|
error = None
|
88
94
|
except Exception as e:
|
89
|
-
pocket_logger.error(
|
95
|
+
pocket_logger.error(
|
96
|
+
f"failed to authenticate in pocket subprocess. error: {e}"
|
97
|
+
)
|
90
98
|
result = None
|
91
99
|
error = e
|
92
100
|
|
@@ -97,7 +105,9 @@ class PocketServer(object):
|
|
97
105
|
result = await self.pocket_core.tool_call(*a, **kw)
|
98
106
|
error = None
|
99
107
|
except Exception as e:
|
100
|
-
pocket_logger.error(
|
108
|
+
pocket_logger.error(
|
109
|
+
f"failed to tool_call in pocket subprocess. error: {e}"
|
110
|
+
)
|
101
111
|
result = None
|
102
112
|
error = e
|
103
113
|
|
@@ -119,10 +129,7 @@ class PocketServer(object):
|
|
119
129
|
else:
|
120
130
|
await asyncio.sleep(0)
|
121
131
|
|
122
|
-
def send_in_parent(self,
|
123
|
-
op: PocketServerOperations,
|
124
|
-
args: tuple,
|
125
|
-
kwargs: dict):
|
132
|
+
def send_in_parent(self, op: PocketServerOperations, args: tuple, kwargs: dict):
|
126
133
|
conn, _ = self.pipe
|
127
134
|
uid = str(uuid.uuid4())
|
128
135
|
message = (op.value, uid, args, kwargs)
|
@@ -145,10 +152,9 @@ class PocketServer(object):
|
|
145
152
|
else:
|
146
153
|
await asyncio.sleep(0)
|
147
154
|
|
148
|
-
async def call_in_subprocess(
|
149
|
-
|
150
|
-
|
151
|
-
kwargs: dict):
|
155
|
+
async def call_in_subprocess(
|
156
|
+
self, op: PocketServerOperations, args: tuple, kwargs: dict
|
157
|
+
):
|
152
158
|
uid = self.send_in_parent(op, args, kwargs)
|
153
159
|
loop = asyncio.get_running_loop()
|
154
160
|
loop.create_task(self.poll_in_parent())
|
@@ -157,21 +163,52 @@ class PocketServer(object):
|
|
157
163
|
def run(self, pocket_core: PocketCore):
|
158
164
|
self._set_mp_start_method()
|
159
165
|
|
166
|
+
error_queue = mp.Queue()
|
160
167
|
self.pipe = mp.Pipe()
|
161
|
-
self.process = mp.Process(
|
162
|
-
|
168
|
+
self.process = mp.Process(
|
169
|
+
target=self._run, args=(pocket_core, )
|
170
|
+
)
|
171
|
+
self.process.start() # process start
|
172
|
+
|
173
|
+
if not error_queue.empty():
|
174
|
+
error_message = error_queue.get()
|
175
|
+
raise error_message
|
176
|
+
|
177
|
+
def _report_initialized(self, error: Optional[Exception] = None):
|
178
|
+
_, conn = self.pipe
|
179
|
+
conn.send(('server-initialization', error,))
|
180
|
+
|
181
|
+
def wait_initialized(self):
|
182
|
+
conn, _ = self.pipe
|
183
|
+
while True:
|
184
|
+
if conn.poll():
|
185
|
+
_, error = conn.recv()
|
186
|
+
if error:
|
187
|
+
raise error
|
188
|
+
return
|
163
189
|
|
164
190
|
def _run(self, pocket_core):
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
191
|
+
try:
|
192
|
+
# init process
|
193
|
+
self.pocket_core = pocket_core
|
194
|
+
self.main_server = self._create_main_server()
|
195
|
+
self.proxy_server = self._create_https_proxy_server()
|
196
|
+
self._report_initialized()
|
197
|
+
|
198
|
+
loop = asyncio.new_event_loop()
|
199
|
+
loop.run_until_complete(self._run_async())
|
200
|
+
loop.close()
|
201
|
+
except Exception as error:
|
202
|
+
self._report_initialized(error)
|
171
203
|
|
172
204
|
def _create_main_server(self) -> Server:
|
173
205
|
app = FastAPI()
|
174
|
-
_config = Config(
|
206
|
+
_config = Config(
|
207
|
+
app,
|
208
|
+
host="0.0.0.0",
|
209
|
+
port=self.internal_server_port,
|
210
|
+
log_level=config().log_level,
|
211
|
+
)
|
175
212
|
app.include_router(tool_router)
|
176
213
|
app.include_router(auth_router)
|
177
214
|
app.add_api_route("/health", lambda: {"status": "ok"}, methods=["GET"])
|
@@ -180,20 +217,25 @@ class PocketServer(object):
|
|
180
217
|
return app
|
181
218
|
|
182
219
|
def _create_https_proxy_server(self) -> Optional[Server]:
|
183
|
-
if not config.enable_local_callback_proxy:
|
220
|
+
if not config().enable_local_callback_proxy:
|
184
221
|
return None
|
185
|
-
from hyperpocket.server.proxy import _generate_ssl_certificates
|
186
|
-
from hyperpocket.server.proxy import https_proxy_app
|
187
|
-
|
188
222
|
from hyperpocket.config.settings import POCKET_ROOT
|
223
|
+
from hyperpocket.server.proxy import _generate_ssl_certificates, https_proxy_app
|
224
|
+
|
189
225
|
ssl_keypath = POCKET_ROOT / "callback_server.key"
|
190
226
|
ssl_certpath = POCKET_ROOT / "callback_server.crt"
|
191
227
|
|
192
228
|
if not ssl_keypath.exists() or not ssl_certpath.exists():
|
193
229
|
_generate_ssl_certificates(ssl_keypath, ssl_certpath)
|
194
230
|
|
195
|
-
_config = Config(
|
196
|
-
|
231
|
+
_config = Config(
|
232
|
+
https_proxy_app,
|
233
|
+
host="0.0.0.0",
|
234
|
+
port=self.proxy_port,
|
235
|
+
ssl_keyfile=ssl_keypath,
|
236
|
+
ssl_certfile=ssl_certpath,
|
237
|
+
log_level=config().log_level,
|
238
|
+
)
|
197
239
|
proxy_server = Server(_config)
|
198
240
|
return proxy_server
|
199
241
|
|
@@ -211,4 +253,6 @@ class PocketServer(object):
|
|
211
253
|
mp.set_start_method("fork", force=True)
|
212
254
|
pocket_logger.debug("Process start method set to 'fork' for Linux.")
|
213
255
|
else:
|
214
|
-
pocket_logger.debug(
|
256
|
+
pocket_logger.debug(
|
257
|
+
f"Unrecognized OS: {os_name}. Default start method will be used."
|
258
|
+
)
|
@@ -1,28 +1,33 @@
|
|
1
1
|
from typing import Optional
|
2
|
+
|
2
3
|
from pydantic import BaseModel, Field
|
3
4
|
|
4
5
|
from hyperpocket.tool.wasm.script import ScriptFileNode
|
5
6
|
|
6
7
|
|
7
8
|
class Script(BaseModel):
|
8
|
-
id: str = Field(alias=
|
9
|
-
tool_id: str = Field(alias=
|
9
|
+
id: str = Field(alias="id")
|
10
|
+
tool_id: str = Field(alias="tool_id")
|
10
11
|
|
11
12
|
|
12
13
|
class ScriptResult(BaseModel):
|
13
|
-
stdout: Optional[str] = Field(alias=
|
14
|
-
stderr: Optional[str] = Field(alias=
|
15
|
-
error: Optional[str] = Field(alias=
|
14
|
+
stdout: Optional[str] = Field(alias="stdout", default=None)
|
15
|
+
stderr: Optional[str] = Field(alias="stderr", default=None)
|
16
|
+
error: Optional[str] = Field(alias="error", default=None)
|
17
|
+
|
16
18
|
|
17
19
|
class ScriptFileTree(BaseModel):
|
18
|
-
tree: dict[str, ScriptFileNode] = Field(alias=
|
20
|
+
tree: dict[str, ScriptFileNode] = Field(alias="tree")
|
21
|
+
|
19
22
|
|
20
23
|
class ScriptEntrypoint(BaseModel):
|
21
|
-
package_name: Optional[str] = Field(alias=
|
22
|
-
entrypoint: str = Field(alias=
|
24
|
+
package_name: Optional[str] = Field(alias="package_name")
|
25
|
+
entrypoint: str = Field(alias="entrypoint")
|
26
|
+
|
23
27
|
|
24
28
|
class ScriptEncodedFile(BaseModel):
|
25
|
-
encoded_file: str = Field(alias=
|
29
|
+
encoded_file: str = Field(alias="encoded_file")
|
30
|
+
|
26
31
|
|
27
32
|
class ScriptFileRequest(BaseModel):
|
28
|
-
path: str = Field(alias=
|
33
|
+
path: str = Field(alias="path")
|
hyperpocket/server/tool/wasm.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1
1
|
from fastapi import APIRouter
|
2
|
-
from fastapi.responses import
|
2
|
+
from fastapi.responses import FileResponse, HTMLResponse
|
3
3
|
|
4
4
|
from hyperpocket.futures import FutureStore
|
5
5
|
from hyperpocket.server.tool.dto import script as scriptdto
|
6
6
|
from hyperpocket.tool.wasm.script import ScriptStore
|
7
7
|
|
8
|
-
wasm_tool_router = APIRouter(
|
9
|
-
prefix="/wasm"
|
10
|
-
)
|
8
|
+
wasm_tool_router = APIRouter(prefix="/wasm")
|
11
9
|
|
12
10
|
|
13
11
|
@wasm_tool_router.get("/scripts/{script_id}/browse", response_class=HTMLResponse)
|
@@ -17,19 +15,21 @@ async def browse_script_page(script_id: str):
|
|
17
15
|
|
18
16
|
|
19
17
|
@wasm_tool_router.post("/scripts/{script_id}/done")
|
20
|
-
async def done_script_page(
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
18
|
+
async def done_script_page(
|
19
|
+
script_id: str, req: scriptdto.ScriptResult
|
20
|
+
) -> scriptdto.ScriptResult:
|
21
|
+
FutureStore.resolve_future(
|
22
|
+
script_id, {"stdout": req.stdout, "stderr": req.stderr, "error": req.error}
|
23
|
+
)
|
26
24
|
return req
|
27
25
|
|
26
|
+
|
28
27
|
@wasm_tool_router.get("/scripts/{script_id}/file_tree")
|
29
28
|
async def get_file_tree(script_id: str) -> scriptdto.ScriptFileTree:
|
30
29
|
script = ScriptStore.get_script(script_id)
|
31
30
|
return scriptdto.ScriptFileTree(tree=script.load_file_tree())
|
32
31
|
|
32
|
+
|
33
33
|
@wasm_tool_router.get("/scripts/{script_id}/entrypoint")
|
34
34
|
async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
|
35
35
|
script = ScriptStore.get_script(script_id)
|
@@ -37,7 +37,10 @@ async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
|
|
37
37
|
entrypoint = f"/tools/wasm/scripts/{script_id}/file/{script.entrypoint}"
|
38
38
|
return scriptdto.ScriptEntrypoint(package_name=package_name, entrypoint=entrypoint)
|
39
39
|
|
40
|
-
|
40
|
+
|
41
|
+
@wasm_tool_router.get(
|
42
|
+
"/scripts/{script_id}/file/{file_name}", response_class=FileResponse
|
43
|
+
)
|
41
44
|
async def get_dist_file(script_id: str, file_name: str):
|
42
45
|
script = ScriptStore.get_script(script_id)
|
43
46
|
return FileResponse(script.dist_file_path(file_name))
|
hyperpocket/session/__init__.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1
1
|
from hyperpocket.session.interface import SessionStorageInterface
|
2
|
-
from hyperpocket.util.find_all_leaf_class_in_package import
|
2
|
+
from hyperpocket.util.find_all_leaf_class_in_package import (
|
3
|
+
find_all_leaf_class_in_package,
|
4
|
+
)
|
3
5
|
|
4
|
-
SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
|
6
|
+
SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
|
7
|
+
"hyperpocket.session", SessionStorageInterface
|
8
|
+
)
|
hyperpocket/session/in_memory.py
CHANGED
@@ -3,16 +3,22 @@ from typing import Dict, List, Optional
|
|
3
3
|
|
4
4
|
from hyperpocket.auth import AuthProvider
|
5
5
|
from hyperpocket.auth.context import AuthContext
|
6
|
-
from hyperpocket.config.session import SessionConfigInMemory
|
7
|
-
from hyperpocket.
|
8
|
-
|
9
|
-
BaseSessionValue,
|
6
|
+
from hyperpocket.config.session import SessionConfigInMemory, SessionType
|
7
|
+
from hyperpocket.session.interface import (
|
8
|
+
SESSION_KEY_DELIMITER,
|
9
|
+
BaseSessionValue,
|
10
|
+
K,
|
11
|
+
SessionStorageInterface,
|
12
|
+
V,
|
13
|
+
)
|
10
14
|
|
11
15
|
InMemorySessionKey = str
|
12
16
|
InMemorySessionValue = BaseSessionValue
|
13
17
|
|
14
18
|
|
15
|
-
class InMemorySessionStorage(
|
19
|
+
class InMemorySessionStorage(
|
20
|
+
SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]
|
21
|
+
):
|
16
22
|
# TODO(moon) : Force it to always take SessionConfig as an input
|
17
23
|
def __init__(self, session_config: SessionConfigInMemory):
|
18
24
|
super().__init__()
|
@@ -22,41 +28,54 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
|
|
22
28
|
def session_storage_type(cls) -> SessionType:
|
23
29
|
return SessionType.IN_MEMORY
|
24
30
|
|
25
|
-
def get(
|
31
|
+
def get(
|
32
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
33
|
+
) -> Optional[V]:
|
26
34
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
27
35
|
return self.storage.get(key, None)
|
28
36
|
|
29
|
-
def get_by_thread_id(
|
37
|
+
def get_by_thread_id(
|
38
|
+
self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
|
39
|
+
) -> List[V]:
|
30
40
|
if auth_provider is None:
|
31
41
|
auth_provider_name = ".*"
|
32
42
|
else:
|
33
43
|
auth_provider_name = auth_provider.name
|
34
44
|
|
35
|
-
pattern = rf
|
45
|
+
pattern = rf"{self._make_session_key(auth_provider_name, thread_id, '.*')}"
|
36
46
|
compiled = re.compile(pattern)
|
37
47
|
|
38
|
-
session_list = [
|
48
|
+
session_list = [
|
49
|
+
value for key, value in self.storage.items() if compiled.match(key)
|
50
|
+
]
|
39
51
|
return session_list
|
40
52
|
|
41
|
-
def set(
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
53
|
+
def set(
|
54
|
+
self,
|
55
|
+
auth_provider: AuthProvider,
|
56
|
+
thread_id: str,
|
57
|
+
profile: str,
|
58
|
+
auth_scopes: List[str],
|
59
|
+
auth_resolve_uid: Optional[str],
|
60
|
+
auth_context: Optional[AuthContext],
|
61
|
+
is_auth_scope_universal: bool,
|
62
|
+
**kwargs,
|
63
|
+
) -> V:
|
48
64
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
49
65
|
session = self._make_session(
|
50
66
|
auth_provider_name=auth_provider.name,
|
51
67
|
auth_scopes=auth_scopes,
|
52
68
|
auth_resolve_uid=auth_resolve_uid,
|
53
69
|
auth_context=auth_context,
|
54
|
-
is_auth_scope_universal=is_auth_scope_universal
|
70
|
+
is_auth_scope_universal=is_auth_scope_universal,
|
71
|
+
)
|
55
72
|
|
56
73
|
self.storage[key] = session
|
57
74
|
return session
|
58
75
|
|
59
|
-
def delete(
|
76
|
+
def delete(
|
77
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
78
|
+
) -> bool:
|
60
79
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
61
80
|
if key in self.storage:
|
62
81
|
self.storage.pop(key)
|
@@ -75,15 +94,16 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
|
|
75
94
|
|
76
95
|
@staticmethod
|
77
96
|
def _make_session(
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
97
|
+
auth_provider_name: str,
|
98
|
+
auth_scopes: List[str],
|
99
|
+
auth_context: AuthContext,
|
100
|
+
auth_resolve_uid: str,
|
101
|
+
is_auth_scope_universal: bool,
|
102
|
+
) -> V:
|
83
103
|
return InMemorySessionValue(
|
84
104
|
auth_provider_name=auth_provider_name,
|
85
105
|
auth_scopes=set(auth_scopes),
|
86
106
|
auth_context=auth_context,
|
87
107
|
auth_resolve_uid=auth_resolve_uid,
|
88
|
-
scoped=is_auth_scope_universal
|
108
|
+
scoped=is_auth_scope_universal,
|
89
109
|
)
|
hyperpocket/session/interface.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import datetime
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from typing import
|
3
|
+
from typing import Generic, Iterable, List, Optional, Set, TypeVar
|
4
4
|
|
5
5
|
from pydantic import BaseModel, Field
|
6
6
|
|
@@ -15,28 +15,37 @@ SESSION_KEY_DELIMITER = "__"
|
|
15
15
|
|
16
16
|
class BaseSessionValue(BaseModel):
|
17
17
|
auth_provider_name: str = Field(
|
18
|
-
description="The name of the authentication provider used to authenticate the current session"
|
18
|
+
description="The name of the authentication provider used to authenticate the current session"
|
19
|
+
)
|
19
20
|
auth_context: Optional[AuthContext] = Field(
|
20
21
|
default=None,
|
21
|
-
description="The authentication context containing the actual session details"
|
22
|
-
|
22
|
+
description="The authentication context containing the actual session details",
|
23
|
+
)
|
24
|
+
scoped: bool = Field(
|
25
|
+
description="Indicates whether the current session is a scoped session"
|
26
|
+
)
|
23
27
|
auth_scopes: Optional[Set[str]] = Field(
|
24
28
|
default=None,
|
25
|
-
description="The authentication scopes of the current session, present only for scoped sessions"
|
29
|
+
description="The authentication scopes of the current session, present only for scoped sessions",
|
30
|
+
)
|
26
31
|
auth_resolve_uid: Optional[str] = Field(
|
27
32
|
default=None,
|
28
|
-
description="A UID used to asynchronously verify whether the user has completed the authentication process"
|
33
|
+
description="A UID used to asynchronously verify whether the user has completed the authentication process",
|
34
|
+
)
|
29
35
|
|
30
36
|
def make_superset_auth_scope(
|
31
|
-
|
32
|
-
|
37
|
+
self, other_scopes: Optional[Iterable[str]] = None
|
38
|
+
) -> set[str]:
|
33
39
|
auth_scopes = self.auth_scopes or set()
|
34
40
|
other_scopes = other_scopes or set()
|
35
41
|
return auth_scopes.union(other_scopes)
|
36
42
|
|
37
|
-
def is_auth_applicable(
|
38
|
-
|
39
|
-
|
43
|
+
def is_auth_applicable(
|
44
|
+
self, auth_provider_name: str, auth_req: AuthenticateRequest
|
45
|
+
) -> bool:
|
46
|
+
return self.auth_provider_name == auth_provider_name and (
|
47
|
+
not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes)
|
48
|
+
)
|
40
49
|
|
41
50
|
def is_near_expires(self) -> bool:
|
42
51
|
if self.auth_context.expires_at is not None:
|
@@ -48,13 +57,15 @@ class BaseSessionValue(BaseModel):
|
|
48
57
|
return False
|
49
58
|
|
50
59
|
|
51
|
-
K = TypeVar(
|
52
|
-
V = TypeVar(
|
60
|
+
K = TypeVar("K")
|
61
|
+
V = TypeVar("V", bound=BaseSessionValue)
|
53
62
|
|
54
63
|
|
55
64
|
class SessionStorageInterface(ABC, Generic[K, V]):
|
56
65
|
@abstractmethod
|
57
|
-
def get(
|
66
|
+
def get(
|
67
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
68
|
+
) -> V:
|
58
69
|
"""
|
59
70
|
Get session
|
60
71
|
|
@@ -69,7 +80,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
69
80
|
raise NotImplementedError
|
70
81
|
|
71
82
|
@abstractmethod
|
72
|
-
def get_by_thread_id(
|
83
|
+
def get_by_thread_id(
|
84
|
+
self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
|
85
|
+
) -> List[V]:
|
73
86
|
"""
|
74
87
|
Get session list by thread id
|
75
88
|
|
@@ -83,14 +96,17 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
83
96
|
raise NotImplementedError
|
84
97
|
|
85
98
|
@abstractmethod
|
86
|
-
def set(
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
99
|
+
def set(
|
100
|
+
self,
|
101
|
+
auth_provider: AuthProvider,
|
102
|
+
thread_id: str,
|
103
|
+
profile: str,
|
104
|
+
auth_scopes: List[str],
|
105
|
+
auth_resolve_uid: Optional[str],
|
106
|
+
auth_context: Optional[AuthContext],
|
107
|
+
is_auth_scope_universal: bool,
|
108
|
+
**kwargs,
|
109
|
+
) -> V:
|
94
110
|
"""
|
95
111
|
Set session, if a session doesn't exist, create new session
|
96
112
|
If set auth_resolve_uid is None and auth_context is not None, created session is regarded as active session.
|
@@ -112,7 +128,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
112
128
|
raise NotImplementedError
|
113
129
|
|
114
130
|
@abstractmethod
|
115
|
-
def delete(
|
131
|
+
def delete(
|
132
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
133
|
+
) -> bool:
|
116
134
|
"""
|
117
135
|
Delete session
|
118
136
|
|