mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/ui_http.py ADDED
@@ -0,0 +1,1034 @@
1
+ from __future__ import annotations
2
+ import hashlib
3
+ import io
4
+ import json
5
+ import secrets
6
+ import threading
7
+ import time
8
+ import uuid
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
12
+ from typing import Any, Callable, Optional
13
+
14
+ from .stata_client import StataClient
15
+ from .config import (
16
+ DEFAULT_HOST,
17
+ DEFAULT_PORT,
18
+ MAX_ARROW_LIMIT,
19
+ MAX_CHARS,
20
+ MAX_LIMIT,
21
+ MAX_REQUEST_BYTES,
22
+ MAX_VARS,
23
+ TOKEN_TTL_S,
24
+ VIEW_TTL_S,
25
+ )
26
+
27
+
28
+ logger = logging.getLogger("mcp_stata")
29
+
30
+ try:
31
+ from .native_ops import argsort_numeric as _native_argsort_numeric
32
+ from .native_ops import argsort_mixed as _native_argsort_mixed
33
+ except Exception:
34
+ _native_argsort_numeric = None
35
+ _native_argsort_mixed = None
36
+
37
+
38
+ def _try_native_argsort(
39
+ table: Any,
40
+ sort_cols: list[str],
41
+ descending: list[bool],
42
+ nulls_last: list[bool],
43
+ ) -> list[int] | None:
44
+ if _native_argsort_numeric is None and _native_argsort_mixed is None:
45
+ return None
46
+ try:
47
+ import pyarrow as pa
48
+ import numpy as np
49
+
50
+ is_string: list[bool] = []
51
+ cols: list[object] = []
52
+ for col in sort_cols:
53
+ arr = table.column(col).combine_chunks()
54
+ if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
55
+ is_string.append(True)
56
+ cols.append(arr.to_pylist())
57
+ continue
58
+ if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
59
+ return None
60
+ is_string.append(False)
61
+ np_arr = arr.to_numpy(zero_copy_only=False)
62
+ if np_arr.dtype != np.float64:
63
+ np_arr = np_arr.astype(np.float64, copy=False)
64
+ # Normalize Stata missing values for numeric columns
65
+ np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
66
+ cols.append(np_arr)
67
+
68
+ if "_n" not in table.column_names:
69
+ return None
70
+ obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
71
+
72
+ if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
73
+ idx = _native_argsort_numeric(cols, descending, nulls_last)
74
+ elif _native_argsort_mixed is not None:
75
+ idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
76
+ else:
77
+ return None
78
+
79
+ return [int(x) for x in (obs[idx] - 1).tolist()]
80
+ except Exception:
81
+ return None
82
+
83
+
84
+ def _get_sorted_indices_polars(
85
+ table: Any,
86
+ sort_cols: list[str],
87
+ descending: list[bool],
88
+ nulls_last: list[bool],
89
+ ) -> list[int]:
90
+ import polars as pl
91
+
92
+ df = pl.from_arrow(table)
93
+ # Normalize Stata missing values for numeric columns
94
+ exprs = []
95
+ for col, dtype in zip(df.columns, df.dtypes):
96
+ if col == "_n":
97
+ exprs.append(pl.col(col))
98
+ continue
99
+ if dtype in (pl.Float32, pl.Float64):
100
+ exprs.append(
101
+ pl.when(pl.col(col) > 8.0e307)
102
+ .then(None)
103
+ .otherwise(pl.col(col))
104
+ .alias(col)
105
+ )
106
+ else:
107
+ exprs.append(pl.col(col))
108
+ df = df.select(exprs)
109
+
110
+ try:
111
+ # Use expressions for arithmetic to avoid eager Series-scalar conversion issues
112
+ # that have been observed in some environments with Int64 dtypes.
113
+ res = df.select(
114
+ idx=pl.arg_sort_by(
115
+ [pl.col(c) for c in sort_cols],
116
+ descending=descending,
117
+ nulls_last=nulls_last,
118
+ ),
119
+ zero_based_n=pl.col("_n") - 1
120
+ )
121
+ return res["zero_based_n"].take(res["idx"]).to_list()
122
+ except Exception:
123
+ # Fallback to eager sort if arg_sort_by fails or has issues
124
+ return (
125
+ df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
126
+ .select(pl.col("_n") - 1)
127
+ .to_series()
128
+ .to_list()
129
+ )
130
+
131
+
132
+
133
+
134
+ def _stable_hash(payload: dict[str, Any]) -> str:
135
+ return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
136
+
137
+
138
+ @dataclass
139
+ class UIChannelInfo:
140
+ base_url: str
141
+ token: str
142
+ expires_at: int
143
+
144
+
145
+ @dataclass
146
+ class ViewHandle:
147
+ view_id: str
148
+ dataset_id: str
149
+ frame: str
150
+ filter_expr: str
151
+ obs_indices: list[int]
152
+ filtered_n: int
153
+ created_at: float
154
+ last_access: float
155
+
156
+
157
+ class UIChannelManager:
158
+ def __init__(
159
+ self,
160
+ client: StataClient,
161
+ *,
162
+ host: str = DEFAULT_HOST,
163
+ port: int = DEFAULT_PORT,
164
+ token_ttl_s: int = TOKEN_TTL_S,
165
+ view_ttl_s: int = VIEW_TTL_S,
166
+ max_limit: int = MAX_LIMIT,
167
+ max_vars: int = MAX_VARS,
168
+ max_chars: int = MAX_CHARS,
169
+ max_request_bytes: int = MAX_REQUEST_BYTES,
170
+ max_arrow_limit: int = MAX_ARROW_LIMIT,
171
+ ):
172
+ self._client = client
173
+ self._host = host
174
+ self._port = port
175
+ self._token_ttl_s = token_ttl_s
176
+ self._view_ttl_s = view_ttl_s
177
+ self._max_limit = max_limit
178
+ self._max_vars = max_vars
179
+ self._max_chars = max_chars
180
+ self._max_request_bytes = max_request_bytes
181
+ self._max_arrow_limit = max_arrow_limit
182
+
183
+ self._lock = threading.Lock()
184
+ self._httpd: ThreadingHTTPServer | None = None
185
+ self._thread: threading.Thread | None = None
186
+
187
+ self._token: str | None = None
188
+ self._expires_at: int = 0
189
+
190
+ self._dataset_version: int = 0
191
+ self._dataset_id_cache: str | None = None
192
+ self._dataset_id_cache_at_version: int = -1
193
+
194
+ self._views: dict[str, dict[str, ViewHandle]] = {} # session_id -> {view_id -> ViewHandle}
195
+ self._sort_index_cache: dict[tuple[str, str, tuple[str, ...]], list[int]] = {} # (session_id, dataset_id, sort_spec)
196
+ self._sort_cache_order: list[tuple[str, str, tuple[str, ...]]] = []
197
+ self._sort_cache_max_entries: int = 10
198
+ self._sort_table_cache: dict[tuple[str, str, tuple[str, ...]], Any] = {} # (session_id, dataset_id, sort_cols)
199
+ self._sort_table_order: list[tuple[str, str, tuple[str, ...]]] = []
200
+ self._sort_table_max_entries: int = 4
201
+ self._dataset_id_caches: dict[str, tuple[str, int]] = {} # session_id -> (digest, version)
202
+
203
+ def notify_potential_dataset_change(self, session_id: str = "default") -> None:
204
+ with self._lock:
205
+ self._dataset_id_caches.pop(session_id, None)
206
+ if session_id in self._views:
207
+ self._views[session_id].clear()
208
+
209
+ # Clear caches for this session
210
+ self._sort_cache_order = [k for k in self._sort_cache_order if k[0] != session_id]
211
+ self._sort_index_cache = {k: v for k, v in self._sort_index_cache.items() if k[0] != session_id}
212
+ self._sort_table_order = [k for k in self._sort_table_order if k[0] != session_id]
213
+ self._sort_table_cache = {k: v for k, v in self._sort_table_cache.items() if k[0] != session_id}
214
+
215
+ @staticmethod
216
+ def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
217
+ normalized: list[str] = []
218
+ for spec in sort_spec:
219
+ if not isinstance(spec, str) or not spec:
220
+ raise ValueError(f"Invalid sort specification: {spec!r}")
221
+ raw = spec.strip()
222
+ if not raw:
223
+ raise ValueError(f"Invalid sort specification: {spec!r}")
224
+ sign = "-" if raw.startswith("-") else "+"
225
+ varname = raw.lstrip("+-")
226
+ if not varname:
227
+ raise ValueError(f"Invalid sort specification: {spec!r}")
228
+ normalized.append(f"{sign}{varname}")
229
+ return tuple(normalized)
230
+
231
+ def _get_cached_sort_indices(
232
+ self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...]
233
+ ) -> list[int] | None:
234
+ key = (session_id, dataset_id, sort_spec)
235
+ with self._lock:
236
+ cached = self._sort_index_cache.get(key)
237
+ if cached is None:
238
+ return None
239
+ if key in self._sort_cache_order:
240
+ self._sort_cache_order.remove(key)
241
+ self._sort_cache_order.append(key)
242
+ return cached
243
+
244
+ def _set_cached_sort_indices(
245
+ self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
246
+ ) -> None:
247
+ key = (session_id, dataset_id, sort_spec)
248
+ with self._lock:
249
+ if key in self._sort_index_cache:
250
+ self._sort_cache_order.remove(key)
251
+ self._sort_index_cache[key] = indices
252
+ self._sort_cache_order.append(key)
253
+ while len(self._sort_cache_order) > self._sort_cache_max_entries:
254
+ evict = self._sort_cache_order.pop(0)
255
+ self._sort_index_cache.pop(evict, None)
256
+
257
+ def _get_cached_sort_table(
258
+ self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...]
259
+ ) -> Any | None:
260
+ key = (session_id, dataset_id, sort_cols)
261
+ with self._lock:
262
+ cached = self._sort_table_cache.get(key)
263
+ if cached is None:
264
+ return None
265
+ if key in self._sort_table_order:
266
+ self._sort_table_order.remove(key)
267
+ self._sort_table_order.append(key)
268
+ return cached
269
+
270
+ def _set_cached_sort_table(
271
+ self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...], table: Any
272
+ ) -> None:
273
+ key = (session_id, dataset_id, sort_cols)
274
+ with self._lock:
275
+ if key in self._sort_table_cache:
276
+ self._sort_table_order.remove(key)
277
+ self._sort_table_cache[key] = table
278
+ self._sort_table_order.append(key)
279
+ while len(self._sort_table_order) > self._sort_table_max_entries:
280
+ evict = self._sort_table_order.pop(0)
281
+ self._sort_table_cache.pop(evict, None)
282
+
283
+ def _get_sort_table(self, session_id: str, dataset_id: str, sort_cols: list[str]) -> Any:
284
+ sort_cols_key = tuple(sort_cols)
285
+ cached = self._get_cached_sort_table(session_id, dataset_id, sort_cols_key)
286
+ if cached is not None:
287
+ return cached
288
+
289
+ # Use an appropriate client for the session
290
+ proxy = self._get_proxy_for_session(session_id)
291
+ state = proxy.get_dataset_state()
292
+ n = int(state.get("n", 0) or 0)
293
+ if n <= 0:
294
+ return None
295
+
296
+ # Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
297
+ arrow_bytes = proxy.get_arrow_stream(
298
+ offset=0,
299
+ limit=n,
300
+ vars=sort_cols,
301
+ include_obs_no=True,
302
+ obs_indices=None,
303
+ )
304
+
305
+ import pyarrow as pa
306
+
307
+ with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
308
+ table = reader.read_all()
309
+
310
+ self._set_cached_sort_table(session_id, dataset_id, sort_cols_key, table)
311
+ return table
312
+
313
+ def get_channel(self) -> UIChannelInfo:
314
+ self._ensure_http_server()
315
+ with self._lock:
316
+ self._ensure_token()
317
+ assert self._httpd is not None
318
+ port = self._httpd.server_address[1]
319
+ base_url = f"http://{self._host}:{port}"
320
+ return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
321
+
322
+ def capabilities(self) -> dict[str, bool]:
323
+ return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
324
+
325
+ def _get_proxy_for_session(self, session_id: str) -> StataClient:
326
+ # Prefer the injected client when present (used by unit tests and single-session setups).
327
+ client = getattr(self, "_client", None)
328
+ if client is not None:
329
+ from .server import StataClientProxy
330
+ if isinstance(client, StataClientProxy):
331
+ return StataClientProxy(session_id=session_id or "default")
332
+ return client
333
+
334
+ from .server import StataClientProxy
335
+ return StataClientProxy(session_id=session_id or "default")
336
+
337
+ def current_dataset_id(self, session_id: str = "default") -> str:
338
+ with self._lock:
339
+ cached = self._dataset_id_caches.get(session_id)
340
+ if cached:
341
+ digest, version = cached
342
+ if version == self._dataset_version:
343
+ return digest
344
+
345
+ proxy = self._get_proxy_for_session(session_id)
346
+ state = proxy.get_dataset_state()
347
+ payload = {
348
+ "version": self._dataset_version,
349
+ "frame": state.get("frame"),
350
+ "n": state.get("n"),
351
+ "k": state.get("k"),
352
+ "sortlist": state.get("sortlist"),
353
+ }
354
+ digest = _stable_hash(payload)
355
+
356
+ with self._lock:
357
+ self._dataset_id_caches[session_id] = (digest, self._dataset_version)
358
+ return digest
359
+
360
+ def get_view(self, session_id: str, view_id: str) -> Optional[ViewHandle]:
361
+ now = time.time()
362
+ with self._lock:
363
+ self._evict_expired_locked(now)
364
+ session_views = self._views.get(session_id)
365
+ if session_views is None:
366
+ return None
367
+ view = session_views.get(view_id)
368
+ if view is None:
369
+ return None
370
+ view.last_access = now
371
+ return view
372
+
373
+ def create_view(self, *, session_id: str, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
374
+ current_id = self.current_dataset_id(session_id)
375
+ if dataset_id != current_id:
376
+ raise DatasetChangedError(current_id)
377
+
378
+ proxy = self._get_proxy_for_session(session_id)
379
+ try:
380
+ obs_indices = proxy.compute_view_indices(filter_expr)
381
+ except ValueError as e:
382
+ raise InvalidFilterError(str(e))
383
+ except RuntimeError as e:
384
+ msg = str(e) or "No data in memory"
385
+ if "no data" in msg.lower():
386
+ raise NoDataInMemoryError(msg)
387
+ raise
388
+ now = time.time()
389
+ view_id = f"view_{uuid.uuid4().hex}"
390
+ view = ViewHandle(
391
+ view_id=view_id,
392
+ dataset_id=current_id,
393
+ frame=frame,
394
+ filter_expr=filter_expr,
395
+ obs_indices=obs_indices,
396
+ filtered_n=len(obs_indices),
397
+ created_at=now,
398
+ last_access=now,
399
+ )
400
+ with self._lock:
401
+ self._evict_expired_locked(now)
402
+ if session_id not in self._views:
403
+ self._views[session_id] = {}
404
+ self._views[session_id][view_id] = view
405
+ return view
406
+
407
+ def delete_view(self, session_id: str, view_id: str) -> bool:
408
+ with self._lock:
409
+ session_views = self._views.get(session_id)
410
+ if session_views is None:
411
+ return False
412
+ return session_views.pop(view_id, None) is not None
413
+
414
+ def validate_token(self, header_value: str | None) -> bool:
415
+ if not header_value:
416
+ return False
417
+ if not header_value.startswith("Bearer "):
418
+ return False
419
+ token = header_value[len("Bearer ") :].strip()
420
+ with self._lock:
421
+ self._ensure_token()
422
+ if self._token is None:
423
+ return False
424
+ if time.time() * 1000 >= self._expires_at:
425
+ return False
426
+ return secrets.compare_digest(token, self._token)
427
+
428
+ def limits(self) -> tuple[int, int, int, int]:
429
+ return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
430
+
431
+ def _ensure_token(self) -> None:
432
+ now_ms = int(time.time() * 1000)
433
+ if self._token is None or now_ms >= self._expires_at:
434
+ self._token = secrets.token_urlsafe(32)
435
+ self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
436
+
437
+ def _evict_expired_locked(self, now: float) -> None:
438
+ for session_id in list(self._views.keys()):
439
+ session_views = self._views[session_id]
440
+ expired: list[str] = []
441
+ for key, view in session_views.items():
442
+ if now - view.last_access >= self._view_ttl_s:
443
+ expired.append(key)
444
+ for key in expired:
445
+ session_views.pop(key, None)
446
+ if not session_views:
447
+ self._views.pop(session_id, None)
448
+
449
+ def _ensure_http_server(self) -> None:
450
+ with self._lock:
451
+ if self._httpd is not None:
452
+ return
453
+
454
+ manager = self
455
+
456
+ class Handler(BaseHTTPRequestHandler):
457
+
458
+ def _send_json(self, status: int, payload: dict[str, Any]) -> None:
459
+ data = json.dumps(payload).encode("utf-8")
460
+ self.send_response(status)
461
+ self.send_header("Content-Type", "application/json")
462
+ self.send_header("Content-Length", str(len(data)))
463
+ self.end_headers()
464
+ self.wfile.write(data)
465
+
466
+ def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
467
+ self.send_response(status)
468
+ self.send_header("Content-Type", content_type)
469
+ self.send_header("Content-Length", str(len(data)))
470
+ self.end_headers()
471
+ self.wfile.write(data)
472
+
473
+ def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
474
+ if status >= 500 or code == "internal_error":
475
+ logger.error("UI HTTP error %s: %s", code, message)
476
+ message = "Internal server error"
477
+ body: dict[str, Any] = {"error": {"code": code, "message": message}}
478
+ if stata_rc is not None:
479
+ body["error"]["stataRc"] = stata_rc
480
+ self._send_json(status, body)
481
+
482
+ def _require_auth(self) -> bool:
483
+ if manager.validate_token(self.headers.get("Authorization")):
484
+ return True
485
+ self._error(401, "auth_failed", "Unauthorized")
486
+ return False
487
+
488
+ def _read_json(self) -> dict[str, Any] | None:
489
+ max_limit, max_vars, max_chars, max_bytes = manager.limits()
490
+ _ = (max_limit, max_vars, max_chars)
491
+
492
+ length = int(self.headers.get("Content-Length", "0") or "0")
493
+ if length <= 0:
494
+ return {}
495
+ if length > max_bytes:
496
+ self._error(400, "request_too_large", "Request too large")
497
+ return None
498
+ raw = self.rfile.read(length)
499
+ try:
500
+ parsed = json.loads(raw.decode("utf-8"))
501
+ except Exception:
502
+ self._error(400, "invalid_request", "Invalid JSON")
503
+ return None
504
+ if not isinstance(parsed, dict):
505
+ self._error(400, "invalid_request", "Expected JSON object")
506
+ return None
507
+ return parsed
508
+
509
+ def do_GET(self) -> None:
510
+ if not self._require_auth():
511
+ return
512
+
513
+ if self.path.startswith("/v1/dataset"):
514
+ from urllib.parse import urlparse, parse_qs
515
+ parsed_url = urlparse(self.path)
516
+ params = parse_qs(parsed_url.query)
517
+ session_id = params.get("sessionId", ["default"])[0]
518
+
519
+ try:
520
+ proxy = manager._get_proxy_for_session(session_id)
521
+ state = proxy.get_dataset_state()
522
+ dataset_id = manager.current_dataset_id(session_id)
523
+ self._send_json(
524
+ 200,
525
+ {
526
+ "dataset": {
527
+ "id": dataset_id,
528
+ "frame": state.get("frame"),
529
+ "n": state.get("n"),
530
+ "k": state.get("k"),
531
+ "changed": state.get("changed"),
532
+ }
533
+ },
534
+ )
535
+ return
536
+ except NoDataInMemoryError as e:
537
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
538
+ return
539
+ except Exception as e:
540
+ self._error(500, "internal_error", str(e))
541
+ return
542
+
543
+ if self.path.startswith("/v1/vars"):
544
+ from urllib.parse import urlparse, parse_qs
545
+ parsed_url = urlparse(self.path)
546
+ params = parse_qs(parsed_url.query)
547
+ session_id = params.get("sessionId", ["default"])[0]
548
+
549
+ try:
550
+ proxy = manager._get_proxy_for_session(session_id)
551
+ state = proxy.get_dataset_state()
552
+ dataset_id = manager.current_dataset_id(session_id)
553
+ variables = proxy.list_variables_rich()
554
+ self._send_json(
555
+ 200,
556
+ {
557
+ "dataset": {"id": dataset_id, "frame": state.get("frame")},
558
+ "variables": variables,
559
+ },
560
+ )
561
+ return
562
+ except NoDataInMemoryError as e:
563
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
564
+ return
565
+ except Exception as e:
566
+ self._error(500, "internal_error", str(e))
567
+ return
568
+
569
+ self._error(404, "not_found", "Not found")
570
+
571
+ def do_POST(self) -> None:
572
+ if not self._require_auth():
573
+ return
574
+
575
+
576
+ if self.path == "/v1/arrow":
577
+ body = self._read_json()
578
+ if body is None:
579
+ return
580
+ try:
581
+ resp_bytes = handle_arrow_request(manager, body, view_id=None)
582
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
583
+ return
584
+ except HTTPError as e:
585
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
586
+ return
587
+ except Exception as e:
588
+ self._error(500, "internal_error", str(e))
589
+ return
590
+
591
+ if self.path == "/v1/page":
592
+ body = self._read_json()
593
+ if body is None:
594
+ return
595
+ try:
596
+ resp = handle_page_request(manager, body, view_id=None)
597
+ self._send_json(200, resp)
598
+ return
599
+ except HTTPError as e:
600
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
601
+ return
602
+ except Exception as e:
603
+ self._error(500, "internal_error", str(e))
604
+ return
605
+
606
+ if self.path == "/v1/views":
607
+ body = self._read_json()
608
+ if body is None:
609
+ return
610
+ dataset_id = str(body.get("datasetId", ""))
611
+ frame = str(body.get("frame", "default"))
612
+ filter_expr = str(body.get("filterExpr", ""))
613
+ session_id = str(body.get("sessionId", "default"))
614
+ if not dataset_id or not filter_expr:
615
+ self._error(400, "invalid_request", "datasetId and filterExpr are required")
616
+ return
617
+ try:
618
+ view = manager.create_view(session_id=session_id, dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
619
+ self._send_json(
620
+ 200,
621
+ {
622
+ "dataset": {"id": view.dataset_id, "frame": view.frame},
623
+ "view": {"id": view.view_id, "filteredN": view.filtered_n},
624
+ },
625
+ )
626
+ return
627
+ except DatasetChangedError as e:
628
+ self._error(409, "dataset_changed", f"Dataset changed for session {session_id}")
629
+ return
630
+ except ValueError as e:
631
+ self._error(400, "invalid_filter", str(e))
632
+ return
633
+ except RuntimeError as e:
634
+ msg = str(e) or "No data in memory"
635
+ if "no data" in msg.lower():
636
+ self._error(400, "no_data_in_memory", msg)
637
+ return
638
+ self._error(500, "internal_error", msg)
639
+ return
640
+ except Exception as e:
641
+ self._error(500, "internal_error", str(e))
642
+ return
643
+
644
+ if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
645
+ parts = self.path.split("/")
646
+ if len(parts) != 5:
647
+ self._error(404, "not_found", "Not found")
648
+ return
649
+ view_id = parts[3]
650
+ body = self._read_json()
651
+ if body is None:
652
+ return
653
+ try:
654
+ resp = handle_page_request(manager, body, view_id=view_id)
655
+ self._send_json(200, resp)
656
+ return
657
+ except HTTPError as e:
658
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
659
+ return
660
+ except Exception as e:
661
+ self._error(500, "internal_error", str(e))
662
+ return
663
+
664
+ if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
665
+ parts = self.path.split("/")
666
+ if len(parts) != 5:
667
+ self._error(404, "not_found", "Not found")
668
+ return
669
+ view_id = parts[3]
670
+ body = self._read_json()
671
+ if body is None:
672
+ return
673
+ try:
674
+ resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
675
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
676
+ return
677
+ except HTTPError as e:
678
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
679
+ return
680
+ except Exception as e:
681
+ self._error(500, "internal_error", str(e))
682
+ return
683
+
684
+ if self.path == "/v1/filters/validate":
685
+ body = self._read_json()
686
+ if body is None:
687
+ return
688
+ filter_expr = str(body.get("filterExpr", ""))
689
+ session_id = str(body.get("sessionId", "default"))
690
+ if not filter_expr:
691
+ self._error(400, "invalid_request", "filterExpr is required")
692
+ return
693
+ try:
694
+ proxy = manager._get_proxy_for_session(session_id)
695
+ proxy.validate_filter_expr(filter_expr)
696
+ self._send_json(200, {"ok": True})
697
+ return
698
+ except ValueError as e:
699
+ self._error(400, "invalid_filter", str(e))
700
+ return
701
+ except RuntimeError as e:
702
+ msg = str(e) or "No data in memory"
703
+ if "no data" in msg.lower():
704
+ self._error(400, "no_data_in_memory", msg)
705
+ return
706
+ self._error(500, "internal_error", msg)
707
+ return
708
+ except Exception as e:
709
+ self._error(500, "internal_error", str(e))
710
+ return
711
+
712
+ self._error(404, "not_found", "Not found")
713
+
714
+ def do_DELETE(self) -> None:
715
+ if not self._require_auth():
716
+ return
717
+
718
+ if self.path.startswith("/v1/views/"):
719
+ parts = self.path.split("/")
720
+ if len(parts) != 4:
721
+ self._error(404, "not_found", "Not found")
722
+ return
723
+ from urllib.parse import urlparse, parse_qs
724
+ parsed_url = urlparse(self.path)
725
+ params = parse_qs(parsed_url.query)
726
+ session_id = params.get("sessionId", ["default"])[0]
727
+ view_id = parts[3]
728
+ if manager.delete_view(session_id, view_id):
729
+ self._send_json(200, {"ok": True})
730
+ else:
731
+ self._error(404, "not_found", f"View {view_id} not found in session {session_id}")
732
+ return
733
+
734
+ self._error(404, "not_found", "Not found")
735
+
736
+ def log_message(self, format: str, *args: Any) -> None:
737
+ return
738
+
739
+ httpd = ThreadingHTTPServer((self._host, self._port), Handler)
740
+ t = threading.Thread(target=httpd.serve_forever, daemon=True)
741
+ t.start()
742
+ self._httpd = httpd
743
+ self._thread = t
744
+
745
+
746
+ class HTTPError(Exception):
747
+ def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
748
+ super().__init__(message)
749
+ self.status = status
750
+ self.code = code
751
+ self.message = message
752
+ self.stata_rc = stata_rc
753
+
754
+
755
+ class DatasetChangedError(Exception):
756
+ def __init__(self, current_dataset_id: str):
757
+ super().__init__("dataset_changed")
758
+ self.current_dataset_id = current_dataset_id
759
+
760
+
761
+ class NoDataInMemoryError(Exception):
762
+ def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
763
+ super().__init__(message)
764
+ self.stata_rc = stata_rc
765
+
766
+
767
+ class InvalidFilterError(Exception):
768
+ def __init__(self, message: str, *, stata_rc: int | None = None):
769
+ super().__init__(message)
770
+ self.message = message
771
+ self.stata_rc = stata_rc
772
+
773
+
774
+ def _resolve_proxy(manager: UIChannelManager, session_id: str) -> StataClient:
775
+ """Resolve the Stata client proxy, preferring injected clients for tests."""
776
+ proxy = getattr(manager, "_client", None)
777
+ if proxy is not None:
778
+ from .server import StataClientProxy
779
+ if not isinstance(proxy, StataClientProxy):
780
+ return proxy
781
+ return manager._get_proxy_for_session(session_id)
782
+
783
+
784
+ def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
785
+ max_limit, max_vars, max_chars, _ = manager.limits()
786
+
787
+ session_id = str(body.get("sessionId", "default"))
788
+
789
+ if view_id is None:
790
+ dataset_id = str(body.get("datasetId", ""))
791
+ frame = str(body.get("frame", "default"))
792
+ else:
793
+ view = manager.get_view(session_id, view_id)
794
+ if view is None:
795
+ raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
796
+ dataset_id = view.dataset_id
797
+ frame = view.frame
798
+
799
+ try:
800
+ offset = int(body.get("offset") or 0)
801
+ except (ValueError, TypeError):
802
+ raise HTTPError(400, "invalid_request", "offset must be a valid integer")
803
+
804
+ limit_raw = body.get("limit")
805
+ if limit_raw is None:
806
+ raise HTTPError(400, "invalid_request", "limit is required")
807
+ try:
808
+ limit = int(limit_raw)
809
+ except (ValueError, TypeError):
810
+ raise HTTPError(400, "invalid_request", "limit must be a valid integer")
811
+
812
+ vars_req = body.get("vars", [])
813
+ include_obs_no = bool(body.get("includeObsNo", False))
814
+ sort_by = body.get("sortBy", [])
815
+ max_chars_raw = body.get("maxChars", max_chars)
816
+
817
+ try:
818
+ max_chars_req = int(max_chars_raw or max_chars)
819
+ except (ValueError, TypeError):
820
+ raise HTTPError(400, "invalid_request", "maxChars must be a valid integer")
821
+
822
+ if offset < 0:
823
+ raise HTTPError(400, "invalid_request", "offset must be >= 0")
824
+ if limit <= 0:
825
+ raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
826
+ if limit > max_limit:
827
+ raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
828
+ if max_chars_req <= 0:
829
+ raise HTTPError(400, "invalid_request", "maxChars must be > 0")
830
+ if max_chars_req > max_chars:
831
+ raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
832
+
833
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
834
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
835
+ if len(vars_req) > max_vars:
836
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
837
+
838
+ if sort_by and (not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by)):
839
+ raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
840
+
841
+ current_id = manager.current_dataset_id(session_id)
842
+ if dataset_id != current_id:
843
+ raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
844
+
845
+ if view_id is None:
846
+ obs_indices = None
847
+ filtered_n = None
848
+ else:
849
+ assert view is not None
850
+ obs_indices = view.obs_indices
851
+ filtered_n = view.filtered_n
852
+
853
+ try:
854
+ if sort_by:
855
+ sort_spec = manager._normalize_sort_spec(sort_by)
856
+ obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
857
+ if obs_indices_sorted is None:
858
+ sort_cols = [s.lstrip("+-") for s in sort_spec]
859
+ descending = [s.startswith("-") for s in sort_spec]
860
+ nulls_last = [False] * len(sort_spec)
861
+
862
+ table = manager._get_sort_table(session_id, dataset_id, sort_cols)
863
+ if table is not None:
864
+ obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
865
+ if obs_indices_sorted is None:
866
+ obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
867
+ manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
868
+
869
+ if obs_indices_sorted:
870
+ if obs_indices:
871
+ filter_set = set(obs_indices)
872
+ obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
873
+ else:
874
+ obs_indices = obs_indices_sorted
875
+
876
+ proxy = _resolve_proxy(manager, session_id)
877
+ dataset_state = proxy.get_dataset_state()
878
+ page = proxy.get_page(
879
+ offset=offset,
880
+ limit=limit,
881
+ vars=vars_req,
882
+ include_obs_no=include_obs_no,
883
+ max_chars=max_chars_req,
884
+ obs_indices=obs_indices,
885
+ )
886
+
887
+ view_obj: dict[str, Any] = {
888
+ "offset": offset,
889
+ "limit": limit,
890
+ "returned": page["returned"],
891
+ "filteredN": filtered_n,
892
+ }
893
+ if view_id is not None:
894
+ view_obj["viewId"] = view_id
895
+
896
+ return {
897
+ "dataset": {
898
+ "id": current_id,
899
+ "frame": dataset_state.get("frame"),
900
+ "n": dataset_state.get("n"),
901
+ "k": dataset_state.get("k"),
902
+ },
903
+ "view": view_obj,
904
+ "vars": page["vars"],
905
+ "rows": page["rows"],
906
+ "display": {
907
+ "maxChars": max_chars_req,
908
+ "truncatedCells": page["truncated_cells"],
909
+ "missing": ".",
910
+ },
911
+ }
912
+
913
+ except HTTPError:
914
+ raise
915
+ except RuntimeError as e:
916
+ msg = str(e) or "No data in memory"
917
+ if "invalid variable" in msg.lower():
918
+ raise HTTPError(400, "invalid_variable", msg)
919
+ if "no data" in msg.lower():
920
+ raise HTTPError(400, "no_data_in_memory", msg)
921
+ raise HTTPError(500, "internal_error", msg)
922
+ except ValueError as e:
923
+ msg = str(e)
924
+ if msg.lower().startswith("invalid variable"):
925
+ raise HTTPError(400, "invalid_variable", msg)
926
+ raise HTTPError(400, "invalid_request", msg)
927
+ except Exception as e:
928
+ raise HTTPError(500, "internal_error", str(e))
929
+
930
+
931
+ def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
932
+ max_limit, max_vars, _, _ = manager.limits()
933
+ chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
934
+ session_id = str(body.get("sessionId", "default"))
935
+
936
+ if view_id is None:
937
+ dataset_id = str(body.get("datasetId", ""))
938
+ frame = str(body.get("frame", "default"))
939
+ else:
940
+ view = manager.get_view(session_id, view_id)
941
+ if view is None:
942
+ raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
943
+ dataset_id = view.dataset_id
944
+ frame = view.frame
945
+
946
+ try:
947
+ offset = int(body.get("offset") or 0)
948
+ except (ValueError, TypeError):
949
+ raise HTTPError(400, "invalid_request", "offset must be a valid integer")
950
+
951
+ limit_raw = body.get("limit")
952
+ if limit_raw is None:
953
+ raise HTTPError(400, "invalid_request", "limit is required")
954
+ try:
955
+ limit = int(limit_raw)
956
+ except (ValueError, TypeError):
957
+ raise HTTPError(400, "invalid_request", "limit must be a valid integer")
958
+
959
+ vars_req = body.get("vars", [])
960
+ include_obs_no = bool(body.get("includeObsNo", False))
961
+ sort_by = body.get("sortBy", [])
962
+
963
+ if offset < 0:
964
+ raise HTTPError(400, "invalid_request", "offset must be >= 0")
965
+ if limit <= 0:
966
+ raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
967
+ if limit > chunk_limit:
968
+ raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
969
+
970
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
971
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
972
+ if len(vars_req) > max_vars:
973
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
974
+
975
+ current_id = manager.current_dataset_id(session_id)
976
+ if dataset_id != current_id:
977
+ raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
978
+
979
+ if view_id is None:
980
+ obs_indices = None
981
+ else:
982
+ assert view is not None
983
+ obs_indices = view.obs_indices
984
+
985
+ try:
986
+ if sort_by:
987
+ if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
988
+ raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
989
+
990
+ sort_spec = manager._normalize_sort_spec(sort_by)
991
+ obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
992
+ if obs_indices_sorted is None:
993
+ sort_cols = [s.lstrip("+-") for s in sort_spec]
994
+ descending = [s.startswith("-") for s in sort_spec]
995
+ nulls_last = [False] * len(sort_spec)
996
+
997
+ table = manager._get_sort_table(session_id, dataset_id, sort_cols)
998
+ if table is not None:
999
+ obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
1000
+ if obs_indices_sorted is None:
1001
+ obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
1002
+ manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
1003
+
1004
+ if obs_indices_sorted:
1005
+ if obs_indices:
1006
+ filter_set = set(obs_indices)
1007
+ obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
1008
+ else:
1009
+ obs_indices = obs_indices_sorted
1010
+
1011
+ proxy = _resolve_proxy(manager, session_id)
1012
+ return proxy.get_arrow_stream(
1013
+ offset=offset,
1014
+ limit=limit,
1015
+ vars=vars_req,
1016
+ include_obs_no=include_obs_no,
1017
+ obs_indices=obs_indices,
1018
+ )
1019
+
1020
+ except RuntimeError as e:
1021
+ msg = str(e) or "No data in memory"
1022
+ if "invalid variable" in msg.lower():
1023
+ raise HTTPError(400, "invalid_variable", msg)
1024
+ if "no data" in msg.lower():
1025
+ raise HTTPError(400, "no_data_in_memory", msg)
1026
+ raise HTTPError(500, "internal_error", msg)
1027
+ except ValueError as e:
1028
+ msg = str(e)
1029
+ if "invalid variable" in msg.lower():
1030
+ raise HTTPError(400, "invalid_variable", msg)
1031
+ raise HTTPError(400, "invalid_request", msg)
1032
+ except Exception as e:
1033
+ raise HTTPError(500, "internal_error", str(e))
1034
+