mcp-stata 1.21.0__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/ui_http.py ADDED
@@ -0,0 +1,999 @@
1
+ from __future__ import annotations
2
+ import hashlib
3
+ import io
4
+ import json
5
+ import secrets
6
+ import threading
7
+ import time
8
+ import uuid
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
12
+ from typing import Any, Callable, Optional
13
+
14
+ from .stata_client import StataClient
15
+ from .config import (
16
+ DEFAULT_HOST,
17
+ DEFAULT_PORT,
18
+ MAX_ARROW_LIMIT,
19
+ MAX_CHARS,
20
+ MAX_LIMIT,
21
+ MAX_REQUEST_BYTES,
22
+ MAX_VARS,
23
+ TOKEN_TTL_S,
24
+ VIEW_TTL_S,
25
+ )
26
+
27
+
28
+ logger = logging.getLogger("mcp_stata")
29
+
30
+ try:
31
+ from .native_ops import argsort_numeric as _native_argsort_numeric
32
+ from .native_ops import argsort_mixed as _native_argsort_mixed
33
+ except Exception:
34
+ _native_argsort_numeric = None
35
+ _native_argsort_mixed = None
36
+
37
+
38
+ def _try_native_argsort(
39
+ table: Any,
40
+ sort_cols: list[str],
41
+ descending: list[bool],
42
+ nulls_last: list[bool],
43
+ ) -> list[int] | None:
44
+ if _native_argsort_numeric is None and _native_argsort_mixed is None:
45
+ return None
46
+ try:
47
+ import pyarrow as pa
48
+ import numpy as np
49
+
50
+ is_string: list[bool] = []
51
+ cols: list[object] = []
52
+ for col in sort_cols:
53
+ arr = table.column(col).combine_chunks()
54
+ if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
55
+ is_string.append(True)
56
+ cols.append(arr.to_pylist())
57
+ continue
58
+ if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
59
+ return None
60
+ is_string.append(False)
61
+ np_arr = arr.to_numpy(zero_copy_only=False)
62
+ if np_arr.dtype != np.float64:
63
+ np_arr = np_arr.astype(np.float64, copy=False)
64
+ # Normalize Stata missing values for numeric columns
65
+ np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
66
+ cols.append(np_arr)
67
+
68
+ if "_n" not in table.column_names:
69
+ return None
70
+ obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
71
+
72
+ if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
73
+ idx = _native_argsort_numeric(cols, descending, nulls_last)
74
+ elif _native_argsort_mixed is not None:
75
+ idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
76
+ else:
77
+ return None
78
+
79
+ return [int(x) for x in (obs[idx] - 1).tolist()]
80
+ except Exception:
81
+ return None
82
+
83
+
84
+ def _get_sorted_indices_polars(
85
+ table: Any,
86
+ sort_cols: list[str],
87
+ descending: list[bool],
88
+ nulls_last: list[bool],
89
+ ) -> list[int]:
90
+ import polars as pl
91
+
92
+ df = pl.from_arrow(table)
93
+ # Normalize Stata missing values for numeric columns
94
+ exprs = []
95
+ for col, dtype in zip(df.columns, df.dtypes):
96
+ if col == "_n":
97
+ exprs.append(pl.col(col))
98
+ continue
99
+ if dtype in (pl.Float32, pl.Float64):
100
+ exprs.append(
101
+ pl.when(pl.col(col) > 8.0e307)
102
+ .then(None)
103
+ .otherwise(pl.col(col))
104
+ .alias(col)
105
+ )
106
+ else:
107
+ exprs.append(pl.col(col))
108
+ df = df.select(exprs)
109
+
110
+ try:
111
+ # Use expressions for arithmetic to avoid eager Series-scalar conversion issues
112
+ # that have been observed in some environments with Int64 dtypes.
113
+ res = df.select(
114
+ idx=pl.arg_sort_by(
115
+ [pl.col(c) for c in sort_cols],
116
+ descending=descending,
117
+ nulls_last=nulls_last,
118
+ ),
119
+ zero_based_n=pl.col("_n") - 1
120
+ )
121
+ return res["zero_based_n"].take(res["idx"]).to_list()
122
+ except Exception:
123
+ # Fallback to eager sort if arg_sort_by fails or has issues
124
+ return (
125
+ df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
126
+ .select(pl.col("_n") - 1)
127
+ .to_series()
128
+ .to_list()
129
+ )
130
+
131
+
132
+
133
+
134
+ def _stable_hash(payload: dict[str, Any]) -> str:
135
+ return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
136
+
137
+
138
+ @dataclass
139
+ class UIChannelInfo:
140
+ base_url: str
141
+ token: str
142
+ expires_at: int
143
+
144
+
145
+ @dataclass
146
+ class ViewHandle:
147
+ view_id: str
148
+ dataset_id: str
149
+ frame: str
150
+ filter_expr: str
151
+ obs_indices: list[int]
152
+ filtered_n: int
153
+ created_at: float
154
+ last_access: float
155
+
156
+
157
+ class UIChannelManager:
158
+ def __init__(
159
+ self,
160
+ client: StataClient,
161
+ *,
162
+ host: str = DEFAULT_HOST,
163
+ port: int = DEFAULT_PORT,
164
+ token_ttl_s: int = TOKEN_TTL_S,
165
+ view_ttl_s: int = VIEW_TTL_S,
166
+ max_limit: int = MAX_LIMIT,
167
+ max_vars: int = MAX_VARS,
168
+ max_chars: int = MAX_CHARS,
169
+ max_request_bytes: int = MAX_REQUEST_BYTES,
170
+ max_arrow_limit: int = MAX_ARROW_LIMIT,
171
+ ):
172
+ self._client = client
173
+ self._host = host
174
+ self._port = port
175
+ self._token_ttl_s = token_ttl_s
176
+ self._view_ttl_s = view_ttl_s
177
+ self._max_limit = max_limit
178
+ self._max_vars = max_vars
179
+ self._max_chars = max_chars
180
+ self._max_request_bytes = max_request_bytes
181
+ self._max_arrow_limit = max_arrow_limit
182
+
183
+ self._lock = threading.Lock()
184
+ self._httpd: ThreadingHTTPServer | None = None
185
+ self._thread: threading.Thread | None = None
186
+
187
+ self._token: str | None = None
188
+ self._expires_at: int = 0
189
+
190
+ self._dataset_version: int = 0
191
+ self._dataset_id_cache: str | None = None
192
+ self._dataset_id_cache_at_version: int = -1
193
+
194
+ self._views: dict[str, ViewHandle] = {}
195
+ self._last_sort_spec: tuple[str, ...] | None = None
196
+ self._last_sort_dataset_id: str | None = None
197
+ self._sort_index_cache: dict[tuple[str, tuple[str, ...]], list[int]] = {}
198
+ self._sort_cache_order: list[tuple[str, tuple[str, ...]]] = []
199
+ self._sort_cache_max_entries: int = 4
200
+ self._sort_table_cache: dict[tuple[str, tuple[str, ...]], Any] = {}
201
+ self._sort_table_order: list[tuple[str, tuple[str, ...]]] = []
202
+ self._sort_table_max_entries: int = 2
203
+
204
+ def notify_potential_dataset_change(self) -> None:
205
+ with self._lock:
206
+ self._dataset_version += 1
207
+ self._dataset_id_cache = None
208
+ self._views.clear()
209
+ self._last_sort_spec = None
210
+ self._last_sort_dataset_id = None
211
+ self._sort_index_cache.clear()
212
+ self._sort_cache_order.clear()
213
+ self._sort_table_cache.clear()
214
+ self._sort_table_order.clear()
215
+
216
+ @staticmethod
217
+ def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
218
+ normalized: list[str] = []
219
+ for spec in sort_spec:
220
+ if not isinstance(spec, str) or not spec:
221
+ raise ValueError(f"Invalid sort specification: {spec!r}")
222
+ raw = spec.strip()
223
+ if not raw:
224
+ raise ValueError(f"Invalid sort specification: {spec!r}")
225
+ sign = "-" if raw.startswith("-") else "+"
226
+ varname = raw.lstrip("+-")
227
+ if not varname:
228
+ raise ValueError(f"Invalid sort specification: {spec!r}")
229
+ normalized.append(f"{sign}{varname}")
230
+ return tuple(normalized)
231
+
232
+ def _get_cached_sort_indices(
233
+ self, dataset_id: str, sort_spec: tuple[str, ...]
234
+ ) -> list[int] | None:
235
+ key = (dataset_id, sort_spec)
236
+ with self._lock:
237
+ cached = self._sort_index_cache.get(key)
238
+ if cached is None:
239
+ return None
240
+ # refresh LRU order
241
+ if key in self._sort_cache_order:
242
+ self._sort_cache_order.remove(key)
243
+ self._sort_cache_order.append(key)
244
+ return cached
245
+
246
+ def _set_cached_sort_indices(
247
+ self, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
248
+ ) -> None:
249
+ key = (dataset_id, sort_spec)
250
+ with self._lock:
251
+ if key in self._sort_index_cache:
252
+ self._sort_cache_order.remove(key)
253
+ self._sort_index_cache[key] = indices
254
+ self._sort_cache_order.append(key)
255
+ while len(self._sort_cache_order) > self._sort_cache_max_entries:
256
+ evict = self._sort_cache_order.pop(0)
257
+ self._sort_index_cache.pop(evict, None)
258
+
259
+ def _get_cached_sort_table(
260
+ self, dataset_id: str, sort_cols: tuple[str, ...]
261
+ ) -> Any | None:
262
+ key = (dataset_id, sort_cols)
263
+ with self._lock:
264
+ cached = self._sort_table_cache.get(key)
265
+ if cached is None:
266
+ return None
267
+ if key in self._sort_table_order:
268
+ self._sort_table_order.remove(key)
269
+ self._sort_table_order.append(key)
270
+ return cached
271
+
272
+ def _set_cached_sort_table(
273
+ self, dataset_id: str, sort_cols: tuple[str, ...], table: Any
274
+ ) -> None:
275
+ key = (dataset_id, sort_cols)
276
+ with self._lock:
277
+ if key in self._sort_table_cache:
278
+ self._sort_table_order.remove(key)
279
+ self._sort_table_cache[key] = table
280
+ self._sort_table_order.append(key)
281
+ while len(self._sort_table_order) > self._sort_table_max_entries:
282
+ evict = self._sort_table_order.pop(0)
283
+ self._sort_table_cache.pop(evict, None)
284
+
285
+ def _get_sort_table(self, dataset_id: str, sort_cols: list[str]) -> Any:
286
+ sort_cols_key = tuple(sort_cols)
287
+ cached = self._get_cached_sort_table(dataset_id, sort_cols_key)
288
+ if cached is not None:
289
+ return cached
290
+
291
+ state = self._client.get_dataset_state()
292
+ n = int(state.get("n", 0) or 0)
293
+ if n <= 0:
294
+ return None
295
+
296
+ # Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
297
+ arrow_bytes = self._client.get_arrow_stream(
298
+ offset=0,
299
+ limit=n,
300
+ vars=sort_cols,
301
+ include_obs_no=True,
302
+ obs_indices=None,
303
+ )
304
+
305
+ import pyarrow as pa
306
+
307
+ with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
308
+ table = reader.read_all()
309
+
310
+ self._set_cached_sort_table(dataset_id, sort_cols_key, table)
311
+ return table
312
+
313
+ def get_channel(self) -> UIChannelInfo:
314
+ self._ensure_http_server()
315
+ with self._lock:
316
+ self._ensure_token()
317
+ assert self._httpd is not None
318
+ port = self._httpd.server_address[1]
319
+ base_url = f"http://{self._host}:{port}"
320
+ return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
321
+
322
+ def capabilities(self) -> dict[str, bool]:
323
+ return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
324
+
325
+ def current_dataset_id(self) -> str:
326
+ with self._lock:
327
+ if self._dataset_id_cache is not None and self._dataset_id_cache_at_version == self._dataset_version:
328
+ return self._dataset_id_cache
329
+
330
+ state = self._client.get_dataset_state()
331
+ payload = {
332
+ "version": self._dataset_version,
333
+ "frame": state.get("frame"),
334
+ "n": state.get("n"),
335
+ "k": state.get("k"),
336
+ "sortlist": state.get("sortlist"),
337
+ }
338
+ digest = _stable_hash(payload)
339
+
340
+ with self._lock:
341
+ self._dataset_id_cache = digest
342
+ self._dataset_id_cache_at_version = self._dataset_version
343
+ return digest
344
+
345
+ def get_view(self, view_id: str) -> Optional[ViewHandle]:
346
+ now = time.time()
347
+ with self._lock:
348
+ self._evict_expired_locked(now)
349
+ view = self._views.get(view_id)
350
+ if view is None:
351
+ return None
352
+ view.last_access = now
353
+ return view
354
+
355
+ def create_view(self, *, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
356
+ current_id = self.current_dataset_id()
357
+ if dataset_id != current_id:
358
+ raise DatasetChangedError(current_id)
359
+
360
+ try:
361
+ obs_indices = self._client.compute_view_indices(filter_expr)
362
+ except ValueError as e:
363
+ raise InvalidFilterError(str(e))
364
+ except RuntimeError as e:
365
+ msg = str(e) or "No data in memory"
366
+ if "no data" in msg.lower():
367
+ raise NoDataInMemoryError(msg)
368
+ raise
369
+ now = time.time()
370
+ view_id = f"view_{uuid.uuid4().hex}"
371
+ view = ViewHandle(
372
+ view_id=view_id,
373
+ dataset_id=current_id,
374
+ frame=frame,
375
+ filter_expr=filter_expr,
376
+ obs_indices=obs_indices,
377
+ filtered_n=len(obs_indices),
378
+ created_at=now,
379
+ last_access=now,
380
+ )
381
+ with self._lock:
382
+ self._evict_expired_locked(now)
383
+ self._views[view_id] = view
384
+ return view
385
+
386
+ def delete_view(self, view_id: str) -> bool:
387
+ with self._lock:
388
+ return self._views.pop(view_id, None) is not None
389
+
390
+ def validate_token(self, header_value: str | None) -> bool:
391
+ if not header_value:
392
+ return False
393
+ if not header_value.startswith("Bearer "):
394
+ return False
395
+ token = header_value[len("Bearer ") :].strip()
396
+ with self._lock:
397
+ self._ensure_token()
398
+ if self._token is None:
399
+ return False
400
+ if time.time() * 1000 >= self._expires_at:
401
+ return False
402
+ return secrets.compare_digest(token, self._token)
403
+
404
+ def limits(self) -> tuple[int, int, int, int]:
405
+ return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
406
+
407
+ def _ensure_token(self) -> None:
408
+ now_ms = int(time.time() * 1000)
409
+ if self._token is None or now_ms >= self._expires_at:
410
+ self._token = secrets.token_urlsafe(32)
411
+ self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
412
+
413
+ def _evict_expired_locked(self, now: float) -> None:
414
+ expired: list[str] = []
415
+ for key, view in self._views.items():
416
+ if now - view.last_access >= self._view_ttl_s:
417
+ expired.append(key)
418
+ for key in expired:
419
+ self._views.pop(key, None)
420
+
421
+ def _ensure_http_server(self) -> None:
422
+ with self._lock:
423
+ if self._httpd is not None:
424
+ return
425
+
426
+ manager = self
427
+
428
+ class Handler(BaseHTTPRequestHandler):
429
+
430
+ def _send_json(self, status: int, payload: dict[str, Any]) -> None:
431
+ data = json.dumps(payload).encode("utf-8")
432
+ self.send_response(status)
433
+ self.send_header("Content-Type", "application/json")
434
+ self.send_header("Content-Length", str(len(data)))
435
+ self.end_headers()
436
+ self.wfile.write(data)
437
+
438
+ def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
439
+ self.send_response(status)
440
+ self.send_header("Content-Type", content_type)
441
+ self.send_header("Content-Length", str(len(data)))
442
+ self.end_headers()
443
+ self.wfile.write(data)
444
+
445
+ def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
446
+ if status >= 500 or code == "internal_error":
447
+ logger.error("UI HTTP error %s: %s", code, message)
448
+ message = "Internal server error"
449
+ body: dict[str, Any] = {"error": {"code": code, "message": message}}
450
+ if stata_rc is not None:
451
+ body["error"]["stataRc"] = stata_rc
452
+ self._send_json(status, body)
453
+
454
+ def _require_auth(self) -> bool:
455
+ if manager.validate_token(self.headers.get("Authorization")):
456
+ return True
457
+ self._error(401, "auth_failed", "Unauthorized")
458
+ return False
459
+
460
+ def _read_json(self) -> dict[str, Any] | None:
461
+ max_limit, max_vars, max_chars, max_bytes = manager.limits()
462
+ _ = (max_limit, max_vars, max_chars)
463
+
464
+ length = int(self.headers.get("Content-Length", "0") or "0")
465
+ if length <= 0:
466
+ return {}
467
+ if length > max_bytes:
468
+ self._error(400, "request_too_large", "Request too large")
469
+ return None
470
+ raw = self.rfile.read(length)
471
+ try:
472
+ parsed = json.loads(raw.decode("utf-8"))
473
+ except Exception:
474
+ self._error(400, "invalid_request", "Invalid JSON")
475
+ return None
476
+ if not isinstance(parsed, dict):
477
+ self._error(400, "invalid_request", "Expected JSON object")
478
+ return None
479
+ return parsed
480
+
481
+ def do_GET(self) -> None:
482
+ if not self._require_auth():
483
+ return
484
+
485
+ if self.path == "/v1/dataset":
486
+ try:
487
+ state = manager._client.get_dataset_state()
488
+ dataset_id = manager.current_dataset_id()
489
+ self._send_json(
490
+ 200,
491
+ {
492
+ "dataset": {
493
+ "id": dataset_id,
494
+ "frame": state.get("frame"),
495
+ "n": state.get("n"),
496
+ "k": state.get("k"),
497
+ "changed": state.get("changed"),
498
+ }
499
+ },
500
+ )
501
+ return
502
+ except NoDataInMemoryError as e:
503
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
504
+ return
505
+ except Exception as e:
506
+ self._error(500, "internal_error", str(e))
507
+ return
508
+
509
+ if self.path == "/v1/vars":
510
+ try:
511
+ state = manager._client.get_dataset_state()
512
+ dataset_id = manager.current_dataset_id()
513
+ variables = manager._client.list_variables_rich()
514
+ self._send_json(
515
+ 200,
516
+ {
517
+ "dataset": {"id": dataset_id, "frame": state.get("frame")},
518
+ "variables": variables,
519
+ },
520
+ )
521
+ return
522
+ except NoDataInMemoryError as e:
523
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
524
+ return
525
+ except Exception as e:
526
+ self._error(500, "internal_error", str(e))
527
+ return
528
+
529
+ self._error(404, "not_found", "Not found")
530
+
531
+ def do_POST(self) -> None:
532
+ if not self._require_auth():
533
+ return
534
+
535
+
536
+ if self.path == "/v1/arrow":
537
+ body = self._read_json()
538
+ if body is None:
539
+ return
540
+ try:
541
+ resp_bytes = handle_arrow_request(manager, body, view_id=None)
542
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
543
+ return
544
+ except HTTPError as e:
545
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
546
+ return
547
+ except Exception as e:
548
+ self._error(500, "internal_error", str(e))
549
+ return
550
+
551
+ if self.path == "/v1/page":
552
+ body = self._read_json()
553
+ if body is None:
554
+ return
555
+ try:
556
+ resp = handle_page_request(manager, body, view_id=None)
557
+ self._send_json(200, resp)
558
+ return
559
+ except HTTPError as e:
560
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
561
+ return
562
+ except Exception as e:
563
+ self._error(500, "internal_error", str(e))
564
+ return
565
+
566
+ if self.path == "/v1/views":
567
+ body = self._read_json()
568
+ if body is None:
569
+ return
570
+ dataset_id = str(body.get("datasetId", ""))
571
+ frame = str(body.get("frame", "default"))
572
+ filter_expr = str(body.get("filterExpr", ""))
573
+ if not dataset_id or not filter_expr:
574
+ self._error(400, "invalid_request", "datasetId and filterExpr are required")
575
+ return
576
+ try:
577
+ view = manager.create_view(dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
578
+ self._send_json(
579
+ 200,
580
+ {
581
+ "dataset": {"id": view.dataset_id, "frame": view.frame},
582
+ "view": {"id": view.view_id, "filteredN": view.filtered_n},
583
+ },
584
+ )
585
+ return
586
+ except DatasetChangedError as e:
587
+ self._error(409, "dataset_changed", "Dataset changed")
588
+ return
589
+ except ValueError as e:
590
+ self._error(400, "invalid_filter", str(e))
591
+ return
592
+ except RuntimeError as e:
593
+ msg = str(e) or "No data in memory"
594
+ if "no data" in msg.lower():
595
+ self._error(400, "no_data_in_memory", msg)
596
+ return
597
+ self._error(500, "internal_error", msg)
598
+ return
599
+ except Exception as e:
600
+ self._error(500, "internal_error", str(e))
601
+ return
602
+
603
+ if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
604
+ parts = self.path.split("/")
605
+ if len(parts) != 5:
606
+ self._error(404, "not_found", "Not found")
607
+ return
608
+ view_id = parts[3]
609
+ body = self._read_json()
610
+ if body is None:
611
+ return
612
+ try:
613
+ resp = handle_page_request(manager, body, view_id=view_id)
614
+ self._send_json(200, resp)
615
+ return
616
+ except HTTPError as e:
617
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
618
+ return
619
+ except Exception as e:
620
+ self._error(500, "internal_error", str(e))
621
+ return
622
+
623
+ if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
624
+ parts = self.path.split("/")
625
+ if len(parts) != 5:
626
+ self._error(404, "not_found", "Not found")
627
+ return
628
+ view_id = parts[3]
629
+ body = self._read_json()
630
+ if body is None:
631
+ return
632
+ try:
633
+ resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
634
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
635
+ return
636
+ except HTTPError as e:
637
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
638
+ return
639
+ except Exception as e:
640
+ self._error(500, "internal_error", str(e))
641
+ return
642
+
643
+ if self.path == "/v1/filters/validate":
644
+ body = self._read_json()
645
+ if body is None:
646
+ return
647
+ filter_expr = str(body.get("filterExpr", ""))
648
+ if not filter_expr:
649
+ self._error(400, "invalid_request", "filterExpr is required")
650
+ return
651
+ try:
652
+ manager._client.validate_filter_expr(filter_expr)
653
+ self._send_json(200, {"ok": True})
654
+ return
655
+ except ValueError as e:
656
+ self._error(400, "invalid_filter", str(e))
657
+ return
658
+ except RuntimeError as e:
659
+ msg = str(e) or "No data in memory"
660
+ if "no data" in msg.lower():
661
+ self._error(400, "no_data_in_memory", msg)
662
+ return
663
+ self._error(500, "internal_error", msg)
664
+ return
665
+ except Exception as e:
666
+ self._error(500, "internal_error", str(e))
667
+ return
668
+
669
+ self._error(404, "not_found", "Not found")
670
+
671
+ def do_DELETE(self) -> None:
672
+ if not self._require_auth():
673
+ return
674
+
675
+ if self.path.startswith("/v1/views/"):
676
+ parts = self.path.split("/")
677
+ if len(parts) != 4:
678
+ self._error(404, "not_found", "Not found")
679
+ return
680
+ view_id = parts[3]
681
+ if manager.delete_view(view_id):
682
+ self._send_json(200, {"ok": True})
683
+ else:
684
+ self._error(404, "not_found", "Not found")
685
+ return
686
+
687
+ self._error(404, "not_found", "Not found")
688
+
689
+ def log_message(self, format: str, *args: Any) -> None:
690
+ return
691
+
692
+ httpd = ThreadingHTTPServer((self._host, self._port), Handler)
693
+ t = threading.Thread(target=httpd.serve_forever, daemon=True)
694
+ t.start()
695
+ self._httpd = httpd
696
+ self._thread = t
697
+
698
+
699
+ class HTTPError(Exception):
700
+ def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
701
+ super().__init__(message)
702
+ self.status = status
703
+ self.code = code
704
+ self.message = message
705
+ self.stata_rc = stata_rc
706
+
707
+
708
+ class DatasetChangedError(Exception):
709
+ def __init__(self, current_dataset_id: str):
710
+ super().__init__("dataset_changed")
711
+ self.current_dataset_id = current_dataset_id
712
+
713
+
714
+ class NoDataInMemoryError(Exception):
715
+ def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
716
+ super().__init__(message)
717
+ self.stata_rc = stata_rc
718
+
719
+
720
+ class InvalidFilterError(Exception):
721
+ def __init__(self, message: str, *, stata_rc: int | None = None):
722
+ super().__init__(message)
723
+ self.message = message
724
+ self.stata_rc = stata_rc
725
+
726
+
727
+ def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
728
+ max_limit, max_vars, max_chars, _ = manager.limits()
729
+
730
+ if view_id is None:
731
+ dataset_id = str(body.get("datasetId", ""))
732
+ frame = str(body.get("frame", "default"))
733
+ else:
734
+ view = manager.get_view(view_id)
735
+ if view is None:
736
+ raise HTTPError(404, "not_found", "View not found")
737
+ dataset_id = view.dataset_id
738
+ frame = view.frame
739
+
740
+ # Parse offset (default 0 is valid since offset >= 0)
741
+ try:
742
+ offset = int(body.get("offset") or 0)
743
+ except (ValueError, TypeError) as e:
744
+ raise HTTPError(400, "invalid_request", f"offset must be a valid integer, got: {body.get('offset')!r}")
745
+
746
+ # Parse limit (no default - must be explicitly provided)
747
+ limit_raw = body.get("limit")
748
+ if limit_raw is None:
749
+ raise HTTPError(400, "invalid_request", "limit is required")
750
+ try:
751
+ limit = int(limit_raw)
752
+ except (ValueError, TypeError) as e:
753
+ raise HTTPError(400, "invalid_request", f"limit must be a valid integer, got: {limit_raw!r}")
754
+
755
+ vars_req = body.get("vars", [])
756
+ include_obs_no = bool(body.get("includeObsNo", False))
757
+
758
+ # Parse sortBy parameter
759
+ sort_by = body.get("sortBy", [])
760
+ if sort_by is not None and not isinstance(sort_by, list):
761
+ raise HTTPError(400, "invalid_request", f"sortBy must be an array, got: {type(sort_by).__name__}")
762
+ if sort_by and not all(isinstance(s, str) for s in sort_by):
763
+ raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
764
+
765
+ # Parse maxChars
766
+ max_chars_raw = body.get("maxChars", max_chars)
767
+ try:
768
+ max_chars_req = int(max_chars_raw or max_chars)
769
+ except (ValueError, TypeError) as e:
770
+ raise HTTPError(400, "invalid_request", f"maxChars must be a valid integer, got: {max_chars_raw!r}")
771
+
772
+ if offset < 0:
773
+ raise HTTPError(400, "invalid_request", f"offset must be >= 0, got: {offset}")
774
+ if limit <= 0:
775
+ raise HTTPError(400, "invalid_request", f"limit must be > 0, got: {limit}")
776
+ if limit > max_limit:
777
+ raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
778
+ if max_chars_req <= 0:
779
+ raise HTTPError(400, "invalid_request", "maxChars must be > 0")
780
+ if max_chars_req > max_chars:
781
+ raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
782
+
783
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
784
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
785
+ if len(vars_req) > max_vars:
786
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
787
+
788
+ current_id = manager.current_dataset_id()
789
+ if dataset_id != current_id:
790
+ raise HTTPError(409, "dataset_changed", "Dataset changed")
791
+
792
+ if view_id is None:
793
+ obs_indices = None
794
+ filtered_n: int | None = None
795
+ else:
796
+ assert view is not None
797
+ obs_indices = view.obs_indices
798
+ filtered_n = view.filtered_n
799
+
800
+ try:
801
+ # Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
802
+ if sort_by:
803
+ try:
804
+ normalized_sort = manager._normalize_sort_spec(sort_by)
805
+ sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
806
+ if sorted_indices is None:
807
+ sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
808
+ descending = [spec.startswith("-") for spec in normalized_sort]
809
+ nulls_last = [not desc for desc in descending]
810
+
811
+ table = manager._get_sort_table(current_id, sort_cols)
812
+ if table is None:
813
+ sorted_indices = []
814
+ else:
815
+ sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
816
+ if sorted_indices is None:
817
+ sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
818
+
819
+ manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
820
+
821
+ if view_id is None:
822
+ obs_indices = sorted_indices
823
+ else:
824
+ assert view is not None
825
+ view_set = set(view.obs_indices)
826
+ obs_indices = [idx for idx in sorted_indices if idx in view_set]
827
+ filtered_n = len(obs_indices)
828
+ except ValueError as e:
829
+ raise HTTPError(400, "invalid_request", f"Invalid sort specification: {e}")
830
+ except RuntimeError as e:
831
+ raise HTTPError(500, "internal_error", f"Failed to apply sort: {e}")
832
+
833
+ dataset_state = manager._client.get_dataset_state()
834
+ page = manager._client.get_page(
835
+ offset=offset,
836
+ limit=limit,
837
+ vars=vars_req,
838
+ include_obs_no=include_obs_no,
839
+ max_chars=max_chars_req,
840
+ obs_indices=obs_indices,
841
+ )
842
+ except HTTPError:
843
+ # Re-raise HTTPError exceptions as-is
844
+ raise
845
+ except RuntimeError as e:
846
+ msg = str(e) or "No data in memory"
847
+ if "no data" in msg.lower():
848
+ raise HTTPError(400, "no_data_in_memory", msg)
849
+ raise HTTPError(500, "internal_error", msg)
850
+ except ValueError as e:
851
+ msg = str(e)
852
+ if msg.lower().startswith("invalid variable"):
853
+ raise HTTPError(400, "invalid_variable", msg)
854
+ raise HTTPError(400, "invalid_request", msg)
855
+ except Exception as e:
856
+ raise HTTPError(500, "internal_error", str(e))
857
+
858
+ view_obj: dict[str, Any] = {
859
+ "offset": offset,
860
+ "limit": limit,
861
+ "returned": page["returned"],
862
+ "filteredN": filtered_n,
863
+ }
864
+ if view_id is not None:
865
+ view_obj["viewId"] = view_id
866
+
867
+ return {
868
+ "dataset": {
869
+ "id": current_id,
870
+ "frame": dataset_state.get("frame"),
871
+ "n": dataset_state.get("n"),
872
+ "k": dataset_state.get("k"),
873
+ },
874
+ "view": view_obj,
875
+ "vars": page["vars"],
876
+ "rows": page["rows"],
877
+ "display": {
878
+ "maxChars": max_chars_req,
879
+ "truncatedCells": page["truncated_cells"],
880
+ "missing": ".",
881
+ },
882
+ }
883
+
884
+
885
+ def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
886
+ max_limit, max_vars, max_chars, _ = manager.limits()
887
+ # Use the specific Arrow limit instead of the general UI page limit
888
+ chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
889
+
890
+ if view_id is None:
891
+ dataset_id = str(body.get("datasetId", ""))
892
+ frame = str(body.get("frame", "default"))
893
+ else:
894
+ view = manager.get_view(view_id)
895
+ if view is None:
896
+ raise HTTPError(404, "not_found", "View not found")
897
+ dataset_id = view.dataset_id
898
+ frame = view.frame
899
+
900
+ # Parse offset (default 0)
901
+ try:
902
+ offset = int(body.get("offset") or 0)
903
+ except (ValueError, TypeError):
904
+ raise HTTPError(400, "invalid_request", "offset must be a valid integer")
905
+
906
+ # Parse limit (required)
907
+ limit_raw = body.get("limit")
908
+ if limit_raw is None:
909
+ # Default to the max arrow limit if not specified?
910
+ # The previous code required it. Let's keep it required but allow large values.
911
+ raise HTTPError(400, "invalid_request", "limit is required")
912
+ try:
913
+ limit = int(limit_raw)
914
+ except (ValueError, TypeError):
915
+ raise HTTPError(400, "invalid_request", "limit must be a valid integer")
916
+
917
+ vars_req = body.get("vars", [])
918
+ include_obs_no = bool(body.get("includeObsNo", False))
919
+ sort_by = body.get("sortBy", [])
920
+
921
+ if offset < 0:
922
+ raise HTTPError(400, "invalid_request", "offset must be >= 0")
923
+ if limit <= 0:
924
+ raise HTTPError(400, "invalid_request", "limit must be > 0")
925
+ # Arrow streams are efficient, but we still respect a (much larger) max limit
926
+ if limit > chunk_limit:
927
+ raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
928
+
929
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
930
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
931
+ if len(vars_req) > max_vars:
932
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
933
+
934
+ current_id = manager.current_dataset_id()
935
+ if dataset_id != current_id:
936
+ raise HTTPError(409, "dataset_changed", "Dataset changed")
937
+
938
+ if view_id is None:
939
+ obs_indices = None
940
+ else:
941
+ assert view is not None
942
+ obs_indices = view.obs_indices
943
+
944
+ try:
945
+ # Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
946
+ if sort_by:
947
+ if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
948
+ raise HTTPError(400, "invalid_request", "sortBy must be a list of strings")
949
+ try:
950
+ normalized_sort = manager._normalize_sort_spec(sort_by)
951
+ sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
952
+ if sorted_indices is None:
953
+ sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
954
+ descending = [spec.startswith("-") for spec in normalized_sort]
955
+ nulls_last = [not desc for desc in descending]
956
+
957
+ table = manager._get_sort_table(current_id, sort_cols)
958
+ if table is None:
959
+ sorted_indices = []
960
+ else:
961
+ sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
962
+ if sorted_indices is None:
963
+ sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
964
+
965
+ manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
966
+
967
+ if view_id is None:
968
+ obs_indices = sorted_indices
969
+ else:
970
+ assert view is not None
971
+ view_set = set(view.obs_indices)
972
+ obs_indices = [idx for idx in sorted_indices if idx in view_set]
973
+ except ValueError as e:
974
+ raise HTTPError(400, "invalid_request", f"Invalid sort: {e}")
975
+ except RuntimeError as e:
976
+ raise HTTPError(500, "internal_error", f"Sort failed: {e}")
977
+
978
+ arrow_bytes = manager._client.get_arrow_stream(
979
+ offset=offset,
980
+ limit=limit,
981
+ vars=vars_req,
982
+ include_obs_no=include_obs_no,
983
+ obs_indices=obs_indices,
984
+ )
985
+ return arrow_bytes
986
+
987
+ except RuntimeError as e:
988
+ msg = str(e) or "No data in memory"
989
+ if "no data" in msg.lower():
990
+ raise HTTPError(400, "no_data_in_memory", msg)
991
+ raise HTTPError(500, "internal_error", msg)
992
+ except ValueError as e:
993
+ msg = str(e)
994
+ if "invalid variable" in msg.lower():
995
+ raise HTTPError(400, "invalid_variable", msg)
996
+ raise HTTPError(400, "invalid_request", msg)
997
+ except Exception as e:
998
+ raise HTTPError(500, "internal_error", str(e))
999
+