superlinear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. apps/__init__.py +4 -0
  2. apps/cli/__init__.py +8 -0
  3. apps/cli/bm25_rag.py +471 -0
  4. apps/cli/chat_repl.py +1497 -0
  5. apps/cli/client.py +195 -0
  6. apps/cli/docs_repl.py +2275 -0
  7. apps/cli/light_rag.py +729 -0
  8. apps/cli/local_snapshots.py +139 -0
  9. apps/cli/locks.py +214 -0
  10. apps/cli/main.py +457 -0
  11. apps/cli/output.py +32 -0
  12. apps/cli/server_cmds.py +516 -0
  13. apps/cli/session_cmds.py +491 -0
  14. apps/cli/snapshot_cmds.py +303 -0
  15. apps/cli/state.py +265 -0
  16. apps/server/__init__.py +4 -0
  17. apps/server/app.py +1363 -0
  18. apps/server/main.py +313 -0
  19. superlinear/__init__.py +114 -0
  20. superlinear/_version.py +3 -0
  21. superlinear/engine/__init__.py +10 -0
  22. superlinear/engine/adapters/__init__.py +12 -0
  23. superlinear/engine/adapters/base.py +91 -0
  24. superlinear/engine/adapters/superlinear.py +1233 -0
  25. superlinear/engine/chat_engine.py +1173 -0
  26. superlinear/engine/chat_types.py +130 -0
  27. superlinear/engine/registry.py +51 -0
  28. superlinear/engine/repetition.py +203 -0
  29. superlinear/engine/session_snapshots.py +451 -0
  30. superlinear/engine/tool_parser.py +83 -0
  31. superlinear/engine/types.py +42 -0
  32. superlinear/kernels/__init__.py +2 -0
  33. superlinear/kernels/common/__init__.py +21 -0
  34. superlinear/kernels/common/adjustment.py +106 -0
  35. superlinear/kernels/common/power.py +154 -0
  36. superlinear/kernels/superlinear/__init__.py +10 -0
  37. superlinear/kernels/superlinear/attention/__init__.py +78 -0
  38. superlinear/kernels/superlinear/attention/_prefill.py +940 -0
  39. superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
  40. superlinear/kernels/superlinear/attention/api.py +433 -0
  41. superlinear/kernels/superlinear/search/__init__.py +33 -0
  42. superlinear/kernels/superlinear/search/_reference.py +204 -0
  43. superlinear/kernels/superlinear/search/_triton.py +488 -0
  44. superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
  45. superlinear/kernels/superlinear/search/api.py +200 -0
  46. superlinear/kernels/superlinear/span/__init__.py +41 -0
  47. superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
  48. superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
  49. superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
  50. superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
  51. superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
  52. superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
  53. superlinear/kernels/superlinear/span/api.py +296 -0
  54. superlinear/kernels/superlinear/span/masks.py +187 -0
  55. superlinear/py.typed +0 -0
  56. superlinear/runtime.py +71 -0
  57. superlinear-0.1.0.dist-info/METADATA +469 -0
  58. superlinear-0.1.0.dist-info/RECORD +62 -0
  59. superlinear-0.1.0.dist-info/WHEEL +5 -0
  60. superlinear-0.1.0.dist-info/entry_points.txt +2 -0
  61. superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
  62. superlinear-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,516 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import signal
