hyperpocket 0.1.10__py3-none-any.whl → 0.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (92) hide show
  1. hyperpocket/__init__.py +4 -4
  2. hyperpocket/auth/__init__.py +12 -7
  3. hyperpocket/auth/calendly/oauth2_handler.py +24 -17
  4. hyperpocket/auth/calendly/oauth2_schema.py +3 -1
  5. hyperpocket/auth/context.py +2 -1
  6. hyperpocket/auth/github/oauth2_handler.py +13 -8
  7. hyperpocket/auth/github/token_handler.py +27 -21
  8. hyperpocket/auth/google/context.py +1 -3
  9. hyperpocket/auth/google/oauth2_context.py +1 -1
  10. hyperpocket/auth/google/oauth2_handler.py +22 -17
  11. hyperpocket/auth/gumloop/token_context.py +1 -4
  12. hyperpocket/auth/gumloop/token_handler.py +48 -20
  13. hyperpocket/auth/gumloop/token_schema.py +2 -1
  14. hyperpocket/auth/handler.py +21 -6
  15. hyperpocket/auth/linear/token_context.py +2 -5
  16. hyperpocket/auth/linear/token_handler.py +45 -21
  17. hyperpocket/auth/notion/context.py +2 -2
  18. hyperpocket/auth/notion/token_context.py +2 -4
  19. hyperpocket/auth/notion/token_handler.py +45 -21
  20. hyperpocket/auth/notion/token_schema.py +0 -1
  21. hyperpocket/auth/reddit/oauth2_handler.py +9 -10
  22. hyperpocket/auth/reddit/oauth2_schema.py +0 -2
  23. hyperpocket/auth/schema.py +4 -1
  24. hyperpocket/auth/slack/oauth2_context.py +3 -1
  25. hyperpocket/auth/slack/oauth2_handler.py +55 -35
  26. hyperpocket/auth/slack/token_context.py +2 -4
  27. hyperpocket/auth/slack/token_handler.py +42 -19
  28. hyperpocket/builtin.py +4 -2
  29. hyperpocket/cli/__main__.py +4 -2
  30. hyperpocket/cli/auth.py +59 -28
  31. hyperpocket/cli/codegen/auth/auth_context_template.py +3 -2
  32. hyperpocket/cli/codegen/auth/auth_token_context_template.py +3 -2
  33. hyperpocket/cli/codegen/auth/auth_token_handler_template.py +6 -5
  34. hyperpocket/cli/codegen/auth/auth_token_schema_template.py +3 -2
  35. hyperpocket/cli/codegen/auth/server_auth_template.py +3 -2
  36. hyperpocket/cli/pull.py +5 -5
  37. hyperpocket/config/__init__.py +3 -8
  38. hyperpocket/config/auth.py +3 -1
  39. hyperpocket/config/logger.py +20 -15
  40. hyperpocket/config/session.py +4 -2
  41. hyperpocket/config/settings.py +19 -2
  42. hyperpocket/futures/__init__.py +1 -1
  43. hyperpocket/futures/futurestore.py +3 -2
  44. hyperpocket/pocket_auth.py +171 -84
  45. hyperpocket/pocket_core.py +51 -33
  46. hyperpocket/pocket_main.py +122 -93
  47. hyperpocket/prompts.py +2 -2
  48. hyperpocket/repository/__init__.py +1 -1
  49. hyperpocket/repository/lock.py +47 -33
  50. hyperpocket/repository/lockfile.py +2 -2
  51. hyperpocket/repository/repository.py +1 -1
  52. hyperpocket/server/__init__.py +1 -1
  53. hyperpocket/server/auth/github.py +2 -1
  54. hyperpocket/server/auth/linear.py +1 -3
  55. hyperpocket/server/auth/notion.py +2 -5
  56. hyperpocket/server/auth/slack.py +1 -3
  57. hyperpocket/server/auth/token.py +17 -11
  58. hyperpocket/server/proxy.py +29 -13
  59. hyperpocket/server/server.py +75 -31
  60. hyperpocket/server/tool/dto/script.py +15 -10
  61. hyperpocket/server/tool/wasm.py +14 -11
  62. hyperpocket/session/__init__.py +6 -2
  63. hyperpocket/session/in_memory.py +44 -24
  64. hyperpocket/session/interface.py +42 -24
  65. hyperpocket/session/redis.py +48 -31
  66. hyperpocket/tool/__init__.py +10 -10
  67. hyperpocket/tool/function/__init__.py +1 -5
  68. hyperpocket/tool/function/annotation.py +11 -9
  69. hyperpocket/tool/function/tool.py +37 -27
  70. hyperpocket/tool/tool.py +59 -36
  71. hyperpocket/tool/wasm/__init__.py +1 -1
  72. hyperpocket/tool/wasm/browser.py +15 -10
  73. hyperpocket/tool/wasm/invoker.py +16 -16
  74. hyperpocket/tool/wasm/script.py +27 -14
  75. hyperpocket/tool/wasm/templates/__init__.py +22 -15
  76. hyperpocket/tool/wasm/templates/node.py +2 -2
  77. hyperpocket/tool/wasm/templates/python.py +2 -2
  78. hyperpocket/tool/wasm/tool.py +27 -14
  79. hyperpocket/tool_like.py +3 -3
  80. hyperpocket/util/__init__.py +1 -1
  81. hyperpocket/util/extract_func_param_desc_from_docstring.py +23 -7
  82. hyperpocket/util/find_all_leaf_class_in_package.py +4 -3
  83. hyperpocket/util/find_all_subclass_in_package.py +4 -2
  84. hyperpocket/util/flatten_json_schema.py +10 -6
  85. hyperpocket/util/function_to_model.py +33 -12
  86. hyperpocket/util/get_objects_from_subpackage.py +1 -1
  87. hyperpocket/util/json_schema_to_model.py +14 -5
  88. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.1.dist-info}/METADATA +11 -5
  89. hyperpocket-0.2.1.dist-info/RECORD +137 -0
  90. hyperpocket-0.1.10.dist-info/RECORD +0 -137
  91. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.1.dist-info}/WHEEL +0 -0
  92. {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.1.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from http import HTTPStatus
2
- from urllib.parse import urlencode, urlunparse, urlparse, parse_qs
2
+ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
3
3
 
4
4
  from fastapi import APIRouter, Form
5
5
  from starlette.responses import HTMLResponse, RedirectResponse
@@ -29,8 +29,12 @@ async def token_form(redirect_uri: str, state: str = ""):
29
29
 
30
30
 
31
31
  @token_router.post("/submit", response_class=RedirectResponse)
32
- async def submit_token(user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)):
33
- new_callback_url = add_query_params(redirect_uri, {"token": user_token, "state": state})
32
+ async def submit_token(
33
+ user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)
34
+ ):
35
+ new_callback_url = add_query_params(
36
+ redirect_uri, {"token": user_token, "state": state}
37
+ )
34
38
  return RedirectResponse(url=new_callback_url, status_code=HTTPStatus.SEE_OTHER)
