mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_stata/__init__.py +3 -0
- mcp_stata/__main__.py +4 -0
- mcp_stata/_native_ops.abi3.so +0 -0
- mcp_stata/config.py +20 -0
- mcp_stata/discovery.py +548 -0
- mcp_stata/graph_detector.py +601 -0
- mcp_stata/models.py +74 -0
- mcp_stata/native_ops.py +87 -0
- mcp_stata/server.py +1333 -0
- mcp_stata/sessions.py +264 -0
- mcp_stata/smcl/smcl2html.py +88 -0
- mcp_stata/stata_client.py +4710 -0
- mcp_stata/streaming_io.py +264 -0
- mcp_stata/test_stata.py +56 -0
- mcp_stata/ui_http.py +1034 -0
- mcp_stata/utils.py +159 -0
- mcp_stata/worker.py +167 -0
- mcp_stata-1.22.1.dist-info/METADATA +488 -0
- mcp_stata-1.22.1.dist-info/RECORD +22 -0
- mcp_stata-1.22.1.dist-info/WHEEL +4 -0
- mcp_stata-1.22.1.dist-info/entry_points.txt +2 -0
- mcp_stata-1.22.1.dist-info/licenses/LICENSE +661 -0
mcp_stata/ui_http.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import hashlib
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
import secrets
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
12
|
+
from typing import Any, Callable, Optional
|
|
13
|
+
|
|
14
|
+
from .stata_client import StataClient
|
|
15
|
+
from .config import (
|
|
16
|
+
DEFAULT_HOST,
|
|
17
|
+
DEFAULT_PORT,
|
|
18
|
+
MAX_ARROW_LIMIT,
|
|
19
|
+
MAX_CHARS,
|
|
20
|
+
MAX_LIMIT,
|
|
21
|
+
MAX_REQUEST_BYTES,
|
|
22
|
+
MAX_VARS,
|
|
23
|
+
TOKEN_TTL_S,
|
|
24
|
+
VIEW_TTL_S,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger("mcp_stata")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from .native_ops import argsort_numeric as _native_argsort_numeric
|
|
32
|
+
from .native_ops import argsort_mixed as _native_argsort_mixed
|
|
33
|
+
except Exception:
|
|
34
|
+
_native_argsort_numeric = None
|
|
35
|
+
_native_argsort_mixed = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _try_native_argsort(
|
|
39
|
+
table: Any,
|
|
40
|
+
sort_cols: list[str],
|
|
41
|
+
descending: list[bool],
|
|
42
|
+
nulls_last: list[bool],
|
|
43
|
+
) -> list[int] | None:
|
|
44
|
+
if _native_argsort_numeric is None and _native_argsort_mixed is None:
|
|
45
|
+
return None
|
|
46
|
+
try:
|
|
47
|
+
import pyarrow as pa
|
|
48
|
+
import numpy as np
|
|
49
|
+
|
|
50
|
+
is_string: list[bool] = []
|
|
51
|
+
cols: list[object] = []
|
|
52
|
+
for col in sort_cols:
|
|
53
|
+
arr = table.column(col).combine_chunks()
|
|
54
|
+
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
|
|
55
|
+
is_string.append(True)
|
|
56
|
+
cols.append(arr.to_pylist())
|
|
57
|
+
continue
|
|
58
|
+
if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
|
|
59
|
+
return None
|
|
60
|
+
is_string.append(False)
|
|
61
|
+
np_arr = arr.to_numpy(zero_copy_only=False)
|
|
62
|
+
if np_arr.dtype != np.float64:
|
|
63
|
+
np_arr = np_arr.astype(np.float64, copy=False)
|
|
64
|
+
# Normalize Stata missing values for numeric columns
|
|
65
|
+
np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
|
|
66
|
+
cols.append(np_arr)
|
|
67
|
+
|
|
68
|
+
if "_n" not in table.column_names:
|
|
69
|
+
return None
|
|
70
|
+
obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
|
|
71
|
+
|
|
72
|
+
if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
|
|
73
|
+
idx = _native_argsort_numeric(cols, descending, nulls_last)
|
|
74
|
+
elif _native_argsort_mixed is not None:
|
|
75
|
+
idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
|
|
76
|
+
else:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
return [int(x) for x in (obs[idx] - 1).tolist()]
|
|
80
|
+
except Exception:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_sorted_indices_polars(
|
|
85
|
+
table: Any,
|
|
86
|
+
sort_cols: list[str],
|
|
87
|
+
descending: list[bool],
|
|
88
|
+
nulls_last: list[bool],
|
|
89
|
+
) -> list[int]:
|
|
90
|
+
import polars as pl
|
|
91
|
+
|
|
92
|
+
df = pl.from_arrow(table)
|
|
93
|
+
# Normalize Stata missing values for numeric columns
|
|
94
|
+
exprs = []
|
|
95
|
+
for col, dtype in zip(df.columns, df.dtypes):
|
|
96
|
+
if col == "_n":
|
|
97
|
+
exprs.append(pl.col(col))
|
|
98
|
+
continue
|
|
99
|
+
if dtype in (pl.Float32, pl.Float64):
|
|
100
|
+
exprs.append(
|
|
101
|
+
pl.when(pl.col(col) > 8.0e307)
|
|
102
|
+
.then(None)
|
|
103
|
+
.otherwise(pl.col(col))
|
|
104
|
+
.alias(col)
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
exprs.append(pl.col(col))
|
|
108
|
+
df = df.select(exprs)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
# Use expressions for arithmetic to avoid eager Series-scalar conversion issues
|
|
112
|
+
# that have been observed in some environments with Int64 dtypes.
|
|
113
|
+
res = df.select(
|
|
114
|
+
idx=pl.arg_sort_by(
|
|
115
|
+
[pl.col(c) for c in sort_cols],
|
|
116
|
+
descending=descending,
|
|
117
|
+
nulls_last=nulls_last,
|
|
118
|
+
),
|
|
119
|
+
zero_based_n=pl.col("_n") - 1
|
|
120
|
+
)
|
|
121
|
+
return res["zero_based_n"].take(res["idx"]).to_list()
|
|
122
|
+
except Exception:
|
|
123
|
+
# Fallback to eager sort if arg_sort_by fails or has issues
|
|
124
|
+
return (
|
|
125
|
+
df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
|
|
126
|
+
.select(pl.col("_n") - 1)
|
|
127
|
+
.to_series()
|
|
128
|
+
.to_list()
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _stable_hash(payload: dict[str, Any]) -> str:
|
|
135
|
+
return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class UIChannelInfo:
|
|
140
|
+
base_url: str
|
|
141
|
+
token: str
|
|
142
|
+
expires_at: int
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class ViewHandle:
|
|
147
|
+
view_id: str
|
|
148
|
+
dataset_id: str
|
|
149
|
+
frame: str
|
|
150
|
+
filter_expr: str
|
|
151
|
+
obs_indices: list[int]
|
|
152
|
+
filtered_n: int
|
|
153
|
+
created_at: float
|
|
154
|
+
last_access: float
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class UIChannelManager:
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
client: StataClient,
|
|
161
|
+
*,
|
|
162
|
+
host: str = DEFAULT_HOST,
|
|
163
|
+
port: int = DEFAULT_PORT,
|
|
164
|
+
token_ttl_s: int = TOKEN_TTL_S,
|
|
165
|
+
view_ttl_s: int = VIEW_TTL_S,
|
|
166
|
+
max_limit: int = MAX_LIMIT,
|
|
167
|
+
max_vars: int = MAX_VARS,
|
|
168
|
+
max_chars: int = MAX_CHARS,
|
|
169
|
+
max_request_bytes: int = MAX_REQUEST_BYTES,
|
|
170
|
+
max_arrow_limit: int = MAX_ARROW_LIMIT,
|
|
171
|
+
):
|
|
172
|
+
self._client = client
|
|
173
|
+
self._host = host
|
|
174
|
+
self._port = port
|
|
175
|
+
self._token_ttl_s = token_ttl_s
|
|
176
|
+
self._view_ttl_s = view_ttl_s
|
|
177
|
+
self._max_limit = max_limit
|
|
178
|
+
self._max_vars = max_vars
|
|
179
|
+
self._max_chars = max_chars
|
|
180
|
+
self._max_request_bytes = max_request_bytes
|
|
181
|
+
self._max_arrow_limit = max_arrow_limit
|
|
182
|
+
|
|
183
|
+
self._lock = threading.Lock()
|
|
184
|
+
self._httpd: ThreadingHTTPServer | None = None
|
|
185
|
+
self._thread: threading.Thread | None = None
|
|
186
|
+
|
|
187
|
+
self._token: str | None = None
|
|
188
|
+
self._expires_at: int = 0
|
|
189
|
+
|
|
190
|
+
self._dataset_version: int = 0
|
|
191
|
+
self._dataset_id_cache: str | None = None
|
|
192
|
+
self._dataset_id_cache_at_version: int = -1
|
|
193
|
+
|
|
194
|
+
self._views: dict[str, dict[str, ViewHandle]] = {} # session_id -> {view_id -> ViewHandle}
|
|
195
|
+
self._sort_index_cache: dict[tuple[str, str, tuple[str, ...]], list[int]] = {} # (session_id, dataset_id, sort_spec)
|
|
196
|
+
self._sort_cache_order: list[tuple[str, str, tuple[str, ...]]] = []
|
|
197
|
+
self._sort_cache_max_entries: int = 10
|
|
198
|
+
self._sort_table_cache: dict[tuple[str, str, tuple[str, ...]], Any] = {} # (session_id, dataset_id, sort_cols)
|
|
199
|
+
self._sort_table_order: list[tuple[str, str, tuple[str, ...]]] = []
|
|
200
|
+
self._sort_table_max_entries: int = 4
|
|
201
|
+
self._dataset_id_caches: dict[str, tuple[str, int]] = {} # session_id -> (digest, version)
|
|
202
|
+
|
|
203
|
+
def notify_potential_dataset_change(self, session_id: str = "default") -> None:
|
|
204
|
+
with self._lock:
|
|
205
|
+
self._dataset_id_caches.pop(session_id, None)
|
|
206
|
+
if session_id in self._views:
|
|
207
|
+
self._views[session_id].clear()
|
|
208
|
+
|
|
209
|
+
# Clear caches for this session
|
|
210
|
+
self._sort_cache_order = [k for k in self._sort_cache_order if k[0] != session_id]
|
|
211
|
+
self._sort_index_cache = {k: v for k, v in self._sort_index_cache.items() if k[0] != session_id}
|
|
212
|
+
self._sort_table_order = [k for k in self._sort_table_order if k[0] != session_id]
|
|
213
|
+
self._sort_table_cache = {k: v for k, v in self._sort_table_cache.items() if k[0] != session_id}
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
|
|
217
|
+
normalized: list[str] = []
|
|
218
|
+
for spec in sort_spec:
|
|
219
|
+
if not isinstance(spec, str) or not spec:
|
|
220
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
221
|
+
raw = spec.strip()
|
|
222
|
+
if not raw:
|
|
223
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
224
|
+
sign = "-" if raw.startswith("-") else "+"
|
|
225
|
+
varname = raw.lstrip("+-")
|
|
226
|
+
if not varname:
|
|
227
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
228
|
+
normalized.append(f"{sign}{varname}")
|
|
229
|
+
return tuple(normalized)
|
|
230
|
+
|
|
231
|
+
def _get_cached_sort_indices(
|
|
232
|
+
self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...]
|
|
233
|
+
) -> list[int] | None:
|
|
234
|
+
key = (session_id, dataset_id, sort_spec)
|
|
235
|
+
with self._lock:
|
|
236
|
+
cached = self._sort_index_cache.get(key)
|
|
237
|
+
if cached is None:
|
|
238
|
+
return None
|
|
239
|
+
if key in self._sort_cache_order:
|
|
240
|
+
self._sort_cache_order.remove(key)
|
|
241
|
+
self._sort_cache_order.append(key)
|
|
242
|
+
return cached
|
|
243
|
+
|
|
244
|
+
def _set_cached_sort_indices(
|
|
245
|
+
self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
|
|
246
|
+
) -> None:
|
|
247
|
+
key = (session_id, dataset_id, sort_spec)
|
|
248
|
+
with self._lock:
|
|
249
|
+
if key in self._sort_index_cache:
|
|
250
|
+
self._sort_cache_order.remove(key)
|
|
251
|
+
self._sort_index_cache[key] = indices
|
|
252
|
+
self._sort_cache_order.append(key)
|
|
253
|
+
while len(self._sort_cache_order) > self._sort_cache_max_entries:
|
|
254
|
+
evict = self._sort_cache_order.pop(0)
|
|
255
|
+
self._sort_index_cache.pop(evict, None)
|
|
256
|
+
|
|
257
|
+
def _get_cached_sort_table(
|
|
258
|
+
self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...]
|
|
259
|
+
) -> Any | None:
|
|
260
|
+
key = (session_id, dataset_id, sort_cols)
|
|
261
|
+
with self._lock:
|
|
262
|
+
cached = self._sort_table_cache.get(key)
|
|
263
|
+
if cached is None:
|
|
264
|
+
return None
|
|
265
|
+
if key in self._sort_table_order:
|
|
266
|
+
self._sort_table_order.remove(key)
|
|
267
|
+
self._sort_table_order.append(key)
|
|
268
|
+
return cached
|
|
269
|
+
|
|
270
|
+
def _set_cached_sort_table(
|
|
271
|
+
self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...], table: Any
|
|
272
|
+
) -> None:
|
|
273
|
+
key = (session_id, dataset_id, sort_cols)
|
|
274
|
+
with self._lock:
|
|
275
|
+
if key in self._sort_table_cache:
|
|
276
|
+
self._sort_table_order.remove(key)
|
|
277
|
+
self._sort_table_cache[key] = table
|
|
278
|
+
self._sort_table_order.append(key)
|
|
279
|
+
while len(self._sort_table_order) > self._sort_table_max_entries:
|
|
280
|
+
evict = self._sort_table_order.pop(0)
|
|
281
|
+
self._sort_table_cache.pop(evict, None)
|
|
282
|
+
|
|
283
|
+
def _get_sort_table(self, session_id: str, dataset_id: str, sort_cols: list[str]) -> Any:
|
|
284
|
+
sort_cols_key = tuple(sort_cols)
|
|
285
|
+
cached = self._get_cached_sort_table(session_id, dataset_id, sort_cols_key)
|
|
286
|
+
if cached is not None:
|
|
287
|
+
return cached
|
|
288
|
+
|
|
289
|
+
# Use an appropriate client for the session
|
|
290
|
+
proxy = self._get_proxy_for_session(session_id)
|
|
291
|
+
state = proxy.get_dataset_state()
|
|
292
|
+
n = int(state.get("n", 0) or 0)
|
|
293
|
+
if n <= 0:
|
|
294
|
+
return None
|
|
295
|
+
|
|
296
|
+
# Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
|
|
297
|
+
arrow_bytes = proxy.get_arrow_stream(
|
|
298
|
+
offset=0,
|
|
299
|
+
limit=n,
|
|
300
|
+
vars=sort_cols,
|
|
301
|
+
include_obs_no=True,
|
|
302
|
+
obs_indices=None,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
import pyarrow as pa
|
|
306
|
+
|
|
307
|
+
with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
|
|
308
|
+
table = reader.read_all()
|
|
309
|
+
|
|
310
|
+
self._set_cached_sort_table(session_id, dataset_id, sort_cols_key, table)
|
|
311
|
+
return table
|
|
312
|
+
|
|
313
|
+
def get_channel(self) -> UIChannelInfo:
|
|
314
|
+
self._ensure_http_server()
|
|
315
|
+
with self._lock:
|
|
316
|
+
self._ensure_token()
|
|
317
|
+
assert self._httpd is not None
|
|
318
|
+
port = self._httpd.server_address[1]
|
|
319
|
+
base_url = f"http://{self._host}:{port}"
|
|
320
|
+
return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
|
|
321
|
+
|
|
322
|
+
def capabilities(self) -> dict[str, bool]:
|
|
323
|
+
return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
|
|
324
|
+
|
|
325
|
+
def _get_proxy_for_session(self, session_id: str) -> StataClient:
|
|
326
|
+
# Prefer the injected client when present (used by unit tests and single-session setups).
|
|
327
|
+
client = getattr(self, "_client", None)
|
|
328
|
+
if client is not None:
|
|
329
|
+
from .server import StataClientProxy
|
|
330
|
+
if isinstance(client, StataClientProxy):
|
|
331
|
+
return StataClientProxy(session_id=session_id or "default")
|
|
332
|
+
return client
|
|
333
|
+
|
|
334
|
+
from .server import StataClientProxy
|
|
335
|
+
return StataClientProxy(session_id=session_id or "default")
|
|
336
|
+
|
|
337
|
+
def current_dataset_id(self, session_id: str = "default") -> str:
|
|
338
|
+
with self._lock:
|
|
339
|
+
cached = self._dataset_id_caches.get(session_id)
|
|
340
|
+
if cached:
|
|
341
|
+
digest, version = cached
|
|
342
|
+
if version == self._dataset_version:
|
|
343
|
+
return digest
|
|
344
|
+
|
|
345
|
+
proxy = self._get_proxy_for_session(session_id)
|
|
346
|
+
state = proxy.get_dataset_state()
|
|
347
|
+
payload = {
|
|
348
|
+
"version": self._dataset_version,
|
|
349
|
+
"frame": state.get("frame"),
|
|
350
|
+
"n": state.get("n"),
|
|
351
|
+
"k": state.get("k"),
|
|
352
|
+
"sortlist": state.get("sortlist"),
|
|
353
|
+
}
|
|
354
|
+
digest = _stable_hash(payload)
|
|
355
|
+
|
|
356
|
+
with self._lock:
|
|
357
|
+
self._dataset_id_caches[session_id] = (digest, self._dataset_version)
|
|
358
|
+
return digest
|
|
359
|
+
|
|
360
|
+
def get_view(self, session_id: str, view_id: str) -> Optional[ViewHandle]:
|
|
361
|
+
now = time.time()
|
|
362
|
+
with self._lock:
|
|
363
|
+
self._evict_expired_locked(now)
|
|
364
|
+
session_views = self._views.get(session_id)
|
|
365
|
+
if session_views is None:
|
|
366
|
+
return None
|
|
367
|
+
view = session_views.get(view_id)
|
|
368
|
+
if view is None:
|
|
369
|
+
return None
|
|
370
|
+
view.last_access = now
|
|
371
|
+
return view
|
|
372
|
+
|
|
373
|
+
def create_view(self, *, session_id: str, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
|
|
374
|
+
current_id = self.current_dataset_id(session_id)
|
|
375
|
+
if dataset_id != current_id:
|
|
376
|
+
raise DatasetChangedError(current_id)
|
|
377
|
+
|
|
378
|
+
proxy = self._get_proxy_for_session(session_id)
|
|
379
|
+
try:
|
|
380
|
+
obs_indices = proxy.compute_view_indices(filter_expr)
|
|
381
|
+
except ValueError as e:
|
|
382
|
+
raise InvalidFilterError(str(e))
|
|
383
|
+
except RuntimeError as e:
|
|
384
|
+
msg = str(e) or "No data in memory"
|
|
385
|
+
if "no data" in msg.lower():
|
|
386
|
+
raise NoDataInMemoryError(msg)
|
|
387
|
+
raise
|
|
388
|
+
now = time.time()
|
|
389
|
+
view_id = f"view_{uuid.uuid4().hex}"
|
|
390
|
+
view = ViewHandle(
|
|
391
|
+
view_id=view_id,
|
|
392
|
+
dataset_id=current_id,
|
|
393
|
+
frame=frame,
|
|
394
|
+
filter_expr=filter_expr,
|
|
395
|
+
obs_indices=obs_indices,
|
|
396
|
+
filtered_n=len(obs_indices),
|
|
397
|
+
created_at=now,
|
|
398
|
+
last_access=now,
|
|
399
|
+
)
|
|
400
|
+
with self._lock:
|
|
401
|
+
self._evict_expired_locked(now)
|
|
402
|
+
if session_id not in self._views:
|
|
403
|
+
self._views[session_id] = {}
|
|
404
|
+
self._views[session_id][view_id] = view
|
|
405
|
+
return view
|
|
406
|
+
|
|
407
|
+
def delete_view(self, session_id: str, view_id: str) -> bool:
|
|
408
|
+
with self._lock:
|
|
409
|
+
session_views = self._views.get(session_id)
|
|
410
|
+
if session_views is None:
|
|
411
|
+
return False
|
|
412
|
+
return session_views.pop(view_id, None) is not None
|
|
413
|
+
|
|
414
|
+
def validate_token(self, header_value: str | None) -> bool:
|
|
415
|
+
if not header_value:
|
|
416
|
+
return False
|
|
417
|
+
if not header_value.startswith("Bearer "):
|
|
418
|
+
return False
|
|
419
|
+
token = header_value[len("Bearer ") :].strip()
|
|
420
|
+
with self._lock:
|
|
421
|
+
self._ensure_token()
|
|
422
|
+
if self._token is None:
|
|
423
|
+
return False
|
|
424
|
+
if time.time() * 1000 >= self._expires_at:
|
|
425
|
+
return False
|
|
426
|
+
return secrets.compare_digest(token, self._token)
|
|
427
|
+
|
|
428
|
+
def limits(self) -> tuple[int, int, int, int]:
|
|
429
|
+
return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
|
|
430
|
+
|
|
431
|
+
def _ensure_token(self) -> None:
|
|
432
|
+
now_ms = int(time.time() * 1000)
|
|
433
|
+
if self._token is None or now_ms >= self._expires_at:
|
|
434
|
+
self._token = secrets.token_urlsafe(32)
|
|
435
|
+
self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
|
|
436
|
+
|
|
437
|
+
def _evict_expired_locked(self, now: float) -> None:
|
|
438
|
+
for session_id in list(self._views.keys()):
|
|
439
|
+
session_views = self._views[session_id]
|
|
440
|
+
expired: list[str] = []
|
|
441
|
+
for key, view in session_views.items():
|
|
442
|
+
if now - view.last_access >= self._view_ttl_s:
|
|
443
|
+
expired.append(key)
|
|
444
|
+
for key in expired:
|
|
445
|
+
session_views.pop(key, None)
|
|
446
|
+
if not session_views:
|
|
447
|
+
self._views.pop(session_id, None)
|
|
448
|
+
|
|
449
|
+
def _ensure_http_server(self) -> None:
|
|
450
|
+
with self._lock:
|
|
451
|
+
if self._httpd is not None:
|
|
452
|
+
return
|
|
453
|
+
|
|
454
|
+
manager = self
|
|
455
|
+
|
|
456
|
+
class Handler(BaseHTTPRequestHandler):
|
|
457
|
+
|
|
458
|
+
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
|
|
459
|
+
data = json.dumps(payload).encode("utf-8")
|
|
460
|
+
self.send_response(status)
|
|
461
|
+
self.send_header("Content-Type", "application/json")
|
|
462
|
+
self.send_header("Content-Length", str(len(data)))
|
|
463
|
+
self.end_headers()
|
|
464
|
+
self.wfile.write(data)
|
|
465
|
+
|
|
466
|
+
def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
|
|
467
|
+
self.send_response(status)
|
|
468
|
+
self.send_header("Content-Type", content_type)
|
|
469
|
+
self.send_header("Content-Length", str(len(data)))
|
|
470
|
+
self.end_headers()
|
|
471
|
+
self.wfile.write(data)
|
|
472
|
+
|
|
473
|
+
def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
|
|
474
|
+
if status >= 500 or code == "internal_error":
|
|
475
|
+
logger.error("UI HTTP error %s: %s", code, message)
|
|
476
|
+
message = "Internal server error"
|
|
477
|
+
body: dict[str, Any] = {"error": {"code": code, "message": message}}
|
|
478
|
+
if stata_rc is not None:
|
|
479
|
+
body["error"]["stataRc"] = stata_rc
|
|
480
|
+
self._send_json(status, body)
|
|
481
|
+
|
|
482
|
+
def _require_auth(self) -> bool:
|
|
483
|
+
if manager.validate_token(self.headers.get("Authorization")):
|
|
484
|
+
return True
|
|
485
|
+
self._error(401, "auth_failed", "Unauthorized")
|
|
486
|
+
return False
|
|
487
|
+
|
|
488
|
+
def _read_json(self) -> dict[str, Any] | None:
|
|
489
|
+
max_limit, max_vars, max_chars, max_bytes = manager.limits()
|
|
490
|
+
_ = (max_limit, max_vars, max_chars)
|
|
491
|
+
|
|
492
|
+
length = int(self.headers.get("Content-Length", "0") or "0")
|
|
493
|
+
if length <= 0:
|
|
494
|
+
return {}
|
|
495
|
+
if length > max_bytes:
|
|
496
|
+
self._error(400, "request_too_large", "Request too large")
|
|
497
|
+
return None
|
|
498
|
+
raw = self.rfile.read(length)
|
|
499
|
+
try:
|
|
500
|
+
parsed = json.loads(raw.decode("utf-8"))
|
|
501
|
+
except Exception:
|
|
502
|
+
self._error(400, "invalid_request", "Invalid JSON")
|
|
503
|
+
return None
|
|
504
|
+
if not isinstance(parsed, dict):
|
|
505
|
+
self._error(400, "invalid_request", "Expected JSON object")
|
|
506
|
+
return None
|
|
507
|
+
return parsed
|
|
508
|
+
|
|
509
|
+
def do_GET(self) -> None:
|
|
510
|
+
if not self._require_auth():
|
|
511
|
+
return
|
|
512
|
+
|
|
513
|
+
if self.path.startswith("/v1/dataset"):
|
|
514
|
+
from urllib.parse import urlparse, parse_qs
|
|
515
|
+
parsed_url = urlparse(self.path)
|
|
516
|
+
params = parse_qs(parsed_url.query)
|
|
517
|
+
session_id = params.get("sessionId", ["default"])[0]
|
|
518
|
+
|
|
519
|
+
try:
|
|
520
|
+
proxy = manager._get_proxy_for_session(session_id)
|
|
521
|
+
state = proxy.get_dataset_state()
|
|
522
|
+
dataset_id = manager.current_dataset_id(session_id)
|
|
523
|
+
self._send_json(
|
|
524
|
+
200,
|
|
525
|
+
{
|
|
526
|
+
"dataset": {
|
|
527
|
+
"id": dataset_id,
|
|
528
|
+
"frame": state.get("frame"),
|
|
529
|
+
"n": state.get("n"),
|
|
530
|
+
"k": state.get("k"),
|
|
531
|
+
"changed": state.get("changed"),
|
|
532
|
+
}
|
|
533
|
+
},
|
|
534
|
+
)
|
|
535
|
+
return
|
|
536
|
+
except NoDataInMemoryError as e:
|
|
537
|
+
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
|
|
538
|
+
return
|
|
539
|
+
except Exception as e:
|
|
540
|
+
self._error(500, "internal_error", str(e))
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
if self.path.startswith("/v1/vars"):
|
|
544
|
+
from urllib.parse import urlparse, parse_qs
|
|
545
|
+
parsed_url = urlparse(self.path)
|
|
546
|
+
params = parse_qs(parsed_url.query)
|
|
547
|
+
session_id = params.get("sessionId", ["default"])[0]
|
|
548
|
+
|
|
549
|
+
try:
|
|
550
|
+
proxy = manager._get_proxy_for_session(session_id)
|
|
551
|
+
state = proxy.get_dataset_state()
|
|
552
|
+
dataset_id = manager.current_dataset_id(session_id)
|
|
553
|
+
variables = proxy.list_variables_rich()
|
|
554
|
+
self._send_json(
|
|
555
|
+
200,
|
|
556
|
+
{
|
|
557
|
+
"dataset": {"id": dataset_id, "frame": state.get("frame")},
|
|
558
|
+
"variables": variables,
|
|
559
|
+
},
|
|
560
|
+
)
|
|
561
|
+
return
|
|
562
|
+
except NoDataInMemoryError as e:
|
|
563
|
+
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
|
|
564
|
+
return
|
|
565
|
+
except Exception as e:
|
|
566
|
+
self._error(500, "internal_error", str(e))
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
self._error(404, "not_found", "Not found")
|
|
570
|
+
|
|
571
|
+
def do_POST(self) -> None:
|
|
572
|
+
if not self._require_auth():
|
|
573
|
+
return
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
if self.path == "/v1/arrow":
|
|
577
|
+
body = self._read_json()
|
|
578
|
+
if body is None:
|
|
579
|
+
return
|
|
580
|
+
try:
|
|
581
|
+
resp_bytes = handle_arrow_request(manager, body, view_id=None)
|
|
582
|
+
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
|
|
583
|
+
return
|
|
584
|
+
except HTTPError as e:
|
|
585
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
586
|
+
return
|
|
587
|
+
except Exception as e:
|
|
588
|
+
self._error(500, "internal_error", str(e))
|
|
589
|
+
return
|
|
590
|
+
|
|
591
|
+
if self.path == "/v1/page":
|
|
592
|
+
body = self._read_json()
|
|
593
|
+
if body is None:
|
|
594
|
+
return
|
|
595
|
+
try:
|
|
596
|
+
resp = handle_page_request(manager, body, view_id=None)
|
|
597
|
+
self._send_json(200, resp)
|
|
598
|
+
return
|
|
599
|
+
except HTTPError as e:
|
|
600
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
601
|
+
return
|
|
602
|
+
except Exception as e:
|
|
603
|
+
self._error(500, "internal_error", str(e))
|
|
604
|
+
return
|
|
605
|
+
|
|
606
|
+
if self.path == "/v1/views":
|
|
607
|
+
body = self._read_json()
|
|
608
|
+
if body is None:
|
|
609
|
+
return
|
|
610
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
611
|
+
frame = str(body.get("frame", "default"))
|
|
612
|
+
filter_expr = str(body.get("filterExpr", ""))
|
|
613
|
+
session_id = str(body.get("sessionId", "default"))
|
|
614
|
+
if not dataset_id or not filter_expr:
|
|
615
|
+
self._error(400, "invalid_request", "datasetId and filterExpr are required")
|
|
616
|
+
return
|
|
617
|
+
try:
|
|
618
|
+
view = manager.create_view(session_id=session_id, dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
|
|
619
|
+
self._send_json(
|
|
620
|
+
200,
|
|
621
|
+
{
|
|
622
|
+
"dataset": {"id": view.dataset_id, "frame": view.frame},
|
|
623
|
+
"view": {"id": view.view_id, "filteredN": view.filtered_n},
|
|
624
|
+
},
|
|
625
|
+
)
|
|
626
|
+
return
|
|
627
|
+
except DatasetChangedError as e:
|
|
628
|
+
self._error(409, "dataset_changed", f"Dataset changed for session {session_id}")
|
|
629
|
+
return
|
|
630
|
+
except ValueError as e:
|
|
631
|
+
self._error(400, "invalid_filter", str(e))
|
|
632
|
+
return
|
|
633
|
+
except RuntimeError as e:
|
|
634
|
+
msg = str(e) or "No data in memory"
|
|
635
|
+
if "no data" in msg.lower():
|
|
636
|
+
self._error(400, "no_data_in_memory", msg)
|
|
637
|
+
return
|
|
638
|
+
self._error(500, "internal_error", msg)
|
|
639
|
+
return
|
|
640
|
+
except Exception as e:
|
|
641
|
+
self._error(500, "internal_error", str(e))
|
|
642
|
+
return
|
|
643
|
+
|
|
644
|
+
if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
|
|
645
|
+
parts = self.path.split("/")
|
|
646
|
+
if len(parts) != 5:
|
|
647
|
+
self._error(404, "not_found", "Not found")
|
|
648
|
+
return
|
|
649
|
+
view_id = parts[3]
|
|
650
|
+
body = self._read_json()
|
|
651
|
+
if body is None:
|
|
652
|
+
return
|
|
653
|
+
try:
|
|
654
|
+
resp = handle_page_request(manager, body, view_id=view_id)
|
|
655
|
+
self._send_json(200, resp)
|
|
656
|
+
return
|
|
657
|
+
except HTTPError as e:
|
|
658
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
659
|
+
return
|
|
660
|
+
except Exception as e:
|
|
661
|
+
self._error(500, "internal_error", str(e))
|
|
662
|
+
return
|
|
663
|
+
|
|
664
|
+
if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
|
|
665
|
+
parts = self.path.split("/")
|
|
666
|
+
if len(parts) != 5:
|
|
667
|
+
self._error(404, "not_found", "Not found")
|
|
668
|
+
return
|
|
669
|
+
view_id = parts[3]
|
|
670
|
+
body = self._read_json()
|
|
671
|
+
if body is None:
|
|
672
|
+
return
|
|
673
|
+
try:
|
|
674
|
+
resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
|
|
675
|
+
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
|
|
676
|
+
return
|
|
677
|
+
except HTTPError as e:
|
|
678
|
+
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
|
|
679
|
+
return
|
|
680
|
+
except Exception as e:
|
|
681
|
+
self._error(500, "internal_error", str(e))
|
|
682
|
+
return
|
|
683
|
+
|
|
684
|
+
if self.path == "/v1/filters/validate":
|
|
685
|
+
body = self._read_json()
|
|
686
|
+
if body is None:
|
|
687
|
+
return
|
|
688
|
+
filter_expr = str(body.get("filterExpr", ""))
|
|
689
|
+
session_id = str(body.get("sessionId", "default"))
|
|
690
|
+
if not filter_expr:
|
|
691
|
+
self._error(400, "invalid_request", "filterExpr is required")
|
|
692
|
+
return
|
|
693
|
+
try:
|
|
694
|
+
proxy = manager._get_proxy_for_session(session_id)
|
|
695
|
+
proxy.validate_filter_expr(filter_expr)
|
|
696
|
+
self._send_json(200, {"ok": True})
|
|
697
|
+
return
|
|
698
|
+
except ValueError as e:
|
|
699
|
+
self._error(400, "invalid_filter", str(e))
|
|
700
|
+
return
|
|
701
|
+
except RuntimeError as e:
|
|
702
|
+
msg = str(e) or "No data in memory"
|
|
703
|
+
if "no data" in msg.lower():
|
|
704
|
+
self._error(400, "no_data_in_memory", msg)
|
|
705
|
+
return
|
|
706
|
+
self._error(500, "internal_error", msg)
|
|
707
|
+
return
|
|
708
|
+
except Exception as e:
|
|
709
|
+
self._error(500, "internal_error", str(e))
|
|
710
|
+
return
|
|
711
|
+
|
|
712
|
+
self._error(404, "not_found", "Not found")
|
|
713
|
+
|
|
714
|
+
def do_DELETE(self) -> None:
|
|
715
|
+
if not self._require_auth():
|
|
716
|
+
return
|
|
717
|
+
|
|
718
|
+
if self.path.startswith("/v1/views/"):
|
|
719
|
+
parts = self.path.split("/")
|
|
720
|
+
if len(parts) != 4:
|
|
721
|
+
self._error(404, "not_found", "Not found")
|
|
722
|
+
return
|
|
723
|
+
from urllib.parse import urlparse, parse_qs
|
|
724
|
+
parsed_url = urlparse(self.path)
|
|
725
|
+
params = parse_qs(parsed_url.query)
|
|
726
|
+
session_id = params.get("sessionId", ["default"])[0]
|
|
727
|
+
view_id = parts[3]
|
|
728
|
+
if manager.delete_view(session_id, view_id):
|
|
729
|
+
self._send_json(200, {"ok": True})
|
|
730
|
+
else:
|
|
731
|
+
self._error(404, "not_found", f"View {view_id} not found in session {session_id}")
|
|
732
|
+
return
|
|
733
|
+
|
|
734
|
+
self._error(404, "not_found", "Not found")
|
|
735
|
+
|
|
736
|
+
def log_message(self, format: str, *args: Any) -> None:
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
httpd = ThreadingHTTPServer((self._host, self._port), Handler)
|
|
740
|
+
t = threading.Thread(target=httpd.serve_forever, daemon=True)
|
|
741
|
+
t.start()
|
|
742
|
+
self._httpd = httpd
|
|
743
|
+
self._thread = t
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
class HTTPError(Exception):
|
|
747
|
+
def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
|
|
748
|
+
super().__init__(message)
|
|
749
|
+
self.status = status
|
|
750
|
+
self.code = code
|
|
751
|
+
self.message = message
|
|
752
|
+
self.stata_rc = stata_rc
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
class DatasetChangedError(Exception):
|
|
756
|
+
def __init__(self, current_dataset_id: str):
|
|
757
|
+
super().__init__("dataset_changed")
|
|
758
|
+
self.current_dataset_id = current_dataset_id
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
class NoDataInMemoryError(Exception):
|
|
762
|
+
def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
|
|
763
|
+
super().__init__(message)
|
|
764
|
+
self.stata_rc = stata_rc
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
class InvalidFilterError(Exception):
|
|
768
|
+
def __init__(self, message: str, *, stata_rc: int | None = None):
|
|
769
|
+
super().__init__(message)
|
|
770
|
+
self.message = message
|
|
771
|
+
self.stata_rc = stata_rc
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def _resolve_proxy(manager: UIChannelManager, session_id: str) -> StataClient:
|
|
775
|
+
"""Resolve the Stata client proxy, preferring injected clients for tests."""
|
|
776
|
+
proxy = getattr(manager, "_client", None)
|
|
777
|
+
if proxy is not None:
|
|
778
|
+
from .server import StataClientProxy
|
|
779
|
+
if not isinstance(proxy, StataClientProxy):
|
|
780
|
+
return proxy
|
|
781
|
+
return manager._get_proxy_for_session(session_id)
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
|
|
785
|
+
max_limit, max_vars, max_chars, _ = manager.limits()
|
|
786
|
+
|
|
787
|
+
session_id = str(body.get("sessionId", "default"))
|
|
788
|
+
|
|
789
|
+
if view_id is None:
|
|
790
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
791
|
+
frame = str(body.get("frame", "default"))
|
|
792
|
+
else:
|
|
793
|
+
view = manager.get_view(session_id, view_id)
|
|
794
|
+
if view is None:
|
|
795
|
+
raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
|
|
796
|
+
dataset_id = view.dataset_id
|
|
797
|
+
frame = view.frame
|
|
798
|
+
|
|
799
|
+
try:
|
|
800
|
+
offset = int(body.get("offset") or 0)
|
|
801
|
+
except (ValueError, TypeError):
|
|
802
|
+
raise HTTPError(400, "invalid_request", "offset must be a valid integer")
|
|
803
|
+
|
|
804
|
+
limit_raw = body.get("limit")
|
|
805
|
+
if limit_raw is None:
|
|
806
|
+
raise HTTPError(400, "invalid_request", "limit is required")
|
|
807
|
+
try:
|
|
808
|
+
limit = int(limit_raw)
|
|
809
|
+
except (ValueError, TypeError):
|
|
810
|
+
raise HTTPError(400, "invalid_request", "limit must be a valid integer")
|
|
811
|
+
|
|
812
|
+
vars_req = body.get("vars", [])
|
|
813
|
+
include_obs_no = bool(body.get("includeObsNo", False))
|
|
814
|
+
sort_by = body.get("sortBy", [])
|
|
815
|
+
max_chars_raw = body.get("maxChars", max_chars)
|
|
816
|
+
|
|
817
|
+
try:
|
|
818
|
+
max_chars_req = int(max_chars_raw or max_chars)
|
|
819
|
+
except (ValueError, TypeError):
|
|
820
|
+
raise HTTPError(400, "invalid_request", "maxChars must be a valid integer")
|
|
821
|
+
|
|
822
|
+
if offset < 0:
|
|
823
|
+
raise HTTPError(400, "invalid_request", "offset must be >= 0")
|
|
824
|
+
if limit <= 0:
|
|
825
|
+
raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
|
|
826
|
+
if limit > max_limit:
|
|
827
|
+
raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
|
|
828
|
+
if max_chars_req <= 0:
|
|
829
|
+
raise HTTPError(400, "invalid_request", "maxChars must be > 0")
|
|
830
|
+
if max_chars_req > max_chars:
|
|
831
|
+
raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
|
|
832
|
+
|
|
833
|
+
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
|
|
834
|
+
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
|
|
835
|
+
if len(vars_req) > max_vars:
|
|
836
|
+
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
|
|
837
|
+
|
|
838
|
+
if sort_by and (not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by)):
|
|
839
|
+
raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
|
|
840
|
+
|
|
841
|
+
current_id = manager.current_dataset_id(session_id)
|
|
842
|
+
if dataset_id != current_id:
|
|
843
|
+
raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
|
|
844
|
+
|
|
845
|
+
if view_id is None:
|
|
846
|
+
obs_indices = None
|
|
847
|
+
filtered_n = None
|
|
848
|
+
else:
|
|
849
|
+
assert view is not None
|
|
850
|
+
obs_indices = view.obs_indices
|
|
851
|
+
filtered_n = view.filtered_n
|
|
852
|
+
|
|
853
|
+
try:
|
|
854
|
+
if sort_by:
|
|
855
|
+
sort_spec = manager._normalize_sort_spec(sort_by)
|
|
856
|
+
obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
|
|
857
|
+
if obs_indices_sorted is None:
|
|
858
|
+
sort_cols = [s.lstrip("+-") for s in sort_spec]
|
|
859
|
+
descending = [s.startswith("-") for s in sort_spec]
|
|
860
|
+
nulls_last = [False] * len(sort_spec)
|
|
861
|
+
|
|
862
|
+
table = manager._get_sort_table(session_id, dataset_id, sort_cols)
|
|
863
|
+
if table is not None:
|
|
864
|
+
obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
|
|
865
|
+
if obs_indices_sorted is None:
|
|
866
|
+
obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
|
|
867
|
+
manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
|
|
868
|
+
|
|
869
|
+
if obs_indices_sorted:
|
|
870
|
+
if obs_indices:
|
|
871
|
+
filter_set = set(obs_indices)
|
|
872
|
+
obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
|
|
873
|
+
else:
|
|
874
|
+
obs_indices = obs_indices_sorted
|
|
875
|
+
|
|
876
|
+
proxy = _resolve_proxy(manager, session_id)
|
|
877
|
+
dataset_state = proxy.get_dataset_state()
|
|
878
|
+
page = proxy.get_page(
|
|
879
|
+
offset=offset,
|
|
880
|
+
limit=limit,
|
|
881
|
+
vars=vars_req,
|
|
882
|
+
include_obs_no=include_obs_no,
|
|
883
|
+
max_chars=max_chars_req,
|
|
884
|
+
obs_indices=obs_indices,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
view_obj: dict[str, Any] = {
|
|
888
|
+
"offset": offset,
|
|
889
|
+
"limit": limit,
|
|
890
|
+
"returned": page["returned"],
|
|
891
|
+
"filteredN": filtered_n,
|
|
892
|
+
}
|
|
893
|
+
if view_id is not None:
|
|
894
|
+
view_obj["viewId"] = view_id
|
|
895
|
+
|
|
896
|
+
return {
|
|
897
|
+
"dataset": {
|
|
898
|
+
"id": current_id,
|
|
899
|
+
"frame": dataset_state.get("frame"),
|
|
900
|
+
"n": dataset_state.get("n"),
|
|
901
|
+
"k": dataset_state.get("k"),
|
|
902
|
+
},
|
|
903
|
+
"view": view_obj,
|
|
904
|
+
"vars": page["vars"],
|
|
905
|
+
"rows": page["rows"],
|
|
906
|
+
"display": {
|
|
907
|
+
"maxChars": max_chars_req,
|
|
908
|
+
"truncatedCells": page["truncated_cells"],
|
|
909
|
+
"missing": ".",
|
|
910
|
+
},
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
except HTTPError:
|
|
914
|
+
raise
|
|
915
|
+
except RuntimeError as e:
|
|
916
|
+
msg = str(e) or "No data in memory"
|
|
917
|
+
if "invalid variable" in msg.lower():
|
|
918
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
919
|
+
if "no data" in msg.lower():
|
|
920
|
+
raise HTTPError(400, "no_data_in_memory", msg)
|
|
921
|
+
raise HTTPError(500, "internal_error", msg)
|
|
922
|
+
except ValueError as e:
|
|
923
|
+
msg = str(e)
|
|
924
|
+
if msg.lower().startswith("invalid variable"):
|
|
925
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
926
|
+
raise HTTPError(400, "invalid_request", msg)
|
|
927
|
+
except Exception as e:
|
|
928
|
+
raise HTTPError(500, "internal_error", str(e))
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
|
|
932
|
+
max_limit, max_vars, _, _ = manager.limits()
|
|
933
|
+
chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
|
|
934
|
+
session_id = str(body.get("sessionId", "default"))
|
|
935
|
+
|
|
936
|
+
if view_id is None:
|
|
937
|
+
dataset_id = str(body.get("datasetId", ""))
|
|
938
|
+
frame = str(body.get("frame", "default"))
|
|
939
|
+
else:
|
|
940
|
+
view = manager.get_view(session_id, view_id)
|
|
941
|
+
if view is None:
|
|
942
|
+
raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
|
|
943
|
+
dataset_id = view.dataset_id
|
|
944
|
+
frame = view.frame
|
|
945
|
+
|
|
946
|
+
try:
|
|
947
|
+
offset = int(body.get("offset") or 0)
|
|
948
|
+
except (ValueError, TypeError):
|
|
949
|
+
raise HTTPError(400, "invalid_request", "offset must be a valid integer")
|
|
950
|
+
|
|
951
|
+
limit_raw = body.get("limit")
|
|
952
|
+
if limit_raw is None:
|
|
953
|
+
raise HTTPError(400, "invalid_request", "limit is required")
|
|
954
|
+
try:
|
|
955
|
+
limit = int(limit_raw)
|
|
956
|
+
except (ValueError, TypeError):
|
|
957
|
+
raise HTTPError(400, "invalid_request", "limit must be a valid integer")
|
|
958
|
+
|
|
959
|
+
vars_req = body.get("vars", [])
|
|
960
|
+
include_obs_no = bool(body.get("includeObsNo", False))
|
|
961
|
+
sort_by = body.get("sortBy", [])
|
|
962
|
+
|
|
963
|
+
if offset < 0:
|
|
964
|
+
raise HTTPError(400, "invalid_request", "offset must be >= 0")
|
|
965
|
+
if limit <= 0:
|
|
966
|
+
raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
|
|
967
|
+
if limit > chunk_limit:
|
|
968
|
+
raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
|
|
969
|
+
|
|
970
|
+
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
|
|
971
|
+
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
|
|
972
|
+
if len(vars_req) > max_vars:
|
|
973
|
+
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
|
|
974
|
+
|
|
975
|
+
current_id = manager.current_dataset_id(session_id)
|
|
976
|
+
if dataset_id != current_id:
|
|
977
|
+
raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
|
|
978
|
+
|
|
979
|
+
if view_id is None:
|
|
980
|
+
obs_indices = None
|
|
981
|
+
else:
|
|
982
|
+
assert view is not None
|
|
983
|
+
obs_indices = view.obs_indices
|
|
984
|
+
|
|
985
|
+
try:
|
|
986
|
+
if sort_by:
|
|
987
|
+
if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
|
|
988
|
+
raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
|
|
989
|
+
|
|
990
|
+
sort_spec = manager._normalize_sort_spec(sort_by)
|
|
991
|
+
obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
|
|
992
|
+
if obs_indices_sorted is None:
|
|
993
|
+
sort_cols = [s.lstrip("+-") for s in sort_spec]
|
|
994
|
+
descending = [s.startswith("-") for s in sort_spec]
|
|
995
|
+
nulls_last = [False] * len(sort_spec)
|
|
996
|
+
|
|
997
|
+
table = manager._get_sort_table(session_id, dataset_id, sort_cols)
|
|
998
|
+
if table is not None:
|
|
999
|
+
obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
|
|
1000
|
+
if obs_indices_sorted is None:
|
|
1001
|
+
obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
|
|
1002
|
+
manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
|
|
1003
|
+
|
|
1004
|
+
if obs_indices_sorted:
|
|
1005
|
+
if obs_indices:
|
|
1006
|
+
filter_set = set(obs_indices)
|
|
1007
|
+
obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
|
|
1008
|
+
else:
|
|
1009
|
+
obs_indices = obs_indices_sorted
|
|
1010
|
+
|
|
1011
|
+
proxy = _resolve_proxy(manager, session_id)
|
|
1012
|
+
return proxy.get_arrow_stream(
|
|
1013
|
+
offset=offset,
|
|
1014
|
+
limit=limit,
|
|
1015
|
+
vars=vars_req,
|
|
1016
|
+
include_obs_no=include_obs_no,
|
|
1017
|
+
obs_indices=obs_indices,
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
except RuntimeError as e:
|
|
1021
|
+
msg = str(e) or "No data in memory"
|
|
1022
|
+
if "invalid variable" in msg.lower():
|
|
1023
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
1024
|
+
if "no data" in msg.lower():
|
|
1025
|
+
raise HTTPError(400, "no_data_in_memory", msg)
|
|
1026
|
+
raise HTTPError(500, "internal_error", msg)
|
|
1027
|
+
except ValueError as e:
|
|
1028
|
+
msg = str(e)
|
|
1029
|
+
if "invalid variable" in msg.lower():
|
|
1030
|
+
raise HTTPError(400, "invalid_variable", msg)
|
|
1031
|
+
raise HTTPError(400, "invalid_request", msg)
|
|
1032
|
+
except Exception as e:
|
|
1033
|
+
raise HTTPError(500, "internal_error", str(e))
|
|
1034
|
+
|