hyperpocket 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. hyperpocket/__init__.py +4 -4
  2. hyperpocket/auth/__init__.py +12 -7
  3. hyperpocket/auth/calendly/oauth2_handler.py +24 -17
  4. hyperpocket/auth/calendly/oauth2_schema.py +3 -1
  5. hyperpocket/auth/context.py +2 -1
  6. hyperpocket/auth/github/oauth2_handler.py +13 -8
  7. hyperpocket/auth/github/token_handler.py +27 -21
  8. hyperpocket/auth/google/context.py +1 -3
  9. hyperpocket/auth/google/oauth2_context.py +1 -1
  10. hyperpocket/auth/google/oauth2_handler.py +22 -17
  11. hyperpocket/auth/gumloop/token_context.py +1 -4
  12. hyperpocket/auth/gumloop/token_handler.py +48 -20
  13. hyperpocket/auth/gumloop/token_schema.py +2 -1
  14. hyperpocket/auth/handler.py +21 -6
  15. hyperpocket/auth/linear/token_context.py +2 -5
  16. hyperpocket/auth/linear/token_handler.py +45 -21
  17. hyperpocket/auth/notion/context.py +2 -2
  18. hyperpocket/auth/notion/token_context.py +2 -4
  19. hyperpocket/auth/notion/token_handler.py +45 -21
  20. hyperpocket/auth/notion/token_schema.py +0 -1
  21. hyperpocket/auth/reddit/oauth2_handler.py +9 -10
  22. hyperpocket/auth/reddit/oauth2_schema.py +0 -2
  23. hyperpocket/auth/schema.py +4 -1
  24. hyperpocket/auth/slack/oauth2_context.py +3 -1
  25. hyperpocket/auth/slack/oauth2_handler.py +55 -35
  26. hyperpocket/auth/slack/token_context.py +2 -4
  27. hyperpocket/auth/slack/token_handler.py +42 -19
  28. hyperpocket/builtin.py +4 -2
  29. hyperpocket/cli/__main__.py +4 -2
  30. hyperpocket/cli/auth.py +59 -28
  31. hyperpocket/cli/codegen/auth/auth_context_template.py +3 -2
  32. hyperpocket/cli/codegen/auth/auth_token_context_template.py +3 -2
  33. hyperpocket/cli/codegen/auth/auth_token_handler_template.py +6 -5
  34. hyperpocket/cli/codegen/auth/auth_token_schema_template.py +3 -2
  35. hyperpocket/cli/codegen/auth/server_auth_template.py +3 -2
  36. hyperpocket/cli/pull.py +5 -5
  37. hyperpocket/config/__init__.py +3 -8
  38. hyperpocket/config/auth.py +3 -1
  39. hyperpocket/config/logger.py +20 -15
  40. hyperpocket/config/session.py +4 -2
  41. hyperpocket/config/settings.py +19 -2
  42. hyperpocket/futures/__init__.py +1 -1
  43. hyperpocket/futures/futurestore.py +3 -2
  44. hyperpocket/pocket_auth.py +171 -84
  45. hyperpocket/pocket_core.py +51 -33
  46. hyperpocket/pocket_main.py +122 -93
  47. hyperpocket/prompts.py +2 -2
  48. hyperpocket/repository/__init__.py +1 -1
  49. hyperpocket/repository/lock.py +47 -33
  50. hyperpocket/repository/lockfile.py +2 -2
  51. hyperpocket/repository/repository.py +1 -1
  52. hyperpocket/server/__init__.py +1 -1
  53. hyperpocket/server/auth/github.py +2 -1
  54. hyperpocket/server/auth/linear.py +1 -3
  55. hyperpocket/server/auth/notion.py +2 -5
  56. hyperpocket/server/auth/slack.py +1 -3
  57. hyperpocket/server/auth/token.py +17 -11
  58. hyperpocket/server/proxy.py +29 -13
  59. hyperpocket/server/server.py +75 -31
  60. hyperpocket/server/tool/dto/script.py +15 -10
  61. hyperpocket/server/tool/wasm.py +14 -11
  62. hyperpocket/session/__init__.py +6 -2
  63. hyperpocket/session/in_memory.py +44 -24
  64. hyperpocket/session/interface.py +42 -24
  65. hyperpocket/session/redis.py +48 -31
  66. hyperpocket/tool/__init__.py +10 -10
  67. hyperpocket/tool/function/__init__.py +1 -5
  68. hyperpocket/tool/function/annotation.py +11 -9
  69. hyperpocket/tool/function/tool.py +37 -27
  70. hyperpocket/tool/tool.py +59 -36
  71. hyperpocket/tool/wasm/__init__.py +1 -1
  72. hyperpocket/tool/wasm/browser.py +15 -10
  73. hyperpocket/tool/wasm/invoker.py +16 -16
  74. hyperpocket/tool/wasm/script.py +27 -14
  75. hyperpocket/tool/wasm/templates/__init__.py +22 -15
  76. hyperpocket/tool/wasm/templates/node.py +2 -2
  77. hyperpocket/tool/wasm/templates/python.py +2 -2
  78. hyperpocket/tool/wasm/tool.py +27 -14
  79. hyperpocket/tool_like.py +3 -3
  80. hyperpocket/util/__init__.py +1 -1
  81. hyperpocket/util/extract_func_param_desc_from_docstring.py +23 -7
  82. hyperpocket/util/find_all_leaf_class_in_package.py +4 -3
  83. hyperpocket/util/find_all_subclass_in_package.py +4 -2
  84. hyperpocket/util/flatten_json_schema.py +10 -6
  85. hyperpocket/util/function_to_model.py +33 -12
  86. hyperpocket/util/get_objects_from_subpackage.py +1 -1
  87. hyperpocket/util/json_schema_to_model.py +14 -5
  88. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/METADATA +11 -5
  89. hyperpocket-0.2.0.dist-info/RECORD +137 -0
  90. hyperpocket-0.1.10.dist-info/RECORD +0 -137
  91. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/WHEEL +0 -0
  92. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from http import HTTPStatus
