superlinear 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apps/__init__.py +4 -0
- apps/cli/__init__.py +8 -0
- apps/cli/bm25_rag.py +471 -0
- apps/cli/chat_repl.py +1497 -0
- apps/cli/client.py +195 -0
- apps/cli/docs_repl.py +2275 -0
- apps/cli/light_rag.py +729 -0
- apps/cli/local_snapshots.py +139 -0
- apps/cli/locks.py +214 -0
- apps/cli/main.py +457 -0
- apps/cli/output.py +32 -0
- apps/cli/server_cmds.py +516 -0
- apps/cli/session_cmds.py +491 -0
- apps/cli/snapshot_cmds.py +303 -0
- apps/cli/state.py +265 -0
- apps/server/__init__.py +4 -0
- apps/server/app.py +1363 -0
- apps/server/main.py +313 -0
- superlinear/__init__.py +114 -0
- superlinear/_version.py +3 -0
- superlinear/engine/__init__.py +10 -0
- superlinear/engine/adapters/__init__.py +12 -0
- superlinear/engine/adapters/base.py +91 -0
- superlinear/engine/adapters/superlinear.py +1233 -0
- superlinear/engine/chat_engine.py +1173 -0
- superlinear/engine/chat_types.py +130 -0
- superlinear/engine/registry.py +51 -0
- superlinear/engine/repetition.py +203 -0
- superlinear/engine/session_snapshots.py +451 -0
- superlinear/engine/tool_parser.py +83 -0
- superlinear/engine/types.py +42 -0
- superlinear/kernels/__init__.py +2 -0
- superlinear/kernels/common/__init__.py +21 -0
- superlinear/kernels/common/adjustment.py +106 -0
- superlinear/kernels/common/power.py +154 -0
- superlinear/kernels/superlinear/__init__.py +10 -0
- superlinear/kernels/superlinear/attention/__init__.py +78 -0
- superlinear/kernels/superlinear/attention/_prefill.py +940 -0
- superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
- superlinear/kernels/superlinear/attention/api.py +433 -0
- superlinear/kernels/superlinear/search/__init__.py +33 -0
- superlinear/kernels/superlinear/search/_reference.py +204 -0
- superlinear/kernels/superlinear/search/_triton.py +488 -0
- superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
- superlinear/kernels/superlinear/search/api.py +200 -0
- superlinear/kernels/superlinear/span/__init__.py +41 -0
- superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
- superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
- superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
- superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
- superlinear/kernels/superlinear/span/api.py +296 -0
- superlinear/kernels/superlinear/span/masks.py +187 -0
- superlinear/py.typed +0 -0
- superlinear/runtime.py +71 -0
- superlinear-0.1.0.dist-info/METADATA +469 -0
- superlinear-0.1.0.dist-info/RECORD +62 -0
- superlinear-0.1.0.dist-info/WHEEL +5 -0
- superlinear-0.1.0.dist-info/entry_points.txt +2 -0
- superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
- superlinear-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Core chat request/streaming event types.
|
|
2
|
+
|
|
3
|
+
These types are internal to the library and are intentionally decoupled from:
|
|
4
|
+
- HTTP transport (FastAPI / SSE)
|
|
5
|
+
- OpenAI request/response JSON envelopes
|
|
6
|
+
|
|
7
|
+
The goal is to keep the core engine reusable for future API surfaces.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Any, Literal
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
ToolChoice = Literal["auto", "none", "required"] | dict[str, Any]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class StreamOptions:
|
|
21
|
+
"""Chunked streaming policy."""
|
|
22
|
+
|
|
23
|
+
flush_every_n_tokens: int = 8
|
|
24
|
+
flush_every_ms: int = 50
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class ToolCall:
|
|
29
|
+
"""A parsed tool call (function name + JSON-serializable arguments)."""
|
|
30
|
+
|
|
31
|
+
id: str
|
|
32
|
+
name: str
|
|
33
|
+
arguments: dict[str, Any]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class ChatMessage:
|
|
38
|
+
"""A normalized chat message.
|
|
39
|
+
|
|
40
|
+
Notes:
|
|
41
|
+
- `tool_calls` is only meaningful for assistant messages.
|
|
42
|
+
- `tool_call_id` is only meaningful for tool messages (tool outputs).
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
role: Literal["system", "user", "assistant", "tool"]
|
|
46
|
+
content: str | None = None
|
|
47
|
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
|
48
|
+
tool_call_id: str | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class ChatRequest:
|
|
53
|
+
"""Normalized internal chat request."""
|
|
54
|
+
|
|
55
|
+
messages: list[ChatMessage]
|
|
56
|
+
tools: list[dict[str, Any]] = field(default_factory=list)
|
|
57
|
+
tool_choice: ToolChoice | None = None
|
|
58
|
+
max_tokens: int = 4096
|
|
59
|
+
temperature: float = 0.0
|
|
60
|
+
top_p: float = 1.0
|
|
61
|
+
stop: list[str] = field(default_factory=list)
|
|
62
|
+
stream: bool = False
|
|
63
|
+
stream_options: StreamOptions = field(default_factory=StreamOptions)
|
|
64
|
+
chat_template_kwargs: dict[str, Any] | None = None # Additional kwargs for chat template
|
|
65
|
+
reasoning_budget: int | None = None # Max tokens for thinking phase (enables thinking when set)
|
|
66
|
+
discard_thinking: bool | None = None # If set, discard <think>...</think> from persisted session state
|
|
67
|
+
stream_thinking: bool | None = None # If True, stream <think>...</think> content as separate deltas
|
|
68
|
+
session_id: str | None = None # Session ID for stateful chat (KV cache reuse)
|
|
69
|
+
session_append_from_pos: int | None = None # If set, only append prompt tokens from this position
|
|
70
|
+
extra: dict[str, Any] = field(default_factory=dict) # Engine-specific per-request overrides
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass(frozen=True)
|
|
74
|
+
class Usage:
|
|
75
|
+
prompt_tokens: int
|
|
76
|
+
completion_tokens: int
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def total_tokens(self) -> int:
|
|
80
|
+
return self.prompt_tokens + self.completion_tokens
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True)
|
|
84
|
+
class Timing:
|
|
85
|
+
prefill_s: float | None = None
|
|
86
|
+
decode_s: float | None = None
|
|
87
|
+
total_s: float | None = None
|
|
88
|
+
tok_per_s: float | None = None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass(frozen=True)
|
|
92
|
+
class DeltaEvent:
|
|
93
|
+
"""Chunked text delta."""
|
|
94
|
+
|
|
95
|
+
text: str
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class ThinkingDeltaEvent:
|
|
100
|
+
"""Chunked thinking delta (text inside <think>...</think>)."""
|
|
101
|
+
|
|
102
|
+
text: str
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass(frozen=True)
|
|
106
|
+
class ToolCallEvent:
|
|
107
|
+
"""A completed tool call (or batch of tool calls)."""
|
|
108
|
+
|
|
109
|
+
tool_calls: list[ToolCall]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass(frozen=True)
|
|
113
|
+
class FinalEvent:
|
|
114
|
+
"""Terminal event for a generation."""
|
|
115
|
+
|
|
116
|
+
finish_reason: Literal["stop", "length", "tool_calls", "cancelled", "error", "repetition"]
|
|
117
|
+
usage: Usage
|
|
118
|
+
timing: Timing
|
|
119
|
+
raw_content: str | None = None # Unstripped content (includes <think> if any) when discard_thinking=False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass(frozen=True)
|
|
123
|
+
class ErrorEvent:
|
|
124
|
+
"""Non-terminal or terminal error event."""
|
|
125
|
+
|
|
126
|
+
message: str
|
|
127
|
+
retryable: bool = False
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
StreamEvent = DeltaEvent | ThinkingDeltaEvent | ToolCallEvent | FinalEvent | ErrorEvent
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Model adapter registry.
|
|
2
|
+
|
|
3
|
+
Maps model identifiers to their corresponding adapter classes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Type
|
|
7
|
+
|
|
8
|
+
from .adapters.base import BaseAdapter
|
|
9
|
+
from .adapters.superlinear import SuperlinearAdapter
|
|
10
|
+
|
|
11
|
+
# Registry mapping model family names to adapter classes
|
|
12
|
+
_ADAPTER_REGISTRY: dict[str, Type[BaseAdapter]] = {
|
|
13
|
+
"superlinear": SuperlinearAdapter,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_adapter(model_family: str) -> BaseAdapter:
|
|
18
|
+
"""
|
|
19
|
+
Get an adapter instance for the given model family.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_family: Name of the model family (e.g., "superlinear").
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
An adapter instance for the model family.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If the model family is not registered.
|
|
29
|
+
"""
|
|
30
|
+
if model_family not in _ADAPTER_REGISTRY:
|
|
31
|
+
available = ", ".join(_ADAPTER_REGISTRY.keys())
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Unknown model family: {model_family!r}. Available: {available}"
|
|
34
|
+
)
|
|
35
|
+
return _ADAPTER_REGISTRY[model_family]()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def register_adapter(model_family: str, adapter_cls: Type[BaseAdapter]) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Register a new adapter for a model family.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model_family: Name of the model family.
|
|
44
|
+
adapter_cls: Adapter class (must inherit from BaseAdapter).
|
|
45
|
+
"""
|
|
46
|
+
_ADAPTER_REGISTRY[model_family] = adapter_cls
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def list_model_families() -> list[str]:
|
|
50
|
+
"""Return list of registered model family names."""
|
|
51
|
+
return list(_ADAPTER_REGISTRY.keys())
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Token-level repetition detection for early stopping.
|
|
2
|
+
|
|
3
|
+
This module implements a high-precision detector for exact token periodicity
|
|
4
|
+
in the recent tail of generated token IDs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class RepeatHit:
|
|
15
|
+
period: int
|
|
16
|
+
repeats: int
|
|
17
|
+
checked_tail_len: int
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class RepetitionDetectionConfig:
|
|
22
|
+
"""Configuration for repetition early-stop.
|
|
23
|
+
|
|
24
|
+
Notes:
|
|
25
|
+
- Defaults are tuned for typical long-context Q&A.
|
|
26
|
+
- `enabled` defaults to True to catch repetition loops early.
|
|
27
|
+
- Settings are conservative to avoid false positives on legitimate content.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
enabled: bool = True
|
|
31
|
+
tail_len: int = 1024
|
|
32
|
+
check_every: int = 32
|
|
33
|
+
min_generated_tokens: int = 256
|
|
34
|
+
min_repeats: int = 3
|
|
35
|
+
max_period: int = 512
|
|
36
|
+
min_unique_tokens: int = 5
|
|
37
|
+
|
|
38
|
+
def validate(self) -> None:
|
|
39
|
+
if self.tail_len <= 0:
|
|
40
|
+
raise ValueError("'repetition_detection.tail_len' must be > 0.")
|
|
41
|
+
if self.check_every <= 0:
|
|
42
|
+
raise ValueError("'repetition_detection.check_every' must be > 0.")
|
|
43
|
+
if self.min_generated_tokens < 0:
|
|
44
|
+
raise ValueError("'repetition_detection.min_generated_tokens' must be >= 0.")
|
|
45
|
+
if self.min_repeats < 2:
|
|
46
|
+
raise ValueError("'repetition_detection.min_repeats' must be >= 2.")
|
|
47
|
+
if self.max_period <= 0:
|
|
48
|
+
raise ValueError("'repetition_detection.max_period' must be > 0.")
|
|
49
|
+
if self.min_unique_tokens <= 0:
|
|
50
|
+
raise ValueError("'repetition_detection.min_unique_tokens' must be > 0.")
|
|
51
|
+
|
|
52
|
+
def merged(self, override: Any | None) -> "RepetitionDetectionConfig":
|
|
53
|
+
"""Merge a request-level override (typically request.extra['repetition_detection'])."""
|
|
54
|
+
if override is None:
|
|
55
|
+
return self
|
|
56
|
+
if isinstance(override, RepetitionDetectionConfig):
|
|
57
|
+
override.validate()
|
|
58
|
+
return override
|
|
59
|
+
if not isinstance(override, dict):
|
|
60
|
+
raise ValueError("'repetition_detection' must be an object.")
|
|
61
|
+
|
|
62
|
+
data: dict[str, Any] = dict(override)
|
|
63
|
+
if "min_unique_tokens" not in data and "min_unique_tokens_in_period" in data:
|
|
64
|
+
data["min_unique_tokens"] = data["min_unique_tokens_in_period"]
|
|
65
|
+
|
|
66
|
+
enabled = self.enabled
|
|
67
|
+
if "enabled" in data:
|
|
68
|
+
raw_enabled = data["enabled"]
|
|
69
|
+
if not isinstance(raw_enabled, bool):
|
|
70
|
+
raise ValueError("'repetition_detection.enabled' must be a boolean.")
|
|
71
|
+
enabled = raw_enabled
|
|
72
|
+
tail_len = self.tail_len if "tail_len" not in data else _coerce_int(data["tail_len"], "tail_len")
|
|
73
|
+
check_every = (
|
|
74
|
+
self.check_every
|
|
75
|
+
if "check_every" not in data
|
|
76
|
+
else _coerce_int(data["check_every"], "check_every")
|
|
77
|
+
)
|
|
78
|
+
min_generated_tokens = (
|
|
79
|
+
self.min_generated_tokens
|
|
80
|
+
if "min_generated_tokens" not in data
|
|
81
|
+
else _coerce_int(data["min_generated_tokens"], "min_generated_tokens", min_value=0)
|
|
82
|
+
)
|
|
83
|
+
min_repeats = (
|
|
84
|
+
self.min_repeats
|
|
85
|
+
if "min_repeats" not in data
|
|
86
|
+
else _coerce_int(data["min_repeats"], "min_repeats", min_value=2)
|
|
87
|
+
)
|
|
88
|
+
max_period = (
|
|
89
|
+
self.max_period
|
|
90
|
+
if "max_period" not in data
|
|
91
|
+
else _coerce_int(data["max_period"], "max_period")
|
|
92
|
+
)
|
|
93
|
+
min_unique_tokens = (
|
|
94
|
+
self.min_unique_tokens
|
|
95
|
+
if "min_unique_tokens" not in data
|
|
96
|
+
else _coerce_int(data["min_unique_tokens"], "min_unique_tokens", min_value=1)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
merged = RepetitionDetectionConfig(
|
|
100
|
+
enabled=enabled,
|
|
101
|
+
tail_len=tail_len,
|
|
102
|
+
check_every=check_every,
|
|
103
|
+
min_generated_tokens=min_generated_tokens,
|
|
104
|
+
min_repeats=min_repeats,
|
|
105
|
+
max_period=max_period,
|
|
106
|
+
min_unique_tokens=min_unique_tokens,
|
|
107
|
+
)
|
|
108
|
+
merged.validate()
|
|
109
|
+
return merged
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _coerce_int(value: Any, name: str, *, min_value: int = 1) -> int:
|
|
113
|
+
if isinstance(value, bool):
|
|
114
|
+
raise ValueError(f"'repetition_detection.{name}' must be an integer.")
|
|
115
|
+
try:
|
|
116
|
+
out = int(value)
|
|
117
|
+
except Exception as exc:
|
|
118
|
+
raise ValueError(f"'repetition_detection.{name}' must be an integer.") from exc
|
|
119
|
+
if out < min_value:
|
|
120
|
+
raise ValueError(f"'repetition_detection.{name}' must be >= {min_value}.")
|
|
121
|
+
return out
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def prefix_function(seq: list[int]) -> list[int]:
|
|
125
|
+
"""Classic KMP prefix-function (pi array) for a sequence of ints.
|
|
126
|
+
|
|
127
|
+
pi[i] = length of the longest proper prefix of seq[:i+1]
|
|
128
|
+
that is also a suffix of seq[:i+1].
|
|
129
|
+
"""
|
|
130
|
+
n = len(seq)
|
|
131
|
+
pi = [0] * n
|
|
132
|
+
j = 0
|
|
133
|
+
for i in range(1, n):
|
|
134
|
+
while j > 0 and seq[i] != seq[j]:
|
|
135
|
+
j = pi[j - 1]
|
|
136
|
+
if j < n and seq[i] == seq[j]:
|
|
137
|
+
j += 1
|
|
138
|
+
pi[i] = j
|
|
139
|
+
return pi
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _nontrivial_period(period_tokens: list[int], *, min_unique_tokens: int) -> bool:
|
|
143
|
+
# Avoid stopping on junk like a single token or whitespace/punctuation-only loops.
|
|
144
|
+
return len(set(period_tokens)) >= min_unique_tokens
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def detect_repetition_kmp_tail(
|
|
148
|
+
tokens: list[int],
|
|
149
|
+
*,
|
|
150
|
+
tail_len: int = 1024,
|
|
151
|
+
min_generated_tokens: int = 256,
|
|
152
|
+
min_repeats: int = 3,
|
|
153
|
+
max_period: int = 512,
|
|
154
|
+
min_unique_tokens: int = 4,
|
|
155
|
+
) -> RepeatHit | None:
|
|
156
|
+
"""Detect exact periodic repetition using only the last `tail_len` tokens.
|
|
157
|
+
|
|
158
|
+
Strategy:
|
|
159
|
+
- Compute KMP prefix function on the tail.
|
|
160
|
+
- Walk the border chain to derive candidate periods.
|
|
161
|
+
- Validate candidate periods by checking the last `min_repeats` blocks match.
|
|
162
|
+
"""
|
|
163
|
+
if tail_len <= 0:
|
|
164
|
+
raise ValueError("'tail_len' must be > 0.")
|
|
165
|
+
if min_generated_tokens < 0:
|
|
166
|
+
raise ValueError("'min_generated_tokens' must be >= 0.")
|
|
167
|
+
if min_repeats < 2:
|
|
168
|
+
raise ValueError("'min_repeats' must be >= 2.")
|
|
169
|
+
if max_period <= 0:
|
|
170
|
+
raise ValueError("'max_period' must be > 0.")
|
|
171
|
+
if min_unique_tokens <= 0:
|
|
172
|
+
raise ValueError("'min_unique_tokens' must be > 0.")
|
|
173
|
+
|
|
174
|
+
if len(tokens) < min_generated_tokens:
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
tail = tokens[-tail_len:] if len(tokens) > tail_len else tokens
|
|
178
|
+
L = len(tail)
|
|
179
|
+
if L < min_repeats:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
pi = prefix_function(tail)
|
|
183
|
+
if not pi:
|
|
184
|
+
return None
|
|
185
|
+
b = pi[-1]
|
|
186
|
+
|
|
187
|
+
# Walk border chain: b -> pi[b-1] -> ...
|
|
188
|
+
while b > 0:
|
|
189
|
+
p = L - b
|
|
190
|
+
if 1 <= p <= max_period:
|
|
191
|
+
need = p * min_repeats
|
|
192
|
+
if L >= need:
|
|
193
|
+
a = tail[-p:]
|
|
194
|
+
ok = True
|
|
195
|
+
for r in range(2, min_repeats + 1):
|
|
196
|
+
if tail[-r * p : -(r - 1) * p] != a:
|
|
197
|
+
ok = False
|
|
198
|
+
break
|
|
199
|
+
if ok and _nontrivial_period(a, min_unique_tokens=min_unique_tokens):
|
|
200
|
+
return RepeatHit(period=p, repeats=min_repeats, checked_tail_len=L)
|
|
201
|
+
b = pi[b - 1]
|
|
202
|
+
|
|
203
|
+
return None
|