6
+ import socket
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ import time
11
+ import urllib.parse
12
+ from dataclasses import asdict, dataclass
13
+ from pathlib import Path
14
+
15
+ from apps.cli.client import DEFAULT_URL, HttpError, SuperlinearClient
16
+ from apps.cli.state import CliState, config_dir, load_state, save_state
17
+
18
+
19
+ class ServerCommandError(RuntimeError):
20
+ pass
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ServerInstancePaths:
25
+ pid_path: Path
26
+ log_path: Path
27
+ meta_path: Path
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class ServerInstanceMeta:
32
+ pid: int
33
+ url: str
34
+ host: str
35
+ port: int
36
+ model: str
37
+ log_path: str
38
+ started_at_unix_s: int
39
+
40
+
41
+ def _normalize_base_url(url: str) -> str:
42
+ url = url.strip()
43
+ if not url:
44
+ raise ServerCommandError("Server URL is empty")
45
+ if "://" not in url:
46
+ url = "http://" + url
47
+ parsed = urllib.parse.urlparse(url)
48
+ if parsed.scheme not in {"http", "https"}:
49
+ raise ServerCommandError(f"Unsupported URL scheme: {parsed.scheme!r}")
50
+ if not parsed.netloc:
51
+ raise ServerCommandError(f"Invalid server URL: {url!r}")
52
+ if parsed.path not in {"", "/"} or parsed.params or parsed.query or parsed.fragment:
53
+ raise ServerCommandError("Server URL must not include a path/query/fragment")
54
+
55
+ host = parsed.hostname
56
+ port = parsed.port
57
+ if host is None:
58
+ raise ServerCommandError(f"Invalid server host in URL: {url!r}")
59
+ if host.lower() == "localhost":
60
+ host = "127.0.0.1"
61
+ if ":" in host and not host.startswith("["):
62
+ host = f"[{host}]"
63
+ netloc = f"{host}:{port}" if port is not None else host
64
+ return f"{parsed.scheme}://{netloc}".rstrip("/")
65
+
66
+
67
+ def _parse_host_port(url: str) -> tuple[str, int]:
68
+ parsed = urllib.parse.urlparse(_normalize_base_url(url))
69
+ host = parsed.hostname
70
+ port = parsed.port
71
+ if host is None:
72
+ raise ServerCommandError(f"Invalid server host in URL: {url!r}")
73
+ if port is None:
74
+ port = 443 if parsed.scheme == "https" else 8787
75
+ return host, int(port)
76
+
77
+
78
+ def _server_dir() -> Path:
79
+ return config_dir() / "server"
80
+
81
+
82
+ def _instance_paths(*, host: str, port: int) -> ServerInstancePaths:
83
+ host = host.strip()
84
+ host_l = host.lower()
85
+ if host_l == "localhost":
86
+ host = "127.0.0.1"
87
+ else:
88
+ host = host_l
89
+
90
+ safe_host = host.replace(":", "_").replace("/", "_")
91
+ prefix = f"{safe_host}_{int(port)}"
92
+ base = _server_dir()
93
+ return ServerInstancePaths(
94
+ pid_path=base / f"{prefix}.pid",
95
+ log_path=base / f"{prefix}.log",
96
+ meta_path=base / f"{prefix}.json",
97
+ )
98
+
99
+
100
+ def _is_pid_running(pid: int) -> bool:
101
+ if pid <= 0:
102
+ return False
103
+ try:
104
+ os.kill(pid, 0)
105
+ except ProcessLookupError:
106
+ return False
107
+ except PermissionError:
108
+ return True
109
+ except OSError:
110
+ return False
111
+ return True
112
+
113
+
114
+ def _read_pid(path: Path) -> int | None:
115
+ try:
116
+ raw = path.read_text(encoding="utf-8").strip()
117
+ except FileNotFoundError:
118
+ return None
119
+ except Exception:
120
+ return None
121
+ try:
122
+ return int(raw)
123
+ except Exception:
124
+ return None
125
+
126
+
127
+ def _write_atomic_text(path: Path, text: str) -> None:
128
+ path.parent.mkdir(parents=True, exist_ok=True)
129
+ tmp_path: Path | None = None
130
+ try:
131
+ with tempfile.NamedTemporaryFile(
132
+ mode="w",
133
+ encoding="utf-8",
134
+ dir=str(path.parent),
135
+ delete=False,
136
+ prefix=path.name + ".",
137
+ suffix=".tmp",
138
+ ) as f:
139
+ f.write(text)
140
+ tmp_path = Path(f.name)
141
+ tmp_path.replace(path)
142
+ finally:
143
+ if tmp_path is not None:
144
+ try:
145
+ tmp_path.unlink(missing_ok=True)
146
+ except Exception:
147
+ pass
148
+
149
+
150
+ def _port_open(host: str, port: int, *, timeout_s: float = 0.2) -> bool:
151
+ check_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host
152
+ try:
153
+ with socket.create_connection((check_host, int(port)), timeout=timeout_s):
154
+ return True
155
+ except OSError:
156
+ return False
157
+
158
+
159
+ def _get_model_id(client: SuperlinearClient) -> str | None:
160
+ try:
161
+ models = client.list_models()
162
+ except HttpError:
163
+ return None
164
+ if not models:
165
+ return None
166
+ first = models[0]
167
+ if isinstance(first, dict) and isinstance(first.get("id"), str):
168
+ return first["id"]
169
+ return None
170
+
171
+
172
+ def server_status(*, url: str) -> int:
173
+ url = _normalize_base_url(url)
174
+ client = SuperlinearClient(base_url=url, timeout_s=5.0)
175
+
176
+ try:
177
+ client.health()
178
+ except HttpError:
179
+ host, port = _parse_host_port(url)
180
+ paths = _instance_paths(host=host, port=port)
181
+ pid = _read_pid(paths.pid_path)
182
+ if pid is not None and not _is_pid_running(pid):
183
+ try:
184
+ paths.pid_path.unlink()
185
+ except Exception:
186
+ pass
187
+ pid = None
188
+
189
+ print(f"stopped url={url}")
190
+ if pid is not None:
191
+ print(f"pid={pid} (not responding)")
192
+ return 1
193
+
194
+ model_id = _get_model_id(client) or "unknown"
195
+
196
+ host, port = _parse_host_port(url)
197
+ paths = _instance_paths(host=host, port=port)
198
+ pid = _read_pid(paths.pid_path)
199
+ if pid is not None and not _is_pid_running(pid):
200
+ pid = None
201
+
202
+ line = f"running url={url} model={model_id}"
203
+ if pid is not None:
204
+ line += f" pid={pid}"
205
+ print(line)
206
+ return 0
207
+
208
+
209
+ def server_start(
210
+ *,
211
+ url: str,
212
+ model: str,
213
+ host: str | None = None,
214
+ port: int | None = None,
215
+ chunk_size: int | None = None,
216
+ attn_implementation: str | None = None,
217
+ decode_kernel: str | None = None,
218
+ device: str | None = None,
219
+ dtype: str | None = None,
220
+ max_prompt_tokens: int | None = None,
221
+ disable_cuda_graph: bool = False,
222
+ disable_shared_fused_moe: bool = False,
223
+ foreground: bool = False,
224
+ ) -> int:
225
+ url = _normalize_base_url(url)
226
+ parsed = urllib.parse.urlparse(url)
227
+ if parsed.scheme != "http":
228
+ raise ServerCommandError("spl server start only supports http URLs (use --url http://...)")
229
+ url_host, url_port = _parse_host_port(url)
230
+
231
+ target_host = url_host
232
+ target_port = int(port if port is not None else url_port)
233
+ if target_port <= 0 or target_port > 65535:
234
+ raise ServerCommandError(f"Invalid port: {target_port}")
235
+
236
+ bind_host = host or url_host
237
+ bind_port = target_port
238
+
239
+ base_url = f"http://{target_host}:{target_port}"
240
+
241
+ if _port_open(target_host, target_port):
242
+ client = SuperlinearClient(base_url=base_url, timeout_s=2.0)
243
+ try:
244
+ client.health()
245
+ except HttpError:
246
+ raise ServerCommandError(
247
+ f"Port {target_port} is in use at host {target_host!r} (not a Superlinear server). "
248
+ "Use --port or stop the conflicting process."
249
+ )
250
+
251
+ model_id = _get_model_id(client)
252
+ if model_id is None:
253
+ raise ServerCommandError(
254
+ f"Port {target_port} is in use at host {target_host!r} (not a Superlinear server). "
255
+ "Use --port or stop the conflicting process."
256
+ )
257
+ print(f"already running url={base_url} model={model_id}")
258
+ return 0
259
+
260
+ paths = _instance_paths(host=target_host, port=target_port)
261
+ paths.pid_path.parent.mkdir(parents=True, exist_ok=True)
262
+
263
+ cmd = [
264
+ sys.executable,
265
+ "-m",
266
+ "apps.server.main",
267
+ "--model",
268
+ model,
269
+ "--host",
270
+ bind_host,
271
+ "--port",
272
+ str(bind_port),
273
+ ]
274
+
275
+ if chunk_size is not None:
276
+ cmd.extend(["--chunk-size", str(int(chunk_size))])
277
+ if device is not None and str(device).strip():
278
+ cmd.extend(["--device", str(device).strip()])
279
+ if dtype is not None and str(dtype).strip():
280
+ cmd.extend(["--dtype", str(dtype).strip()])
281
+ if attn_implementation is not None and str(attn_implementation).strip():
282
+ cmd.extend(["--attn-implementation", str(attn_implementation).strip()])
283
+ if decode_kernel is not None and str(decode_kernel).strip():
284
+ cmd.extend(["--decode-kernel", str(decode_kernel).strip()])
285
+
286
+ if max_prompt_tokens is not None:
287
+ cmd.extend(["--max-prompt-tokens", str(int(max_prompt_tokens))])
288
+
289
+ if disable_cuda_graph:
290
+ cmd.append("--disable-cuda-graph")
291
+ if disable_shared_fused_moe:
292
+ cmd.append("--disable-shared-fused-moe")
293
+
294
+ expected_model_id = os.path.basename(model.rstrip("/")) or "superlinear"
295
+
296
+ if foreground:
297
+ print(f"starting (foreground) url={base_url} model={expected_model_id}")
298
+ return subprocess.call(cmd)
299
+
300
+ with open(paths.log_path, "ab") as logf:
301
+ proc = subprocess.Popen(
302
+ cmd,
303
+ stdin=subprocess.DEVNULL,
304
+ stdout=logf,
305
+ stderr=logf,
306
+ start_new_session=True,
307
+ )
308
+
309
+ # Quick sanity: if the process exited immediately, surface a helpful error.
310
+ time.sleep(0.2)
311
+ if proc.poll() is not None:
312
+ raise ServerCommandError(
313
+ f"Server process exited immediately (code={proc.returncode}). Check logs: {paths.log_path}"
314
+ )
315
+
316
+ meta = ServerInstanceMeta(
317
+ pid=int(proc.pid),
318
+ url=base_url,
319
+ host=target_host,
320
+ port=target_port,
321
+ model=model,
322
+ log_path=str(paths.log_path),
323
+ started_at_unix_s=int(time.time()),
324
+ )
325
+ _write_atomic_text(paths.pid_path, f"{proc.pid}\n")
326
+ _write_atomic_text(paths.meta_path, json.dumps(asdict(meta), ensure_ascii=False, indent=2) + "\n")
327
+
328
+ print(f"starting url={base_url} model={expected_model_id} logs={paths.log_path}")
329
+
330
+ def _read_new_log_lines(*, fp, pos: int) -> tuple[int, list[str]]:
331
+ try:
332
+ fp.seek(pos)
333
+ chunk = fp.read()
334
+ except Exception:
335
+ return pos, []
336
+ if not chunk:
337
+ return fp.tell(), []
338
+ # Splitlines keeps output readable even if the server writes partial lines.
339
+ lines = chunk.splitlines()
340
+ return fp.tell(), [str(l) for l in lines if str(l).strip()]
341
+
342
+ def _is_startup_relevant(line: str) -> bool:
343
+ # Keep the CLI output high-signal; the full log remains in the log file.
344
+ s = line.strip()
345
+ if not s:
346
+ return False
347
+ if s.startswith("[server]") or s.startswith("[warmup]"):
348
+ return True
349
+ if "Loading checkpoint shards" in s:
350
+ return True
351
+ if s.startswith("Traceback") or "ERROR" in s or "Exception" in s:
352
+ return True
353
+ return False
354
+
355
+ # Wait for server to become ready
356
+ client = SuperlinearClient(base_url=base_url, timeout_s=5.0)
357
+ spinner = ["|", "/", "-", "\\"]
358
+ spin_idx = 0
359
+ poll_interval = 1.0
360
+ max_wait_s = 600 # 10 minutes max wait
361
+ start_wait = time.monotonic()
362
+
363
+ log_fp = None
364
+ log_pos = 0
365
+ printed_any_logs = False
366
+ try:
367
+ log_fp = open(paths.log_path, "r", encoding="utf-8", errors="replace")
368
+ # Start from the current end; we only want new lines from this run.
369
+ log_fp.seek(0, os.SEEK_END)
370
+ log_pos = log_fp.tell()
371
+ except Exception:
372
+ log_fp = None
373
+
374
+ try:
375
+ while True:
376
+ elapsed = time.monotonic() - start_wait
377
+ if elapsed > max_wait_s:
378
+ print(f"\rtimeout after {int(elapsed)}s waiting for server. check logs: {paths.log_path}", file=sys.stderr)
379
+ return 1
380
+
381
+ # Check if process died
382
+ if proc.poll() is not None:
383
+ print(f"\rserver process exited (code={proc.returncode}). check logs: {paths.log_path}", file=sys.stderr)
384
+ return 1
385
+
386
+ # Try health check
387
+ try:
388
+ client.health()
389
+ # Server is ready!
390
+ print(f"\rserver ready ({int(elapsed)}s) ")
391
+ print(f"openai api-compatible endpoint: {base_url}/v1/chat/completions")
392
+ return 0
393
+ except Exception:
394
+ pass
395
+
396
+ # Stream high-signal server logs during startup so users see warmup/load phases
397
+ # even when running detached.
398
+ if log_fp is not None:
399
+ log_pos, new_lines = _read_new_log_lines(fp=log_fp, pos=log_pos)
400
+ out_lines = [ln for ln in new_lines if _is_startup_relevant(ln)]
401
+ if out_lines:
402
+ # Clear spinner line before printing logs.
403
+ sys.stdout.write("\r" + (" " * 120) + "\r")
404
+ for ln in out_lines[-20:]:
405
+ # Avoid flooding if the server emits many lines at once.
406
+ print(ln)
407
+ printed_any_logs = True
408
+
409
+ # Show spinner
410
+ sys.stdout.write(f"\rwaiting for model to load... {spinner[spin_idx]} ({int(elapsed)}s)")
411
+ sys.stdout.flush()
412
+ spin_idx = (spin_idx + 1) % len(spinner)
413
+ time.sleep(poll_interval)
414
+ except KeyboardInterrupt:
415
+ print("\n(cancelled) server still starting in background")
416
+ print(f"check: `spl --url {base_url} server status`")
417
+ return 0
418
+ finally:
419
+ try:
420
+ if log_fp is not None:
421
+ log_fp.close()
422
+ except Exception:
423
+ pass
424
+
425
+
426
+ def server_stop(*, url: str, force: bool = False) -> int:
427
+ url = _normalize_base_url(url)
428
+ host, port = _parse_host_port(url)
429
+ paths = _instance_paths(host=host, port=port)
430
+ pid = _read_pid(paths.pid_path)
431
+
432
+ client = SuperlinearClient(base_url=url, timeout_s=5.0)
433
+ reachable = True
434
+ try:
435
+ client.health()
436
+ except HttpError:
437
+ reachable = False
438
+
439
+ if reachable:
440
+ try:
441
+ payload = client.request_json("GET", "/v1/sessions", timeout_s=5.0)
442
+ except HttpError as exc:
443
+ raise ServerCommandError(str(exc)) from exc
444
+
445
+ sessions = payload.get("sessions") if isinstance(payload, dict) else None
446
+ if not isinstance(sessions, list):
447
+ sessions = []
448
+ active_session_ids = [s for s in sessions if isinstance(s, str)]
449
+ if active_session_ids and not force:
450
+ msg = (
451
+ f"refusing to stop url={url}: {len(active_session_ids)} active session(s) exist; "
452
+ "in-memory sessions may be lost on stop.\n"
453
+ "next steps: `spl session ls`, then `spl snapshot save --session <id>` or `spl session rm <id>`, "
454
+ "then retry; rerun with `--force` to stop anyway."
455
+ )
456
+ print(msg, file=sys.stderr)
457
+ return 2
458
+
459
+ if pid is None:
460
+ if reachable:
461
+ print(
462
+ f"cannot stop url={url}: server is reachable but no managed PID file found at {paths.pid_path}",
463
+ file=sys.stderr,
464
+ )
465
+ return 1
466
+ print(f"stopped url={url} (not running)")
467
+ return 0
468
+
469
+ if not _is_pid_running(pid):
470
+ try:
471
+ paths.pid_path.unlink()
472
+ except Exception:
473
+ pass
474
+ print(f"stopped url={url} (stale pid={pid})")
475
+ return 0
476
+
477
+ os.kill(pid, signal.SIGTERM)
478
+ deadline = time.time() + 10.0
479
+ while time.time() < deadline:
480
+ if not _is_pid_running(pid):
481
+ break
482
+ time.sleep(0.1)
483
+
484
+ if _is_pid_running(pid):
485
+ os.kill(pid, signal.SIGKILL)
486
+ time.sleep(0.1)
487
+
488
+ try:
489
+ paths.pid_path.unlink()
490
+ except Exception:
491
+ pass
492
+
493
+ print(f"stopped url={url} pid={pid}")
494
+ return 0
495
+
496
+
497
+ def server_connect(*, url: str) -> int:
498
+ """Save a server URL as the default for future CLI commands."""
499
+ url = _normalize_base_url(url)
500
+
501
+ # Validate the server is reachable
502
+ client = SuperlinearClient(base_url=url, timeout_s=10.0)
503
+ try:
504
+ client.health()
505
+ except Exception as exc:
506
+ print(f"error: cannot reach server at {url}: {exc}", file=sys.stderr)
507
+ return 1
508
+
509
+ # Save to state
510
+ state = load_state()
511
+ state.server_url = url
512
+ save_state(state)
513
+
514
+ print(f"connected to {url}")
515
+ print("future commands will use this server by default")
516
+ return 0