mcp-stata 1.18.0__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/ui_http.py ADDED
@@ -0,0 +1,998 @@
1
+ import hashlib
2
+ import io
3
+ import json
4
+ import secrets
5
+ import threading
6
+ import time
7
+ import uuid
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
11
+ from typing import Any, Callable, Optional
12
+
13
+ from .stata_client import StataClient
14
+ from .config import (
15
+ DEFAULT_HOST,
16
+ DEFAULT_PORT,
17
+ MAX_ARROW_LIMIT,
18
+ MAX_CHARS,
19
+ MAX_LIMIT,
20
+ MAX_REQUEST_BYTES,
21
+ MAX_VARS,
22
+ TOKEN_TTL_S,
23
+ VIEW_TTL_S,
24
+ )
25
+
26
+
27
+ logger = logging.getLogger("mcp_stata")
28
+
29
+ try:
30
+ from .native_ops import argsort_numeric as _native_argsort_numeric
31
+ from .native_ops import argsort_mixed as _native_argsort_mixed
32
+ except Exception:
33
+ _native_argsort_numeric = None
34
+ _native_argsort_mixed = None
35
+
36
+
37
+ def _try_native_argsort(
38
+ table: Any,
39
+ sort_cols: list[str],
40
+ descending: list[bool],
41
+ nulls_last: list[bool],
42
+ ) -> list[int] | None:
43
+ if _native_argsort_numeric is None and _native_argsort_mixed is None:
44
+ return None
45
+ try:
46
+ import pyarrow as pa
47
+ import numpy as np
48
+
49
+ is_string: list[bool] = []
50
+ cols: list[object] = []
51
+ for col in sort_cols:
52
+ arr = table.column(col).combine_chunks()
53
+ if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
54
+ is_string.append(True)
55
+ cols.append(arr.to_pylist())
56
+ continue
57
+ if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
58
+ return None
59
+ is_string.append(False)
60
+ np_arr = arr.to_numpy(zero_copy_only=False)
61
+ if np_arr.dtype != np.float64:
62
+ np_arr = np_arr.astype(np.float64, copy=False)
63
+ # Normalize Stata missing values for numeric columns
64
+ np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
65
+ cols.append(np_arr)
66
+
67
+ if "_n" not in table.column_names:
68
+ return None
69
+ obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
70
+
71
+ if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
72
+ idx = _native_argsort_numeric(cols, descending, nulls_last)
73
+ elif _native_argsort_mixed is not None:
74
+ idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
75
+ else:
76
+ return None
77
+
78
+ return [int(x) for x in (obs[idx] - 1).tolist()]
79
+ except Exception:
80
+ return None
81
+
82
+
83
+ def _get_sorted_indices_polars(
84
+ table: Any,
85
+ sort_cols: list[str],
86
+ descending: list[bool],
87
+ nulls_last: list[bool],
88
+ ) -> list[int]:
89
+ import polars as pl
90
+
91
+ df = pl.from_arrow(table)
92
+ # Normalize Stata missing values for numeric columns
93
+ exprs = []
94
+ for col, dtype in zip(df.columns, df.dtypes):
95
+ if col == "_n":
96
+ exprs.append(pl.col(col))
97
+ continue
98
+ if dtype in (pl.Float32, pl.Float64):
99
+ exprs.append(
100
+ pl.when(pl.col(col) > 8.0e307)
101
+ .then(None)
102
+ .otherwise(pl.col(col))
103
+ .alias(col)
104
+ )
105
+ else:
106
+ exprs.append(pl.col(col))
107
+ df = df.select(exprs)
108
+
109
+ try:
110
+ # Use expressions for arithmetic to avoid eager Series-scalar conversion issues
111
+ # that have been observed in some environments with Int64 dtypes.
112
+ res = df.select(
113
+ idx=pl.arg_sort_by(
114
+ [pl.col(c) for c in sort_cols],
115
+ descending=descending,
116
+ nulls_last=nulls_last,
117
+ ),
118
+ zero_based_n=pl.col("_n") - 1
119
+ )
120
+ return res["zero_based_n"].take(res["idx"]).to_list()
121
+ except Exception:
122
+ # Fallback to eager sort if arg_sort_by fails or has issues
123
+ return (
124
+ df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
125
+ .select(pl.col("_n") - 1)
126
+ .to_series()
127
+ .to_list()
128
+ )
129
+
130
+
131
+
132
+
133
+ def _stable_hash(payload: dict[str, Any]) -> str:
134
+ return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
135
+
136
+
137
+ @dataclass
138
+ class UIChannelInfo:
139
+ base_url: str
140
+ token: str
141
+ expires_at: int
142
+
143
+
144
+ @dataclass
145
+ class ViewHandle:
146
+ view_id: str
147
+ dataset_id: str
148
+ frame: str
149
+ filter_expr: str
150
+ obs_indices: list[int]
151
+ filtered_n: int
152
+ created_at: float
153
+ last_access: float
154
+
155
+
156
+ class UIChannelManager:
157
+ def __init__(
158
+ self,
159
+ client: StataClient,
160
+ *,
161
+ host: str = DEFAULT_HOST,
162
+ port: int = DEFAULT_PORT,
163
+ token_ttl_s: int = TOKEN_TTL_S,
164
+ view_ttl_s: int = VIEW_TTL_S,
165
+ max_limit: int = MAX_LIMIT,
166
+ max_vars: int = MAX_VARS,
167
+ max_chars: int = MAX_CHARS,
168
+ max_request_bytes: int = MAX_REQUEST_BYTES,
169
+ max_arrow_limit: int = MAX_ARROW_LIMIT,
170
+ ):
171
+ self._client = client
172
+ self._host = host
173
+ self._port = port
174
+ self._token_ttl_s = token_ttl_s
175
+ self._view_ttl_s = view_ttl_s
176
+ self._max_limit = max_limit
177
+ self._max_vars = max_vars
178
+ self._max_chars = max_chars
179
+ self._max_request_bytes = max_request_bytes
180
+ self._max_arrow_limit = max_arrow_limit
181
+
182
+ self._lock = threading.Lock()
183
+ self._httpd: ThreadingHTTPServer | None = None
184
+ self._thread: threading.Thread | None = None
185
+
186
+ self._token: str | None = None
187
+ self._expires_at: int = 0
188
+
189
+ self._dataset_version: int = 0
190
+ self._dataset_id_cache: str | None = None
191
+ self._dataset_id_cache_at_version: int = -1
192
+
193
+ self._views: dict[str, ViewHandle] = {}
194
+ self._last_sort_spec: tuple[str, ...] | None = None
195
+ self._last_sort_dataset_id: str | None = None
196
+ self._sort_index_cache: dict[tuple[str, tuple[str, ...]], list[int]] = {}
197
+ self._sort_cache_order: list[tuple[str, tuple[str, ...]]] = []
198
+ self._sort_cache_max_entries: int = 4
199
+ self._sort_table_cache: dict[tuple[str, tuple[str, ...]], Any] = {}
200
+ self._sort_table_order: list[tuple[str, tuple[str, ...]]] = []
201
+ self._sort_table_max_entries: int = 2
202
+
203
+ def notify_potential_dataset_change(self) -> None:
204
+ with self._lock:
205
+ self._dataset_version += 1
206
+ self._dataset_id_cache = None
207
+ self._views.clear()
208
+ self._last_sort_spec = None
209
+ self._last_sort_dataset_id = None
210
+ self._sort_index_cache.clear()
211
+ self._sort_cache_order.clear()
212
+ self._sort_table_cache.clear()
213
+ self._sort_table_order.clear()
214
+
215
+ @staticmethod
216
+ def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
217
+ normalized: list[str] = []
218
+ for spec in sort_spec:
219
+ if not isinstance(spec, str) or not spec:
220
+ raise ValueError(f"Invalid sort specification: {spec!r}")
221
+ raw = spec.strip()
222
+ if not raw:
223
+ raise ValueError(f"Invalid sort specification: {spec!r}")
224
+ sign = "-" if raw.startswith("-") else "+"
225
+ varname = raw.lstrip("+-")
226
+ if not varname:
227
+ raise ValueError(f"Invalid sort specification: {spec!r}")
228
+ normalized.append(f"{sign}{varname}")
229
+ return tuple(normalized)
230
+
231
+ def _get_cached_sort_indices(
232
+ self, dataset_id: str, sort_spec: tuple[str, ...]
233
+ ) -> list[int] | None:
234
+ key = (dataset_id, sort_spec)
235
+ with self._lock:
236
+ cached = self._sort_index_cache.get(key)
237
+ if cached is None:
238
+ return None
239
+ # refresh LRU order
240
+ if key in self._sort_cache_order:
241
+ self._sort_cache_order.remove(key)
242
+ self._sort_cache_order.append(key)
243
+ return cached
244
+
245
+ def _set_cached_sort_indices(
246
+ self, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
247
+ ) -> None:
248
+ key = (dataset_id, sort_spec)
249
+ with self._lock:
250
+ if key in self._sort_index_cache:
251
+ self._sort_cache_order.remove(key)
252
+ self._sort_index_cache[key] = indices
253
+ self._sort_cache_order.append(key)
254
+ while len(self._sort_cache_order) > self._sort_cache_max_entries:
255
+ evict = self._sort_cache_order.pop(0)
256
+ self._sort_index_cache.pop(evict, None)
257
+
258
+ def _get_cached_sort_table(
259
+ self, dataset_id: str, sort_cols: tuple[str, ...]
260
+ ) -> Any | None:
261
+ key = (dataset_id, sort_cols)
262
+ with self._lock:
263
+ cached = self._sort_table_cache.get(key)
264
+ if cached is None:
265
+ return None
266
+ if key in self._sort_table_order:
267
+ self._sort_table_order.remove(key)
268
+ self._sort_table_order.append(key)
269
+ return cached
270
+
271
+ def _set_cached_sort_table(
272
+ self, dataset_id: str, sort_cols: tuple[str, ...], table: Any
273
+ ) -> None:
274
+ key = (dataset_id, sort_cols)
275
+ with self._lock:
276
+ if key in self._sort_table_cache:
277
+ self._sort_table_order.remove(key)
278
+ self._sort_table_cache[key] = table
279
+ self._sort_table_order.append(key)
280
+ while len(self._sort_table_order) > self._sort_table_max_entries:
281
+ evict = self._sort_table_order.pop(0)
282
+ self._sort_table_cache.pop(evict, None)
283
+
284
+ def _get_sort_table(self, dataset_id: str, sort_cols: list[str]) -> Any:
285
+ sort_cols_key = tuple(sort_cols)
286
+ cached = self._get_cached_sort_table(dataset_id, sort_cols_key)
287
+ if cached is not None:
288
+ return cached
289
+
290
+ state = self._client.get_dataset_state()
291
+ n = int(state.get("n", 0) or 0)
292
+ if n <= 0:
293
+ return None
294
+
295
+ # Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
296
+ arrow_bytes = self._client.get_arrow_stream(
297
+ offset=0,
298
+ limit=n,
299
+ vars=sort_cols,
300
+ include_obs_no=True,
301
+ obs_indices=None,
302
+ )
303
+
304
+ import pyarrow as pa
305
+
306
+ with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
307
+ table = reader.read_all()
308
+
309
+ self._set_cached_sort_table(dataset_id, sort_cols_key, table)
310
+ return table
311
+
312
+ def get_channel(self) -> UIChannelInfo:
313
+ self._ensure_http_server()
314
+ with self._lock:
315
+ self._ensure_token()
316
+ assert self._httpd is not None
317
+ port = self._httpd.server_address[1]
318
+ base_url = f"http://{self._host}:{port}"
319
+ return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
320
+
321
+ def capabilities(self) -> dict[str, bool]:
322
+ return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
323
+
324
+ def current_dataset_id(self) -> str:
325
+ with self._lock:
326
+ if self._dataset_id_cache is not None and self._dataset_id_cache_at_version == self._dataset_version:
327
+ return self._dataset_id_cache
328
+
329
+ state = self._client.get_dataset_state()
330
+ payload = {
331
+ "version": self._dataset_version,
332
+ "frame": state.get("frame"),
333
+ "n": state.get("n"),
334
+ "k": state.get("k"),
335
+ "sortlist": state.get("sortlist"),
336
+ }
337
+ digest = _stable_hash(payload)
338
+
339
+ with self._lock:
340
+ self._dataset_id_cache = digest
341
+ self._dataset_id_cache_at_version = self._dataset_version
342
+ return digest
343
+
344
+ def get_view(self, view_id: str) -> Optional[ViewHandle]:
345
+ now = time.time()
346
+ with self._lock:
347
+ self._evict_expired_locked(now)
348
+ view = self._views.get(view_id)
349
+ if view is None:
350
+ return None
351
+ view.last_access = now
352
+ return view
353
+
354
+ def create_view(self, *, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
355
+ current_id = self.current_dataset_id()
356
+ if dataset_id != current_id:
357
+ raise DatasetChangedError(current_id)
358
+
359
+ try:
360
+ obs_indices = self._client.compute_view_indices(filter_expr)
361
+ except ValueError as e:
362
+ raise InvalidFilterError(str(e))
363
+ except RuntimeError as e:
364
+ msg = str(e) or "No data in memory"
365
+ if "no data" in msg.lower():
366
+ raise NoDataInMemoryError(msg)
367
+ raise
368
+ now = time.time()
369
+ view_id = f"view_{uuid.uuid4().hex}"
370
+ view = ViewHandle(
371
+ view_id=view_id,
372
+ dataset_id=current_id,
373
+ frame=frame,
374
+ filter_expr=filter_expr,
375
+ obs_indices=obs_indices,
376
+ filtered_n=len(obs_indices),
377
+ created_at=now,
378
+ last_access=now,
379
+ )
380
+ with self._lock:
381
+ self._evict_expired_locked(now)
382
+ self._views[view_id] = view
383
+ return view
384
+
385
+ def delete_view(self, view_id: str) -> bool:
386
+ with self._lock:
387
+ return self._views.pop(view_id, None) is not None
388
+
389
+ def validate_token(self, header_value: str | None) -> bool:
390
+ if not header_value:
391
+ return False
392
+ if not header_value.startswith("Bearer "):
393
+ return False
394
+ token = header_value[len("Bearer ") :].strip()
395
+ with self._lock:
396
+ self._ensure_token()
397
+ if self._token is None:
398
+ return False
399
+ if time.time() * 1000 >= self._expires_at:
400
+ return False
401
+ return secrets.compare_digest(token, self._token)
402
+
403
+ def limits(self) -> tuple[int, int, int, int]:
404
+ return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
405
+
406
+ def _ensure_token(self) -> None:
407
+ now_ms = int(time.time() * 1000)
408
+ if self._token is None or now_ms >= self._expires_at:
409
+ self._token = secrets.token_urlsafe(32)
410
+ self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
411
+
412
+ def _evict_expired_locked(self, now: float) -> None:
413
+ expired: list[str] = []
414
+ for key, view in self._views.items():
415
+ if now - view.last_access >= self._view_ttl_s:
416
+ expired.append(key)
417
+ for key in expired:
418
+ self._views.pop(key, None)
419
+
420
+ def _ensure_http_server(self) -> None:
421
+ with self._lock:
422
+ if self._httpd is not None:
423
+ return
424
+
425
+ manager = self
426
+
427
+ class Handler(BaseHTTPRequestHandler):
428
+
429
+ def _send_json(self, status: int, payload: dict[str, Any]) -> None:
430
+ data = json.dumps(payload).encode("utf-8")
431
+ self.send_response(status)
432
+ self.send_header("Content-Type", "application/json")
433
+ self.send_header("Content-Length", str(len(data)))
434
+ self.end_headers()
435
+ self.wfile.write(data)
436
+
437
+ def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
438
+ self.send_response(status)
439
+ self.send_header("Content-Type", content_type)
440
+ self.send_header("Content-Length", str(len(data)))
441
+ self.end_headers()
442
+ self.wfile.write(data)
443
+
444
+ def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
445
+ if status >= 500 or code == "internal_error":
446
+ logger.error("UI HTTP error %s: %s", code, message)
447
+ message = "Internal server error"
448
+ body: dict[str, Any] = {"error": {"code": code, "message": message}}
449
+ if stata_rc is not None:
450
+ body["error"]["stataRc"] = stata_rc
451
+ self._send_json(status, body)
452
+
453
+ def _require_auth(self) -> bool:
454
+ if manager.validate_token(self.headers.get("Authorization")):
455
+ return True
456
+ self._error(401, "auth_failed", "Unauthorized")
457
+ return False
458
+
459
+ def _read_json(self) -> dict[str, Any] | None:
460
+ max_limit, max_vars, max_chars, max_bytes = manager.limits()
461
+ _ = (max_limit, max_vars, max_chars)
462
+
463
+ length = int(self.headers.get("Content-Length", "0") or "0")
464
+ if length <= 0:
465
+ return {}
466
+ if length > max_bytes:
467
+ self._error(400, "request_too_large", "Request too large")
468
+ return None
469
+ raw = self.rfile.read(length)
470
+ try:
471
+ parsed = json.loads(raw.decode("utf-8"))
472
+ except Exception:
473
+ self._error(400, "invalid_request", "Invalid JSON")
474
+ return None
475
+ if not isinstance(parsed, dict):
476
+ self._error(400, "invalid_request", "Expected JSON object")
477
+ return None
478
+ return parsed
479
+
480
+ def do_GET(self) -> None:
481
+ if not self._require_auth():
482
+ return
483
+
484
+ if self.path == "/v1/dataset":
485
+ try:
486
+ state = manager._client.get_dataset_state()
487
+ dataset_id = manager.current_dataset_id()
488
+ self._send_json(
489
+ 200,
490
+ {
491
+ "dataset": {
492
+ "id": dataset_id,
493
+ "frame": state.get("frame"),
494
+ "n": state.get("n"),
495
+ "k": state.get("k"),
496
+ "changed": state.get("changed"),
497
+ }
498
+ },
499
+ )
500
+ return
501
+ except NoDataInMemoryError as e:
502
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
503
+ return
504
+ except Exception as e:
505
+ self._error(500, "internal_error", str(e))
506
+ return
507
+
508
+ if self.path == "/v1/vars":
509
+ try:
510
+ state = manager._client.get_dataset_state()
511
+ dataset_id = manager.current_dataset_id()
512
+ variables = manager._client.list_variables_rich()
513
+ self._send_json(
514
+ 200,
515
+ {
516
+ "dataset": {"id": dataset_id, "frame": state.get("frame")},
517
+ "variables": variables,
518
+ },
519
+ )
520
+ return
521
+ except NoDataInMemoryError as e:
522
+ self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
523
+ return
524
+ except Exception as e:
525
+ self._error(500, "internal_error", str(e))
526
+ return
527
+
528
+ self._error(404, "not_found", "Not found")
529
+
530
+ def do_POST(self) -> None:
531
+ if not self._require_auth():
532
+ return
533
+
534
+
535
+ if self.path == "/v1/arrow":
536
+ body = self._read_json()
537
+ if body is None:
538
+ return
539
+ try:
540
+ resp_bytes = handle_arrow_request(manager, body, view_id=None)
541
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
542
+ return
543
+ except HTTPError as e:
544
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
545
+ return
546
+ except Exception as e:
547
+ self._error(500, "internal_error", str(e))
548
+ return
549
+
550
+ if self.path == "/v1/page":
551
+ body = self._read_json()
552
+ if body is None:
553
+ return
554
+ try:
555
+ resp = handle_page_request(manager, body, view_id=None)
556
+ self._send_json(200, resp)
557
+ return
558
+ except HTTPError as e:
559
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
560
+ return
561
+ except Exception as e:
562
+ self._error(500, "internal_error", str(e))
563
+ return
564
+
565
+ if self.path == "/v1/views":
566
+ body = self._read_json()
567
+ if body is None:
568
+ return
569
+ dataset_id = str(body.get("datasetId", ""))
570
+ frame = str(body.get("frame", "default"))
571
+ filter_expr = str(body.get("filterExpr", ""))
572
+ if not dataset_id or not filter_expr:
573
+ self._error(400, "invalid_request", "datasetId and filterExpr are required")
574
+ return
575
+ try:
576
+ view = manager.create_view(dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
577
+ self._send_json(
578
+ 200,
579
+ {
580
+ "dataset": {"id": view.dataset_id, "frame": view.frame},
581
+ "view": {"id": view.view_id, "filteredN": view.filtered_n},
582
+ },
583
+ )
584
+ return
585
+ except DatasetChangedError as e:
586
+ self._error(409, "dataset_changed", "Dataset changed")
587
+ return
588
+ except ValueError as e:
589
+ self._error(400, "invalid_filter", str(e))
590
+ return
591
+ except RuntimeError as e:
592
+ msg = str(e) or "No data in memory"
593
+ if "no data" in msg.lower():
594
+ self._error(400, "no_data_in_memory", msg)
595
+ return
596
+ self._error(500, "internal_error", msg)
597
+ return
598
+ except Exception as e:
599
+ self._error(500, "internal_error", str(e))
600
+ return
601
+
602
+ if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
603
+ parts = self.path.split("/")
604
+ if len(parts) != 5:
605
+ self._error(404, "not_found", "Not found")
606
+ return
607
+ view_id = parts[3]
608
+ body = self._read_json()
609
+ if body is None:
610
+ return
611
+ try:
612
+ resp = handle_page_request(manager, body, view_id=view_id)
613
+ self._send_json(200, resp)
614
+ return
615
+ except HTTPError as e:
616
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
617
+ return
618
+ except Exception as e:
619
+ self._error(500, "internal_error", str(e))
620
+ return
621
+
622
+ if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
623
+ parts = self.path.split("/")
624
+ if len(parts) != 5:
625
+ self._error(404, "not_found", "Not found")
626
+ return
627
+ view_id = parts[3]
628
+ body = self._read_json()
629
+ if body is None:
630
+ return
631
+ try:
632
+ resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
633
+ self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
634
+ return
635
+ except HTTPError as e:
636
+ self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
637
+ return
638
+ except Exception as e:
639
+ self._error(500, "internal_error", str(e))
640
+ return
641
+
642
+ if self.path == "/v1/filters/validate":
643
+ body = self._read_json()
644
+ if body is None:
645
+ return
646
+ filter_expr = str(body.get("filterExpr", ""))
647
+ if not filter_expr:
648
+ self._error(400, "invalid_request", "filterExpr is required")
649
+ return
650
+ try:
651
+ manager._client.validate_filter_expr(filter_expr)
652
+ self._send_json(200, {"ok": True})
653
+ return
654
+ except ValueError as e:
655
+ self._error(400, "invalid_filter", str(e))
656
+ return
657
+ except RuntimeError as e:
658
+ msg = str(e) or "No data in memory"
659
+ if "no data" in msg.lower():
660
+ self._error(400, "no_data_in_memory", msg)
661
+ return
662
+ self._error(500, "internal_error", msg)
663
+ return
664
+ except Exception as e:
665
+ self._error(500, "internal_error", str(e))
666
+ return
667
+
668
+ self._error(404, "not_found", "Not found")
669
+
670
+ def do_DELETE(self) -> None:
671
+ if not self._require_auth():
672
+ return
673
+
674
+ if self.path.startswith("/v1/views/"):
675
+ parts = self.path.split("/")
676
+ if len(parts) != 4:
677
+ self._error(404, "not_found", "Not found")
678
+ return
679
+ view_id = parts[3]
680
+ if manager.delete_view(view_id):
681
+ self._send_json(200, {"ok": True})
682
+ else:
683
+ self._error(404, "not_found", "Not found")
684
+ return
685
+
686
+ self._error(404, "not_found", "Not found")
687
+
688
+ def log_message(self, format: str, *args: Any) -> None:
689
+ return
690
+
691
+ httpd = ThreadingHTTPServer((self._host, self._port), Handler)
692
+ t = threading.Thread(target=httpd.serve_forever, daemon=True)
693
+ t.start()
694
+ self._httpd = httpd
695
+ self._thread = t
696
+
697
+
698
+ class HTTPError(Exception):
699
+ def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
700
+ super().__init__(message)
701
+ self.status = status
702
+ self.code = code
703
+ self.message = message
704
+ self.stata_rc = stata_rc
705
+
706
+
707
+ class DatasetChangedError(Exception):
708
+ def __init__(self, current_dataset_id: str):
709
+ super().__init__("dataset_changed")
710
+ self.current_dataset_id = current_dataset_id
711
+
712
+
713
+ class NoDataInMemoryError(Exception):
714
+ def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
715
+ super().__init__(message)
716
+ self.stata_rc = stata_rc
717
+
718
+
719
+ class InvalidFilterError(Exception):
720
+ def __init__(self, message: str, *, stata_rc: int | None = None):
721
+ super().__init__(message)
722
+ self.message = message
723
+ self.stata_rc = stata_rc
724
+
725
+
726
+ def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
727
+ max_limit, max_vars, max_chars, _ = manager.limits()
728
+
729
+ if view_id is None:
730
+ dataset_id = str(body.get("datasetId", ""))
731
+ frame = str(body.get("frame", "default"))
732
+ else:
733
+ view = manager.get_view(view_id)
734
+ if view is None:
735
+ raise HTTPError(404, "not_found", "View not found")
736
+ dataset_id = view.dataset_id
737
+ frame = view.frame
738
+
739
+ # Parse offset (default 0 is valid since offset >= 0)
740
+ try:
741
+ offset = int(body.get("offset") or 0)
742
+ except (ValueError, TypeError) as e:
743
+ raise HTTPError(400, "invalid_request", f"offset must be a valid integer, got: {body.get('offset')!r}")
744
+
745
+ # Parse limit (no default - must be explicitly provided)
746
+ limit_raw = body.get("limit")
747
+ if limit_raw is None:
748
+ raise HTTPError(400, "invalid_request", "limit is required")
749
+ try:
750
+ limit = int(limit_raw)
751
+ except (ValueError, TypeError) as e:
752
+ raise HTTPError(400, "invalid_request", f"limit must be a valid integer, got: {limit_raw!r}")
753
+
754
+ vars_req = body.get("vars", [])
755
+ include_obs_no = bool(body.get("includeObsNo", False))
756
+
757
+ # Parse sortBy parameter
758
+ sort_by = body.get("sortBy", [])
759
+ if sort_by is not None and not isinstance(sort_by, list):
760
+ raise HTTPError(400, "invalid_request", f"sortBy must be an array, got: {type(sort_by).__name__}")
761
+ if sort_by and not all(isinstance(s, str) for s in sort_by):
762
+ raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
763
+
764
+ # Parse maxChars
765
+ max_chars_raw = body.get("maxChars", max_chars)
766
+ try:
767
+ max_chars_req = int(max_chars_raw or max_chars)
768
+ except (ValueError, TypeError) as e:
769
+ raise HTTPError(400, "invalid_request", f"maxChars must be a valid integer, got: {max_chars_raw!r}")
770
+
771
+ if offset < 0:
772
+ raise HTTPError(400, "invalid_request", f"offset must be >= 0, got: {offset}")
773
+ if limit <= 0:
774
+ raise HTTPError(400, "invalid_request", f"limit must be > 0, got: {limit}")
775
+ if limit > max_limit:
776
+ raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
777
+ if max_chars_req <= 0:
778
+ raise HTTPError(400, "invalid_request", "maxChars must be > 0")
779
+ if max_chars_req > max_chars:
780
+ raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
781
+
782
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
783
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
784
+ if len(vars_req) > max_vars:
785
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
786
+
787
+ current_id = manager.current_dataset_id()
788
+ if dataset_id != current_id:
789
+ raise HTTPError(409, "dataset_changed", "Dataset changed")
790
+
791
+ if view_id is None:
792
+ obs_indices = None
793
+ filtered_n: int | None = None
794
+ else:
795
+ assert view is not None
796
+ obs_indices = view.obs_indices
797
+ filtered_n = view.filtered_n
798
+
799
+ try:
800
+ # Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
801
+ if sort_by:
802
+ try:
803
+ normalized_sort = manager._normalize_sort_spec(sort_by)
804
+ sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
805
+ if sorted_indices is None:
806
+ sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
807
+ descending = [spec.startswith("-") for spec in normalized_sort]
808
+ nulls_last = [not desc for desc in descending]
809
+
810
+ table = manager._get_sort_table(current_id, sort_cols)
811
+ if table is None:
812
+ sorted_indices = []
813
+ else:
814
+ sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
815
+ if sorted_indices is None:
816
+ sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
817
+
818
+ manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
819
+
820
+ if view_id is None:
821
+ obs_indices = sorted_indices
822
+ else:
823
+ assert view is not None
824
+ view_set = set(view.obs_indices)
825
+ obs_indices = [idx for idx in sorted_indices if idx in view_set]
826
+ filtered_n = len(obs_indices)
827
+ except ValueError as e:
828
+ raise HTTPError(400, "invalid_request", f"Invalid sort specification: {e}")
829
+ except RuntimeError as e:
830
+ raise HTTPError(500, "internal_error", f"Failed to apply sort: {e}")
831
+
832
+ dataset_state = manager._client.get_dataset_state()
833
+ page = manager._client.get_page(
834
+ offset=offset,
835
+ limit=limit,
836
+ vars=vars_req,
837
+ include_obs_no=include_obs_no,
838
+ max_chars=max_chars_req,
839
+ obs_indices=obs_indices,
840
+ )
841
+ except HTTPError:
842
+ # Re-raise HTTPError exceptions as-is
843
+ raise
844
+ except RuntimeError as e:
845
+ msg = str(e) or "No data in memory"
846
+ if "no data" in msg.lower():
847
+ raise HTTPError(400, "no_data_in_memory", msg)
848
+ raise HTTPError(500, "internal_error", msg)
849
+ except ValueError as e:
850
+ msg = str(e)
851
+ if msg.lower().startswith("invalid variable"):
852
+ raise HTTPError(400, "invalid_variable", msg)
853
+ raise HTTPError(400, "invalid_request", msg)
854
+ except Exception as e:
855
+ raise HTTPError(500, "internal_error", str(e))
856
+
857
+ view_obj: dict[str, Any] = {
858
+ "offset": offset,
859
+ "limit": limit,
860
+ "returned": page["returned"],
861
+ "filteredN": filtered_n,
862
+ }
863
+ if view_id is not None:
864
+ view_obj["viewId"] = view_id
865
+
866
+ return {
867
+ "dataset": {
868
+ "id": current_id,
869
+ "frame": dataset_state.get("frame"),
870
+ "n": dataset_state.get("n"),
871
+ "k": dataset_state.get("k"),
872
+ },
873
+ "view": view_obj,
874
+ "vars": page["vars"],
875
+ "rows": page["rows"],
876
+ "display": {
877
+ "maxChars": max_chars_req,
878
+ "truncatedCells": page["truncated_cells"],
879
+ "missing": ".",
880
+ },
881
+ }
882
+
883
+
884
+ def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
885
+ max_limit, max_vars, max_chars, _ = manager.limits()
886
+ # Use the specific Arrow limit instead of the general UI page limit
887
+ chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
888
+
889
+ if view_id is None:
890
+ dataset_id = str(body.get("datasetId", ""))
891
+ frame = str(body.get("frame", "default"))
892
+ else:
893
+ view = manager.get_view(view_id)
894
+ if view is None:
895
+ raise HTTPError(404, "not_found", "View not found")
896
+ dataset_id = view.dataset_id
897
+ frame = view.frame
898
+
899
+ # Parse offset (default 0)
900
+ try:
901
+ offset = int(body.get("offset") or 0)
902
+ except (ValueError, TypeError):
903
+ raise HTTPError(400, "invalid_request", "offset must be a valid integer")
904
+
905
+ # Parse limit (required)
906
+ limit_raw = body.get("limit")
907
+ if limit_raw is None:
908
+ # Default to the max arrow limit if not specified?
909
+ # The previous code required it. Let's keep it required but allow large values.
910
+ raise HTTPError(400, "invalid_request", "limit is required")
911
+ try:
912
+ limit = int(limit_raw)
913
+ except (ValueError, TypeError):
914
+ raise HTTPError(400, "invalid_request", "limit must be a valid integer")
915
+
916
+ vars_req = body.get("vars", [])
917
+ include_obs_no = bool(body.get("includeObsNo", False))
918
+ sort_by = body.get("sortBy", [])
919
+
920
+ if offset < 0:
921
+ raise HTTPError(400, "invalid_request", "offset must be >= 0")
922
+ if limit <= 0:
923
+ raise HTTPError(400, "invalid_request", "limit must be > 0")
924
+ # Arrow streams are efficient, but we still respect a (much larger) max limit
925
+ if limit > chunk_limit:
926
+ raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
927
+
928
+ if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
929
+ raise HTTPError(400, "invalid_request", "vars must be a list of strings")
930
+ if len(vars_req) > max_vars:
931
+ raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
932
+
933
+ current_id = manager.current_dataset_id()
934
+ if dataset_id != current_id:
935
+ raise HTTPError(409, "dataset_changed", "Dataset changed")
936
+
937
+ if view_id is None:
938
+ obs_indices = None
939
+ else:
940
+ assert view is not None
941
+ obs_indices = view.obs_indices
942
+
943
+ try:
944
+ # Apply sorting if requested (Rust native sorter with Polars fallback; no dataset mutation)
945
+ if sort_by:
946
+ if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
947
+ raise HTTPError(400, "invalid_request", "sortBy must be a list of strings")
948
+ try:
949
+ normalized_sort = manager._normalize_sort_spec(sort_by)
950
+ sorted_indices = manager._get_cached_sort_indices(current_id, normalized_sort)
951
+ if sorted_indices is None:
952
+ sort_cols = [spec.lstrip("+-") for spec in normalized_sort]
953
+ descending = [spec.startswith("-") for spec in normalized_sort]
954
+ nulls_last = [not desc for desc in descending]
955
+
956
+ table = manager._get_sort_table(current_id, sort_cols)
957
+ if table is None:
958
+ sorted_indices = []
959
+ else:
960
+ sorted_indices = _try_native_argsort(table, sort_cols, descending, nulls_last)
961
+ if sorted_indices is None:
962
+ sorted_indices = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
963
+
964
+ manager._set_cached_sort_indices(current_id, normalized_sort, sorted_indices)
965
+
966
+ if view_id is None:
967
+ obs_indices = sorted_indices
968
+ else:
969
+ assert view is not None
970
+ view_set = set(view.obs_indices)
971
+ obs_indices = [idx for idx in sorted_indices if idx in view_set]
972
+ except ValueError as e:
973
+ raise HTTPError(400, "invalid_request", f"Invalid sort: {e}")
974
+ except RuntimeError as e:
975
+ raise HTTPError(500, "internal_error", f"Sort failed: {e}")
976
+
977
+ arrow_bytes = manager._client.get_arrow_stream(
978
+ offset=offset,
979
+ limit=limit,
980
+ vars=vars_req,
981
+ include_obs_no=include_obs_no,
982
+ obs_indices=obs_indices,
983
+ )
984
+ return arrow_bytes
985
+
986
+ except RuntimeError as e:
987
+ msg = str(e) or "No data in memory"
988
+ if "no data" in msg.lower():
989
+ raise HTTPError(400, "no_data_in_memory", msg)
990
+ raise HTTPError(500, "internal_error", msg)
991
+ except ValueError as e:
992
+ msg = str(e)
993
+ if "invalid variable" in msg.lower():
994
+ raise HTTPError(400, "invalid_variable", msg)
995
+ raise HTTPError(400, "invalid_request", msg)
996
+ except Exception as e:
997
+ raise HTTPError(500, "internal_error", str(e))
998
+