n6k 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- n6k-0.1.0/.claude/settings.local.json +9 -0
- n6k-0.1.0/.flake8 +2 -0
- n6k-0.1.0/.gitignore +1 -0
- n6k-0.1.0/.python-version +1 -0
- n6k-0.1.0/PKG-INFO +9 -0
- n6k-0.1.0/b +0 -0
- n6k-0.1.0/demo/app.py +26 -0
- n6k-0.1.0/frontend.md +30 -0
- n6k-0.1.0/pyproject.toml +28 -0
- n6k-0.1.0/src/n6k_server/__init__.py +1 -0
- n6k-0.1.0/src/n6k_server/app.py +44 -0
- n6k-0.1.0/src/n6k_server/formats.py +58 -0
- n6k-0.1.0/src/n6k_server/pushdown.py +70 -0
- n6k-0.1.0/src/n6k_server/routes/__init__.py +0 -0
- n6k-0.1.0/src/n6k_server/routes/tables.py +47 -0
- n6k-0.1.0/tests/__init__.py +0 -0
- n6k-0.1.0/tests/test_demo.py +128 -0
- n6k-0.1.0/tests/test_formats.py +84 -0
- n6k-0.1.0/tests/test_pushdown.py +129 -0
- n6k-0.1.0/uv.lock +793 -0
n6k-0.1.0/.flake8
ADDED
n6k-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__pycache__/
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
n6k-0.1.0/PKG-INFO
ADDED
n6k-0.1.0/b
ADDED
|
Binary file
|
n6k-0.1.0/demo/app.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import seaborn as sns
|
|
3
|
+
from starlette.responses import PlainTextResponse
|
|
4
|
+
|
|
5
|
+
import n6k_server
|
|
6
|
+
|
|
7
|
+
app = n6k_server.create_app()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@n6k_server.data(app, "demo")
|
|
11
|
+
def demo():
|
|
12
|
+
return pd.DataFrame({
|
|
13
|
+
"id": [1, 2, 3],
|
|
14
|
+
"name": ["alice", "bob", "charlie"],
|
|
15
|
+
"score": [95.5, 87.3, 91.8],
|
|
16
|
+
})
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@n6k_server.data(app, "iris")
|
|
20
|
+
def iris():
|
|
21
|
+
return sns.load_dataset("iris")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@app.get("/health")
|
|
25
|
+
async def health() -> PlainTextResponse:
|
|
26
|
+
return PlainTextResponse("ok")
|
n6k-0.1.0/frontend.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Data Protocol
|
|
2
|
+
|
|
3
|
+
All data responses are binary. No JSON.
|
|
4
|
+
|
|
5
|
+
## Endpoints
|
|
6
|
+
|
|
7
|
+
| Method | Path | Description |
|
|
8
|
+
|--------|------|-------------|
|
|
9
|
+
| GET | `/health` | Plain text `ok` |
|
|
10
|
+
| GET | `/api/demo` | Table: `id` (int64), `name` (utf8), `score` (float64) |
|
|
11
|
+
|
|
12
|
+
## Content Negotiation
|
|
13
|
+
|
|
14
|
+
Set the `Accept` header to choose the wire format:
|
|
15
|
+
|
|
16
|
+
| Accept | Format |
|
|
17
|
+
|--------|--------|
|
|
18
|
+
| `application/vnd.apache.arrow.stream` | Arrow IPC Stream |
|
|
19
|
+
| `application/vnd.apache.arrow.file` | Arrow IPC File |
|
|
20
|
+
| `application/x-parquet` | Parquet |
|
|
21
|
+
|
|
22
|
+
Default (missing or unrecognized Accept): **Arrow IPC Stream**.
|
|
23
|
+
|
|
24
|
+
First matching media type wins. Quality parameters are stripped, not weighted.
|
|
25
|
+
|
|
26
|
+
The response `Content-Type` will match the negotiated format.
|
|
27
|
+
|
|
28
|
+
## Errors
|
|
29
|
+
|
|
30
|
+
All errors are `text/plain`. No JSON error envelopes.
|
n6k-0.1.0/pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "n6k"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Arrow/Parquet REST server"
|
|
5
|
+
requires-python = ">=3.12"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"fastapi>=0.115",
|
|
8
|
+
"uvicorn[standard]>=0.34",
|
|
9
|
+
"pyarrow>=19",
|
|
10
|
+
"pandas>=2",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[dependency-groups]
|
|
14
|
+
dev = [
|
|
15
|
+
"httpx>=0.28",
|
|
16
|
+
"pytest>=8",
|
|
17
|
+
"pytest-asyncio>=0.25",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["hatchling"]
|
|
22
|
+
build-backend = "hatchling.build"
|
|
23
|
+
|
|
24
|
+
[tool.hatch.build.targets.wheel]
|
|
25
|
+
packages = ["src/n6k_server"]
|
|
26
|
+
|
|
27
|
+
[tool.pytest.ini_options]
|
|
28
|
+
asyncio_mode = "auto"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from n6k_server.app import create_app, data
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from fastapi import FastAPI, Request
|
|
4
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
5
|
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
6
|
+
from starlette.responses import PlainTextResponse
|
|
7
|
+
|
|
8
|
+
from n6k_server.routes import tables
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_app() -> FastAPI:
|
|
12
|
+
application = FastAPI()
|
|
13
|
+
application.state.data_sources = {}
|
|
14
|
+
|
|
15
|
+
application.add_middleware(
|
|
16
|
+
CORSMiddleware,
|
|
17
|
+
allow_origins=["*"],
|
|
18
|
+
allow_methods=["*"],
|
|
19
|
+
allow_headers=["*"],
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
@application.exception_handler(StarletteHTTPException)
|
|
23
|
+
async def http_exception_handler(
|
|
24
|
+
request: Request, exc: StarletteHTTPException
|
|
25
|
+
) -> PlainTextResponse:
|
|
26
|
+
return PlainTextResponse(exc.detail, status_code=exc.status_code)
|
|
27
|
+
|
|
28
|
+
@application.exception_handler(Exception)
|
|
29
|
+
async def plain_text_exception_handler(
|
|
30
|
+
request: Request, exc: Exception
|
|
31
|
+
) -> PlainTextResponse:
|
|
32
|
+
status = getattr(exc, "status_code", 500)
|
|
33
|
+
return PlainTextResponse(str(exc), status_code=status)
|
|
34
|
+
|
|
35
|
+
application.include_router(tables.router, prefix="/tables")
|
|
36
|
+
|
|
37
|
+
return application
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def data(app: FastAPI, name: str):
|
|
41
|
+
def decorator(fn):
|
|
42
|
+
app.state.data_sources[name] = fn
|
|
43
|
+
return fn
|
|
44
|
+
return decorator
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
from io import BytesIO
|
|
3
|
+
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import pyarrow.ipc as ipc
|
|
6
|
+
import pyarrow.parquet as pq
|
|
7
|
+
from starlette.requests import Request
|
|
8
|
+
from starlette.responses import Response
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ArrowFormat(StrEnum):
|
|
12
|
+
IPC_STREAM = "application/vnd.apache.arrow.stream"
|
|
13
|
+
IPC_FILE = "application/vnd.apache.arrow.file"
|
|
14
|
+
PARQUET = "application/x-parquet"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_ACCEPT_MAP: dict[str, ArrowFormat] = {f.value: f for f in ArrowFormat}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def negotiate(request: Request) -> ArrowFormat:
|
|
21
|
+
accept = request.headers.get("accept", "")
|
|
22
|
+
for media_type in accept.split(","):
|
|
23
|
+
key = media_type.strip().split(";")[0].strip()
|
|
24
|
+
if key in _ACCEPT_MAP:
|
|
25
|
+
return _ACCEPT_MAP[key]
|
|
26
|
+
return ArrowFormat.IPC_STREAM
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _serialize_ipc_stream(table: pa.Table) -> bytes:
|
|
30
|
+
sink = pa.BufferOutputStream()
|
|
31
|
+
with ipc.new_stream(sink, table.schema) as writer:
|
|
32
|
+
writer.write_table(table)
|
|
33
|
+
return sink.getvalue().to_pybytes()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _serialize_ipc_file(table: pa.Table) -> bytes:
|
|
37
|
+
sink = pa.BufferOutputStream()
|
|
38
|
+
with ipc.new_file(sink, table.schema) as writer:
|
|
39
|
+
writer.write_table(table)
|
|
40
|
+
return sink.getvalue().to_pybytes()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _serialize_parquet(table: pa.Table) -> bytes:
|
|
44
|
+
buf = BytesIO()
|
|
45
|
+
pq.write_table(table, buf)
|
|
46
|
+
return buf.getvalue()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
_SERIALIZERS = {
|
|
50
|
+
ArrowFormat.IPC_STREAM: _serialize_ipc_stream,
|
|
51
|
+
ArrowFormat.IPC_FILE: _serialize_ipc_file,
|
|
52
|
+
ArrowFormat.PARQUET: _serialize_parquet,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def arrow_response(table: pa.Table, fmt: ArrowFormat) -> Response:
|
|
57
|
+
body = _SERIALIZERS[fmt](table)
|
|
58
|
+
return Response(content=body, media_type=fmt.value)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import pyarrow.compute as pc
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def apply_columns(table: pa.Table, columns_param: str | None) -> pa.Table:
|
|
8
|
+
if not columns_param:
|
|
9
|
+
return table
|
|
10
|
+
names = [c.strip() for c in columns_param.split(",") if c.strip()]
|
|
11
|
+
valid = [n for n in names if n in table.schema.names]
|
|
12
|
+
if not valid:
|
|
13
|
+
return table
|
|
14
|
+
return table.select(valid)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_OPS = {
|
|
18
|
+
"=": pc.equal,
|
|
19
|
+
"!=": pc.not_equal,
|
|
20
|
+
">": pc.greater,
|
|
21
|
+
">=": pc.greater_equal,
|
|
22
|
+
"<": pc.less,
|
|
23
|
+
"<=": pc.less_equal,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def apply_filters(table: pa.Table, filters_param: str | None) -> pa.Table:
|
|
28
|
+
if not filters_param:
|
|
29
|
+
return table
|
|
30
|
+
try:
|
|
31
|
+
predicates = json.loads(filters_param)
|
|
32
|
+
except (json.JSONDecodeError, TypeError):
|
|
33
|
+
return table
|
|
34
|
+
if not isinstance(predicates, list):
|
|
35
|
+
return table
|
|
36
|
+
|
|
37
|
+
mask: pa.ChunkedArray | None = None
|
|
38
|
+
|
|
39
|
+
for pred in predicates:
|
|
40
|
+
if not isinstance(pred, dict):
|
|
41
|
+
continue
|
|
42
|
+
col = pred.get("col")
|
|
43
|
+
op = pred.get("op")
|
|
44
|
+
if col not in table.schema.names or op is None:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
column = table.column(col)
|
|
48
|
+
|
|
49
|
+
if op == "is_null":
|
|
50
|
+
m = pc.is_null(column)
|
|
51
|
+
elif op == "is_not_null":
|
|
52
|
+
m = pc.is_valid(column)
|
|
53
|
+
elif op == "in":
|
|
54
|
+
value = pred.get("value")
|
|
55
|
+
if not isinstance(value, list):
|
|
56
|
+
continue
|
|
57
|
+
m = pc.is_in(column, pa.array(value))
|
|
58
|
+
elif op in _OPS:
|
|
59
|
+
value = pred.get("value")
|
|
60
|
+
if value is None:
|
|
61
|
+
continue
|
|
62
|
+
m = _OPS[op](column, pa.scalar(value))
|
|
63
|
+
else:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
mask = m if mask is None else pc.and_(mask, m)
|
|
67
|
+
|
|
68
|
+
if mask is None:
|
|
69
|
+
return table
|
|
70
|
+
return table.filter(mask)
|
|
File without changes
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from fastapi import APIRouter, Request
|
|
4
|
+
from starlette.responses import Response
|
|
5
|
+
|
|
6
|
+
from n6k_server.formats import arrow_response, negotiate
|
|
7
|
+
from n6k_server.pushdown import apply_columns, apply_filters
|
|
8
|
+
|
|
9
|
+
router = APIRouter()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@router.get("/")
|
|
13
|
+
async def list_tables(request: Request) -> Response:
|
|
14
|
+
names = list(request.app.state.data_sources.keys())
|
|
15
|
+
table = pa.table({
|
|
16
|
+
"name": pa.array(names, type=pa.utf8()),
|
|
17
|
+
"path": pa.array([f"/tables/{n}" for n in names], type=pa.utf8()),
|
|
18
|
+
"schema_path": pa.array([f"/tables/{n}/_schema" for n in names], type=pa.utf8()),
|
|
19
|
+
})
|
|
20
|
+
return arrow_response(table, negotiate(request))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@router.get("/{name}")
|
|
24
|
+
async def get_table(name: str, request: Request, columns: str | None = None, filters: str | None = None) -> Response:
|
|
25
|
+
fn = request.app.state.data_sources.get(name)
|
|
26
|
+
if fn is None:
|
|
27
|
+
return Response(content=f"table not found: {name}", status_code=404)
|
|
28
|
+
df = fn()
|
|
29
|
+
if not isinstance(df, pd.DataFrame):
|
|
30
|
+
raise TypeError(f"data source '{name}' must return a DataFrame, got {type(df).__name__}")
|
|
31
|
+
table = pa.Table.from_pandas(df)
|
|
32
|
+
table = apply_filters(table, filters)
|
|
33
|
+
table = apply_columns(table, columns)
|
|
34
|
+
return arrow_response(table, negotiate(request))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@router.get("/{name}/_schema")
|
|
38
|
+
async def get_schema(name: str, request: Request) -> Response:
|
|
39
|
+
fn = request.app.state.data_sources.get(name)
|
|
40
|
+
if fn is None:
|
|
41
|
+
return Response(content=f"table not found: {name}", status_code=404)
|
|
42
|
+
df = fn()
|
|
43
|
+
if not isinstance(df, pd.DataFrame):
|
|
44
|
+
raise TypeError(f"data source '{name}' must return a DataFrame, got {type(df).__name__}")
|
|
45
|
+
table = pa.Table.from_pandas(df)
|
|
46
|
+
empty = table.schema.empty_table()
|
|
47
|
+
return arrow_response(empty, negotiate(request))
|
|
File without changes
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from io import BytesIO
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.ipc as ipc
|
|
7
|
+
import pyarrow.parquet as pq
|
|
8
|
+
import pytest
|
|
9
|
+
from httpx import AsyncClient, ASGITransport
|
|
10
|
+
from starlette.responses import PlainTextResponse
|
|
11
|
+
|
|
12
|
+
import n6k_server
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def app():
|
|
17
|
+
application = n6k_server.create_app()
|
|
18
|
+
|
|
19
|
+
@n6k_server.data(application, "demo")
|
|
20
|
+
def demo():
|
|
21
|
+
return pd.DataFrame({
|
|
22
|
+
"id": [1, 2, 3],
|
|
23
|
+
"name": ["alice", "bob", "charlie"],
|
|
24
|
+
"score": [95.5, 87.3, 91.8],
|
|
25
|
+
})
|
|
26
|
+
|
|
27
|
+
@application.get("/health")
|
|
28
|
+
async def health() -> PlainTextResponse:
|
|
29
|
+
return PlainTextResponse("ok")
|
|
30
|
+
|
|
31
|
+
return application
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def transport(app):
|
|
36
|
+
return ASGITransport(app=app)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
async def client(transport):
|
|
41
|
+
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
|
42
|
+
yield c
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TestHealth:
|
|
46
|
+
async def test_returns_ok(self, client: AsyncClient):
|
|
47
|
+
resp = await client.get("/health")
|
|
48
|
+
assert resp.status_code == 200
|
|
49
|
+
assert resp.text == "ok"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TestDemo:
|
|
53
|
+
async def test_default_returns_ipc_stream(self, client: AsyncClient):
|
|
54
|
+
resp = await client.get("/tables/demo")
|
|
55
|
+
assert resp.status_code == 200
|
|
56
|
+
assert resp.headers["content-type"] == "application/vnd.apache.arrow.stream"
|
|
57
|
+
table = ipc.open_stream(resp.content).read_all()
|
|
58
|
+
assert table.num_rows == 3
|
|
59
|
+
assert "id" in table.column_names
|
|
60
|
+
assert "name" in table.column_names
|
|
61
|
+
assert "score" in table.column_names
|
|
62
|
+
|
|
63
|
+
async def test_accept_ipc_file(self, client: AsyncClient):
|
|
64
|
+
resp = await client.get(
|
|
65
|
+
"/tables/demo",
|
|
66
|
+
headers={"accept": "application/vnd.apache.arrow.file"},
|
|
67
|
+
)
|
|
68
|
+
assert resp.status_code == 200
|
|
69
|
+
assert resp.headers["content-type"] == "application/vnd.apache.arrow.file"
|
|
70
|
+
table = ipc.open_file(resp.content).read_all()
|
|
71
|
+
assert table.num_rows == 3
|
|
72
|
+
|
|
73
|
+
async def test_accept_parquet(self, client: AsyncClient):
|
|
74
|
+
resp = await client.get(
|
|
75
|
+
"/tables/demo",
|
|
76
|
+
headers={"accept": "application/x-parquet"},
|
|
77
|
+
)
|
|
78
|
+
assert resp.status_code == 200
|
|
79
|
+
assert resp.headers["content-type"] == "application/x-parquet"
|
|
80
|
+
table = pq.read_table(BytesIO(resp.content))
|
|
81
|
+
assert table.num_rows == 3
|
|
82
|
+
|
|
83
|
+
async def test_unknown_accept_defaults_to_stream(self, client: AsyncClient):
|
|
84
|
+
resp = await client.get(
|
|
85
|
+
"/tables/demo",
|
|
86
|
+
headers={"accept": "application/json"},
|
|
87
|
+
)
|
|
88
|
+
assert resp.status_code == 200
|
|
89
|
+
assert resp.headers["content-type"] == "application/vnd.apache.arrow.stream"
|
|
90
|
+
|
|
91
|
+
async def test_roundtrip_data_integrity(self, client: AsyncClient):
|
|
92
|
+
resp = await client.get("/tables/demo")
|
|
93
|
+
table = ipc.open_stream(resp.content).read_all()
|
|
94
|
+
assert table.column("id").to_pylist() == [1, 2, 3]
|
|
95
|
+
assert table.column("name").to_pylist() == ["alice", "bob", "charlie"]
|
|
96
|
+
assert table.column("score").to_pylist() == [95.5, 87.3, 91.8]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TestTablesPushdown:
|
|
100
|
+
async def test_columns_projection(self, client: AsyncClient):
|
|
101
|
+
resp = await client.get("/tables/demo?columns=id,name")
|
|
102
|
+
assert resp.status_code == 200
|
|
103
|
+
table = ipc.open_stream(resp.content).read_all()
|
|
104
|
+
assert table.column_names == ["id", "name"]
|
|
105
|
+
assert table.num_rows == 3
|
|
106
|
+
|
|
107
|
+
async def test_filter_equal(self, client: AsyncClient):
|
|
108
|
+
filters = json.dumps([{"col": "id", "op": "=", "value": 1}])
|
|
109
|
+
resp = await client.get(f"/tables/demo?filters={filters}")
|
|
110
|
+
assert resp.status_code == 200
|
|
111
|
+
table = ipc.open_stream(resp.content).read_all()
|
|
112
|
+
assert table.num_rows == 1
|
|
113
|
+
assert table.column("name").to_pylist() == ["alice"]
|
|
114
|
+
|
|
115
|
+
async def test_filter_and_columns(self, client: AsyncClient):
|
|
116
|
+
filters = json.dumps([{"col": "score", "op": ">", "value": 90.0}])
|
|
117
|
+
resp = await client.get(f"/tables/demo?columns=id,name&filters={filters}")
|
|
118
|
+
assert resp.status_code == 200
|
|
119
|
+
table = ipc.open_stream(resp.content).read_all()
|
|
120
|
+
assert table.column_names == ["id", "name"]
|
|
121
|
+
assert table.num_rows == 2
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class TestExceptionHandler:
|
|
125
|
+
async def test_404_returns_plain_text(self, client: AsyncClient):
|
|
126
|
+
resp = await client.get("/nonexistent")
|
|
127
|
+
assert resp.status_code == 404
|
|
128
|
+
assert "application/json" not in resp.headers.get("content-type", "")
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import pyarrow as pa
|
|
2
|
+
import pyarrow.ipc as ipc
|
|
3
|
+
import pyarrow.parquet as pq
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from starlette.testclient import TestClient
|
|
6
|
+
|
|
7
|
+
from n6k_server.formats import ArrowFormat, negotiate, arrow_response, _serialize_ipc_stream, _serialize_ipc_file, _serialize_parquet
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _make_table() -> pa.Table:
|
|
11
|
+
return pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]})
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _FakeRequest:
|
|
15
|
+
def __init__(self, accept: str = ""):
|
|
16
|
+
self.headers = {"accept": accept}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestNegotiate:
|
|
20
|
+
def test_ipc_stream(self):
|
|
21
|
+
req = _FakeRequest("application/vnd.apache.arrow.stream")
|
|
22
|
+
assert negotiate(req) == ArrowFormat.IPC_STREAM
|
|
23
|
+
|
|
24
|
+
def test_ipc_file(self):
|
|
25
|
+
req = _FakeRequest("application/vnd.apache.arrow.file")
|
|
26
|
+
assert negotiate(req) == ArrowFormat.IPC_FILE
|
|
27
|
+
|
|
28
|
+
def test_parquet(self):
|
|
29
|
+
req = _FakeRequest("application/x-parquet")
|
|
30
|
+
assert negotiate(req) == ArrowFormat.PARQUET
|
|
31
|
+
|
|
32
|
+
def test_default_when_missing(self):
|
|
33
|
+
req = _FakeRequest("")
|
|
34
|
+
assert negotiate(req) == ArrowFormat.IPC_STREAM
|
|
35
|
+
|
|
36
|
+
def test_default_when_unknown(self):
|
|
37
|
+
req = _FakeRequest("application/json")
|
|
38
|
+
assert negotiate(req) == ArrowFormat.IPC_STREAM
|
|
39
|
+
|
|
40
|
+
def test_first_match_wins(self):
|
|
41
|
+
req = _FakeRequest("application/x-parquet, application/vnd.apache.arrow.file")
|
|
42
|
+
assert negotiate(req) == ArrowFormat.PARQUET
|
|
43
|
+
|
|
44
|
+
def test_quality_params_ignored_gracefully(self):
|
|
45
|
+
req = _FakeRequest("application/vnd.apache.arrow.file;q=0.9")
|
|
46
|
+
assert negotiate(req) == ArrowFormat.IPC_FILE
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TestSerializers:
|
|
50
|
+
def test_ipc_stream_roundtrip(self):
|
|
51
|
+
table = _make_table()
|
|
52
|
+
data = _serialize_ipc_stream(table)
|
|
53
|
+
result = ipc.open_stream(data).read_all()
|
|
54
|
+
assert result.equals(table)
|
|
55
|
+
|
|
56
|
+
def test_ipc_file_roundtrip(self):
|
|
57
|
+
table = _make_table()
|
|
58
|
+
data = _serialize_ipc_file(table)
|
|
59
|
+
result = ipc.open_file(data).read_all()
|
|
60
|
+
assert result.equals(table)
|
|
61
|
+
|
|
62
|
+
def test_parquet_roundtrip(self):
|
|
63
|
+
table = _make_table()
|
|
64
|
+
data = _serialize_parquet(table)
|
|
65
|
+
result = pq.read_table(BytesIO(data))
|
|
66
|
+
assert result.equals(table)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TestArrowResponse:
|
|
70
|
+
def test_content_type_ipc_stream(self):
|
|
71
|
+
table = _make_table()
|
|
72
|
+
resp = arrow_response(table, ArrowFormat.IPC_STREAM)
|
|
73
|
+
assert resp.media_type == "application/vnd.apache.arrow.stream"
|
|
74
|
+
|
|
75
|
+
def test_content_type_parquet(self):
|
|
76
|
+
table = _make_table()
|
|
77
|
+
resp = arrow_response(table, ArrowFormat.PARQUET)
|
|
78
|
+
assert resp.media_type == "application/x-parquet"
|
|
79
|
+
|
|
80
|
+
def test_body_is_deserializable(self):
|
|
81
|
+
table = _make_table()
|
|
82
|
+
resp = arrow_response(table, ArrowFormat.IPC_STREAM)
|
|
83
|
+
result = ipc.open_stream(resp.body).read_all()
|
|
84
|
+
assert result.equals(table)
|