35
39
 
36
40
 
@@ -40,12 +44,14 @@ def add_query_params(url: str, params: dict):
40
44
  query_params.update(params)
41
45
  new_query = urlencode(query_params, doseq=True)
42
46
 
43
- new_url = urlunparse((
44
- url_parts.scheme,
45
- url_parts.netloc,
46
- url_parts.path,
47
- url_parts.params,
48
- new_query,
49
- url_parts.fragment
50
- ))
47
+ new_url = urlunparse(
48
+ (
49
+ url_parts.scheme,
50
+ url_parts.netloc,
51
+ url_parts.path,
52
+ url_parts.params,
53
+ new_query,
54
+ url_parts.fragment,
55
+ )
56
+ )
51
57
  return new_url
@@ -9,22 +9,27 @@ async def proxy(request: Request, path: str):
9
9
  async with httpx.AsyncClient() as client:
10
10
  resp = await client.request(
11
11
  method=request.method,
12
- url=f"{config.internal_base_url}/{path}",
12
+ url=f"{config().internal_base_url}/{path}",
13
13
  headers=request.headers,
14
14
  content=await request.body(),
15
15
  params=request.query_params,
16
16
  timeout=300,
17
17
  )
18
- return HTMLResponse(content=resp.text, headers=resp.headers, status_code=resp.status_code)
18
+ return HTMLResponse(
19
+ content=resp.text, headers=resp.headers, status_code=resp.status_code
20
+ )
19
21
 