2
- from urllib.parse import urlencode, urlunparse, urlparse, parse_qs
2
+ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
3
3
 
4
4
  from fastapi import APIRouter, Form
5
5
  from starlette.responses import HTMLResponse, RedirectResponse
@@ -29,8 +29,12 @@ async def token_form(redirect_uri: str, state: str = ""):
29
29
 
30
30
 
31
31
  @token_router.post("/submit", response_class=RedirectResponse)
32
- async def submit_token(user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)):
33
- new_callback_url = add_query_params(redirect_uri, {"token": user_token, "state": state})
32
+ async def submit_token(
33
+ user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)
34
+ ):
35
+ new_callback_url = add_query_params(
36
+ redirect_uri, {"token": user_token, "state": state}
37
+ )
34
38
  return RedirectResponse(url=new_callback_url, status_code=HTTPStatus.SEE_OTHER)
35
39
 
36
40
 
@@ -40,12 +44,14 @@ def add_query_params(url: str, params: dict):
40
44
  query_params.update(params)
41
45
  new_query = urlencode(query_params, doseq=True)
42
46
 
43
- new_url = urlunparse((
44
- url_parts.scheme,
45
- url_parts.netloc,
46
- url_parts.path,
47
- url_parts.params,
48
- new_query,
49
- url_parts.fragment
50
- ))
47
+ new_url = urlunparse(
48
+ (
49
+ url_parts.scheme,
50
+ url_parts.netloc,
51
+ url_parts.path,
52
+ url_parts.params,
53
+ new_query,
54
+ url_parts.fragment,
55
+ )
56
+ )
51
57
  return new_url
@@ -9,22 +9,27 @@ async def proxy(request: Request, path: str):
9
9
  async with httpx.AsyncClient() as client:
10
10
  resp = await client.request(
11
11
  method=request.method,
12
- url=f"{config.internal_base_url}/{path}",
12
+ url=f"{config().internal_base_url}/{path}",
13
13
  headers=request.headers,
14
14
  content=await request.body(),
15
15
  params=request.query_params,
16
16
  timeout=300,
17
17
  )
18
- return HTMLResponse(content=resp.text, headers=resp.headers, status_code=resp.status_code)
18
+ return HTMLResponse(
19
+ content=resp.text, headers=resp.headers, status_code=resp.status_code
20
+ )
19
21
 
