asgi-compression 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: asgi-compression
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.9
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: multidict>=6.2.0
8
+ Provides-Extra: all
9
+ Requires-Dist: brotli>=1.1.0; extra == "all"
10
+ Requires-Dist: zstandard>=0.23.0; extra == "all"
11
+ Provides-Extra: br
12
+ Requires-Dist: brotli>=1.1.0; extra == "br"
13
+ Provides-Extra: zstd
14
+ Requires-Dist: zstandard>=0.23.0; extra == "zstd"
File without changes
@@ -0,0 +1,17 @@
1
+ from .base import CompressionAlgorithm, ContentEncoding
2
+ from .brotli import BrotliAlgorithm, BrotliMode
3
+ from .gzip import GzipAlgorithm
4
+ from .identity import IdentityAlgorithm
5
+ from .middleware import CompressionMiddleware
6
+ from .zstd import ZstdAlgorithm
7
+
8
+ __all__ = [
9
+ "CompressionMiddleware",
10
+ "CompressionAlgorithm",
11
+ "ContentEncoding",
12
+ "GzipAlgorithm",
13
+ "BrotliAlgorithm",
14
+ "BrotliMode",
15
+ "IdentityAlgorithm",
16
+ "ZstdAlgorithm",
17
+ ]
@@ -0,0 +1,140 @@
1
+ import typing
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+
6
+ from .types import ASGIApp, Headers, Message, Receive, Scope, Send
7
+
8
+ DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
9
+ DEFAULT_MINIMUM_SIZE = 500
10
+
11
+
12
+ class ContentEncoding(str, Enum):
13
+ GZIP = "gzip"
14
+ BROTLI = "br"
15
+ ZSTD = "zstd"
16
+ IDENTITY = "identity"
17
+
18
+
19
+ async def unattached_send(message: Message) -> typing.NoReturn:
20
+ raise RuntimeError("send awaitable not set") # pragma: no cover
21
+
22
+
23
+ class CompressionResponder(ABC):
24
+ """Base class for all compression responders."""
25
+
26
+ content_encoding: ContentEncoding
27
+
28
+ def __init__(self, app: ASGIApp, minimum_size: int) -> None:
29
+ self.app = app
30
+ self.minimum_size = minimum_size
31
+ self._send: Send = unattached_send
32
+ self._initial_message: Message = {}
33
+ self._started = False
34
+ self._content_encoding_set = False
35
+ self._content_type_is_excluded = False
36
+
37
+ async def __call__(
38
+ self,
39
+ scope: Scope,
40
+ receive: Receive,
41
+ send: Send,
42
+ ) -> None:
43
+ self._send = send
44
+ await self.app(scope, receive, self.send_with_compression)
45
+
46
+ async def send_with_compression(self, message: Message) -> None:
47
+ message_type = message["type"]
48
+ if message_type == "http.response.start":
49
+ # Don't send the initial message until we've determined how to
50
+ # modify the outgoing headers correctly.
51
+ self._initial_message = message
52
+ headers = Headers(raw=self._initial_message["headers"])
53
+
54
+ self._content_encoding_set = "content-encoding" in headers
55
+ self._content_type_is_excluded = headers.get(
56
+ "content-type", ""
57
+ ).startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
58
+
59
+ elif message_type == "http.response.body" and (
60
+ self._content_encoding_set or self._content_type_is_excluded
61
+ ):
62
+ if not self._started:
63
+ self._started = True
64
+ await self._send(self._initial_message)
65
+ await self._send(message)
66
+
67
+ elif message_type == "http.response.body" and not self._started:
68
+ self._started = True
69
+ body = message.get("body", b"")
70
+ more_body = message.get("more_body", False)
71
+
72
+ if len(body) < self.minimum_size and not more_body:
73
+ # Don't apply compression to small outgoing responses.
74
+ # Don't add Vary header for small responses
75
+ await self._send(self._initial_message)
76
+ await self._send(message)
77
+ elif not more_body:
78
+ # Standard response.
79
+ body = self.apply_compression(body, more_body=False)
80
+
81
+ headers = Headers(raw=self._initial_message["headers"])
82
+ headers.add_vary_header("Accept-Encoding")
83
+
84
+ if body != message["body"]:
85
+ headers["Content-Encoding"] = self.content_encoding
86
+ headers["Content-Length"] = str(len(body))
87
+ message["body"] = body
88
+
89
+ self._initial_message["headers"] = headers.encode()
90
+ await self._send(self._initial_message)
91
+ await self._send(message)
92
+ else:
93
+ # Initial body in streaming response.
94
+ body = self.apply_compression(body, more_body=True)
95
+
96
+ headers = Headers(raw=self._initial_message["headers"])
97
+ headers.add_vary_header("Accept-Encoding")
98
+
99
+ if body != message["body"]:
100
+ headers["Content-Encoding"] = self.content_encoding
101
+ if "Content-Length" in headers:
102
+ del headers["Content-Length"]
103
+
104
+ message["body"] = body
105
+
106
+ self._initial_message["headers"] = headers.encode()
107
+ await self._send(self._initial_message)
108
+ await self._send(message)
109
+ elif message_type == "http.response.body": # pragma: no branch
110
+ # Remaining body in streaming response.
111
+ body = message.get("body", b"")
112
+ more_body = message.get("more_body", False)
113
+
114
+ message["body"] = self.apply_compression(body, more_body=more_body)
115
+ await self._send(message)
116
+
117
+ @abstractmethod
118
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
119
+ """Apply compression on the response body.
120
+
121
+ If more_body is False, any compression file should be closed. If it
122
+ isn't, it won't be closed automatically until all background tasks
123
+ complete.
124
+ """
125
+ raise NotImplementedError
126
+
127
+
128
+ @dataclass
129
+ class CompressionAlgorithm(ABC):
130
+ """Base class for compression algorithms."""
131
+
132
+ type: ContentEncoding
133
+ minimum_size: int = DEFAULT_MINIMUM_SIZE
134
+
135
+ def create_responder(self, app: ASGIApp) -> "CompressionResponder":
136
+ """Create a responder for this compression algorithm."""
137
+ raise NotImplementedError
138
+
139
+ def check_available(self) -> None:
140
+ """Check if the algorithm is available in the current environment."""
@@ -0,0 +1,101 @@
1
+ import io
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING
5
+
6
+ from .base import CompressionAlgorithm, CompressionResponder, ContentEncoding
7
+ from .types import ASGIApp
8
+
9
+ if TYPE_CHECKING:
10
+ import brotli
11
+
12
+
13
+ def import_brotli() -> None:
14
+ global brotli
15
+ try:
16
+ import brotli
17
+ except ImportError as e:
18
+ raise ImportError(
19
+ "brotli is not installed, run `pip install brotli`]"
20
+ ) from e
21
+
22
+
23
+ class BrotliMode(Enum):
24
+ TEXT = "text"
25
+ FONT = "font"
26
+ GENERIC = "generic"
27
+
28
+ def to_brotli_mode(self) -> int:
29
+ if self == BrotliMode.TEXT:
30
+ return brotli.MODE_TEXT
31
+ elif self == BrotliMode.FONT:
32
+ return brotli.MODE_FONT
33
+ elif self == BrotliMode.GENERIC:
34
+ return brotli.MODE_GENERIC
35
+ else:
36
+ assert False, f"Expected code to be unreachable, but got: {self}"
37
+
38
+
39
+ class BrotliResponder(CompressionResponder):
40
+ """Responder that applies brotli compression."""
41
+
42
+ content_encoding = ContentEncoding.BROTLI
43
+
44
+ def __init__(
45
+ self,
46
+ app: ASGIApp,
47
+ minimum_size: int,
48
+ quality: int = 4,
49
+ mode: BrotliMode = BrotliMode.TEXT,
50
+ lgwin: int = 22,
51
+ lgblock: int = 0,
52
+ ) -> None:
53
+ super().__init__(app, minimum_size)
54
+
55
+ import_brotli()
56
+
57
+ self.brotli_buffer = io.BytesIO()
58
+ self.compressor = brotli.Compressor(
59
+ quality=quality,
60
+ mode=mode.to_brotli_mode(),
61
+ lgwin=lgwin,
62
+ lgblock=lgblock,
63
+ )
64
+
65
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
66
+ compressed = self.compressor.process(body)
67
+ self.brotli_buffer.write(compressed)
68
+
69
+ if not more_body:
70
+ final_data = self.compressor.finish()
71
+ self.brotli_buffer.write(final_data)
72
+
73
+ compressed_data = self.brotli_buffer.getvalue()
74
+
75
+ self.brotli_buffer.seek(0)
76
+ self.brotli_buffer.truncate()
77
+ return compressed_data
78
+
79
+
80
+ @dataclass
81
+ class BrotliAlgorithm(CompressionAlgorithm):
82
+ """Brotli compression algorithm."""
83
+
84
+ type: ContentEncoding = ContentEncoding.BROTLI
85
+ quality: int = 4
86
+ mode: BrotliMode = BrotliMode.TEXT
87
+ lgwin: int = 22
88
+ lgblock: int = 0
89
+
90
+ def create_responder(self, app: ASGIApp) -> "BrotliResponder":
91
+ return BrotliResponder(
92
+ app=app,
93
+ minimum_size=self.minimum_size,
94
+ quality=self.quality,
95
+ mode=self.mode,
96
+ lgwin=self.lgwin,
97
+ lgblock=self.lgblock,
98
+ )
99
+
100
+ def check_available(self) -> None:
101
+ import_brotli()
@@ -0,0 +1,61 @@
1
+ import gzip
2
+ import io
3
+ from dataclasses import dataclass
4
+
5
+ from .base import CompressionAlgorithm, CompressionResponder, ContentEncoding
6
+ from .types import ASGIApp, Receive, Scope, Send
7
+
8
+
9
+ class GzipResponder(CompressionResponder):
10
+ """Responder that applies gzip compression."""
11
+
12
+ content_encoding = ContentEncoding.GZIP
13
+
14
+ def __init__(
15
+ self,
16
+ app: ASGIApp,
17
+ minimum_size: int,
18
+ compresslevel: int = 9,
19
+ ) -> None:
20
+ super().__init__(app, minimum_size)
21
+
22
+ self.gzip_buffer = io.BytesIO()
23
+ self.gzip_file = gzip.GzipFile(
24
+ mode="wb",
25
+ fileobj=self.gzip_buffer,
26
+ compresslevel=compresslevel,
27
+ )
28
+
29
+ async def __call__(
30
+ self,
31
+ scope: Scope,
32
+ receive: Receive,
33
+ send: Send,
34
+ ) -> None:
35
+ with self.gzip_buffer, self.gzip_file:
36
+ await super().__call__(scope, receive, send)
37
+
38
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
39
+ self.gzip_file.write(body)
40
+ if not more_body:
41
+ self.gzip_file.close()
42
+
43
+ body = self.gzip_buffer.getvalue()
44
+ self.gzip_buffer.seek(0)
45
+ self.gzip_buffer.truncate()
46
+ return body
47
+
48
+
49
+ @dataclass
50
+ class GzipAlgorithm(CompressionAlgorithm):
51
+ """Gzip compression algorithm."""
52
+
53
+ type: ContentEncoding = ContentEncoding.GZIP
54
+ compresslevel: int = 9
55
+
56
+ def create_responder(self, app: ASGIApp) -> GzipResponder:
57
+ return GzipResponder(
58
+ app=app,
59
+ minimum_size=self.minimum_size,
60
+ compresslevel=self.compresslevel,
61
+ )
@@ -0,0 +1,23 @@
1
+ from dataclasses import dataclass
2
+
3
+ from .base import CompressionAlgorithm, CompressionResponder, ContentEncoding
4
+ from .types import ASGIApp
5
+
6
+
7
+ class IdentityResponder(CompressionResponder):
8
+ """Responder that doesn't apply any compression."""
9
+
10
+ content_encoding = ContentEncoding.IDENTITY
11
+
12
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
13
+ return body
14
+
15
+
16
+ @dataclass
17
+ class IdentityAlgorithm(CompressionAlgorithm):
18
+ """No compression, identity algorithm."""
19
+
20
+ type: ContentEncoding = ContentEncoding.IDENTITY
21
+
22
+ def create_responder(self, app: ASGIApp) -> IdentityResponder:
23
+ return IdentityResponder(app=app, minimum_size=self.minimum_size)
@@ -0,0 +1,82 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from .base import (
4
+ DEFAULT_MINIMUM_SIZE,
5
+ CompressionAlgorithm,
6
+ CompressionResponder,
7
+ )
8
+ from .identity import IdentityAlgorithm
9
+ from .types import ASGIApp, Headers, Receive, Scope, Send
10
+
11
+
12
+ class CompressionMiddleware:
13
+ """
14
+ Unified ASGI middleware for response compression.
15
+
16
+ Supports multiple compression algorithms and automatically negotiates
17
+ the best available algorithm based on the client's Accept-Encoding header.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ app: ASGIApp,
23
+ algorithms: Optional[List[CompressionAlgorithm]] = None,
24
+ minimum_size: int = DEFAULT_MINIMUM_SIZE,
25
+ ) -> None:
26
+ """
27
+ Initialize the compression middleware.
28
+
29
+ Args:
30
+ app: The ASGI application.
31
+ algorithms: List of compression algorithms to use, in order of preference.
32
+ If not provided, no compression will be applied.
33
+ minimum_size: The minimum response size to apply compression.
34
+ This will be used as the default for algorithms that don't specify it.
35
+ """
36
+
37
+ self.app = app
38
+ self.minimum_size = minimum_size
39
+
40
+ self.algorithms = algorithms or []
41
+ for algorithm in self.algorithms:
42
+ try:
43
+ algorithm.check_available()
44
+ except ImportError as e:
45
+ raise e from e
46
+
47
+ self._default_algorithm = IdentityAlgorithm(minimum_size=minimum_size)
48
+
49
+ # Set minimum_size if not explicitly set in the algorithm
50
+ for algorithm in self.algorithms:
51
+ if (
52
+ algorithm.minimum_size == DEFAULT_MINIMUM_SIZE
53
+ and minimum_size != DEFAULT_MINIMUM_SIZE
54
+ ):
55
+ algorithm.minimum_size = minimum_size
56
+
57
+ async def __call__(
58
+ self,
59
+ scope: Scope,
60
+ receive: Receive,
61
+ send: Send,
62
+ ) -> None:
63
+ """ASGI application interface."""
64
+ if scope["type"] != "http": # pragma: no cover
65
+ await self.app(scope, receive, send)
66
+ return
67
+
68
+ headers = Headers(scope=scope)
69
+ accept_encoding = headers.get("Accept-Encoding", "")
70
+
71
+ # Find the first supported algorithm that matches the Accept-Encoding header
72
+ responder: Union[CompressionResponder, None] = None
73
+ for algorithm in self.algorithms:
74
+ if str(algorithm.type.value) in accept_encoding:
75
+ responder = algorithm.create_responder(self.app)
76
+ break
77
+
78
+ # If no matching algorithm, use identity (no compression)
79
+ if responder is None:
80
+ responder = self._default_algorithm.create_responder(self.app)
81
+
82
+ await responder(scope, receive, send)
@@ -0,0 +1,63 @@
1
+ from typing import (
2
+ Any,
3
+ Awaitable,
4
+ Callable,
5
+ List,
6
+ Mapping,
7
+ MutableMapping,
8
+ Optional,
9
+ )
10
+
11
+ from multidict import CIMultiDict
12
+
13
+ Scope = MutableMapping[str, Any]
14
+ Message = MutableMapping[str, Any]
15
+ Receive = Callable[[], Awaitable[Message]]
16
+ Send = Callable[[Message], Awaitable[None]]
17
+ ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
18
+
19
+
20
+ class Headers(CIMultiDict[str]):
21
+ def __init__(
22
+ self,
23
+ headers: Optional[Mapping[str, str]] = None,
24
+ raw: Optional[list[tuple[bytes, bytes]]] = None,
25
+ scope: Optional[Scope] = None,
26
+ ) -> None:
27
+ headers_list: List[tuple[str, str]] = []
28
+ if headers is not None:
29
+ assert raw is None, 'Cannot set both "headers" and "raw".'
30
+ assert scope is None, 'Cannot set both "headers" and "scope".'
31
+ headers_list = list(headers.items())
32
+ elif raw is not None:
33
+ assert scope is None, 'Cannot set both "raw" and "scope".'
34
+ headers_list = [
35
+ (key.decode("latin-1"), value.decode("latin-1"))
36
+ for key, value in raw
37
+ ]
38
+ elif scope is not None:
39
+ # scope["headers"] isn't necessarily a list
40
+ # it might be a tuple or other iterable
41
+ scope_headers = scope["headers"] = list(scope["headers"])
42
+ headers_list = [
43
+ (key.decode("latin-1"), value.decode("latin-1"))
44
+ for key, value in scope_headers
45
+ ]
46
+
47
+ super().__init__(headers_list)
48
+
49
+ def add_vary_header(self, vary: str) -> None:
50
+ existing = self.get("vary")
51
+ if existing is not None:
52
+ # Check if the value is already in the Vary header to avoid duplication
53
+ values = [x.strip() for x in existing.split(",")]
54
+ if vary not in values:
55
+ vary = f"{existing}, {vary}"
56
+
57
+ self["vary"] = vary
58
+
59
+ def encode(self) -> list[tuple[bytes, bytes]]:
60
+ return [
61
+ (key.encode("latin-1"), value.encode("latin-1"))
62
+ for key, value in self.items()
63
+ ]
@@ -0,0 +1,92 @@
1
+ import io
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .base import CompressionAlgorithm, CompressionResponder, ContentEncoding
6
+ from .types import ASGIApp, Receive, Scope, Send
7
+
8
+ if TYPE_CHECKING:
9
+ import zstandard
10
+
11
+
12
+ def import_zstandard() -> None:
13
+ global zstandard
14
+ try:
15
+ import zstandard
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ "zstandard is not installed, run `pip install zstandard`]"
19
+ ) from e
20
+
21
+
22
+ class ZstdResponder(CompressionResponder):
23
+ """Responder that applies Zstandard compression."""
24
+
25
+ content_encoding = ContentEncoding.ZSTD
26
+
27
+ def __init__(
28
+ self,
29
+ app: ASGIApp,
30
+ minimum_size: int,
31
+ level: int = 3,
32
+ threads: int = 0,
33
+ write_checksum: bool = False,
34
+ write_content_size: bool = True,
35
+ ) -> None:
36
+ super().__init__(app, minimum_size)
37
+
38
+ import_zstandard()
39
+
40
+ self.zstd_buffer = io.BytesIO()
41
+ self.compressor = zstandard.ZstdCompressor(
42
+ level=level,
43
+ threads=threads,
44
+ write_checksum=write_checksum,
45
+ write_content_size=write_content_size,
46
+ )
47
+ self.compression_stream = self.compressor.stream_writer(
48
+ self.zstd_buffer
49
+ )
50
+
51
+ async def __call__(
52
+ self,
53
+ scope: Scope,
54
+ receive: Receive,
55
+ send: Send,
56
+ ) -> None:
57
+ with self.zstd_buffer, self.compression_stream:
58
+ await super().__call__(scope, receive, send)
59
+
60
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
61
+ self.compression_stream.write(body)
62
+ if not more_body:
63
+ self.compression_stream.flush(zstandard.FLUSH_FRAME)
64
+
65
+ body = self.zstd_buffer.getvalue()
66
+ self.zstd_buffer.seek(0)
67
+ self.zstd_buffer.truncate()
68
+ return body
69
+
70
+
71
+ @dataclass
72
+ class ZstdAlgorithm(CompressionAlgorithm):
73
+ """Zstandard compression algorithm."""
74
+
75
+ type: ContentEncoding = ContentEncoding.ZSTD
76
+ level: int = 3
77
+ threads: int = 0
78
+ write_checksum: bool = False
79
+ write_content_size: bool = True
80
+
81
+ def create_responder(self, app: ASGIApp) -> ZstdResponder:
82
+ return ZstdResponder(
83
+ app=app,
84
+ minimum_size=self.minimum_size,
85
+ level=self.level,
86
+ threads=self.threads,
87
+ write_checksum=self.write_checksum,
88
+ write_content_size=self.write_content_size,
89
+ )
90
+
91
+ def check_available(self) -> None:
92
+ import_zstandard()
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: asgi-compression
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.9
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: multidict>=6.2.0
8
+ Provides-Extra: all
9
+ Requires-Dist: brotli>=1.1.0; extra == "all"
10
+ Requires-Dist: zstandard>=0.23.0; extra == "all"
11
+ Provides-Extra: br
12
+ Requires-Dist: brotli>=1.1.0; extra == "br"
13
+ Provides-Extra: zstd
14
+ Requires-Dist: zstandard>=0.23.0; extra == "zstd"
@@ -0,0 +1,16 @@
1
+ README.md
2
+ pyproject.toml
3
+ asgi_compression/__init__.py
4
+ asgi_compression/base.py
5
+ asgi_compression/brotli.py
6
+ asgi_compression/gzip.py
7
+ asgi_compression/identity.py
8
+ asgi_compression/middleware.py
9
+ asgi_compression/types.py
10
+ asgi_compression/zstd.py
11
+ asgi_compression.egg-info/PKG-INFO
12
+ asgi_compression.egg-info/SOURCES.txt
13
+ asgi_compression.egg-info/dependency_links.txt
14
+ asgi_compression.egg-info/requires.txt
15
+ asgi_compression.egg-info/top_level.txt
16
+ tests/test_middleware.py
@@ -0,0 +1,11 @@
1
+ multidict>=6.2.0
2
+
3
+ [all]
4
+ brotli>=1.1.0
5
+ zstandard>=0.23.0
6
+
7
+ [br]
8
+ brotli>=1.1.0
9
+
10
+ [zstd]
11
+ zstandard>=0.23.0
@@ -0,0 +1 @@
1
+ asgi_compression
@@ -0,0 +1,52 @@
1
+ [project]
2
+ name = "asgi-compression"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.9"
7
+ dependencies = [
8
+ "multidict>=6.2.0",
9
+ ]
10
+
11
+ [project.optional-dependencies]
12
+ all = [
13
+ "brotli>=1.1.0",
14
+ "zstandard>=0.23.0",
15
+ ]
16
+ br = [
17
+ "brotli>=1.1.0",
18
+ ]
19
+ zstd = [
20
+ "zstandard>=0.23.0",
21
+ ]
22
+
23
+ [dependency-groups]
24
+ dev = [
25
+ "brotli>=1.1.0",
26
+ "django>=4.2.20",
27
+ "django-eventstream>=4.5.1",
28
+ "httpx>=0.28.1",
29
+ "litestar>=2.15.1",
30
+ "pyright>=1.1.399",
31
+ "pytest>=8.3.5",
32
+ "pytest-asyncio>=0.25.3",
33
+ "ruff>=0.11.6",
34
+ "sse-starlette>=2.2.1",
35
+ "starlette>=0.46.1",
36
+ "typing-extensions>=4.12.2",
37
+ "zstandard>=0.23.0",
38
+ ]
39
+
40
+ [tool.pytest.ini_options]
41
+ asyncio_mode = "auto"
42
+ asyncio_default_fixture_loop_scope = "session"
43
+
44
+ [tool.ruff]
45
+ line-length = 80
46
+
47
+ [tool.ruff.lint.isort]
48
+ known-first-party = ["asgi-compression"]
49
+
50
+ [build-system]
51
+ requires = ["setuptools"]
52
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,303 @@
1
+ from typing import Any
2
+
3
+ import pytest
4
+ from httpx import AsyncClient
5
+ from starlette.applications import Starlette
6
+ from starlette.responses import PlainTextResponse
7
+ from starlette.routing import Route
8
+ from typing_extensions import assert_never
9
+
10
+ from asgi_compression import brotli, zstd
11
+ from asgi_compression.brotli import BrotliAlgorithm
12
+ from asgi_compression.gzip import GzipAlgorithm
13
+ from asgi_compression.middleware import CompressionMiddleware
14
+ from asgi_compression.zstd import ZstdAlgorithm
15
+
16
+ from .types import Encoding
17
+ from .utils import get_test_client, unimport_module
18
+
19
+
20
+ def determine_encoding(request: pytest.FixtureRequest) -> Encoding:
21
+ id_str = request.node.callspec.id
22
+ if "gzip" in id_str:
23
+ return "gzip"
24
+ elif "br" in id_str:
25
+ return "br"
26
+ elif "zstd" in id_str:
27
+ return "zstd"
28
+ else:
29
+ assert_never(id_str)
30
+
31
+
32
+ async def test_middleware_responds_to_proper_encoding(
33
+ client: AsyncClient,
34
+ request: pytest.FixtureRequest,
35
+ ) -> None:
36
+ encoding = determine_encoding(request)
37
+ response = await client.get("/", headers={"accept-encoding": encoding})
38
+ assert response.status_code == 200
39
+ assert response.text == "x" * 4000
40
+ assert response.headers["Content-Encoding"] == encoding
41
+ assert response.headers["Vary"] == "Accept-Encoding"
42
+ assert int(response.headers["Content-Length"]) < 4000
43
+
44
+
45
+ async def test_middleware_handles_identity_encoding(
46
+ client: AsyncClient,
47
+ ) -> None:
48
+ response = await client.get("/", headers={"accept-encoding": "identity"})
49
+ assert response.status_code == 200
50
+ assert response.text == "x" * 4000
51
+ assert "Content-Encoding" not in response.headers
52
+ assert response.headers["Vary"] == "Accept-Encoding"
53
+
54
+
55
+ async def test_compression_not_in_accept_encoding(client: AsyncClient) -> None:
56
+ response = await client.get("/", headers={"accept-encoding": "identity"})
57
+ assert response.status_code == 200
58
+ assert response.text == "x" * 4000
59
+ assert "Content-Encoding" not in response.headers
60
+ assert response.headers["Vary"] == "Accept-Encoding"
61
+
62
+ content_length = response.headers.get("Content-Length")
63
+ if content_length:
64
+ # Django doesn't set 'Content-Length' header automatically
65
+ assert int(content_length) == 4000
66
+
67
+
68
+ async def test_compression_ignored_for_small_responses(
69
+ client: AsyncClient,
70
+ ) -> None:
71
+ response = await client.get(
72
+ "/small_response",
73
+ headers={"accept-encoding": "gzip, br"},
74
+ )
75
+ expected_response = "Hello world!"
76
+ assert response.status_code == 200
77
+ assert response.text == expected_response
78
+ assert "Content-Encoding" not in response.headers
79
+ assert "Vary" not in response.headers
80
+
81
+ content_length = response.headers.get("Content-Length")
82
+ if content_length:
83
+ # Django doesn't set 'Content-Length' header automatically
84
+ assert int(content_length) == len(expected_response)
85
+
86
+
87
+ async def test_compression_streaming_response(
88
+ client: AsyncClient,
89
+ request: pytest.FixtureRequest,
90
+ ) -> None:
91
+ encoding = determine_encoding(request)
92
+
93
+ response = await client.get(
94
+ "/streaming_response",
95
+ headers={"accept-encoding": encoding},
96
+ )
97
+ assert response.status_code == 200
98
+ assert response.text == "x" * 4000
99
+ assert response.headers["Content-Encoding"] == encoding
100
+ assert response.headers["Vary"] == "Accept-Encoding"
101
+ assert "Content-Length" not in response.headers
102
+
103
+
104
+ async def test_compression_streaming_response_identity(
105
+ client: AsyncClient,
106
+ ) -> None:
107
+ response = await client.get(
108
+ "/streaming_response",
109
+ headers={"accept-encoding": "identity"},
110
+ )
111
+ assert response.status_code == 200
112
+ assert response.text == "x" * 4000
113
+ assert "Content-Encoding" not in response.headers
114
+ assert response.headers["Vary"] == "Accept-Encoding"
115
+ assert "Content-Length" not in response.headers
116
+
117
+
118
+ async def test_compression_ignored_for_responses_with_encoding_set(
119
+ client: AsyncClient,
120
+ ) -> None:
121
+ response = await client.get(
122
+ "/streaming_response_with_content_encoding",
123
+ headers={"accept-encoding": "gzip, br, text"},
124
+ )
125
+ assert response.status_code == 200
126
+ assert response.text == "x" * 4000
127
+ assert response.headers["Content-Encoding"] == "text"
128
+ assert "Vary" not in response.headers
129
+ assert "Content-Length" not in response.headers
130
+
131
+
132
+ async def test_compression_ignored_on_server_sent_events(
133
+ client: AsyncClient,
134
+ ) -> None:
135
+ async with client.stream(
136
+ "GET",
137
+ "/server_sent_events",
138
+ headers={"accept-encoding": "gzip, br"},
139
+ ) as response:
140
+ assert response.status_code == 200
141
+ assert "Content-Encoding" not in response.headers
142
+ assert "Content-Length" not in response.headers
143
+ assert response.headers["Content-Type"].startswith("text/event-stream")
144
+
145
+ current_message: dict[str, Any] = {}
146
+ completed_messages = []
147
+
148
+ async for line in response.aiter_lines():
149
+ line = line.rstrip("\r")
150
+
151
+ if not line.strip():
152
+ if current_message:
153
+ completed_messages.append(current_message)
154
+
155
+ current_message = {}
156
+ continue
157
+
158
+ if ":" in line:
159
+ field, value = line.split(":", 1)
160
+ value = value[1:] if value.startswith(" ") else value
161
+ current_message[field] = value
162
+
163
+ for i, message in enumerate(completed_messages):
164
+ assert message == {
165
+ "id": str(i),
166
+ "event": "message",
167
+ "data": "x" * 400,
168
+ }
169
+
170
+
171
+ async def test_multiple_algorithms_negotiation():
172
+ """Test that the middleware correctly negotiates between multiple algorithms."""
173
+
174
+ async def homepage(request):
175
+ return PlainTextResponse("x" * 4000)
176
+
177
+ app = Starlette(routes=[Route("/", endpoint=homepage)])
178
+
179
+ middleware = CompressionMiddleware(
180
+ app=app,
181
+ algorithms=[
182
+ # Order matters - first match will be used
183
+ BrotliAlgorithm(quality=4),
184
+ GzipAlgorithm(compresslevel=6),
185
+ ZstdAlgorithm(level=3),
186
+ ],
187
+ )
188
+
189
+ async with get_test_client(middleware) as client:
190
+ response = await client.get(
191
+ "/",
192
+ headers={"accept-encoding": "br, gzip, zstd"},
193
+ )
194
+ assert response.status_code == 200
195
+ assert response.headers["Content-Encoding"] == "br"
196
+
197
+ response = await client.get(
198
+ "/", headers={"accept-encoding": "gzip, zstd"}
199
+ )
200
+ assert response.status_code == 200
201
+ assert response.headers["Content-Encoding"] == "gzip"
202
+
203
+ response = await client.get("/", headers={"accept-encoding": "zstd"})
204
+ assert response.status_code == 200
205
+ assert response.headers["Content-Encoding"] == "zstd"
206
+
207
+ response = await client.get(
208
+ "/",
209
+ headers={"accept-encoding": "identity"},
210
+ )
211
+ assert response.status_code == 200
212
+ assert "Content-Encoding" not in response.headers
213
+
214
+
215
+ async def test_custom_minimum_size():
216
+ """Test that the minimum size setting is respected."""
217
+
218
+ async def small_response(request):
219
+ return PlainTextResponse("x" * 100)
220
+
221
+ async def large_response(request):
222
+ return PlainTextResponse("x" * 2000)
223
+
224
+ app = Starlette(
225
+ routes=[
226
+ Route("/small", endpoint=small_response),
227
+ Route("/large", endpoint=large_response),
228
+ ]
229
+ )
230
+
231
+ # Set a high minimum size
232
+ middleware = CompressionMiddleware(
233
+ app=app,
234
+ algorithms=[GzipAlgorithm()],
235
+ minimum_size=1000, # Only compress responses larger than 1KB
236
+ )
237
+
238
+ async with get_test_client(middleware) as client:
239
+ # Small response should not be compressed
240
+ response = await client.get(
241
+ "/small",
242
+ headers={"accept-encoding": "gzip"},
243
+ )
244
+ assert response.status_code == 200
245
+ assert "Content-Encoding" not in response.headers
246
+
247
+ # Large response should be compressed
248
+ response = await client.get(
249
+ "/large",
250
+ headers={"accept-encoding": "gzip"},
251
+ )
252
+ assert response.status_code == 200
253
+ assert response.headers["Content-Encoding"] == "gzip"
254
+
255
+
256
+ async def test_zstd_compression():
257
+ """Test specific Zstandard compression features."""
258
+
259
+ async def homepage(request):
260
+ return PlainTextResponse("x" * 4000)
261
+
262
+ app = Starlette(routes=[Route("/", endpoint=homepage)])
263
+
264
+ middleware = CompressionMiddleware(app=app, algorithms=[ZstdAlgorithm()])
265
+
266
+ async with get_test_client(middleware) as client:
267
+ response = await client.get(
268
+ "/",
269
+ headers={"accept-encoding": "zstd"},
270
+ )
271
+ assert response.status_code == 200
272
+ assert response.text == "x" * 4000
273
+ assert response.headers["Content-Encoding"] == "zstd"
274
+ assert response.headers["Vary"] == "Accept-Encoding"
275
+ assert int(response.headers["Content-Length"]) < 4000
276
+
277
+
278
+ def test_brotli_not_available(monkeypatch: pytest.MonkeyPatch):
279
+ unimport_module(
280
+ monkeypatch=monkeypatch,
281
+ module_name="brotli",
282
+ to_reload=brotli,
283
+ )
284
+ app = Starlette(routes=[])
285
+ with pytest.raises(ImportError):
286
+ CompressionMiddleware(
287
+ app=app,
288
+ algorithms=[BrotliAlgorithm()],
289
+ )
290
+
291
+
292
+ def test_zstd_not_available(monkeypatch: pytest.MonkeyPatch):
293
+ unimport_module(
294
+ monkeypatch=monkeypatch,
295
+ module_name="zstandard",
296
+ to_reload=zstd,
297
+ )
298
+ app = Starlette(routes=[])
299
+ with pytest.raises(ImportError):
300
+ CompressionMiddleware(
301
+ app=app,
302
+ algorithms=[ZstdAlgorithm()],
303
+ )