20
22
 
21
23
  def add_callback_proxy(app: FastAPI):
22
- app.add_api_route(f"/{config.callback_url_rewrite_prefix}/{{path:path}}", proxy,
23
- methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
24
+ app.add_api_route(
25
+ f"/{config().callback_url_rewrite_prefix}/{{path:path}}",
26
+ proxy,
27
+ methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
28
+ )
24
29
 
25
30
 
26
31
  https_proxy_app = None
27
- if config.enable_local_callback_proxy:
32
+ if config().enable_local_callback_proxy:
28
33
  https_proxy_app = FastAPI()
29
34
  add_callback_proxy(https_proxy_app)
30
35
 
@@ -44,20 +49,31 @@ def _generate_ssl_certificates(ssl_keypath, ssl_certpath):
44
49
  "/emailAddress=local@example.com"
45
50
  )
46
51
  command = [
47
- "openssl", "req", "-x509",
48
- "-newkey", "rsa:4096",
49
- "-keyout", ssl_keypath,
50
- "-out", ssl_certpath,
51
- "-days", "1",
52
+ "openssl",
53
+ "req",
54
+ "-x509",
55
+ "-newkey",
56
+ "rsa:4096",
57
+ "-keyout",
58
+ ssl_keypath,
59
+ "-out",
60
+ ssl_certpath,
61
+ "-days",
62
+ "1",
52
63
  "-nodes",
53
- '-subj', subj,
64
+ "-subj",
65
+ subj,
54
66
  "-sha256",
55
67
  ]
56
68
 
57
69
  try:
58
70
  # 명령 실행
59
- subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
60
- pocket_logger.info("SSL Certificates generated: callback_server.key, callback_server.crt")
71
+ subprocess.run(
72
+ command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
73
+ )
74
+ pocket_logger.info(
75
+ "SSL Certificates generated: callback_server.key, callback_server.crt"
76
+ )
61
77
  except subprocess.CalledProcessError as e:
62
78
  pocket_logger.warning(f"An error occurred while generating certificates: {e}")
63
79
  raise e
@@ -30,9 +30,11 @@ class PocketServer(object):
30
30
  future_store: dict[str, asyncio.Future]
31
31
  torn_down: bool = False
32
32
 
33
- def __init__(self,
34
- internal_server_port: int = config.internal_server_port,
35
- proxy_port: int = config.public_server_port):
33
+ def __init__(
34
+ self,
35
+ internal_server_port: int = config().internal_server_port,
36
+ proxy_port: int = config().public_server_port,
37
+ ):
36
38
  self.internal_server_port = internal_server_port
37
39
  self.proxy_port = proxy_port
38
40
  self.future_store = dict()
@@ -49,7 +51,9 @@ class PocketServer(object):
49
51
  try:
50
52
  await asyncio.gather(
51
53
  self.main_server.serve(),
52
- self.proxy_server.serve() if self.proxy_server is not None else asyncio.sleep(0),
54
+ self.proxy_server.serve()
55
+ if self.proxy_server is not None
56
+ else asyncio.sleep(0),
53
57
  self.poll_in_child(),
54
58
  )
55
59
  except Exception as e:
@@ -75,7 +79,9 @@ class PocketServer(object):
75
79
  result = self.pocket_core.prepare_auth(*a, **kw)
76
80
  error = None
77
81
  except Exception as e:
78
- pocket_logger.error(f"failed to prepare in pocket subprocess. error: {e}")
82
+ pocket_logger.error(
83
+ f"failed to prepare in pocket subprocess. error: {e}"
84
+ )
79
85
  result = None
80
86
  error = e
81
87
 
@@ -86,7 +92,9 @@ class PocketServer(object):
86
92
  result = await self.pocket_core.authenticate(*a, **kw)
87
93
  error = None
88
94
  except Exception as e:
89
- pocket_logger.error(f"failed to authenticate in pocket subprocess. error: {e}")
95
+ pocket_logger.error(
96
+ f"failed to authenticate in pocket subprocess. error: {e}"
97
+ )
90
98
  result = None
91
99
  error = e
92
100
 