20
22
 
21
23
  def add_callback_proxy(app: FastAPI):
22
- app.add_api_route(f"/{config.callback_url_rewrite_prefix}/{{path:path}}", proxy,
23
- methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
24
+ app.add_api_route(
25
+ f"/{config().callback_url_rewrite_prefix}/{{path:path}}",
26
+ proxy,
27
+ methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
28
+ )
24
29
 
25
30
 
26
31
  https_proxy_app = None
27
- if config.enable_local_callback_proxy:
32
+ if config().enable_local_callback_proxy:
28
33
  https_proxy_app = FastAPI()
29
34
  add_callback_proxy(https_proxy_app)
30
35
 
@@ -44,20 +49,31 @@ def _generate_ssl_certificates(ssl_keypath, ssl_certpath):
44
49
  "/emailAddress=local@example.com"
45
50
  )
46
51
  command = [
47
- "openssl", "req", "-x509",
48
- "-newkey", "rsa:4096",
49
- "-keyout", ssl_keypath,
50
- "-out", ssl_certpath,
51
- "-days", "1",
52
+ "openssl",
53
+ "req",
54
+ "-x509",
55
+ "-newkey",
56
+ "rsa:4096",
57
+ "-keyout",
58
+ ssl_keypath,
59
+ "-out",
60
+ ssl_certpath,
61
+ "-days",
62
+ "1",
52
63
  "-nodes",
53
- '-subj', subj,
64
+ "-subj",
65
+ subj,
54
66
  "-sha256",
55
67
  ]
56
68
 
57
69
  try:
58
70
  # 명령 실행
59
- subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
60
- pocket_logger.info("SSL Certificates generated: callback_server.key, callback_server.crt")
71
+ subprocess.run(
72
+ command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
73
+ )
74
+ pocket_logger.info(
75
+ "SSL Certificates generated: callback_server.key, callback_server.crt"
76
+ )
61
77
  except subprocess.CalledProcessError as e:
62
78
  pocket_logger.warning(f"An error occurred while generating certificates: {e}")
63
79
  raise e
@@ -30,9 +30,11 @@ class PocketServer(object):
30
30
  future_store: dict[str, asyncio.Future]
31
31
  torn_down: bool = False
32
32
 
33
- def __init__(self,
34
- internal_server_port: int = config.internal_server_port,
35
- proxy_port: int = config.public_server_port):
33
+ def __init__(
34
+ self,
35
+ internal_server_port: int = config().internal_server_port,
36
+ proxy_port: int = config().public_server_port,
37
+ ):
36
38
  self.internal_server_port = internal_server_port
37
39
  self.proxy_port = proxy_port
38
40
  self.future_store = dict()
@@ -49,7 +51,9 @@ class PocketServer(object):
49
51
  try:
50
52
  await asyncio.gather(
51
53
  self.main_server.serve(),
52
- self.proxy_server.serve() if self.proxy_server is not None else asyncio.sleep(0),
54
+ self.proxy_server.serve()
55
+ if self.proxy_server is not None
56
+ else asyncio.sleep(0),
53
57
  self.poll_in_child(),
54
58
  )
55
59
  except Exception as e:
@@ -75,7 +79,9 @@ class PocketServer(object):
75
79
  result = self.pocket_core.prepare_auth(*a, **kw)
76
80
  error = None
77
81
  except Exception as e:
78
- pocket_logger.error(f"failed to prepare in pocket subprocess. error: {e}")
82
+ pocket_logger.error(
83
+ f"failed to prepare in pocket subprocess. error: {e}"
84
+ )
79
85
  result = None
80
86
  error = e
81
87
 
@@ -86,7 +92,9 @@ class PocketServer(object):
86
92
  result = await self.pocket_core.authenticate(*a, **kw)
87
93
  error = None
88
94
  except Exception as e:
89
- pocket_logger.error(f"failed to authenticate in pocket subprocess. error: {e}")
95
+ pocket_logger.error(
96
+ f"failed to authenticate in pocket subprocess. error: {e}"
97
+ )
90
98
  result = None
91
99
  error = e
92
100
 
@@ -97,7 +105,9 @@ class PocketServer(object):
97
105
  result = await self.pocket_core.tool_call(*a, **kw)
