twirpy 0.1.0.dev4__tar.gz → 0.3.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,13 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: twirpy
3
- Version: 0.1.0.dev4
3
+ Version: 0.3.0.dev2
4
4
  Summary: Twirp runtime library for Python
5
5
  Project-URL: repository, https://github.com/cryptact/twirpy
6
6
  License-Expression: BSD-3-Clause
7
7
  License-File: LICENSE
8
8
  Keywords: protobuf,rpc,twirp
9
- Requires-Python: >=3.10
9
+ Requires-Python: >=3.12
10
+ Requires-Dist: asgiref
10
11
  Requires-Dist: protobuf
11
12
  Requires-Dist: requests
12
13
  Requires-Dist: structlog
@@ -54,7 +55,7 @@ Unzip the archive and move the binary to a directory in your PATH.
54
55
  On macOS, you can use the following commands:
55
56
  ```sh
56
57
  curl -L -o- \
57
- https://github.com/Cryptact/twirpy/releases/latest/download/protoc-gen-twirpy_Darwin_arm64.tar.gz \
58
+ https://github.com/Cryptact/twirpy/releases/latest/download/protoc-gen-twirpy-darwin-arm64.tar.gz \
58
59
  | tar xz -C ~/.local/bin protoc-gen-twirpy
59
60
  ````
60
61
 
@@ -74,6 +75,8 @@ We use [`hatch`](https://hatch.pypa.io/latest/) to manage the development proces
74
75
  To open a shell with the development environment, run: `hatch shell`.
75
76
  To run the linter, run: `hatch fmt --check` or `hatch fmt` to fix the issues.
76
77
 
78
+ To run the type checker, run: `hatch run types:check`.
79
+
77
80
  ## Standing on the shoulders of giants
78
81
 
79
82
  - The initial version of twirpy was made from an internal copy of https://github.com/daroot/protoc-gen-twirp_python_srv
@@ -38,7 +38,7 @@ Unzip the archive and move the binary to a directory in your PATH.
38
38
  On macOS, you can use the following commands:
39
39
  ```sh
40
40
  curl -L -o- \
41
- https://github.com/Cryptact/twirpy/releases/latest/download/protoc-gen-twirpy_Darwin_arm64.tar.gz \
41
+ https://github.com/Cryptact/twirpy/releases/latest/download/protoc-gen-twirpy-darwin-arm64.tar.gz \
42
42
  | tar xz -C ~/.local/bin protoc-gen-twirpy
43
43
  ````
44
44
 