@@ -97,7 +105,9 @@ class PocketServer(object):
97
105
  result = await self.pocket_core.tool_call(*a, **kw)
98
106
  error = None
99
107
  except Exception as e:
100
- pocket_logger.error(f"failed to tool_call in pocket subprocess. error: {e}")
108
+ pocket_logger.error(
109
+ f"failed to tool_call in pocket subprocess. error: {e}"
110
+ )
101
111
  result = None
102
112
  error = e
103
113
 
@@ -119,10 +129,7 @@ class PocketServer(object):
119
129
  else:
120
130
  await asyncio.sleep(0)
121
131
 
122
- def send_in_parent(self,
123
- op: PocketServerOperations,
124
- args: tuple,
125
- kwargs: dict):
132
+ def send_in_parent(self, op: PocketServerOperations, args: tuple, kwargs: dict):
126
133
  conn, _ = self.pipe
127
134
  uid = str(uuid.uuid4())
128
135
  message = (op.value, uid, args, kwargs)
@@ -145,10 +152,9 @@ class PocketServer(object):
145
152
  else:
146
153
  await asyncio.sleep(0)
147
154
 
148
- async def call_in_subprocess(self,
149
- op: PocketServerOperations,
150
- args: tuple,
151
- kwargs: dict):
155
+ async def call_in_subprocess(
156
+ self, op: PocketServerOperations, args: tuple, kwargs: dict
157
+ ):
152
158
  uid = self.send_in_parent(op, args, kwargs)
153
159
  loop = asyncio.get_running_loop()
154
160
  loop.create_task(self.poll_in_parent())
@@ -157,21 +163,52 @@ class PocketServer(object):
157
163
  def run(self, pocket_core: PocketCore):
158
164
  self._set_mp_start_method()
159
165
 
166
+ error_queue = mp.Queue()
160
167
  self.pipe = mp.Pipe()
161
- self.process = mp.Process(target=self._run, args=(pocket_core,))
162
- self.process.start()
168
+ self.process = mp.Process(
169
+ target=self._run, args=(pocket_core, )
170
+ )
171
+ self.process.start() # process start
172
+
173
+ if not error_queue.empty():
174
+ error_message = error_queue.get()
175
+ raise error_message
176
+
177
+ def _report_initialized(self, error: Optional[Exception] = None):
178
+ _, conn = self.pipe
179
+ conn.send(('server-initialization', error,))
180
+
181
+ def wait_initialized(self):
182
+ conn, _ = self.pipe
183
+ while True:
184
+ if conn.poll():
185
+ _, error = conn.recv()
186
+ if error:
187
+ raise error
188
+ return
163
189
 
164
190
  def _run(self, pocket_core):
165
- self.pocket_core = pocket_core
166
- self.main_server = self._create_main_server()
167
- self.proxy_server = self._create_https_proxy_server()
168
- loop = asyncio.new_event_loop()
169
- loop.run_until_complete(self._run_async())
170
- loop.close()
191
+ try:
192
+ # init process
193
+ self.pocket_core = pocket_core
194
+ self.main_server = self._create_main_server()
195
+ self.proxy_server = self._create_https_proxy_server()
196
+ self._report_initialized()
197
+
198
+ loop = asyncio.new_event_loop()
199
+ loop.run_until_complete(self._run_async())
200
+ loop.close()
201
+ except Exception as error:
202
+ self._report_initialized(error)
171
203
 
172
204
  def _create_main_server(self) -> Server:
173
205
  app = FastAPI()
174
- _config = Config(app, host="0.0.0.0", port=self.internal_server_port)
206
+ _config = Config(
207
+ app,
208
+ host="0.0.0.0",
209
+ port=self.internal_server_port,
210
+ log_level=config().log_level,
211
+ )
175
212
  app.include_router(tool_router)
176
213
  app.include_router(auth_router)
177
214
  app.add_api_route("/health", lambda: {"status": "ok"}, methods=["GET"])
@@ -180,20 +217,25 @@ class PocketServer(object):
180
217
  return app
181
218
 
182
219
  def _create_https_proxy_server(self) -> Optional[Server]:
183
- if not config.enable_local_callback_proxy:
220
+ if not config().enable_local_callback_proxy:
184
221
  return None