98
106
  error = None
99
107
  except Exception as e:
100
- pocket_logger.error(f"failed to tool_call in pocket subprocess. error: {e}")
108
+ pocket_logger.error(
109
+ f"failed to tool_call in pocket subprocess. error: {e}"
110
+ )
101
111
  result = None
102
112
  error = e
103
113
 
@@ -119,10 +129,7 @@ class PocketServer(object):
119
129
  else:
120
130
  await asyncio.sleep(0)
121
131
 
122
- def send_in_parent(self,
123
- op: PocketServerOperations,
124
- args: tuple,
125
- kwargs: dict):
132
+ def send_in_parent(self, op: PocketServerOperations, args: tuple, kwargs: dict):
126
133
  conn, _ = self.pipe
127
134
  uid = str(uuid.uuid4())
128
135
  message = (op.value, uid, args, kwargs)
@@ -145,10 +152,9 @@ class PocketServer(object):
145
152
  else:
146
153
  await asyncio.sleep(0)
147
154
 
148
- async def call_in_subprocess(self,
149
- op: PocketServerOperations,
150
- args: tuple,
151
- kwargs: dict):
155
+ async def call_in_subprocess(
156
+ self, op: PocketServerOperations, args: tuple, kwargs: dict
157
+ ):
152
158
  uid = self.send_in_parent(op, args, kwargs)
153
159
  loop = asyncio.get_running_loop()
154
160
  loop.create_task(self.poll_in_parent())
@@ -157,21 +163,52 @@ class PocketServer(object):
157
163
  def run(self, pocket_core: PocketCore):
158
164
  self._set_mp_start_method()
159
165
 
166
+ error_queue = mp.Queue()
160
167
  self.pipe = mp.Pipe()
161
- self.process = mp.Process(target=self._run, args=(pocket_core,))
162
- self.process.start()
168
+ self.process = mp.Process(
169
+ target=self._run, args=(pocket_core, )
170
+ )
171
+ self.process.start() # process start
172
+
173
+ if not error_queue.empty():
174
+ error_message = error_queue.get()
175
+ raise error_message
176
+
177
+ def _report_initialized(self, error: Optional[Exception] = None):
178
+ _, conn = self.pipe
179
+ conn.send(('server-initialization', error,))
180
+
181
+ def wait_initialized(self):
182
+ conn, _ = self.pipe
183
+ while True:
184
+ if conn.poll():
185
+ _, error = conn.recv()
186
+ if error:
187
+ raise error
188
+ return
163
189
 
164
190
  def _run(self, pocket_core):
165
- self.pocket_core = pocket_core
166
- self.main_server = self._create_main_server()
167
- self.proxy_server = self._create_https_proxy_server()
168
- loop = asyncio.new_event_loop()
169
- loop.run_until_complete(self._run_async())
170
- loop.close()
191
+ try:
192
+ # init process
193
+ self.pocket_core = pocket_core
194
+ self.main_server = self._create_main_server()
195
+ self.proxy_server = self._create_https_proxy_server()
196
+ self._report_initialized()
197
+
198
+ loop = asyncio.new_event_loop()
199
+ loop.run_until_complete(self._run_async())
200
+ loop.close()
201
+ except Exception as error:
202
+ self._report_initialized(error)
171
203
 
172
204
  def _create_main_server(self) -> Server:
173
205
  app = FastAPI()
174
- _config = Config(app, host="0.0.0.0", port=self.internal_server_port)
206
+ _config = Config(
207
+ app,
208
+ host="0.0.0.0",
209
+ port=self.internal_server_port,
210
+ log_level=config().log_level,
211
+ )
175
212
  app.include_router(tool_router)
176
213
  app.include_router(auth_router)
177
214
  app.add_api_route("/health", lambda: {"status": "ok"}, methods=["GET"])
@@ -180,20 +217,25 @@ class PocketServer(object):
180
217
  return app
181
218
 
182
219
  def _create_https_proxy_server(self) -> Optional[Server]:
183
- if not config.enable_local_callback_proxy:
220
+ if not config().enable_local_callback_proxy:
184
221
  return None
