superlinear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. apps/__init__.py +4 -0
  2. apps/cli/__init__.py +8 -0
  3. apps/cli/bm25_rag.py +471 -0
  4. apps/cli/chat_repl.py +1497 -0
  5. apps/cli/client.py +195 -0
  6. apps/cli/docs_repl.py +2275 -0
  7. apps/cli/light_rag.py +729 -0
  8. apps/cli/local_snapshots.py +139 -0
  9. apps/cli/locks.py +214 -0
  10. apps/cli/main.py +457 -0
  11. apps/cli/output.py +32 -0
  12. apps/cli/server_cmds.py +516 -0
  13. apps/cli/session_cmds.py +491 -0
  14. apps/cli/snapshot_cmds.py +303 -0
  15. apps/cli/state.py +265 -0
  16. apps/server/__init__.py +4 -0
  17. apps/server/app.py +1363 -0
  18. apps/server/main.py +313 -0
  19. superlinear/__init__.py +114 -0
  20. superlinear/_version.py +3 -0
  21. superlinear/engine/__init__.py +10 -0
  22. superlinear/engine/adapters/__init__.py +12 -0
  23. superlinear/engine/adapters/base.py +91 -0
  24. superlinear/engine/adapters/superlinear.py +1233 -0
  25. superlinear/engine/chat_engine.py +1173 -0
  26. superlinear/engine/chat_types.py +130 -0
  27. superlinear/engine/registry.py +51 -0
  28. superlinear/engine/repetition.py +203 -0
  29. superlinear/engine/session_snapshots.py +451 -0
  30. superlinear/engine/tool_parser.py +83 -0
  31. superlinear/engine/types.py +42 -0
  32. superlinear/kernels/__init__.py +2 -0
  33. superlinear/kernels/common/__init__.py +21 -0
  34. superlinear/kernels/common/adjustment.py +106 -0
  35. superlinear/kernels/common/power.py +154 -0
  36. superlinear/kernels/superlinear/__init__.py +10 -0
  37. superlinear/kernels/superlinear/attention/__init__.py +78 -0
  38. superlinear/kernels/superlinear/attention/_prefill.py +940 -0
  39. superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
  40. superlinear/kernels/superlinear/attention/api.py +433 -0
  41. superlinear/kernels/superlinear/search/__init__.py +33 -0
  42. superlinear/kernels/superlinear/search/_reference.py +204 -0
  43. superlinear/kernels/superlinear/search/_triton.py +488 -0
  44. superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
  45. superlinear/kernels/superlinear/search/api.py +200 -0
  46. superlinear/kernels/superlinear/span/__init__.py +41 -0
  47. superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
  48. superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
  49. superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
  50. superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
  51. superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
  52. superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
  53. superlinear/kernels/superlinear/span/api.py +296 -0
  54. superlinear/kernels/superlinear/span/masks.py +187 -0
  55. superlinear/py.typed +0 -0
  56. superlinear/runtime.py +71 -0
  57. superlinear-0.1.0.dist-info/METADATA +469 -0
  58. superlinear-0.1.0.dist-info/RECORD +62 -0
  59. superlinear-0.1.0.dist-info/WHEEL +5 -0
  60. superlinear-0.1.0.dist-info/entry_points.txt +2 -0
  61. superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
  62. superlinear-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,130 @@