185
- from hyperpocket.server.proxy import _generate_ssl_certificates
186
- from hyperpocket.server.proxy import https_proxy_app
187
-
188
222
  from hyperpocket.config.settings import POCKET_ROOT
223
+ from hyperpocket.server.proxy import _generate_ssl_certificates, https_proxy_app
224
+
189
225
  ssl_keypath = POCKET_ROOT / "callback_server.key"
190
226
  ssl_certpath = POCKET_ROOT / "callback_server.crt"
191
227
 
192
228
  if not ssl_keypath.exists() or not ssl_certpath.exists():
193
229
  _generate_ssl_certificates(ssl_keypath, ssl_certpath)
194
230
 
195
- _config = Config(https_proxy_app, host="0.0.0.0", port=self.proxy_port, ssl_keyfile=ssl_keypath,
196
- ssl_certfile=ssl_certpath)
231
+ _config = Config(
232
+ https_proxy_app,
233
+ host="0.0.0.0",
234
+ port=self.proxy_port,
235
+ ssl_keyfile=ssl_keypath,
236
+ ssl_certfile=ssl_certpath,
237
+ log_level=config().log_level,
238
+ )
197
239
  proxy_server = Server(_config)
198
240
  return proxy_server
199
241
 
@@ -211,4 +253,6 @@ class PocketServer(object):
211
253
  mp.set_start_method("fork", force=True)
212
254
  pocket_logger.debug("Process start method set to 'fork' for Linux.")
213
255
  else:
214
- pocket_logger.debug(f"Unrecognized OS: {os_name}. Default start method will be used.")
256
+ pocket_logger.debug(
257
+ f"Unrecognized OS: {os_name}. Default start method will be used."
258
+ )
@@ -1,28 +1,33 @@
1
1
  from typing import Optional
2
+
2
3
  from pydantic import BaseModel, Field
3
4
 
4
5
  from hyperpocket.tool.wasm.script import ScriptFileNode
5
6
 
6
7
 
7
8
  class Script(BaseModel):
8
- id: str = Field(alias='id')
9
- tool_id: str = Field(alias='tool_id')
9
+ id: str = Field(alias="id")
10
+ tool_id: str = Field(alias="tool_id")
10
11
 
11
12
 
12
13
  class ScriptResult(BaseModel):
13
- stdout: Optional[str] = Field(alias='stdout', default=None)
14
- stderr: Optional[str] = Field(alias='stderr', default=None)
15
- error: Optional[str] = Field(alias='error', default=None)
14
+ stdout: Optional[str] = Field(alias="stdout", default=None)
15
+ stderr: Optional[str] = Field(alias="stderr", default=None)
16
+ error: Optional[str] = Field(alias="error", default=None)
17
+
16
18
 
17
19
  class ScriptFileTree(BaseModel):
18
- tree: dict[str, ScriptFileNode] = Field(alias='tree')
20
+ tree: dict[str, ScriptFileNode] = Field(alias="tree")
21
+
19
22
 
20
23
  class ScriptEntrypoint(BaseModel):
21
- package_name: Optional[str] = Field(alias='package_name')
22
- entrypoint: str = Field(alias='entrypoint')
24
+ package_name: Optional[str] = Field(alias="package_name")
25
+ entrypoint: str = Field(alias="entrypoint")
26
+
23
27
 
24
28
  class ScriptEncodedFile(BaseModel):
25
- encoded_file: str = Field(alias='encoded_file')
29
+ encoded_file: str = Field(alias="encoded_file")
30
+
26
31
 
27
32
  class ScriptFileRequest(BaseModel):
28
- path: str = Field(alias='path')
33
+ path: str = Field(alias="path")
@@ -1,13 +1,11 @@
1
1
  from fastapi import APIRouter
2
- from fastapi.responses import HTMLResponse, FileResponse
2
+ from fastapi.responses import FileResponse, HTMLResponse
3
3
 
4
4
  from hyperpocket.futures import FutureStore
5
5
  from hyperpocket.server.tool.dto import script as scriptdto
6
6
  from hyperpocket.tool.wasm.script import ScriptStore
7
7
 
8
- wasm_tool_router = APIRouter(
9
- prefix="/wasm"
10
- )
8
+ wasm_tool_router = APIRouter(prefix="/wasm")
11
9
 