@@ -58,6 +58,8 @@ We use [`hatch`](https://hatch.pypa.io/latest/) to manage the development proces
58
58
  To open a shell with the development environment, run: `hatch shell`.
59
59
  To run the linter, run: `hatch fmt --check` or `hatch fmt` to fix the issues.
60
60
 
61
+ To run the type checker, run: `hatch run types:check`.
62
+
61
63
  ## Standing on the shoulders of giants
62
64
 
63
65
  - The initial version of twirpy was made from an internal copy of https://github.com/daroot/protoc-gen-twirp_python_srv
@@ -10,7 +10,7 @@ name = "twirpy"
10
10
  dynamic = ["version"]
11
11
  description = "Twirp runtime library for Python"
12
12
  readme = "README.md"
13
- requires-python = ">=3.10"
13
+ requires-python = ">=3.12"
14
14
  license = "BSD-3-Clause"
15
15
  keywords = [
16
16
  "protobuf",
@@ -18,6 +18,7 @@ keywords = [
18
18
  "twirp",
19
19
  ]
20
20
  dependencies = [
21
+ "asgiref",
21
22
  "protobuf",
22
23
  "requests",
23
24
  "structlog",
@@ -42,11 +43,18 @@ include = ["/twirp"]
42
43
  packages = ["twirp"]
43
44
 
44
45
  [tool.hatch.envs.default]
45
- python = "3.12"
46
+ python = "3.13"
46
47
  dependencies = [
47
48
  "aiohttp",
48
49
  ]
49
50
 
51
+ [tool.hatch.envs.types]
52
+ extra-dependencies = [
53
+ "mypy>=1.0.0",
54
+ ]
55
+ [tool.hatch.envs.types.scripts]
56
+ check = "mypy --install-types --non-interactive twirp"
57
+
50
58
  [tool.ruff]
51
59
  line-length = 120
52
60
  lint.select = [
@@ -0,0 +1 @@
1
+ __version__ = "0.3.0-dev.2"
@@ -1,43 +1,27 @@
1
- import asyncio
2
- import functools
3
- import typing
4
1
  import traceback
5
2
 
3
+ from asgiref.typing import (
4
+ ASGIReceiveCallable,
5
+ ASGISendCallable,
6
+ Scope,
7
+ HTTPResponseStartEvent,
8
+ HTTPResponseBodyEvent,
9
+ ASGIReceiveEvent,
10
+ )
11
+ from google.protobuf.message import Message
12
+
13
+ from twirp.endpoint import Endpoint, TwirpMethod
6
14
  from . import base
7
15
  from . import exceptions
8
16
  from . import errors
9
17
  from . import ctxkeys
18
+ from . import context
10
19
 
11
- try:
12
- import contextvars # Python 3.7+ only.
13
- except ImportError: # pragma: no cover
14
- contextvars = None # type: ignore
15
-
16
-
17
- # Lifted from starlette.concurrency
18
- async def run_in_threadpool(func: typing.Callable, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
19
- loop = asyncio.get_event_loop()
20
- if contextvars is not None: # pragma: no cover
21
- # Ensure we run in the same context
22
- child = functools.partial(func, *args, **kwargs)
23
- context = contextvars.copy_context()
24
- func = context.run
25
- args = (child,)
26
- elif kwargs: # pragma: no cover
27
- # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
28
- func = functools.partial(func, **kwargs)
29
- return await loop.run_in_executor(None, func, *args)
30
-
31
-
32
- def thread_pool_runner(func):
33
- async def run(ctx, request):
34
- return await run_in_threadpool(func, ctx, request)
35
-
36
- return run
20
+ Headers = dict[str, str]
37
21
 
38
22
 
39
23
  class TwirpASGIApp(base.TwirpBaseApp):
40
- async def __call__(self, scope, receive, send):
24
+ async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
41
25
  assert scope["type"] == "http"
42
26
  ctx = self._ctx_class()
43
27
  try:
@@ -54,7 +38,7 @@ class TwirpASGIApp(base.TwirpBaseApp):
54
38
  ctx.set(ctxkeys.RAW_HEADERS, headers)
55
39
  self._hook.request_received(ctx=ctx)
56
40
 
57
- endpoint = self._get_endpoint(scope["path"])
41
+ endpoint: Endpoint = self._get_endpoint(scope["path"])
58
42
  headers = {k.decode("utf-8"): v.decode("utf-8") for (k, v) in scope["headers"]}
59
43
  self.validate_content_length(headers=headers)
60
44
  encoder, decoder = self._get_encoder_decoder(endpoint, headers)
@@ -66,7 +50,7 @@ class TwirpASGIApp(base.TwirpBaseApp):
66
50
  self._hook.request_routed(ctx=ctx)
67
51
  raw_receive = await self._recv_all(receive)
68
52
  request = decoder(raw_receive)
69
- response_data = await self._with_middlewares(func=endpoint.function, ctx=ctx, request=request)
53
+ response_data: Message = await self._with_middlewares(func=endpoint.function, ctx=ctx, request=request)
70
54
  self._hook.response_prepared(ctx=ctx)
71
55
 
72
56
  body_bytes, headers = encoder(response_data)
@@ -77,28 +61,27 @@ class TwirpASGIApp(base.TwirpBaseApp):
77
61
  except Exception as e:
78
62
  await self.handle_error(ctx, e, scope, receive, send)
79
63
 
80
- def _with_middlewares(self, *args, func, ctx, request):
81
- chain = iter(self._middlewares + (func,))
64
+ async def _with_middlewares(self, *, func: TwirpMethod, ctx: context.Context, request: Message) -> Message:
65
+ chain = iter(self._middlewares)
82
66
 
83
- def bind(fn):
84
- if not asyncio.iscoroutinefunction(fn):
85
- fn = thread_pool_runner(fn)
67
+ def _bind_next() -> TwirpMethod:
68
+ try:
69
+ middleware = next(chain)
86
70
 
87
- async def nxt(ctx, request):
88
- try:
89
- cur = next(chain)
90
- return await fn(ctx, request, bind(cur))
91
- except StopIteration:
92
- pass
93
- return await fn(ctx, request)
71
+ async def _next(ctx_: context.Context, request_: Message) -> Message:
72
+ return await middleware(ctx_, request_, _bind_next())
94
73
 
95
- return nxt
74
+ return _next
75
+ except StopIteration:
76
+ return func
96
77
 
97
- return bind(next(chain))(ctx, request)
78
+ fn = _bind_next()
79
+ return await fn(ctx, request)
98
80
 
99
- async def handle_error(self, ctx, exc, scope, receive, send):
81
+ async def handle_error(
82
+ self, ctx: context.Context, exc: Exception, scope_: Scope, receive_: ASGIReceiveCallable, send: ASGISendCallable
83
+ ) -> None:
100
84
  status = 500
101
- body_bytes = b"{}"
102
85
  logger = ctx.get_logger()
103
86
  error_data = {}
104
87
  ctx.set(ctxkeys.ORIGINAL_EXCEPTION, exc)
@@ -128,28 +111,36 @@ class TwirpASGIApp(base.TwirpBaseApp):
128
111
  )
129
112
  self._hook.response_sent(ctx=ctx)
130
113
 
131
- async def _respond(self, *args, send, status, headers, body_bytes):
114
+ @staticmethod
115
+ async def _respond(*, send: ASGISendCallable, status: int, headers: Headers, body_bytes: bytes) -> None:
132
116
  headers["Content-Length"] = str(len(body_bytes))
133
117
  resp_headers = [(k.encode("utf-8"), v.encode("utf-8")) for (k, v) in headers.items()]
134
118
  await send(
135
- {
136
- "type": "http.response.start",
137
- "status": status,
138
- "headers": resp_headers,
139
- }
119
+ HTTPResponseStartEvent(
120
+ type="http.response.start",
121
+ status=status,
122
+ headers=resp_headers,
123
+ trailers=False,
124
+ )
140
125
  )
141
126
  await send(
142
- {
143
- "type": "http.response.body",
144
- "body": body_bytes,
145
- }
127
+ HTTPResponseBodyEvent(
128
+ type="http.response.body",
129
+ body=body_bytes,
130
+ more_body=False,
131
+ )
146
132
  )
147
133
 
148
- async def _recv_all(self, receive):
134
+ async def _recv_all(self, receive: ASGIReceiveCallable) -> bytes:
149
135
  body = b""
150
136
  more_body = True
151
137
  while more_body:
152
- message = await receive()
138
+ message: ASGIReceiveEvent = await receive()
139
+ if message["type"] != "http.request":
140
+ raise exceptions.TwirpServerException(
141
+ code=errors.Errors.Internal,
142
+ message="expected http.request message type, got " + message["type"],
143
+ )
153
144
  body += message.get("body", b"")
154
145
  more_body = message.get("more_body", False)
155
146
 
@@ -165,9 +156,12 @@ class TwirpASGIApp(base.TwirpBaseApp):
165
156
 
166
157
  # we will check content-length header value and make sure it is
167
158
  # below the limit set
168
- def validate_content_length(self, headers):
159
+ def validate_content_length(self, headers: Headers) -> None:
169
160
  try:
170
- content_length = int(headers.get("content-length"))
161
+ raw_value = headers.get("content-length", None)
162
+ if not raw_value:
163
+ return
164
+ content_length = int(raw_value)
171
165
  except (ValueError, TypeError):
172
166
  return
173
167
 
@@ -1,10 +1,13 @@
1
- import asyncio
2
1
  import json
2
+ from typing import Any
3
3
 
4
4
  import aiohttp
5
+ from aiohttp.typedefs import StrOrURL
6
+ from google.protobuf.message import Message
5
7
 
6
8
  from . import exceptions
7
9
  from . import errors
10
+ from . import context
8
11
 
9
12
 
10
13
  class AsyncTwirpClient:
@@ -12,7 +15,16 @@ class AsyncTwirpClient:
12
15
  self._address = address
13
16
  self._session = session
14
17
 
15
- async def _make_request(self, *, url, ctx, request, response_obj, session=None, **kwargs):
18
+ async def _make_request[RQ: Message, RP: Message](
19
+ self,
20
+ *,
21
+ url: StrOrURL,
22
+ ctx: context.Context,
23
+ request: RQ,
24
+ response_obj: type[RP],
25
+ session: aiohttp.ClientSession | None = None,
26
+ **kwargs: Any,
27
+ ) -> RP:
16
28
  headers = ctx.get_headers()
17
29
  if "headers" in kwargs:
18
30
  headers.update(kwargs["headers"])
@@ -36,7 +48,7 @@ class AsyncTwirpClient:
36
48
  raise exceptions.twirp_error_from_intermediary(
37
49
  resp.status, resp.reason, resp.headers, await resp.text()
38
50
  ) from None
39
- except asyncio.TimeoutError as e:
51
+ except TimeoutError as e:
40
52
  raise exceptions.TwirpServerException(
41
53
  code=errors.Errors.DeadlineExceeded,
42
54
  message=str(e) or "request timeout",
@@ -1,10 +1,11 @@
1
1
  import functools
2
- from collections import namedtuple
2
+ from collections.abc import Awaitable
3
+ from typing import Any
4
+ from collections.abc import Callable
3
5
 
4
- from google.protobuf import json_format
5
- from google.protobuf import message
6
+ from google.protobuf import json_format, message
6
7
  from google.protobuf import symbol_database as _symbol_database
7
-
8
+ from google.protobuf.message import Message
8
9
 
9
10
  from . import context
10
11
 
@@ -12,31 +13,39 @@ from . import server
12
13
  from . import exceptions
13
14
  from . import errors
14
15
  from . import hook as vtwirp_hook
16
+ from .endpoint import Endpoint, TwirpMethod
15
17
 
16
18
  _sym_lookup = _symbol_database.Default().GetSymbol
17
19
 
18
- Endpoint = namedtuple("Endpoint", ["service_name", "name", "function", "input", "output"])
20
+ Middleware = Callable[[context.Context, Message, TwirpMethod], Awaitable[Message]]
19
21
 
20
22
 
21
23
  class TwirpBaseApp:
22
- def __init__(self, *middlewares, hook=None, prefix="", max_receive_message_length=1024 * 100 * 100, ctx_class=None):
23
- self._prefix = prefix
24
- self._services = {}
25
- self._max_receive_message_length = max_receive_message_length
24
+ def __init__(
25
+ self,
26
+ *middlewares: Middleware,
27
+ hook: vtwirp_hook.TwirpHook | None = None,
28
+ prefix: str = "",
29
+ max_receive_message_length: int = 1024 * 100 * 100,
30
+ ctx_class: type[context.Context] | None = None,
31
+ ) -> None:
32
+ self._prefix: str = prefix
33
+ self._services: dict[str, server.TwirpServer] = {}
34
+ self._max_receive_message_length: int = max_receive_message_length
26
35
  if ctx_class is None:
27
36
  ctx_class = context.Context
28
37
  assert issubclass(ctx_class, context.Context)
29
- self._ctx_class = ctx_class
30
- self._middlewares = middlewares
38
+ self._ctx_class: type[context.Context] = ctx_class
39
+ self._middlewares: tuple[Middleware, ...] = middlewares
31
40
  if hook is None:
32
41
  hook = vtwirp_hook.TwirpHook()
33
42
  assert isinstance(hook, vtwirp_hook.TwirpHook)
34
- self._hook = hook
43
+ self._hook: vtwirp_hook.TwirpHook = hook
35
44
 
36
- def add_service(self, svc: server.TwirpServer):
45
+ def add_service(self, svc: server.TwirpServer) -> None:
37
46
  self._services[self._prefix + svc.prefix] = svc
38
47
 
39
- def _get_endpoint(self, path):
48
+ def _get_endpoint(self, path: str) -> Endpoint:
40
49
  svc = self._services.get(path.rsplit("/", 1)[0], None)
41
50
  if svc is None:
42
51
  raise exceptions.TwirpServerException(code=errors.Errors.NotFound, message="not found")
@@ -44,7 +53,7 @@ class TwirpBaseApp:
44
53
  return svc.get_endpoint(path[len(self._prefix) :])
45
54
 
46
55
  @staticmethod
47
- def json_decoder(body, data_obj=None):
56
+ def json_decoder(body: bytes, data_obj: type[Message]) -> Message:
48
57
  data = data_obj()
49
58
  try:
50
59
  json_format.Parse(body, data)
@@ -56,7 +65,7 @@ class TwirpBaseApp:
56
65
  return data
57
66
 
58
67
  @staticmethod
59
- def json_encoder(value, data_obj=None):
68
+ def json_encoder(value: Any, data_obj: type[Message]) -> tuple[bytes, dict[str, str]]:
60
69
  if not isinstance(value, data_obj):
61
70
  raise exceptions.TwirpServerException(
62
71
  code=errors.Errors.Internal,
@@ -70,7 +79,7 @@ class TwirpBaseApp:
70
79
  }
71
80
 
72
81
  @staticmethod
73
- def proto_decoder(body, data_obj=None):
82
+ def proto_decoder(body: bytes, data_obj: type[Message]) -> Message:
74
83
  data = data_obj()
75
84
  try:
76
85
  data.ParseFromString(body)
@@ -82,7 +91,7 @@ class TwirpBaseApp:
82
91
  return data
83
92
 
84
93
  @staticmethod
85
- def proto_encoder(value, data_obj=None):
94
+ def proto_encoder(value: Any, data_obj: type[Message]) -> tuple[bytes, dict[str, str]]:
86
95
  if not isinstance(value, data_obj):
87
96
  raise exceptions.TwirpServerException(
88
97
  code=errors.Errors.Internal,
@@ -93,7 +102,9 @@ class TwirpBaseApp:
93
102
 
94
103
  return value.SerializeToString(), {"Content-Type": "application/protobuf"}
95
104
 
96
- def _get_encoder_decoder(self, endpoint, headers):
105
+ def _get_encoder_decoder(
106
+ self, endpoint: Endpoint, headers: dict[str, str]
107
+ ) -> tuple[Callable[[Any], tuple[bytes, dict[str, str]]], Callable[[bytes], Message]]:
97
108
  ctype = headers.get("content-type", None)
98
109
  if "application/json" == ctype:
99
110
  decoder = functools.partial(self.json_decoder, data_obj=endpoint.input)
@@ -1,15 +1,21 @@
1
+ from typing import Any
2
+
1
3
  import requests
4
+ from google.protobuf.message import Message
2
5
 
6
+ from . import context
3
7
  from . import exceptions
4
8
  from . import errors
5
9
 
6
10
 
7
11
  class TwirpClient:
8
- def __init__(self, address, timeout=5):
12
+ def __init__(self, address: str, timeout: int = 5) -> None:
9
13
  self._address = address
10
14
  self._timeout = timeout
11
15
 
12
- def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
16
+ def _make_request[Req: Message, Resp: Message](
17
+ self, *, url: str, ctx: context.Context, request: Req, response_obj: type[Resp], **kwargs: Any
18
+ ) -> Resp:
13
19
  if "timeout" not in kwargs:
14
20
  kwargs["timeout"] = self._timeout
15
21
  headers = ctx.get_headers()
@@ -20,7 +26,7 @@ class TwirpClient:
20
26
  try:
21
27
  resp = requests.post(url=self._address + url, data=request.SerializeToString(), **kwargs)
22
28
  if resp.status_code == 200:
23
- response = response_obj()
29
+ response: Resp = response_obj()
24
30
  response.ParseFromString(resp.content)
25
31
  return response
26
32
  try:
@@ -1,3 +1,7 @@
1
+ from typing import Any
2
+
3
+ from structlog.stdlib import BoundLogger
4
+
1
5
  from . import logging
2
6
 
3
7
 
@@ -6,23 +10,23 @@ class Context:
6
10
  request currently being processed.
7
11
  """
8
12
 
9
- def __init__(self, *args, logger=None, headers=None):
13
+ def __init__(self, *, logger: BoundLogger | None = None, headers: dict[str, Any] | None = None):
10
14
  """Create a new Context object
11
15
 
12
16
  Keyword arguments:
13
17
  logger: Logger that will be used for logging.
14
18
  headers: Headers for the request.
15
19
  """
16
- self._values = {}
20
+ self._values: dict[str, Any] = {}
17
21
  if logger is None:
18
22
  logger = logging.get_logger()
19
- self._logger = logger
23
+ self._logger: BoundLogger = logger
20
24
  if headers is None:
21
25
  headers = {}
22
- self._headers = headers
23
- self._response_headers = {}
26
+ self._headers: dict[str, Any] = headers
27
+ self._response_headers: dict[str, Any] = {}
24
28
 
25
- def set(self, key, value):
29
+ def set(self, key: str, value: Any) -> None:
26
30
  """Set a Context value
27
31
 
28
32
  Arguments:
@@ -31,7 +35,7 @@ class Context:
31
35
  """
32
36
  self._values[key] = value
33
37
 
34
- def get(self, key):
38
+ def get(self, key: str) -> Any:
35
39
  """Get a Context value
36
40
 
37
41
  Arguments:
@@ -39,11 +43,11 @@ class Context:
39
43
  """
40
44
  return self._values[key]
41
45
 
42
- def get_logger(self):
46
+ def get_logger(self) -> BoundLogger:
43
47
  """Get current logger used by Context."""
44
48
  return self._logger
45
49
 
46
- def set_logger(self, logger):
50
+ def set_logger(self, logger: BoundLogger) -> None:
47
51
  """Set logger for this Context
48
52
 
49
53
  Arguments:
@@ -51,11 +55,11 @@ class Context:
51
55
  """
52
56
  self._logger = logger
53
57
 
54
- def get_headers(self):
58
+ def get_headers(self) -> dict[str, Any]:
55
59
  """Get request headers that are currently stored."""
56
60
  return self._headers
57
61
 
58
- def set_header(self, key, value):
62
+ def set_header(self, key: str, value: Any) -> None:
59
63
  """Set a request header
60
64
 
61
65
  Arguments:
@@ -64,11 +68,11 @@ class Context:
64
68
  """
65
69
  self._headers[key] = value
66
70
 
67
- def get_response_headers(self):
71
+ def get_response_headers(self) -> dict[str, Any]:
68
72
  """Get response headers that are currently stored."""
69
73
  return self._response_headers
70
74
 
71
- def set_response_header(self, key, value):
75
+ def set_response_header(self, key: str, value: Any) -> None:
72
76
  """Set a response header
73
77
 
74
78
  Arguments:
@@ -0,0 +1,17 @@
1
+ from dataclasses import dataclass
2
+ from collections.abc import Callable, Awaitable
3
+
4
+ from google.protobuf.message import Message
5
+
6
+ from twirp import context
7
+
8
+ TwirpMethod = Callable[[context.Context, Message], Awaitable[Message]]
9
+
10
+
11
+ @dataclass
12
+ class Endpoint:
13
+ service_name: str
14
+ name: str
15
+ function: TwirpMethod
16
+ input: type[Message]
17
+ output: type[Message]
@@ -1,4 +1,5 @@
1
1
  from enum import Enum
2
+ from typing import Self
2
3
 
3
4
 
4
5
  class Errors(Enum):
@@ -22,8 +23,8 @@ class Errors(Enum):
22
23
  Malformed = "malformed"
23
24
  NoError = ""
24
25
 
25
- @staticmethod
26
- def get_status_code(code):
26
+ @classmethod
27
+ def get_status_code(cls, code: Self) -> int:
27
28
  return {
28
29
  Errors.Canceled: 408,
29
30
  Errors.Unknown: 500,
@@ -1,12 +1,21 @@
1
1
  import json
2
2
  from http.client import HTTPException
3
- from typing import Any
3
+ from typing import Any, TypedDict, Self
4
+
5
+ from multidict import CIMultiDictProxy
6
+ from requests.structures import CaseInsensitiveDict
4
7
 
5
8
  from . import errors
6
9
 
7
10
 
11
+ class TwirpServerExceptionDict(TypedDict, total=False):
12
+ code: str
13
+ msg: str
14
+ meta: dict[str, Any]
15
+
16
+
8
17
  class TwirpServerException(HTTPException):
9
- def __init__(self, *args, code, message, meta: dict[str, Any] | None = None):
18
+ def __init__(self, *, code: errors.Errors | str, message: str, meta: dict[str, Any] | None = None):
10
19
  try:
11
20
  self._code = errors.Errors(code)
12
21
  except ValueError:
@@ -16,50 +25,52 @@ class TwirpServerException(HTTPException):
16
25
  super().__init__(message)
17
26
 
18
27
  @property
19
- def code(self):
28
+ def code(self) -> errors.Errors:
20
29
  if isinstance(self._code, errors.Errors):
21
30
  return self._code
22
31
  return errors.Errors.Unknown
23
32
 
24
33
  @property
25
- def message(self):
34
+ def message(self) -> str:
26
35
  return self._message
27
36
 
28
37
  @property
29
- def meta(self):
38
+ def meta(self) -> dict[str, Any]:
30
39
  return self._meta
31
40
 
32
- def to_dict(self):
33
- err = {"code": self._code.value, "msg": self._message, "meta": {}}
41
+ def to_dict(self) -> TwirpServerExceptionDict:
42
+ err: TwirpServerExceptionDict = {"code": self._code.value, "msg": self._message, "meta": {}}
34
43
  for k, v in self._meta.items():
35
44
  err["meta"][k] = str(v)
36
45
  return err
37
46
 
38
- def to_json_bytes(self):
47
+ def to_json_bytes(self) -> bytes:
39
48
  return json.dumps(self.to_dict()).encode("utf-8")
40
49
 
41
- @staticmethod
42
- def from_json(err_dict):
43
- return TwirpServerException(
50
+ @classmethod
51
+ def from_json(cls, err_dict: TwirpServerExceptionDict) -> Self:
52
+ return cls(
44
53
  code=err_dict.get("code", errors.Errors.Unknown),
45
54
  message=err_dict.get("msg", ""),
46
55
  meta=err_dict.get("meta", {}),
47
56
  )
48
57
 
49
58
 
50
- def InvalidArgument(*args, argument, error):
59
+ def InvalidArgument(*, argument: str, error: str) -> TwirpServerException:
51
60
  return TwirpServerException(
52
61
  code=errors.Errors.InvalidArgument, message=f"{argument} {error}", meta={"argument": argument}
53
62
  )
54
63
 
55
64
 
56
- def RequiredArgument(*args, argument):
65
+ def RequiredArgument(*, argument: str) -> TwirpServerException:
57
66
  return InvalidArgument(argument=argument, error="is required")
58
67
 
59
68
 
60
- def twirp_error_from_intermediary(status, reason, headers, body):
69
+ def twirp_error_from_intermediary(
70
+ status: int, reason: str | None, headers: CaseInsensitiveDict[str] | CIMultiDictProxy[str], body: str
71
+ ) -> TwirpServerException:
61
72
  # see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
62
- meta = {
73
+ meta: dict[str, str | None] = {
63
74
  "http_error_from_intermediary": "true",
64
75
  "status_code": str(status),
65
76
  }
@@ -1,47 +1,50 @@
1
+ from . import context
2
+
3
+
1
4
  class TwirpHook:
2
5
  # Called as soon as a request is received, always called
3
- def request_received(self, *args, ctx):
6
+ def request_received(self, *, ctx: context.Context) -> None:
4
7
  pass
5
8
 
6
9
  # Called once the request is routed, service name known, only called if request is routable
7
- def request_routed(self, *args, ctx):
10
+ def request_routed(self, *, ctx: context.Context) -> None:
8
11
  pass
9
12
 
10
13
  # Called once the response is prepared, not called for error cases
11
- def response_prepared(self, *args, ctx):
14
+ def response_prepared(self, *, ctx: context.Context) -> None:
12
15
  pass
13
16
 
14
17
  # Called if an error occurs
15
- def error(self, *args, ctx, exc):
18
+ def error(self, *, ctx: context.Context, exc: Exception) -> None:
16
19
  pass
17
20
 
18
21
  # Called after error is sent, always called
19
- def response_sent(self, *args, ctx):
22
+ def response_sent(self, *, ctx: context.Context) -> None:
20
23
  pass
21
24
 
22
25
 
23
26
  class ChainHooks(TwirpHook):
24
- def __init__(self, *hooks):
27
+ def __init__(self, *hooks: TwirpHook) -> None:
25
28
  for hook in hooks:
26
29
  assert isinstance(hook, TwirpHook)
27
30
  self._hooks = hooks
28
31
 
29
- def request_received(self, *args, ctx):
32
+ def request_received(self, *, ctx: context.Context) -> None:
30
33
  for hook in self._hooks:
31
34
  hook.request_received(ctx=ctx)
32
35
 
33
- def request_routed(self, *args, ctx):
36
+ def request_routed(self, *, ctx: context.Context) -> None:
34
37
  for hook in self._hooks:
35
38
  hook.request_routed(ctx=ctx)
36
39
 
37
- def response_prepared(self, *args, ctx):
40
+ def response_prepared(self, *, ctx: context.Context) -> None:
38
41
  for hook in self._hooks:
39
42
  hook.response_prepared(ctx=ctx)
40
43
 
41
- def error(self, *args, ctx, exc):
44
+ def error(self, *, ctx: context.Context, exc: Exception) -> None:
42
45
  for hook in self._hooks:
43
46
  hook.error(ctx=ctx, exc=exc)
44
47
 
45
- def response_sent(self, *args, ctx):
48
+ def response_sent(self, *, ctx: context.Context) -> None:
46
49
  for hook in self._hooks:
47
50
  hook.response_sent(ctx=ctx)
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import logging
3
+ from typing import Any
3
4
 
4
5
  import structlog
5
6
  from structlog.stdlib import LoggerFactory, add_log_level
@@ -7,7 +8,7 @@ from structlog.stdlib import LoggerFactory, add_log_level
7
8
  _configured = False
8
9
 
9
10
 
10
- def configure(force=False):
11
+ def configure(force: bool = False) -> None:
11
12
  """
12
13
  Configures logging & structlog modules
13
14
 
@@ -52,11 +53,11 @@ def configure(force=False):
52
53
  _configured = True
53
54
 
54
55
 
55
- def get_logger(**kwargs):
56
+ def get_logger(**kwargs: Any) -> structlog.stdlib.BoundLogger:
56
57
  """
57
58
  Get the structlog logger
58
59
  """
59
60
  # Configure logging modules
60
61
  configure()
61
62
  # Return structlog
62
- return structlog.get_logger(**kwargs)
63
+ return structlog.stdlib.get_logger(**kwargs)
File without changes
@@ -1,18 +1,21 @@
1
+ from typing import Any
2
+
3
+ from .endpoint import Endpoint
1
4
  from . import exceptions
2
5
  from . import errors
3
6
 
4
7
 
5
8
  class TwirpServer:
6
- def __init__(self, *args, service):
7
- self.service = service
8
- self._endpoints = {}
9
- self._prefix = ""
9
+ def __init__(self, *, service: Any) -> None:
10
+ self.service: Any = service
11
+ self._endpoints: dict[str, Endpoint] = {}
12
+ self._prefix: str = ""
10
13
 
11
14
  @property
12
- def prefix(self):
15
+ def prefix(self) -> str:
13
16
  return self._prefix
14
17
 
15
- def get_endpoint(self, path):
18
+ def get_endpoint(self, path: str) -> Endpoint:
16
19
  (_, url_pre, rpc_method) = path.rpartition(self._prefix + "/")
17
20
  if not url_pre or not rpc_method:
18
21
  raise exceptions.TwirpServerException(
@@ -21,7 +24,7 @@ class TwirpServer:
21
24
  meta={"twirp_invalid_route": "POST " + path},
22
25
  )
23
26
 
24
- endpoint = self._endpoints.get(rpc_method, None)
27
+ endpoint: Endpoint | None = self._endpoints.get(rpc_method, None)
25
28
  if not endpoint:
26
29
  raise exceptions.TwirpServerException(
27
30
  code=errors.Errors.Unimplemented,
@@ -1 +0,0 @@
1
- __version__ = "0.1.0-dev.4"
File without changes
File without changes