1
+ """Core chat request/streaming event types.
2
+
3
+ These types are internal to the library and are intentionally decoupled from:
4
+ - HTTP transport (FastAPI / SSE)
5
+ - OpenAI request/response JSON envelopes
6
+
7
+ The goal is to keep the core engine reusable for future API surfaces.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Literal
14
+
15
+
16
+ ToolChoice = Literal["auto", "none", "required"] | dict[str, Any]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class StreamOptions:
21
+ """Chunked streaming policy."""
22
+
23
+ flush_every_n_tokens: int = 8
24
+ flush_every_ms: int = 50
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ToolCall:
29
+ """A parsed tool call (function name + JSON-serializable arguments)."""
30
+
31
+ id: str
32
+ name: str
33
+ arguments: dict[str, Any]
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class ChatMessage:
38
+ """A normalized chat message.
39
+
40
+ Notes:
41
+ - `tool_calls` is only meaningful for assistant messages.
42
+ - `tool_call_id` is only meaningful for tool messages (tool outputs).
43
+ """
44
+
45
+ role: Literal["system", "user", "assistant", "tool"]
46
+ content: str | None = None
47
+ tool_calls: list[ToolCall] = field(default_factory=list)
48
+ tool_call_id: str | None = None
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class ChatRequest:
53
+ """Normalized internal chat request."""
54
+
55
+ messages: list[ChatMessage]
56
+ tools: list[dict[str, Any]] = field(default_factory=list)
57
+ tool_choice: ToolChoice | None = None
58
+ max_tokens: int = 4096
59
+ temperature: float = 0.0
60
+ top_p: float = 1.0
61
+ stop: list[str] = field(default_factory=list)
62
+ stream: bool = False
63
+ stream_options: StreamOptions = field(default_factory=StreamOptions)
64
+ chat_template_kwargs: dict[str, Any] | None = None # Additional kwargs for chat template
65
+ reasoning_budget: int | None = None # Max tokens for thinking phase (enables thinking when set)
66
+ discard_thinking: bool | None = None # If set, discard <think>...</think> from persisted session state
67
+ stream_thinking: bool | None = None # If True, stream <think>...</think> content as separate deltas
68
+ session_id: str | None = None # Session ID for stateful chat (KV cache reuse)
69
+ session_append_from_pos: int | None = None # If set, only append prompt tokens from this position
70
+ extra: dict[str, Any] = field(default_factory=dict) # Engine-specific per-request overrides
71
+
72
+
73
+ @dataclass(frozen=True)
74
+ class Usage:
75
+ prompt_tokens: int
76
+ completion_tokens: int
77
+
78
+ @property
79
+ def total_tokens(self) -> int:
80
+ return self.prompt_tokens + self.completion_tokens
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class Timing:
85
+ prefill_s: float | None = None
86
+ decode_s: float | None = None
87
+ total_s: float | None = None
88
+ tok_per_s: float | None = None
89
+
90
+
91
+ @dataclass(frozen=True)
92
+ class DeltaEvent:
93
+ """Chunked text delta."""
94
+
95
+ text: str
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class ThinkingDeltaEvent:
100
+ """Chunked thinking delta (text inside <think>...</think>)."""
101
+
102
+ text: str
103
+
104
+
105
+ @dataclass(frozen=True)
106
+ class ToolCallEvent:
107
+ """A completed tool call (or batch of tool calls)."""
108
+
109
+ tool_calls: list[ToolCall]
110
+
111
+
112
+ @dataclass(frozen=True)
113
+ class FinalEvent:
114
+ """Terminal event for a generation."""
115
+
116
+ finish_reason: Literal["stop", "length", "tool_calls", "cancelled", "error", "repetition"]
117
+ usage: Usage
118
+ timing: Timing
119
+ raw_content: str | None = None # Unstripped content (includes <think> if any) when discard_thinking=False
120
+
121
+
122
+ @dataclass(frozen=True)
123
+ class ErrorEvent:
124
+ """Non-terminal or terminal error event."""
125
+
126
+ message: str
127
+ retryable: bool = False
128
+
129
+
130
+ StreamEvent = DeltaEvent | ThinkingDeltaEvent | ToolCallEvent | FinalEvent | ErrorEvent
@@ -0,0 +1,51 @@
1
+ """Model adapter registry.
2
+
3
+ Maps model identifiers to their corresponding adapter classes.
4
+ """
5
+
6
+ from typing import Type
7
+
8
+ from .adapters.base import BaseAdapter
9
+ from .adapters.superlinear import SuperlinearAdapter
10
+
11
+ # Registry mapping model family names to adapter classes
12
+ _ADAPTER_REGISTRY: dict[str, Type[BaseAdapter]] = {
13
+ "superlinear": SuperlinearAdapter,
14
+ }
15
+
16
+
17
+ def get_adapter(model_family: str) -> BaseAdapter:
18
+ """
19
+ Get an adapter instance for the given model family.
20
+
21
+ Args:
22
+ model_family: Name of the model family (e.g., "superlinear").
23
+
24
+ Returns:
25
+ An adapter instance for the model family.
26
+
27
+ Raises:
28
+ ValueError: If the model family is not registered.
29
+ """
30
+ if model_family not in _ADAPTER_REGISTRY:
31
+ available = ", ".join(_ADAPTER_REGISTRY.keys())
32
+ raise ValueError(
33
+ f"Unknown model family: {model_family!r}. Available: {available}"
34
+ )
35
+ return _ADAPTER_REGISTRY[model_family]()
36
+
37
+
38
+ def register_adapter(model_family: str, adapter_cls: Type[BaseAdapter]) -> None:
39
+ """
40
+ Register a new adapter for a model family.
41
+
42
+ Args:
43
+ model_family: Name of the model family.
44
+ adapter_cls: Adapter class (must inherit from BaseAdapter).
45
+ """
46
+ _ADAPTER_REGISTRY[model_family] = adapter_cls
47
+
48
+
49
+ def list_model_families() -> list[str]:
50
+ """Return list of registered model family names."""
51
+ return list(_ADAPTER_REGISTRY.keys())
@@ -0,0 +1,203 @@
1
+ """Token-level repetition detection for early stopping.
2
+
3
+ This module implements a high-precision detector for exact token periodicity
4
+ in the recent tail of generated token IDs.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Any
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class RepeatHit:
15
+ period: int
16
+ repeats: int
17
+ checked_tail_len: int
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class RepetitionDetectionConfig:
22
+ """Configuration for repetition early-stop.
23
+
24
+ Notes:
25
+ - Defaults are tuned for typical long-context Q&A.
26
+ - `enabled` defaults to True to catch repetition loops early.
27
+ - Settings are conservative to avoid false positives on legitimate content.
28
+ """
29
+
30
+ enabled: bool = True
31
+ tail_len: int = 1024
32
+ check_every: int = 32
33
+ min_generated_tokens: int = 256
34
+ min_repeats: int = 3
35
+ max_period: int = 512
36
+ min_unique_tokens: int = 5
37
+
38
+ def validate(self) -> None:
39
+ if self.tail_len <= 0:
40
+ raise ValueError("'repetition_detection.tail_len' must be > 0.")
41
+ if self.check_every <= 0:
42
+ raise ValueError("'repetition_detection.check_every' must be > 0.")
43
+ if self.min_generated_tokens < 0:
44
+ raise ValueError("'repetition_detection.min_generated_tokens' must be >= 0.")
45
+ if self.min_repeats < 2:
46
+ raise ValueError("'repetition_detection.min_repeats' must be >= 2.")
47
+ if self.max_period <= 0:
48
+ raise ValueError("'repetition_detection.max_period' must be > 0.")
49
+ if self.min_unique_tokens <= 0:
50
+ raise ValueError("'repetition_detection.min_unique_tokens' must be > 0.")
51
+
52
+ def merged(self, override: Any | None) -> "RepetitionDetectionConfig":
53
+ """Merge a request-level override (typically request.extra['repetition_detection'])."""
54
+ if override is None:
55
+ return self
56
+ if isinstance(override, RepetitionDetectionConfig):
57
+ override.validate()
58
+ return override
59
+ if not isinstance(override, dict):
60
+ raise ValueError("'repetition_detection' must be an object.")
61
+
62
+ data: dict[str, Any] = dict(override)
63
+ if "min_unique_tokens" not in data and "min_unique_tokens_in_period" in data:
64
+ data["min_unique_tokens"] = data["min_unique_tokens_in_period"]
65
+
66
+ enabled = self.enabled
67
+ if "enabled" in data:
68
+ raw_enabled = data["enabled"]
69
+ if not isinstance(raw_enabled, bool):
70
+ raise ValueError("'repetition_detection.enabled' must be a boolean.")
71
+ enabled = raw_enabled
72
+ tail_len = self.tail_len if "tail_len" not in data else _coerce_int(data["tail_len"], "tail_len")
73
+ check_every = (
74
+ self.check_every
75
+ if "check_every" not in data
76
+ else _coerce_int(data["check_every"], "check_every")
77
+ )
78
+ min_generated_tokens = (
79
+ self.min_generated_tokens
80
+ if "min_generated_tokens" not in data
81
+ else _coerce_int(data["min_generated_tokens"], "min_generated_tokens", min_value=0)
82
+ )
83
+ min_repeats = (
84
+ self.min_repeats
85
+ if "min_repeats" not in data
86
+ else _coerce_int(data["min_repeats"], "min_repeats", min_value=2)
87
+ )
88
+ max_period = (
89
+ self.max_period
90
+ if "max_period" not in data
91
+ else _coerce_int(data["max_period"], "max_period")
92
+ )
93
+ min_unique_tokens = (
94
+ self.min_unique_tokens
95
+ if "min_unique_tokens" not in data
96
+ else _coerce_int(data["min_unique_tokens"], "min_unique_tokens", min_value=1)
97
+ )
98
+
99
+ merged = RepetitionDetectionConfig(
100
+ enabled=enabled,
101
+ tail_len=tail_len,
102
+ check_every=check_every,
103
+ min_generated_tokens=min_generated_tokens,
104
+ min_repeats=min_repeats,
105
+ max_period=max_period,
106
+ min_unique_tokens=min_unique_tokens,
107
+ )
108
+ merged.validate()
109
+ return merged
110
+
111
+
112
+ def _coerce_int(value: Any, name: str, *, min_value: int = 1) -> int:
113
+ if isinstance(value, bool):
114
+ raise ValueError(f"'repetition_detection.{name}' must be an integer.")
115
+ try:
116
+ out = int(value)
117
+ except Exception as exc:
118
+ raise ValueError(f"'repetition_detection.{name}' must be an integer.") from exc
119
+ if out < min_value:
120
+ raise ValueError(f"'repetition_detection.{name}' must be >= {min_value}.")
121
+ return out
122
+
123
+
124
+ def prefix_function(seq: list[int]) -> list[int]:
125
+ """Classic KMP prefix-function (pi array) for a sequence of ints.
126
+
127
+ pi[i] = length of the longest proper prefix of seq[:i+1]
128
+ that is also a suffix of seq[:i+1].
129
+ """
130
+ n = len(seq)
131
+ pi = [0] * n
132
+ j = 0
133
+ for i in range(1, n):
134
+ while j > 0 and seq[i] != seq[j]:
135
+ j = pi[j - 1]
136
+ if j < n and seq[i] == seq[j]:
137
+ j += 1
138
+ pi[i] = j
139
+ return pi
140
+
141
+
142
+ def _nontrivial_period(period_tokens: list[int], *, min_unique_tokens: int) -> bool:
143
+ # Avoid stopping on junk like a single token or whitespace/punctuation-only loops.
144
+ return len(set(period_tokens)) >= min_unique_tokens
145
+
146
+
147
+ def detect_repetition_kmp_tail(
148
+ tokens: list[int],
149
+ *,
150
+ tail_len: int = 1024,
151
+ min_generated_tokens: int = 256,
152
+ min_repeats: int = 3,
153
+ max_period: int = 512,
154
+ min_unique_tokens: int = 4,
155
+ ) -> RepeatHit | None:
156
+ """Detect exact periodic repetition using only the last `tail_len` tokens.
157
+
158
+ Strategy:
159
+ - Compute KMP prefix function on the tail.
160
+ - Walk the border chain to derive candidate periods.
161
+ - Validate candidate periods by checking the last `min_repeats` blocks match.
162
+ """
163
+ if tail_len <= 0:
164
+ raise ValueError("'tail_len' must be > 0.")
165
+ if min_generated_tokens < 0:
166
+ raise ValueError("'min_generated_tokens' must be >= 0.")
167
+ if min_repeats < 2:
168
+ raise ValueError("'min_repeats' must be >= 2.")
169
+ if max_period <= 0:
170
+ raise ValueError("'max_period' must be > 0.")
171
+ if min_unique_tokens <= 0:
172
+ raise ValueError("'min_unique_tokens' must be > 0.")
173
+
174
+ if len(tokens) < min_generated_tokens:
175
+ return None
176
+
177
+ tail = tokens[-tail_len:] if len(tokens) > tail_len else tokens
178
+ L = len(tail)
179
+ if L < min_repeats:
180
+ return None
181
+
182
+ pi = prefix_function(tail)
183
+ if not pi:
184
+ return None
185
+ b = pi[-1]
186
+
187
+ # Walk border chain: b -> pi[b-1] -> ...
188
+ while b > 0:
189
+ p = L - b
190
+ if 1 <= p <= max_period:
191
+ need = p * min_repeats
192
+ if L >= need:
193
+ a = tail[-p:]
194
+ ok = True
195
+ for r in range(2, min_repeats + 1):
196
+ if tail[-r * p : -(r - 1) * p] != a:
197
+ ok = False
198
+ break
199
+ if ok and _nontrivial_period(a, min_unique_tokens=min_unique_tokens):
200
+ return RepeatHit(period=p, repeats=min_repeats, checked_tail_len=L)
201
+ b = pi[b - 1]
202
+
203
+ return None