12
10
 
13
11
  @wasm_tool_router.get("/scripts/{script_id}/browse", response_class=HTMLResponse)
@@ -17,19 +15,21 @@ async def browse_script_page(script_id: str):
17
15
 
18
16
 
19
17
  @wasm_tool_router.post("/scripts/{script_id}/done")
20
- async def done_script_page(script_id: str, req: scriptdto.ScriptResult) -> scriptdto.ScriptResult:
21
- FutureStore.resolve_future(script_id, {
22
- 'stdout': req.stdout,
23
- 'stderr': req.stderr,
24
- 'error': req.error
25
- })
18
+ async def done_script_page(
19
+ script_id: str, req: scriptdto.ScriptResult
20
+ ) -> scriptdto.ScriptResult:
21
+ FutureStore.resolve_future(
22
+ script_id, {"stdout": req.stdout, "stderr": req.stderr, "error": req.error}
23
+ )
26
24
  return req
27
25
 
26
+
28
27
  @wasm_tool_router.get("/scripts/{script_id}/file_tree")
29
28
  async def get_file_tree(script_id: str) -> scriptdto.ScriptFileTree:
30
29
  script = ScriptStore.get_script(script_id)
31
30
  return scriptdto.ScriptFileTree(tree=script.load_file_tree())
32
31
 
32
+
33
33
  @wasm_tool_router.get("/scripts/{script_id}/entrypoint")
34
34
  async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
35
35
  script = ScriptStore.get_script(script_id)
@@ -37,7 +37,10 @@ async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
37
37
  entrypoint = f"/tools/wasm/scripts/{script_id}/file/{script.entrypoint}"
38
38
  return scriptdto.ScriptEntrypoint(package_name=package_name, entrypoint=entrypoint)
39
39
 
40
- @wasm_tool_router.get("/scripts/{script_id}/file/{file_name}", response_class=FileResponse)
40
+
41
+ @wasm_tool_router.get(
42
+ "/scripts/{script_id}/file/{file_name}", response_class=FileResponse
43
+ )
41
44
  async def get_dist_file(script_id: str, file_name: str):
42
45
  script = ScriptStore.get_script(script_id)
43
46
  return FileResponse(script.dist_file_path(file_name))
@@ -1,4 +1,8 @@
1
1
  from hyperpocket.session.interface import SessionStorageInterface
2
- from hyperpocket.util.find_all_leaf_class_in_package import find_all_leaf_class_in_package
2
+ from hyperpocket.util.find_all_leaf_class_in_package import (
3
+ find_all_leaf_class_in_package,
4
+ )
3
5
 
4
- SESSION_STORAGE_LIST = find_all_leaf_class_in_package("hyperpocket.session", SessionStorageInterface)
6
+ SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
7
+ "hyperpocket.session", SessionStorageInterface
8
+ )
@@ -3,16 +3,22 @@ from typing import Dict, List, Optional
3
3
 
4
4
  from hyperpocket.auth import AuthProvider
5
5
  from hyperpocket.auth.context import AuthContext
6
- from hyperpocket.config.session import SessionConfigInMemory
7
- from hyperpocket.config.session import SessionType
8
- from hyperpocket.session.interface import SessionStorageInterface, SESSION_KEY_DELIMITER, \
9
- BaseSessionValue, V, K
6
+ from hyperpocket.config.session import SessionConfigInMemory, SessionType
7
+ from hyperpocket.session.interface import (
8
+ SESSION_KEY_DELIMITER,
9
+ BaseSessionValue,
10
+ K,
11
+ SessionStorageInterface,
12
+ V,
13
+ )
10
14
 
11
15
  InMemorySessionKey = str
12
16
  InMemorySessionValue = BaseSessionValue
13
17
 
14
18
 
15
- class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]):
19
+ class InMemorySessionStorage(
20
+ SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]
21
+ ):
16
22
  # TODO(moon) : Force it to always take SessionConfig as an input
17
23
  def __init__(self, session_config: SessionConfigInMemory):
18
24
  super().__init__()
@@ -22,41 +28,54 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
22
28
  def session_storage_type(cls) -> SessionType:
23
29
  return SessionType.IN_MEMORY
24
30
 