185
- from hyperpocket.server.proxy import _generate_ssl_certificates
186
- from hyperpocket.server.proxy import https_proxy_app
187
-
188
222
  from hyperpocket.config.settings import POCKET_ROOT
223
+ from hyperpocket.server.proxy import _generate_ssl_certificates, https_proxy_app
224
+
189
225
  ssl_keypath = POCKET_ROOT / "callback_server.key"
190
226
  ssl_certpath = POCKET_ROOT / "callback_server.crt"
191
227
 
192
228
  if not ssl_keypath.exists() or not ssl_certpath.exists():
193
229
  _generate_ssl_certificates(ssl_keypath, ssl_certpath)
194
230
 
195
- _config = Config(https_proxy_app, host="0.0.0.0", port=self.proxy_port, ssl_keyfile=ssl_keypath,
196
- ssl_certfile=ssl_certpath)
231
+ _config = Config(
232
+ https_proxy_app,
233
+ host="0.0.0.0",
234
+ port=self.proxy_port,
235
+ ssl_keyfile=ssl_keypath,
236
+ ssl_certfile=ssl_certpath,
237
+ log_level=config().log_level,
238
+ )
197
239
  proxy_server = Server(_config)
198
240
  return proxy_server
199
241
 
@@ -211,4 +253,6 @@ class PocketServer(object):
211
253
  mp.set_start_method("fork", force=True)
212
254
  pocket_logger.debug("Process start method set to 'fork' for Linux.")
213
255
  else:
214
- pocket_logger.debug(f"Unrecognized OS: {os_name}. Default start method will be used.")
256
+ pocket_logger.debug(
257
+ f"Unrecognized OS: {os_name}. Default start method will be used."
258
+ )
@@ -1,28 +1,33 @@
1
1
  from typing import Optional
2
+
2
3
  from pydantic import BaseModel, Field
3
4
 
4
5
  from hyperpocket.tool.wasm.script import ScriptFileNode
5
6
 
6
7
 
7
8
  class Script(BaseModel):
8
- id: str = Field(alias='id')
9
- tool_id: str = Field(alias='tool_id')
9
+ id: str = Field(alias="id")
10
+ tool_id: str = Field(alias="tool_id")
10
11
 
11
12
 
12
13
  class ScriptResult(BaseModel):
13
- stdout: Optional[str] = Field(alias='stdout', default=None)
14
- stderr: Optional[str] = Field(alias='stderr', default=None)
15
- error: Optional[str] = Field(alias='error', default=None)
14
+ stdout: Optional[str] = Field(alias="stdout", default=None)
15
+ stderr: Optional[str] = Field(alias="stderr", default=None)
16
+ error: Optional[str] = Field(alias="error", default=None)
17
+
16
18
 
17
19
  class ScriptFileTree(BaseModel):
18
- tree: dict[str, ScriptFileNode] = Field(alias='tree')
20
+ tree: dict[str, ScriptFileNode] = Field(alias="tree")
21
+
19
22
 
20
23
  class ScriptEntrypoint(BaseModel):
21
- package_name: Optional[str] = Field(alias='package_name')
22
- entrypoint: str = Field(alias='entrypoint')
24
+ package_name: Optional[str] = Field(alias="package_name")
25
+ entrypoint: str = Field(alias="entrypoint")
26
+
23
27
 
24
28
  class ScriptEncodedFile(BaseModel):
25
- encoded_file: str = Field(alias='encoded_file')
29
+ encoded_file: str = Field(alias="encoded_file")
30
+
26
31
 
27
32
  class ScriptFileRequest(BaseModel):
28
- path: str = Field(alias='path')
33
+ path: str = Field(alias="path")
@@ -1,13 +1,11 @@
1
1
  from fastapi import APIRouter
2
- from fastapi.responses import HTMLResponse, FileResponse
2
+ from fastapi.responses import FileResponse, HTMLResponse
3
3
 
4
4
  from hyperpocket.futures import FutureStore
5
5
  from hyperpocket.server.tool.dto import script as scriptdto
6
6
  from hyperpocket.tool.wasm.script import ScriptStore
7
7
 
8
- wasm_tool_router = APIRouter(
9
- prefix="/wasm"
10
- )
8
+ wasm_tool_router = APIRouter(prefix="/wasm")
11
9
 
