mcp-stata 1.18.0__cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-stata might be problematic. Click here for more details.
- mcp_stata/__init__.py +4 -0
- mcp_stata/_native_ops.cpython-312-aarch64-linux-gnu.so +0 -0
- mcp_stata/config.py +20 -0
- mcp_stata/discovery.py +550 -0
- mcp_stata/graph_detector.py +401 -0
- mcp_stata/models.py +62 -0
- mcp_stata/native_ops.py +87 -0
- mcp_stata/server.py +1130 -0
- mcp_stata/smcl/smcl2html.py +88 -0
- mcp_stata/stata_client.py +3692 -0
- mcp_stata/streaming_io.py +263 -0
- mcp_stata/test_stata.py +54 -0
- mcp_stata/ui_http.py +998 -0
- mcp_stata-1.18.0.dist-info/METADATA +471 -0
- mcp_stata-1.18.0.dist-info/RECORD +18 -0
- mcp_stata-1.18.0.dist-info/WHEEL +5 -0
- mcp_stata-1.18.0.dist-info/entry_points.txt +2 -0
- mcp_stata-1.18.0.dist-info/licenses/LICENSE +661 -0
mcp_stata/ui_http.py
ADDED
|
@@ -0,0 +1,998 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import io
|
|
3
|
+
import json
|
|
4
|
+
import secrets
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
11
|
+
from typing import Any, Callable, Optional
|
|
12
|
+
|
|
13
|
+
from .stata_client import StataClient
|
|
14
|
+
from .config import (
|
|
15
|
+
DEFAULT_HOST,
|
|
16
|
+
DEFAULT_PORT,
|
|
17
|
+
MAX_ARROW_LIMIT,
|
|
18
|
+
MAX_CHARS,
|
|
19
|
+
MAX_LIMIT,
|
|
20
|
+
MAX_REQUEST_BYTES,
|
|
21
|
+
MAX_VARS,
|
|
22
|
+
TOKEN_TTL_S,
|
|
23
|
+
VIEW_TTL_S,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("mcp_stata")
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from .native_ops import argsort_numeric as _native_argsort_numeric
|
|
31
|
+
from .native_ops import argsort_mixed as _native_argsort_mixed
|
|
32
|
+
except Exception:
|
|
33
|
+
_native_argsort_numeric = None
|
|
34
|
+
_native_argsort_mixed = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _try_native_argsort(
|
|
38
|
+
table: Any,
|
|
39
|
+
sort_cols: list[str],
|
|
40
|
+
descending: list[bool],
|
|
41
|
+
nulls_last: list[bool],
|
|
42
|
+
) -> list[int] | None:
|
|
43
|
+
if _native_argsort_numeric is None and _native_argsort_mixed is None:
|
|
44
|
+
return None
|
|
45
|
+
try:
|
|
46
|
+
import pyarrow as pa
|
|
47
|
+
import numpy as np
|
|
48
|
+
|
|
49
|
+
is_string: list[bool] = []
|
|
50
|
+
cols: list[object] = []
|
|
51
|
+
for col in sort_cols:
|
|
52
|
+
arr = table.column(col).combine_chunks()
|
|
53
|
+
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
|
|
54
|
+
is_string.append(True)
|
|
55
|
+
cols.append(arr.to_pylist())
|
|
56
|
+
continue
|
|
57
|
+
if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
|
|
58
|
+
return None
|
|
59
|
+
is_string.append(False)
|
|
60
|
+
np_arr = arr.to_numpy(zero_copy_only=False)
|
|
61
|
+
if np_arr.dtype != np.float64:
|
|
62
|
+
np_arr = np_arr.astype(np.float64, copy=False)
|
|
63
|
+
# Normalize Stata missing values for numeric columns
|
|
64
|
+
np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
|
|
65
|
+
cols.append(np_arr)
|
|
66
|
+
|
|
67
|
+
if "_n" not in table.column_names:
|
|
68
|
+
return None
|
|
69
|
+
obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
|
|
70
|
+
|
|
71
|
+
if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
|
|
72
|
+
idx = _native_argsort_numeric(cols, descending, nulls_last)
|
|
73
|
+
elif _native_argsort_mixed is not None:
|
|
74
|
+
idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
|
|
75
|
+
else:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
return [int(x) for x in (obs[idx] - 1).tolist()]
|
|
79
|
+
except Exception:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_sorted_indices_polars(
|
|
84
|
+
table: Any,
|
|
85
|
+
sort_cols: list[str],
|
|
86
|
+
descending: list[bool],
|
|
87
|
+
nulls_last: list[bool],
|
|
88
|
+
) -> list[int]:
|
|
89
|
+
import polars as pl
|
|
90
|
+
|
|
91
|
+
df = pl.from_arrow(table)
|
|
92
|
+
# Normalize Stata missing values for numeric columns
|
|
93
|
+
exprs = []
|
|
94
|
+
for col, dtype in zip(df.columns, df.dtypes):
|
|
95
|
+
if col == "_n":
|
|
96
|
+
exprs.append(pl.col(col))
|
|
97
|
+
continue
|
|
98
|
+
if dtype in (pl.Float32, pl.Float64):
|
|
99
|
+
exprs.append(
|
|
100
|
+
pl.when(pl.col(col) > 8.0e307)
|
|
101
|
+
.then(None)
|
|
102
|
+
.otherwise(pl.col(col))
|
|
103
|
+
.alias(col)
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
exprs.append(pl.col(col))
|
|
107
|
+
df = df.select(exprs)
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
# Use expressions for arithmetic to avoid eager Series-scalar conversion issues
|
|
111
|
+
# that have been observed in some environments with Int64 dtypes.
|
|
112
|
+
res = df.select(
|
|
113
|
+
idx=pl.arg_sort_by(
|
|
114
|
+
[pl.col(c) for c in sort_cols],
|
|
115
|
+
descending=descending,
|
|
116
|
+
nulls_last=nulls_last,
|
|
117
|
+
),
|
|
118
|
+
zero_based_n=pl.col("_n") - 1
|
|
119
|
+
)
|
|
120
|
+
return res["zero_based_n"].take(res["idx"]).to_list()
|
|
121
|
+
except Exception:
|
|
122
|
+
# Fallback to eager sort if arg_sort_by fails or has issues
|
|
123
|
+
return (
|
|
124
|
+
df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
|
|
125
|
+
.select(pl.col("_n") - 1)
|
|
126
|
+
.to_series()
|
|
127
|
+
.to_list()
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _stable_hash(payload: dict[str, Any]) -> str:
|
|
134
|
+
return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class UIChannelInfo:
|
|
139
|
+
base_url: str
|
|
140
|
+
token: str
|
|
141
|
+
expires_at: int
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class ViewHandle:
|
|
146
|
+
view_id: str
|
|
147
|
+
dataset_id: str
|
|
148
|
+
frame: str
|
|
149
|
+
filter_expr: str
|
|
150
|
+
obs_indices: list[int]
|
|
151
|
+
filtered_n: int
|
|
152
|
+
created_at: float
|
|
153
|
+
last_access: float
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class UIChannelManager:
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
client: StataClient,
|
|
160
|
+
*,
|
|
161
|
+
host: str = DEFAULT_HOST,
|
|
162
|
+
port: int = DEFAULT_PORT,
|
|
163
|
+
token_ttl_s: int = TOKEN_TTL_S,
|
|
164
|
+
view_ttl_s: int = VIEW_TTL_S,
|
|
165
|
+
max_limit: int = MAX_LIMIT,
|
|
166
|
+
max_vars: int = MAX_VARS,
|
|
167
|
+
max_chars: int = MAX_CHARS,
|
|
168
|
+
max_request_bytes: int = MAX_REQUEST_BYTES,
|
|
169
|
+
max_arrow_limit: int = MAX_ARROW_LIMIT,
|
|
170
|
+
):
|
|
171
|
+
self._client = client
|
|
172
|
+
self._host = host
|
|
173
|
+
self._port = port
|
|
174
|
+
self._token_ttl_s = token_ttl_s
|
|
175
|
+
self._view_ttl_s = view_ttl_s
|
|
176
|
+
self._max_limit = max_limit
|
|
177
|
+
self._max_vars = max_vars
|
|
178
|
+
self._max_chars = max_chars
|
|
179
|
+
self._max_request_bytes = max_request_bytes
|
|
180
|
+
self._max_arrow_limit = max_arrow_limit
|
|
181
|
+
|
|
182
|
+
self._lock = threading.Lock()
|
|
183
|
+
self._httpd: ThreadingHTTPServer | None = None
|
|
184
|
+
self._thread: threading.Thread | None = None
|
|
185
|
+
|
|
186
|
+
self._token: str | None = None
|
|
187
|
+
self._expires_at: int = 0
|
|
188
|
+
|
|
189
|
+
self._dataset_version: int = 0
|
|
190
|
+
self._dataset_id_cache: str | None = None
|
|
191
|
+
self._dataset_id_cache_at_version: int = -1
|
|
192
|
+
|
|
193
|
+
self._views: dict[str, ViewHandle] = {}
|
|
194
|
+
self._last_sort_spec: tuple[str, ...] | None = None
|
|
195
|
+
self._last_sort_dataset_id: str | None = None
|
|
196
|
+
self._sort_index_cache: dict[tuple[str, tuple[str, ...]], list[int]] = {}
|
|
197
|
+
self._sort_cache_order: list[tuple[str, tuple[str, ...]]] = []
|
|
198
|
+
self._sort_cache_max_entries: int = 4
|
|
199
|
+
self._sort_table_cache: dict[tuple[str, tuple[str, ...]], Any] = {}
|
|
200
|
+
self._sort_table_order: list[tuple[str, tuple[str, ...]]] = []
|
|
201
|
+
self._sort_table_max_entries: int = 2
|
|
202
|
+
|
|
203
|
+
def notify_potential_dataset_change(self) -> None:
|
|
204
|
+
with self._lock:
|
|
205
|
+
self._dataset_version += 1
|
|
206
|
+
self._dataset_id_cache = None
|
|
207
|
+
self._views.clear()
|
|
208
|
+
self._last_sort_spec = None
|
|
209
|
+
self._last_sort_dataset_id = None
|
|
210
|
+
self._sort_index_cache.clear()
|
|
211
|
+
self._sort_cache_order.clear()
|
|
212
|
+
self._sort_table_cache.clear()
|
|
213
|
+
self._sort_table_order.clear()
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
|
|
217
|
+
normalized: list[str] = []
|
|
218
|
+
for spec in sort_spec:
|
|
219
|
+
if not isinstance(spec, str) or not spec:
|
|
220
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
221
|
+
raw = spec.strip()
|
|
222
|
+
if not raw:
|
|
223
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
224
|
+
sign = "-" if raw.startswith("-") else "+"
|
|
225
|
+
varname = raw.lstrip("+-")
|
|
226
|
+
if not varname:
|
|
227
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
228
|
+
normalized.append(f"{sign}{varname}")
|
|
229
|
+
return tuple(normalized)
|
|
230
|
+
|
|
231
|
+
def _get_cached_sort_indices(
|
|
232
|
+
self, dataset_id: str, sort_spec: tuple[str, ...]
|
|
233
|
+
) -> list[int] | None:
|
|
234
|
+
key = (dataset_id, sort_spec)
|
|
235
|
+
with self._lock:
|
|
236
|
+
cached = self._sort_index_cache.get(key)
|
|
237
|
+
if cached is None:
|
|
238
|
+
return None
|
|
239
|
+
# refresh LRU order
|
|
240
|
+
if key in self._sort_cache_order:
|
|
241
|
+
self._sort_cache_order.remove(key)
|
|
242
|
+
self._sort_cache_order.append(key)
|
|
243
|
+
return cached
|
|
244
|
+
|
|
245
|
+
def _set_cached_sort_indices(
|
|
246
|
+
self, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
|
|
247
|
+
) -> None:
|
|
248
|
+
key = (dataset_id, sort_spec)
|
|
249
|
+
with self._lock:
|
|
250
|
+
if key in self._sort_index_cache:
|
|
251
|
+
self._sort_cache_order.remove(key)
|
|
252
|
+
self._sort_index_cache[key] = indices
|
|
253
|
+
self._sort_cache_order.append(key)
|
|
254
|
+
while len(self._sort_cache_order) > self._sort_cache_max_entries:
|
|
255
|
+
evict = self._sort_cache_order.pop(0)
|
|
256
|
+
self._sort_index_cache.pop(evict, None)
|
|
257
|
+
|
|
258
|
+
def _get_cached_sort_table(
|
|
259
|
+
self, dataset_id: str, sort_cols: tuple[str, ...]
|
|
260
|
+
) -> Any | None:
|
|
261
|
+
key = (dataset_id, sort_cols)
|
|
262
|
+
with self._lock:
|
|
263
|
+
cached = self._sort_table_cache.get(key)
|
|
264
|
+
if cached is None:
|
|
265
|
+
return None
|
|
266
|
+
if key in self._sort_table_order:
|
|
267
|
+
self._sort_table_order.remove(key)
|
|
268
|
+
self._sort_table_order.append(key)
|
|
269
|
+
return cached
|
|
270
|
+
|
|
271
|
+
def _set_cached_sort_table(
|
|
272
|
+
self, dataset_id: str, sort_cols: tuple[str, ...], table: Any
|
|
273
|
+
) -> None:
|
|
274
|
+
key = (dataset_id, sort_cols)
|
|
275
|
+
with self._lock:
|
|
276
|
+
if key in self._sort_table_cache:
|
|
277
|
+
self._sort_table_order.remove(key)
|
|
278
|
+
self._sort_table_cache[key] = table
|
|
279
|
+
self._sort_table_order.append(key)
|
|
280
|
+
while len(self._sort_table_order) > self._sort_table_max_entries:
|
|
281
|
+
evict = self._sort_table_order.pop(0)
|
|
282
|
+
self._sort_table_cache.pop(evict, None)
|
|
283
|
+
|
|
284
|
+
def _get_sort_table(self, dataset_id: str, sort_cols: list[str]) -> Any:
|
|
285
|
+
sort_cols_key = tuple(sort_cols)
|
|
286
|
+
cached = self._get_cached_sort_table(dataset_id, sort_cols_key)
|
|
287
|
+
if cached is not None:
|
|
288
|
+
return cached
|
|
289
|
+
|
|
290
|
+
state = self._client.get_dataset_state()
|
|
291
|
+
n = int(state.get("n", 0) or 0)
|
|
292
|
+
if n <= 0:
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
# Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
|
|
296
|
+
arrow_bytes = self._client.get_arrow_stream(
|
|
297
|
+
offset=0,
|
|
298
|
+
limit=n,
|
|
299
|
+
vars=sort_cols,
|
|
300
|
+
include_obs_no=True,
|
|
301
|
+
obs_indices=None,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
import pyarrow as pa
|
|
305
|
+
|
|
306
|
+
with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
|
|
307
|
+
table = reader.read_all()
|
|
308
|
+
|
|
309
|
+
self._set_cached_sort_table(dataset_id, sort_cols_key, table)
|
|
310
|
+
return table
|
|
311
|
+
|
|
312
|
+
def get_channel(self) -> UIChannelInfo:
|
|
313
|
+
self._ensure_http_server()
|
|
314
|
+
with self._lock:
|
|
315
|
+
self._ensure_token()
|
|
316
|
+
assert self._httpd is not None
|
|
317
|
+
port = self._httpd.server_address[1]
|
|
318
|
+
base_url = f"http://{self._host}:{port}"
|
|
319
|
+
return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
|
|
320
|
+
|
|
321
|
+
def capabilities(self) -> dict[str, bool]:
|
|
322
|
+
return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
|
|
323
|
+
|
|
324
|
+
def current_dataset_id(self) -> str:
|
|
325
|
+
with self._lock:
|
|
326
|
+
if self._dataset_id_cache is not None and self._dataset_id_cache_at_version == self._dataset_version:
|
|
327
|
+
return self._dataset_id_cache
|
|
328
|
+
|
|
329
|
+
state = self._client.get_dataset_state()
|
|
330
|
+
payload = {
|
|
331
|
+
"version": self._dataset_version,
|
|
332
|
+
"frame": state.get("frame"),
|
|
333
|
+
"n": state.get("n"),
|
|
334
|
+
"k": state.get("k"),
|
|
335
|
+
"sortlist": state.get("sortlist"),
|
|
336
|
+
}
|
|
337
|
+
digest = _stable_hash(payload)
|
|
338
|
+
|
|
339
|
+
with self._lock:
|
|
340
|
+
self._dataset_id_cache = digest
|
|
341
|
+
self._dataset_id_cache_at_version = self._dataset_version
|
|
342
|
+
return digest
|
|
343
|
+
|
|
344
|
+
def get_view(self, view_id: str) -> Optional[ViewHandle]:
|
|
345
|
+
now = time.time()
|
|
346
|
+
with self._lock:
|
|
347
|
+
self._evict_expired_locked(now)
|
|
348
|
+
view = self._views.get(view_id)
|
|
349
|
+
if view is None:
|
|
350
|
+
return None
|
|
351
|
+
view.last_access = now
|
|
352
|
+
return view
|
|
353
|
+
|
|
354
|
+
def create_view(self, *, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
|
|
355
|
+
current_id = self.current_dataset_id()
|
|
356
|
+
if dataset_id != current_id:
|
|
357
|
+
raise DatasetChangedError(current_id)
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
obs_indices = self._client.compute_view_indices(filter_expr)
|
|
361
|
+
except ValueError as e:
|
|
362
|
+
raise InvalidFilterError(str(e))
|
|
363
|
+
except RuntimeError as e:
|
|
364
|
+
msg = str(e) or "No data in memory"
|
|
365
|
+
if "no data" in msg.lower():
|
|
366
|
+
raise NoDataInMemoryError(msg)
|
|
367
|
+
raise
|
|
368
|
+
now = time.time()
|
|
369
|
+
view_id = f"view_{uuid.uuid4().hex}"
|
|
370
|
+
view = ViewHandle(
|
|
371
|
+
view_id=view_id,
|
|
372
|
+
dataset_id=current_id,
|
|
373
|
+
frame=frame,
|
|
374
|
+
filter_expr=filter_expr,
|
|
375
|
+
obs_indices=obs_indices,
|
|
376
|
+
filtered_n=len(obs_indices),
|
|
377
|
+
created_at=now,
|
|
378
|
+
last_access=now,
|
|
379
|
+
)
|
|
380
|
+
with self._lock:
|
|
381
|
+
self._evict_expired_locked(now)
|
|
382
|
+
self._views[view_id] = view
|
|
383
|
+
return view
|
|
384
|
+
|
|
385
|
+
def delete_view(self, view_id: str) -> bool:
|
|
386
|
+
with self._lock:
|
|
387
|
+
return self._views.pop(view_id, None) is not None
|
|
388
|
+
|
|
389
|
+
def validate_token(self, header_value: str | None) -> bool:
|
|
390
|
+
if not header_value:
|
|
391
|
+
return False
|
|
392
|
+
if not header_value.startswith("Bearer "):
|
|
393
|
+
return False
|
|
394
|
+
token = header_value[len("Bearer ") :].strip()
|
|
395
|
+
with self._lock:
|
|
396
|
+
self._ensure_token()
|
|
397
|
+
if self._token is None:
|
|
398
|
+
return False
|
|
399
|
+
if time.time() * 1000 >= self._expires_at:
|
|
400
|
+
return False
|
|
401
|
+
return secrets.compare_digest(token, self._token)
|
|
402
|
+
|
|
403
|
+
def limits(self) -> tuple[int, int, int, int]:
|
|
404
|
+
return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
|
|
405
|
+
|
|
406
|
+
def _ensure_token(self) -> None:
|
|
407
|
+
now_ms = int(time.time() * 1000)
|
|
408
|
+
if self._token is None or now_ms >= self._expires_at:
|
|
409
|
+
self._token = secrets.token_urlsafe(32)
|
|
410
|
+
self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
|
|
411
|
+
|
|
412
|
+
def _evict_expired_locked(self, now: float) -> None:
|
|
413
|
+
expired: list[str] = []
|
|
414
|
+
for key, view in self._views.items():
|
|
415
|
+
if now - view.last_access >= self._view_ttl_s:
|
|
416
|
+
expired.append(key)
|
|
417
|
+
for key in expired:
|
|
418
|
+
self._views.pop(key, None)
|
|
419
|
+
|
|
420
|
+
def _ensure_http_server(self) -> None:
|
|
421
|
+
with self._lock:
|
|
422
|
+
if self._httpd is not None:
|
|
423
|
+
return
|
|
424
|
+
|
|
425
|
+
manager = self
|
|
426
|
+
|
|
427
|
+
class Handler(BaseHTTPRequestHandler):
|
|
428
|
+
|
|
429
|
+
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
|
|
430
|
+
data = json.dumps(payload).encode("utf-8")
|
|
431
|
+
self.send_response(status)
|
|
432
|
+
self.send_header("Content-Type", "application/json")
|
|
433
|
+
self.send_header("Content-Length", str(len(data)))
|
|
434
|
+
self.end_headers()
|
|
435
|
+
self.wfile.write(data)
|
|
436
|
+
|
|
437
|
+
def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
|
|
438
|
+
self.send_response(status)
|
|
439
|
+
self.send_header("Content-Type", content_type)
|
|
440
|
+
self.send_header("Content-Length", str(len(data)))
|
|
441
|
+
self.end_headers()
|
|
442
|
+
self.wfile.write(data)
|
|
443
|
+
|
|
444
|
+
def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
|
|
445
|
+
if status >= 500 or code == "internal_error":
|
|
446
|
+
logger.error("UI HTTP error %s: %s", code, message)
|
|
447
|
+
message = "Internal server error"
|
|
448
|
+
body: dict[str, Any] = {"error": {"code": code, "message": message}}
|
|
449
|
+
if stata_rc is not None:
|
|
450
|
+
body["error"]["stataRc"] = stata_rc
|
|
451
|
+
self._send_json(status, body)
|
|
452
|
+
|
|
453
|
+
def _require_auth(self) -> bool:
|
|
454
|
+
if manager.validate_token(self.headers.get("Authorization")):
|
|
455
|
+
return True
|
|
456
|
+
self._error(401, "auth_failed", "Unauthorized")
|
|
457
|
+
return False
|
|
458
|
+
|
|
459
|
+
def _read_json(self) -> dict[str, Any] | None:
|
|
460
|
+
max_limit, max_vars, max_chars, max_bytes = manager.limits()
|
|
461
|
+
_ = (max_limit, max_vars, max_chars)
|
|
462
|
+
|
|
463
|
+
length = int(self.headers.get("Content-Length", "0") or "0")
|
|
464
|
+
if length <= 0:
|
|
465
|
+
return {}
|
|
466
|
+
if length > max_bytes:
|
|
467
|
+
self._error(400, "request_too_large", "Request too large")
|
|
468
|
+
return None
|
|
469
|
+
raw = self.rfile.read(length)
|
|
470
|
+
try:
|
|
471
|
+
parsed = json.loads(raw.decode("utf-8"))
|
|
472
|
+
except Exception:
|
|
473
|
+
self._error(400, "invalid_request", "Invalid JSON")
|
|
474
|
+
return None
|
|
475
|
+
if not isinstance(parsed, dict):
|
|
476
|
+
self._error(400, "invalid_request", "Expected JSON object")
|
|
477
|
+
return None
|
|
478
|
+
return parsed
|
|
479
|
+
|
|
480
|
+
def do_GET(self) -> None:
|
|
481
|
+
if not self._require_auth():
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
if self.path == "/v1/dataset":
|
|
485
|
+
try:
|
|
486
|
+
state = manager._client.get_dataset_state()
|
|
487
|
+
dataset_id = manager.current_dataset_id()
|
|
488
|
+
self._send_json(
|
|
489
|
+
200,
|
|
490
|
+
{
|
|
491
|
+
"dataset": {
|
|
492
|
+
"id": dataset_id,
|
|
493
|
+
"frame": state.get("frame"),
|
|
494
|
+
"n": state.get("n"),
|
|
495
|
+
"k": state.get("k"),
|
|
496
|
+
"changed": state.get("changed"),
|
|
497
|
+
}
|
|
498
|
+
},
|
|
499
|
+
)
|
|
500
|
+
return
|
|
501
|
+
except NoDataInMemoryError as e:
|
|
502
|
+
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
|
|
503
|
+
return
|
|
504
|
+
except Exception as e:
|
|
505
|
+
self._error(500, "internal_error", str(e))
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
if self.path == "/v1/vars":
|
|
509
|
+
try:
|
|
510
|
+
state = manager._client.get_dataset_state()
|
|
511
|
+
dataset_id = manager.current_dataset_id()
|
|
512
|
+
variables = manager._client.list_variables_rich()
|
|
513
|
+
self._send_json(
|
|
514
|
+
200,
|
|
515
|
+
{
|
|
516
|
+
"dataset": {"id": dataset_id, "frame": state.get("frame")},
|
|
517
|
+
"variables": variables,
|
|
518
|
+
},
|
|
519
|
+
)
|
|
520
|
+
return
|
|
521
|
+
except NoDataInMemoryError as e:
|
|
522
|
+
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
|
|
523
|
+
return
|
|
524
|
+
except Exception as e:
|
|
525
|
+
self._error(500, "internal_error", str(e))
|
|
526
|
+
return
|
|
527
|
+
|
|
528
|
+
self._error(404, "not_found", "Not found")
|
|
529
|
+
|
|
530
|
+
def do_POST(self) -> None:
|
|
531
|
+
if not self._require_auth():
|
|
532
|
+
return
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
if self.path == "/v1/arrow":
|
|
536
|
+
body = self._read_json()
|
|
537
|
+
if body is None:
|
|
538
|
+
return
|
|
539
|
+
try:
|
|
540
|
+
resp_bytes = handle_arrow_request(manager, body, view_id=None)
|
|
541
|
+
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
|
|
542
|
+
return
|
|
543
|
+
except HTTPError as e:
|
|
544
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
545
|
+
return
|
|
546
|
+
except Exception as e:
|
|
547
|
+
self._error(500, "internal_error", str(e))
|
|
548
|
+
return
|
|
549
|
+
|
|
550
|
+
if self.path == "/v1/page":
|
|
551
|
+
body = self._read_json()
|
|
552
|
+
if body is None:
|
|
553
|
+
return
|
|
554
|
+
try:
|
|
555
|
+
resp = handle_page_request(manager, body, view_id=None)
|
|
556
|
+
self._send_json(200, resp)
|
|
557
|
+
return
|
|
558
|
+
except HTTPError as e:
|
|
559
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
560
|
+
return
|
|
561
|
+
except Exception as e:
|
|
562
|
+
self._error(500, "internal_error", str(e))
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
if self.path == "/v1/views":
|
|
566
|
+
body = self._read_json()
|
|
567
|
+
if body is None:
|
|
568
|
+
return
|
|
569
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
570
|
+
frame = str(body.get("frame", "default"))
|
|
571
|
+
filter_expr = str(body.get("filterExpr", ""))
|
|
572
|
+
if not dataset_id or not filter_expr:
|
|
573
|
+
self._error(400, "invalid_request", "datasetId and filterExpr are required")
|
|
574
|
+
return
|
|
575
|
+
try:
|
|
576
|
+
view = manager.create_view(dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
|
|
577
|
+
self._send_json(
|
|
578
|
+
200,
|
|
579
|
+
{
|
|
580
|
+
"dataset": {"id": view.dataset_id, "frame": view.frame},
|
|
581
|
+
"view": {"id": view.view_id, "filteredN": view.filtered_n},
|
|
582
|
+
},
|
|
583
|
+
)
|
|
584
|
+
return
|
|
585
|
+
except DatasetChangedError as e:
|
|
586
|
+
self._error(409, "dataset_changed", "Dataset changed")
|
|
587
|
+
return
|
|
588
|
+
except ValueError as e:
|
|
589
|
+
self._error(400, "invalid_filter", str(e))
|
|
590
|
+
return
|
|
591
|
+
except RuntimeError as e:
|
|
592
|
+
msg = str(e) or "No data in memory"
|
|
593
|
+
if "no data" in msg.lower():
|
|
594
|
+
self._error(400, "no_data_in_memory", msg)
|
|
595
|
+
return
|
|
596
|
+
self._error(500, "internal_error", msg)
|
|
597
|
+
return
|
|
598
|
+
except Exception as e:
|
|
599
|
+
self._error(500, "internal_error", str(e))
|
|
600
|
+
return
|
|
601
|
+
|
|
602
|
+
if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
|
|
603
|
+
parts = self.path.split("/")
|
|
604
|
+
if len(parts) != 5:
|
|
605
|
+
self._error(404, "not_found", "Not found")
|
|
606
|
+
return
|
|
607
|
+
view_id = parts[3]
|
|
608
|
+
body = self._read_json()
|
|
609
|
+
if body is None:
|
|
610
|
+
return
|
|
611
|
+
try:
|
|
612
|
+
resp = handle_page_request(manager, body, view_id=view_id)
|
|
613
|
+
self._send_json(200, resp)
|
|
614
|
+
return
|
|
615
|
+
except HTTPError as e:
|
|
616
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
617
|
+
return
|
|
618
|
+
except Exception as e:
|
|
619
|
+
self._error(500, "internal_error", str(e))
|
|
620
|
+
return
|
|
621
|
+
|
|
622
|
+
if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
|
|
623
|
+
parts = self.path.split("/")
|
|
624
|
+
if len(parts) != 5:
|
|
625
|
+
self._error(404, "not_found", "Not found")
|
|
626
|
+
return
|
|
627
|
+
view_id = parts[3]
|
|
628
|
+
body = self._read_json()
|
|
629
|
+
if body is None:
|
|
630
|
+
return
|
|
631
|
+
try:
|
|
632
|
+
resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
|
|
633
|
+
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
|
|
634
|
+
return
|
|
635
|
+
except HTTPError as e:
|
|
636
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
637
|
+
return
|
|
638
|
+
except Exception as e:
|
|
639
|
+
self._error(500, "internal_error", str(e))
|
|
640
|
+
return
|
|
641
|
+
|
|
642
|
+
if self.path == "/v1/filters/validate":
|
|
643
|
+
body = self._read_json()
|
|
644
|
+
if body is None:
|
|
645
|
+
return
|
|
646
|
+
filter_expr = str(body.get("filterExpr", ""))
|
|
647
|
+
if not filter_expr:
|
|
648
|
+
self._error(400, "invalid_request", "filterExpr is required")
|
|
649
|
+
return
|
|
650
|
+
try:
|
|
651
|
+
manager._client.validate_filter_expr(filter_expr)
|
|
652
|
+
self._send_json(200, {"ok": True})
|
|
653
|
+
return
|
|
654
|
+
except ValueError as e:
|
|
655
|
+
self._error(400, "invalid_filter", str(e))
|
|
656
|
+
return
|
|
657
|
+
except RuntimeError as e:
|
|
658
|
+
msg = str(e) or "No data in memory"
|
|
659
|
+
if "no data" in msg.lower():
|
|
660
|
+
self._error(400, "no_data_in_memory", msg)
|
|
661
|
+
return
|
|
662
|
+
self._error(500, "internal_error", msg)
|
|
663
|
+
return
|
|
664
|
+
except Exception as e:
|
|
665
|
+
self._error(500, "internal_error", str(e))
|
|
666
|
+
return
|
|
667
|
+
|
|
668
|
+
self._error(404, "not_found", "Not found")
|
|
669
|
+
|
|
670
|
+
def do_DELETE(self) -> None:
|
|
671
|
+
if not self._require_auth():
|
|
672
|
+
return
|
|
673
|
+
|
|
674
|
+
if self.path.startswith("/v1/views/"):
|
|
675
|
+
parts = self.path.split("/")
|
|
676
|
+
if len(parts) != 4:
|
|
677
|
+
self._error(404, "not_found", "Not found")
|
|
678
|
+
return
|
|
679
|
+
view_id = parts[3]
|
|
680
|
+
if manager.delete_view(view_id):
|
|
681
|
+
self._send_json(200, {"ok": True})
|
|
682
|
+
else:
|
|
683
|
+
self._error(404, "not_found", "Not found")
|
|
684
|
+
return
|
|
685
|
+
|
|
686
|
+
self._error(404, "not_found", "Not found")
|
|
687
|
+
|
|
688
|
+
def log_message(self, format: str, *args: Any) -> None:
|
|
689
|
+
return
|
|
690
|
+
|
|
691
|
+
httpd = ThreadingHTTPServer((self._host, self._port), Handler)
|
|
692
|
+
t = threading.Thread(target=httpd.serve_forever, daemon=True)
|
|
693
|
+
t.start()
|
|
694
|
+
self._httpd = httpd
|
|
695
|
+
self._thread = t
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class HTTPError(Exception):
|
|
699
|
+
def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
|
|
700
|
+
super().__init__(message)
|
|
701
|
+
self.status = status
|
|
702
|
+
self.code = code
|
|
703
|
+
self.message = message
|
|
704
|
+
self.stata_rc = stata_rc
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
class DatasetChangedError(Exception):
|
|
708
|
+
def __init__(self, current_dataset_id: str):
|
|
709
|
+
super().__init__("dataset_changed")
|
|
710
|
+
self.current_dataset_id = current_dataset_id
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class NoDataInMemoryError(Exception):
|
|
714
|
+
def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
|
|
715
|
+
super().__init__(message)
|
|
716
|
+
self.stata_rc = stata_rc
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class InvalidFilterError(Exception):
|
|
720
|
+
def __init__(self, message: str, *, stata_rc: int | None = None):
|
|
721
|
+
super().__init__(message)
|
|
722
|
+
self.message = message
|
|
723
|
+
self.stata_rc = stata_rc
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
|
|
727
|
+
max_limit, max_vars, max_chars, _ = manager.limits()
|
|
728
|
+
|
|
729
|
+
if view_id is None:
|
|
730
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
731
|
+
frame = str(body.get("frame", "default"))
|
|
732
|
+
else:
|
|
733
|
+
view = manager.get_view(view_id)
|
|
734
|
+
if view is None:
|
|
735
|
+
raise HTTPError(404, "not_found", "View not found")
|
|
736
|
+
dataset_id = view.dataset_id
|
|
737
|
+
frame = view.frame
|
|
738
|
+
|
|
739
|
+
# Parse offset (default 0 is valid since offset >= 0)
|
|
740
|
+
try:
|
|
741
|
+
offset = int(body.get("offset") or 0)
|
|
742
|
+
except (ValueError, TypeError) as e:
|
|
743
|
+
raise HTTPError(400, "invalid_request", f"offset must be a valid integer, got: {body.get('offset')!r}")
|
|
744
|
+
|
|
745
|
+
# Parse limit (no default - must be explicitly provided)
|
|
746
|
+
limit_raw = body.get("limit")
|
|
747
|
+
if limit_raw is None:
|
|
748
|
+
raise HTTPError(400, "invalid_request", "limit is required")
|
|
749
|
+
try:
|
|
750
|
+
limit = int(limit_raw)
|
|
751
|
+
except (ValueError, TypeError) as e:
|
|
752
|
+
raise HTTPError(400, "invalid_request", f"limit must be a valid integer, got: {limit_raw!r}")
|
|
753
|
+
|
|
754
|
+
vars_req = body.get("vars", [])
|
|
755
|
+
include_obs_no = bool(body.get("includeObsNo", False))
|
|
756
|
+
|
|
757
|
+
# Parse sortBy parameter
|
|
758
|
+
sort_by = body.get("sortBy", [])
|
|
759
|
+
if sort_by is not None and not isinstance(sort_by, list):
|
|
760
|
+
raise HTTPError(400, "invalid_request", f"sortBy must be an array, got: {type(sort_by).__name__}")
|
|
761
|
+
if sort_by and not all(isinstance(s, str) for s in sort_by):
|
|
762
|
+
raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
|
|
763
|
+
|
|
764
|
+
# Parse maxChars
|
|
765
|
+
max_chars_raw = body.get("maxChars", max_chars)
|
|
766
|
+
try:
|
|
767
|
+
max_chars_req = int(max_chars_raw or max_chars)
|
|
768
|
+
except (ValueError, TypeError) as e:
|
|
769
|
+
raise HTTPError(400, "invalid_request", f"maxChars must be a valid integer, got: {max_chars_raw!r}")
|
|
770
|
+
|
|
771
|
+
if offset < 0:
|
|
772
|
+
raise HTTPError(400, "invalid_request", f"offset must be >= 0, got: {offset}")
|
|
773
|
+
if limit <= 0:
|
|
774
|
+
raise HTTPError(400, "invalid_request", f"limit must be > 0, got: {limit}")
|
|
775
|
+
if limit > max_limit:
|
|
776
|
+
raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
|
|
777
|
+
if max_chars_req <= 0:
|
|
778
|
+
raise HTTPError(400, "invalid_request", "maxChars must be > 0")
|
|
779
|
+
if max_chars_req > max_chars:
|
|
780
|
+
raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
|
|
781
|
+
|
|
782
|
+
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
|
|
783
|
+
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
|
|
784
|
+
if len(vars_req) > max_vars:
|
|
785
|
+
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
|
|
786
|
+
|
|
787
|
+
current_id = manager.current_dataset_id()
|
|
788
|
+
if dataset_id != current_id:
|
|
789
|
+
raise HTTPError(409, "dataset_changed", "Dataset changed")
|
|
790
|
+
|
|
791
|
+
if view_id is None:
|
|
792
|
+
obs_indices = None
|
|
793
|
+
filtered_n: int | None = None
|
|
794
|
+
else:
|
|
795
|
+
assert view is not None
|
|
796
|
+
obs_indices = view.obs_indices
|
|
797
|
+
filtered_n = view.filtered_n
|
|
798
|
+
|
|
799
|
+
try:
|
|
800
|
+
# Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
|
|
801
|
+
if sort_by:
|
|
802
|
+
try:
|
|
803
|
+
normalized_sort = manager._normalize_sort_spec(sort_by)
|
|
804
|
+
sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
|
|
805
|
+
if sorted_indices is None:
|
|
806
|
+
sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
|
|
807
|
+
descending = [spec.startswith("-") for spec in normalized_sort]
|
|
808
|
+
nulls_last = [not desc for desc in descending]
|
|
809
|
+
|
|
810
|
+
table = manager._get_sort_table(current_id, sort_cols)
|
|
811
|
+
if table is None:
|
|
812
|
+
sorted_indices = []
|
|
813
|
+
else:
|
|
814
|
+
sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
|
|
815
|
+
if sorted_indices is None:
|
|
816
|
+
sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
|
|
817
|
+
|
|
818
|
+
manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
|
|
819
|
+
|
|
820
|
+
if view_id is None:
|
|
821
|
+
obs_indices = sorted_indices
|
|
822
|
+
else:
|
|
823
|
+
assert view is not None
|
|
824
|
+
view_set = set(view.obs_indices)
|
|
825
|
+
obs_indices = [idx for idx in sorted_indices if idx in view_set]
|
|
826
|
+
filtered_n = len(obs_indices)
|
|
827
|
+
except ValueError as e:
|
|
828
|
+
raise HTTPError(400, "invalid_request", f"Invalid sort specification: {e}")
|
|
829
|
+
except RuntimeError as e:
|
|
830
|
+
raise HTTPError(500, "internal_error", f"Failed to apply sort: {e}")
|
|
831
|
+
|
|
832
|
+
dataset_state = manager._client.get_dataset_state()
|
|
833
|
+
page = manager._client.get_page(
|
|
834
|
+
offset=offset,
|
|
835
|
+
limit=limit,
|
|
836
|
+
vars=vars_req,
|
|
837
|
+
include_obs_no=include_obs_no,
|
|
838
|
+
max_chars=max_chars_req,
|
|
839
|
+
obs_indices=obs_indices,
|
|
840
|
+
)
|
|
841
|
+
except HTTPError:
|
|
842
|
+
# Re-raise HTTPError exceptions as-is
|
|
843
|
+
raise
|
|
844
|
+
except RuntimeError as e:
|
|
845
|
+
msg = str(e) or "No data in memory"
|
|
846
|
+
if "no data" in msg.lower():
|
|
847
|
+
raise HTTPError(400, "no_data_in_memory", msg)
|
|
848
|
+
raise HTTPError(500, "internal_error", msg)
|
|
849
|
+
except ValueError as e:
|
|
850
|
+
msg = str(e)
|
|
851
|
+
if msg.lower().startswith("invalid variable"):
|
|
852
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
853
|
+
raise HTTPError(400, "invalid_request", msg)
|
|
854
|
+
except Exception as e:
|
|
855
|
+
raise HTTPError(500, "internal_error", str(e))
|
|
856
|
+
|
|
857
|
+
view_obj: dict[str, Any] = {
|
|
858
|
+
"offset": offset,
|
|
859
|
+
"limit": limit,
|
|
860
|
+
"returned": page["returned"],
|
|
861
|
+
"filteredN": filtered_n,
|
|
862
|
+
}
|
|
863
|
+
if view_id is not None:
|
|
864
|
+
view_obj["viewId"] = view_id
|
|
865
|
+
|
|
866
|
+
return {
|
|
867
|
+
"dataset": {
|
|
868
|
+
"id": current_id,
|
|
869
|
+
"frame": dataset_state.get("frame"),
|
|
870
|
+
"n": dataset_state.get("n"),
|
|
871
|
+
"k": dataset_state.get("k"),
|
|
872
|
+
},
|
|
873
|
+
"view": view_obj,
|
|
874
|
+
"vars": page["vars"],
|
|
875
|
+
"rows": page["rows"],
|
|
876
|
+
"display": {
|
|
877
|
+
"maxChars": max_chars_req,
|
|
878
|
+
"truncatedCells": page["truncated_cells"],
|
|
879
|
+
"missing": ".",
|
|
880
|
+
},
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
|
|
885
|
+
max_limit, max_vars, max_chars, _ = manager.limits()
|
|
886
|
+
# Use the specific Arrow limit instead of the general UI page limit
|
|
887
|
+
chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
|
|
888
|
+
|
|
889
|
+
if view_id is None:
|
|
890
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
891
|
+
frame = str(body.get("frame", "default"))
|
|
892
|
+
else:
|
|
893
|
+
view = manager.get_view(view_id)
|
|
894
|
+
if view is None:
|
|
895
|
+
raise HTTPError(404, "not_found", "View not found")
|
|
896
|
+
dataset_id = view.dataset_id
|
|
897
|
+
frame = view.frame
|
|
898
|
+
|
|
899
|
+
# Parse offset (default 0)
|
|
900
|
+
try:
|
|
901
|
+
offset = int(body.get("offset") or 0)
|
|
902
|
+
except (ValueError, TypeError):
|
|
903
|
+
raise HTTPError(400, "invalid_request", "offset must be a valid integer")
|
|
904
|
+
|
|
905
|
+
# Parse limit (required)
|
|
906
|
+
limit_raw = body.get("limit")
|
|
907
|
+
if limit_raw is None:
|
|
908
|
+
# Default to the max arrow limit if not specified?
|
|
909
|
+
# The previous code required it. Let's keep it required but allow large values.
|
|
910
|
+
raise HTTPError(400, "invalid_request", "limit is required")
|
|
911
|
+
try:
|
|
912
|
+
limit = int(limit_raw)
|
|
913
|
+
except (ValueError, TypeError):
|
|
914
|
+
raise HTTPError(400, "invalid_request", "limit must be a valid integer")
|
|
915
|
+
|
|
916
|
+
vars_req = body.get("vars", [])
|
|
917
|
+
include_obs_no = bool(body.get("includeObsNo", False))
|
|
918
|
+
sort_by = body.get("sortBy", [])
|
|
919
|
+
|
|
920
|
+
if offset < 0:
|
|
921
|
+
raise HTTPError(400, "invalid_request", "offset must be >= 0")
|
|
922
|
+
if limit <= 0:
|
|
923
|
+
raise HTTPError(400, "invalid_request", "limit must be > 0")
|
|
924
|
+
# Arrow streams are efficient, but we still respect a (much larger) max limit
|
|
925
|
+
if limit > chunk_limit:
|
|
926
|
+
raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
|
|
927
|
+
|
|
928
|
+
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
|
|
929
|
+
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
|
|
930
|
+
if len(vars_req) > max_vars:
|
|
931
|
+
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
|
|
932
|
+
|
|
933
|
+
current_id = manager.current_dataset_id()
|
|
934
|
+
if dataset_id != current_id:
|
|
935
|
+
raise HTTPError(409, "dataset_changed", "Dataset changed")
|
|
936
|
+
|
|
937
|
+
if view_id is None:
|
|
938
|
+
obs_indices = None
|
|
939
|
+
else:
|
|
940
|
+
assert view is not None
|
|
941
|
+
obs_indices = view.obs_indices
|
|
942
|
+
|
|
943
|
+
try:
|
|
944
|
+
# Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
|
|
945
|
+
if sort_by:
|
|
946
|
+
if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
|
|
947
|
+
raise HTTPError(400, "invalid_request", "sortBy must be a list of strings")
|
|
948
|
+
try:
|
|
949
|
+
normalized_sort = manager._normalize_sort_spec(sort_by)
|
|
950
|
+
sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
|
|
951
|
+
if sorted_indices is None:
|
|
952
|
+
sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
|
|
953
|
+
descending = [spec.startswith("-") for spec in normalized_sort]
|
|
954
|
+
nulls_last = [not desc for desc in descending]
|
|
955
|
+
|
|
956
|
+
table = manager._get_sort_table(current_id, sort_cols)
|
|
957
|
+
if table is None:
|
|
958
|
+
sorted_indices = []
|
|
959
|
+
else:
|
|
960
|
+
sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
|
|
961
|
+
if sorted_indices is None:
|
|
962
|
+
sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
|
|
963
|
+
|
|
964
|
+
manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
|
|
965
|
+
|
|
966
|
+
if view_id is None:
|
|
967
|
+
obs_indices = sorted_indices
|
|
968
|
+
else:
|
|
969
|
+
assert view is not None
|
|
970
|
+
view_set = set(view.obs_indices)
|
|
971
|
+
obs_indices = [idx for idx in sorted_indices if idx in view_set]
|
|
972
|
+
except ValueError as e:
|
|
973
|
+
raise HTTPError(400, "invalid_request", f"Invalid sort: {e}")
|
|
974
|
+
except RuntimeError as e:
|
|
975
|
+
raise HTTPError(500, "internal_error", f"Sort failed: {e}")
|
|
976
|
+
|
|
977
|
+
arrow_bytes = manager._client.get_arrow_stream(
|
|
978
|
+
offset=offset,
|
|
979
|
+
limit=limit,
|
|
980
|
+
vars=vars_req,
|
|
981
|
+
include_obs_no=include_obs_no,
|
|
982
|
+
obs_indices=obs_indices,
|
|
983
|
+
)
|
|
984
|
+
return arrow_bytes
|
|
985
|
+
|
|
986
|
+
except RuntimeError as e:
|
|
987
|
+
msg = str(e) or "No data in memory"
|
|
988
|
+
if "no data" in msg.lower():
|
|
989
|
+
raise HTTPError(400, "no_data_in_memory", msg)
|
|
990
|
+
raise HTTPError(500, "internal_error", msg)
|
|
991
|
+
except ValueError as e:
|
|
992
|
+
msg = str(e)
|
|
993
|
+
if "invalid variable" in msg.lower():
|
|
994
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
995
|
+
raise HTTPError(400, "invalid_request", msg)
|
|
996
|
+
except Exception as e:
|
|
997
|
+
raise HTTPError(500, "internal_error", str(e))
|
|
998
|
+
|