25
- def get(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> Optional[V]:
31
+ def get(
32
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
33
+ ) -> Optional[V]:
26
34
  key = self._make_session_key(auth_provider.name, thread_id, profile)
27
35
  return self.storage.get(key, None)
28
36
 
29
- def get_by_thread_id(self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs) -> List[V]:
37
+ def get_by_thread_id(
38
+ self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
39
+ ) -> List[V]:
30
40
  if auth_provider is None:
31
41
  auth_provider_name = ".*"
32
42
  else:
33
43
  auth_provider_name = auth_provider.name
34
44
 
35
- pattern = rf'{self._make_session_key(auth_provider_name, thread_id, ".*")}'
45
+ pattern = rf"{self._make_session_key(auth_provider_name, thread_id, '.*')}"
36
46
  compiled = re.compile(pattern)
37
47
 
38
- session_list = [value for key, value in self.storage.items() if compiled.match(key)]
48
+ session_list = [
49
+ value for key, value in self.storage.items() if compiled.match(key)
50
+ ]
39
51
  return session_list
40
52
 
41
- def set(self, auth_provider: AuthProvider,
42
- thread_id: str,
43
- profile: str,
44
- auth_scopes: List[str],
45
- auth_resolve_uid: Optional[str],
46
- auth_context: Optional[AuthContext],
47
- is_auth_scope_universal: bool, **kwargs) -> V:
53
+ def set(
54
+ self,
55
+ auth_provider: AuthProvider,
56
+ thread_id: str,
57
+ profile: str,
58
+ auth_scopes: List[str],
59
+ auth_resolve_uid: Optional[str],
60
+ auth_context: Optional[AuthContext],
61
+ is_auth_scope_universal: bool,
62
+ **kwargs,
63
+ ) -> V:
48
64
  key = self._make_session_key(auth_provider.name, thread_id, profile)
49
65
  session = self._make_session(
50
66
  auth_provider_name=auth_provider.name,
51
67
  auth_scopes=auth_scopes,
52
68
  auth_resolve_uid=auth_resolve_uid,
53
69
  auth_context=auth_context,
54
- is_auth_scope_universal=is_auth_scope_universal)
70
+ is_auth_scope_universal=is_auth_scope_universal,
71
+ )
55
72
 
56
73
  self.storage[key] = session
57
74
  return session
58
75
 
59
- def delete(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> bool:
76
+ def delete(
77
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
78
+ ) -> bool:
60
79
  key = self._make_session_key(auth_provider.name, thread_id, profile)
61
80
  if key in self.storage:
62
81
  self.storage.pop(key)
@@ -75,15 +94,16 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
75
94
 
76
95
  @staticmethod
77
96
  def _make_session(
78
- auth_provider_name: str,
79
- auth_scopes: List[str],
80
- auth_context: AuthContext,
81
- auth_resolve_uid: str,
82
- is_auth_scope_universal: bool) -> V:
97
+ auth_provider_name: str,
98
+ auth_scopes: List[str],
99
+ auth_context: AuthContext,
100
+ auth_resolve_uid: str,
101
+ is_auth_scope_universal: bool,
102
+ ) -> V:
83
103
  return InMemorySessionValue(
84
104
  auth_provider_name=auth_provider_name,
85
105
  auth_scopes=set(auth_scopes),
86
106
  auth_context=auth_context,
87
107
  auth_resolve_uid=auth_resolve_uid,
88
- scoped=is_auth_scope_universal
108
+ scoped=is_auth_scope_universal,
89
109
  )
@@ -1,6 +1,6 @@
1
1
  import datetime
2
2
  from abc import ABC, abstractmethod
3
- from typing import TypeVar, Generic, List, Set, Optional, Iterable
3
+ from typing import Generic, Iterable, List, Optional, Set, TypeVar
4
4
 
5
5
  from pydantic import BaseModel, Field
6
6
 
@@ -15,28 +15,37 @@ SESSION_KEY_DELIMITER = "__"
15
15
 
16
16
  class BaseSessionValue(BaseModel):