12
10
 
13
11
  @wasm_tool_router.get("/scripts/{script_id}/browse", response_class=HTMLResponse)
@@ -17,19 +15,21 @@ async def browse_script_page(script_id: str):
17
15
 
18
16
 
19
17
  @wasm_tool_router.post("/scripts/{script_id}/done")
20
- async def done_script_page(script_id: str, req: scriptdto.ScriptResult) -> scriptdto.ScriptResult:
21
- FutureStore.resolve_future(script_id, {
22
- 'stdout': req.stdout,
23
- 'stderr': req.stderr,
24
- 'error': req.error
25
- })
18
+ async def done_script_page(
19
+ script_id: str, req: scriptdto.ScriptResult
20
+ ) -> scriptdto.ScriptResult:
21
+ FutureStore.resolve_future(
22
+ script_id, {"stdout": req.stdout, "stderr": req.stderr, "error": req.error}
23
+ )
26
24
  return req
27
25
 
26
+
28
27
  @wasm_tool_router.get("/scripts/{script_id}/file_tree")
29
28
  async def get_file_tree(script_id: str) -> scriptdto.ScriptFileTree:
30
29
  script = ScriptStore.get_script(script_id)
31
30
  return scriptdto.ScriptFileTree(tree=script.load_file_tree())
32
31
 
32
+
33
33
  @wasm_tool_router.get("/scripts/{script_id}/entrypoint")
34
34
  async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
35
35
  script = ScriptStore.get_script(script_id)
@@ -37,7 +37,10 @@ async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
37
37
  entrypoint = f"/tools/wasm/scripts/{script_id}/file/{script.entrypoint}"
38
38
  return scriptdto.ScriptEntrypoint(package_name=package_name, entrypoint=entrypoint)
39
39
 
40
- @wasm_tool_router.get("/scripts/{script_id}/file/{file_name}", response_class=FileResponse)
40
+
41
+ @wasm_tool_router.get(
42
+ "/scripts/{script_id}/file/{file_name}", response_class=FileResponse
43
+ )
41
44
  async def get_dist_file(script_id: str, file_name: str):
42
45
  script = ScriptStore.get_script(script_id)
43
46
  return FileResponse(script.dist_file_path(file_name))
@@ -1,4 +1,8 @@
1
1
  from hyperpocket.session.interface import SessionStorageInterface
2
- from hyperpocket.util.find_all_leaf_class_in_package import find_all_leaf_class_in_package
2
+ from hyperpocket.util.find_all_leaf_class_in_package import (
3
+ find_all_leaf_class_in_package,
4
+ )
3
5
 
4
- SESSION_STORAGE_LIST = find_all_leaf_class_in_package("hyperpocket.session", SessionStorageInterface)
6
+ SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
7
+ "hyperpocket.session", SessionStorageInterface
8
+ )
@@ -3,16 +3,22 @@ from typing import Dict, List, Optional
3
3
 
4
4
  from hyperpocket.auth import AuthProvider
5
5
  from hyperpocket.auth.context import AuthContext
6
- from hyperpocket.config.session import SessionConfigInMemory
7
- from hyperpocket.config.session import SessionType
8
- from hyperpocket.session.interface import SessionStorageInterface, SESSION_KEY_DELIMITER, \
9
- BaseSessionValue, V, K
6
+ from hyperpocket.config.session import SessionConfigInMemory, SessionType
7
+ from hyperpocket.session.interface import (
8
+ SESSION_KEY_DELIMITER,
9
+ BaseSessionValue,
10
+ K,
11
+ SessionStorageInterface,
12
+ V,
13
+ )
10
14
 
11
15
  InMemorySessionKey = str
12
16
  InMemorySessionValue = BaseSessionValue
13
17
 
14
18
 
15
- class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]):
19
+ class InMemorySessionStorage(
20
+ SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]
21
+ ):
16
22
  # TODO(moon) : Force it to always take SessionConfig as an input
17
23
  def __init__(self, session_config: SessionConfigInMemory):
18
24
  super().__init__()
@@ -22,41 +28,54 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
22
28
  def session_storage_type(cls) -> SessionType:
23
29
  return SessionType.IN_MEMORY
24
30
 