17
17
  auth_provider_name: str = Field(
18
- description="The name of the authentication provider used to authenticate the current session")
18
+ description="The name of the authentication provider used to authenticate the current session"
19
+ )
19
20
  auth_context: Optional[AuthContext] = Field(
20
21
  default=None,
21
- description="The authentication context containing the actual session details")
22
- scoped: bool = Field(description="Indicates whether the current session is a scoped session")
22
+ description="The authentication context containing the actual session details",
23
+ )
24
+ scoped: bool = Field(
25
+ description="Indicates whether the current session is a scoped session"
26
+ )
23
27
  auth_scopes: Optional[Set[str]] = Field(
24
28
  default=None,
25
- description="The authentication scopes of the current session, present only for scoped sessions")
29
+ description="The authentication scopes of the current session, present only for scoped sessions",
30
+ )
26
31
  auth_resolve_uid: Optional[str] = Field(
27
32
  default=None,
28
- description="A UID used to asynchronously verify whether the user has completed the authentication process")
33
+ description="A UID used to asynchronously verify whether the user has completed the authentication process",
34
+ )
29
35
 
30
36
  def make_superset_auth_scope(
31
- self,
32
- other_scopes: Optional[Iterable[str]] = None) -> set[str]:
37
+ self, other_scopes: Optional[Iterable[str]] = None
38
+ ) -> set[str]:
33
39
  auth_scopes = self.auth_scopes or set()
34
40
  other_scopes = other_scopes or set()
35
41
  return auth_scopes.union(other_scopes)
36
42
 
37
- def is_auth_applicable(self, auth_provider_name: str, auth_req: AuthenticateRequest) -> bool:
38
- return self.auth_provider_name == auth_provider_name \
39
- and (not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes))
43
+ def is_auth_applicable(
44
+ self, auth_provider_name: str, auth_req: AuthenticateRequest
45
+ ) -> bool:
46
+ return self.auth_provider_name == auth_provider_name and (
47
+ not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes)
48
+ )
40
49
 
41
50
  def is_near_expires(self) -> bool:
42
51
  if self.auth_context.expires_at is not None:
@@ -48,13 +57,15 @@ class BaseSessionValue(BaseModel):
48
57
  return False
49
58
 
50
59
 
51
- K = TypeVar('K')
52
- V = TypeVar('V', bound=BaseSessionValue)
60
+ K = TypeVar("K")
61
+ V = TypeVar("V", bound=BaseSessionValue)
53
62
 
54
63
 
55
64
  class SessionStorageInterface(ABC, Generic[K, V]):
56
65
  @abstractmethod
57
- def get(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> V:
66
+ def get(
67
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
68
+ ) -> V:
58
69
  """
59
70
  Get session
60
71
 
@@ -69,7 +80,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
69
80
  raise NotImplementedError
70
81
 
71
82
  @abstractmethod
72
- def get_by_thread_id(self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs) -> List[V]:
83
+ def get_by_thread_id(
84
+ self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
85
+ ) -> List[V]:
73
86
  """
74
87
  Get session list by thread id
75
88
 
@@ -83,14 +96,17 @@ class SessionStorageInterface(ABC, Generic[K, V]):
83
96
  raise NotImplementedError
84
97
 
85
98
  @abstractmethod
86
- def set(self,
87
- auth_provider: AuthProvider,
88
- thread_id: str,
89
- profile: str,
90
- auth_scopes: List[str],
91
- auth_resolve_uid: Optional[str],
92
- auth_context: Optional[AuthContext],
93
- is_auth_scope_universal: bool, **kwargs) -> V:
99
+ def set(
100
+ self,
101
+ auth_provider: AuthProvider,
102
+ thread_id: str,
103
+ profile: str,
104
+ auth_scopes: List[str],
105
+ auth_resolve_uid: Optional[str],
106
+ auth_context: Optional[AuthContext],
107
+ is_auth_scope_universal: bool,
108
+ **kwargs,
109
+ ) -> V:
94
110
  """
95
111
  Set session, if a session doesn't exist, create new session
96
112
  If set auth_resolve_uid is None and auth_context is not None, created session is regarded as active session.
@@ -112,7 +128,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
112
128
  raise NotImplementedError
113
129
 
114
130
  @abstractmethod
115
- def delete(self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs) -> bool:
131
+ def delete(
132
+ self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
133
+ ) -> bool:
116
134
  """
117
135
  Delete session
118
136