25
- def get(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> Optional[V]:
31
+ def get(
32
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
33
+ ) -> Optional[V]:
26
34
  key = self._make_session_key(auth_provider.name, thread_id, profile)
27
35
  return self.storage.get(key, None)
28
36
 
29
- def get_by_thread_id(self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs) -> List[V]:
37
+ def get_by_thread_id(
38
+ self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
39
+ ) -> List[V]:
30
40
  if auth_provider is None:
31
41
  auth_provider_name = ".*"
32
42
  else:
33
43
  auth_provider_name = auth_provider.name
34
44
 
35
- pattern = rf'{self._make_session_key(auth_provider_name, thread_id, ".*")}'
45
+ pattern = rf"{self._make_session_key(auth_provider_name, thread_id, '.*')}"
36
46
  compiled = re.compile(pattern)
37
47
 
38
- session_list = [value for key, value in self.storage.items() if compiled.match(key)]
48
+ session_list = [
49
+ value for key, value in self.storage.items() if compiled.match(key)
50
+ ]
39
51
  return session_list
40
52
 
41
- def set(self, auth_provider: AuthProvider,
42
- thread_id: str,
43
- profile: str,
44
- auth_scopes: List[str],
45
- auth_resolve_uid: Optional[str],
46
- auth_context: Optional[AuthContext],
47
- is_auth_scope_universal: bool, **kwargs) -> V:
53
+ def set(
54
+ self,
55
+ auth_provider: AuthProvider,
56
+ thread_id: str,
57
+ profile: str,
58
+ auth_scopes: List[str],
59
+ auth_resolve_uid: Optional[str],
60
+ auth_context: Optional[AuthContext],
61
+ is_auth_scope_universal: bool,
62
+ **kwargs,
63
+ ) -> V:
48
64
  key = self._make_session_key(auth_provider.name, thread_id, profile)
49
65
  session = self._make_session(
50
66
  auth_provider_name=auth_provider.name,
51
67
  auth_scopes=auth_scopes,
52
68
  auth_resolve_uid=auth_resolve_uid,
53
69
  auth_context=auth_context,
54
- is_auth_scope_universal=is_auth_scope_universal)
70
+ is_auth_scope_universal=is_auth_scope_universal,
71
+ )
55
72
 
56
73
  self.storage[key] = session
57
74
  return session
58
75
 
59
- def delete(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> bool:
76
+ def delete(
77
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
78
+ ) -> bool:
60
79
  key = self._make_session_key(auth_provider.name, thread_id, profile)
61
80
  if key in self.storage:
62
81
  self.storage.pop(key)
@@ -75,15 +94,16 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
75
94
 
76
95
  @staticmethod
77
96
  def _make_session(
78
- auth_provider_name: str,
79
- auth_scopes: List[str],
80
- auth_context: AuthContext,
81
- auth_resolve_uid: str,
82
- is_auth_scope_universal: bool) -> V:
97
+ auth_provider_name: str,
98
+ auth_scopes: List[str],
99
+ auth_context: AuthContext,
100
+ auth_resolve_uid: str,
101
+ is_auth_scope_universal: bool,
102
+ ) -> V:
83
103
  return InMemorySessionValue(
84
104
  auth_provider_name=auth_provider_name,
85
105
  auth_scopes=set(auth_scopes),
86
106
  auth_context=auth_context,
87
107
  auth_resolve_uid=auth_resolve_uid,
88
- scoped=is_auth_scope_universal
108
+ scoped=is_auth_scope_universal,
89
109
  )
@@ -1,6 +1,6 @@
1
1
  import datetime
2
2
  from abc import ABC, abstractmethod
3
- from typing import TypeVar, Generic, List, Set, Optional, Iterable
3
+ from typing import Generic, Iterable, List, Optional, Set, TypeVar
4
4
 
5
5
  from pydantic import BaseModel, Field
6
6
 
@@ -15,28 +15,37 @@ SESSION_KEY_DELIMITER = "__"
15
15
 
16
16
  class BaseSessionValue(BaseModel):
17
17
  auth_provider_name: str = Field(
18
- description="The name of the authentication provider used to authenticate the current session")
18
+ description="The name of the authentication provider used to authenticate the current session"
19
+ )
19
20
  auth_context: Optional[AuthContext] = Field(
20
21
  default=None,
21
- description="The authentication context containing the actual session details")
22
- scoped: bool = Field(description="Indicates whether the current session is a scoped session")
22
+ description="The authentication context containing the actual session details",
23
+ )
24
+ scoped: bool = Field(
25
+ description="Indicates whether the current session is a scoped session"
26
+ )
23
27
  auth_scopes: Optional[Set[str]] = Field(
24
28
  default=None,
25
- description="The authentication scopes of the current session, present only for scoped sessions")
29
+ description="The authentication scopes of the current session, present only for scoped sessions",
30
+ )
26
31
  auth_resolve_uid: Optional[str] = Field(
27
32
  default=None,
28
- description="A UID used to asynchronously verify whether the user has completed the authentication process")
33
+ description="A UID used to asynchronously verify whether the user has completed the authentication process",
34
+ )
29
35
 
30
36
  def make_superset_auth_scope(
31
- self,
32
- other_scopes: Optional[Iterable[str]] = None) -> set[str]:
37
+ self, other_scopes: Optional[Iterable[str]] = None
38
+ ) -> set[str]:
33
39
  auth_scopes = self.auth_scopes or set()
34
40
  other_scopes = other_scopes or set()
35
41
  return auth_scopes.union(other_scopes)
36
42
 
37
- def is_auth_applicable(self, auth_provider_name: str, auth_req: AuthenticateRequest) -> bool:
38
- return self.auth_provider_name == auth_provider_name \
39
- and (not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes))
43
+ def is_auth_applicable(
44
+ self, auth_provider_name: str, auth_req: AuthenticateRequest
45
+ ) -> bool:
46
+ return self.auth_provider_name == auth_provider_name and (
47
+ not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes)
48
+ )
40
49
 
41
50
  def is_near_expires(self) -> bool:
42
51
  if self.auth_context.expires_at is not None:
@@ -48,13 +57,15 @@ class BaseSessionValue(BaseModel):
48
57
  return False
49
58
 
50
59
 
51
- K = TypeVar('K')
52
- V = TypeVar('V', bound=BaseSessionValue)
60
+ K = TypeVar("K")
61
+ V = TypeVar("V", bound=BaseSessionValue)
53
62
 
54
63
 
55
64
  class SessionStorageInterface(ABC, Generic[K, V]):
56
65
  @abstractmethod
57
- def get(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> V:
66
+ def get(
67
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
68
+ ) -> V:
58
69
  """
59
70
  Get session
60
71
 
@@ -69,7 +80,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
69
80
  raise NotImplementedError
70
81
 
71
82
  @abstractmethod
72
- def get_by_thread_id(self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs) -> List[V]:
83
+ def get_by_thread_id(
84
+ self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
85
+ ) -> List[V]:
73
86
  """
74
87
  Get session list by thread id
75
88
 
@@ -83,14 +96,17 @@ class SessionStorageInterface(ABC, Generic[K, V]):
83
96
  raise NotImplementedError
84
97
 
85
98
  @abstractmethod
86
- def set(self,
87
- auth_provider: AuthProvider,
88
- thread_id: str,
89
- profile: str,
90
- auth_scopes: List[str],
91
- auth_resolve_uid: Optional[str],
92
- auth_context: Optional[AuthContext],
93
- is_auth_scope_universal: bool, **kwargs) -> V:
99
+ def set(
100
+ self,
101
+ auth_provider: AuthProvider,
102
+ thread_id: str,
103
+ profile: str,
104
+ auth_scopes: List[str],
105
+ auth_resolve_uid: Optional[str],
106
+ auth_context: Optional[AuthContext],
107
+ is_auth_scope_universal: bool,
108
+ **kwargs,
109
+ ) -> V:
94
110
  """
95
111
  Set session, if a session doesn't exist, create new session
96
112
  If set auth_resolve_uid is None and auth_context is not None, created session is regarded as active session.
@@ -112,7 +128,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
112
128
  raise NotImplementedError
113
129
 
114
130
  @abstractmethod
115
- def delete(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> bool:
131
+ def delete(
132
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
133
+ ) -> bool:
116
134
  """
117
135
  Delete session
118
136