bit-ttt-engine 0.6.2__cp310-cp310-win_amd64.whl → 0.7.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bit_ttt_engine-0.7.0.dist-info/METADATA +136 -0
- bit_ttt_engine-0.7.0.dist-info/RECORD +14 -0
- bit_ttt_engine-0.7.0.dist-info/entry_points.txt +2 -0
- cortex_rust/__init__.py +21 -1
- cortex_rust/__main__.py +4 -0
- cortex_rust/__pycache__/__init__.cpython-310.pyc +0 -0
- cortex_rust/chat.py +196 -0
- cortex_rust/cli.py +381 -0
- cortex_rust/cortex_rust.cp310-win_amd64.pyd +0 -0
- cortex_rust/engine.py +253 -0
- cortex_rust/server.py +493 -0
- bit_ttt_engine-0.6.2.dist-info/METADATA +0 -118
- bit_ttt_engine-0.6.2.dist-info/RECORD +0 -9
- cortex_rust/__init__.pyi +0 -100
- cortex_rust/py.typed +0 -0
- {bit_ttt_engine-0.6.2.dist-info → bit_ttt_engine-0.7.0.dist-info}/WHEEL +0 -0
- {bit_ttt_engine-0.6.2.dist-info → bit_ttt_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
cortex_rust/cli.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""Bit-TTT-Engine CLI — interactive chat and text generation.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
python -m cortex_rust chat model.gguf [--template llama3] [--system "..."]
|
|
5
|
+
python -m cortex_rust generate model.gguf --prompt "..." [--max-tokens 200]
|
|
6
|
+
python -m cortex_rust info model.gguf
|
|
7
|
+
|
|
8
|
+
Or if installed with pip:
|
|
9
|
+
bit-ttt chat model.gguf
|
|
10
|
+
bit-ttt generate model.gguf --prompt "Hello"
|
|
11
|
+
bit-ttt info model.gguf
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import sys
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _resolve_model_path(model_path: str) -> str:
|
|
22
|
+
"""Resolve model path: local file or HuggingFace repo ID.
|
|
23
|
+
|
|
24
|
+
If model_path is a local file, return as-is.
|
|
25
|
+
If it looks like a HF repo ID (contains '/'), auto-download the best GGUF.
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
"model.gguf" → local file
|
|
29
|
+
"bartowski/Llama-3-8B-Instruct-GGUF" → download Q4_K_M
|
|
30
|
+
"user/repo:Q8_0" → download Q8_0 variant
|
|
31
|
+
"""
|
|
32
|
+
# Local file exists — use directly
|
|
33
|
+
if os.path.exists(model_path):
|
|
34
|
+
return model_path
|
|
35
|
+
|
|
36
|
+
# Not a repo ID pattern — treat as local path (will error later)
|
|
37
|
+
if "/" not in model_path:
|
|
38
|
+
return model_path
|
|
39
|
+
|
|
40
|
+
# Parse repo ID and optional quantization hint
|
|
41
|
+
# Format: "user/repo" or "user/repo:Q8_0"
|
|
42
|
+
if ":" in model_path:
|
|
43
|
+
repo_id, quant_hint = model_path.rsplit(":", 1)
|
|
44
|
+
else:
|
|
45
|
+
repo_id, quant_hint = model_path, "Q4_K_M"
|
|
46
|
+
|
|
47
|
+
print(f"📥 Resolving model from HuggingFace: {repo_id}")
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from huggingface_hub import HfApi, hf_hub_download
|
|
51
|
+
except ImportError:
|
|
52
|
+
print("❌ huggingface_hub required for auto-download.")
|
|
53
|
+
print(" Install: pip install huggingface_hub")
|
|
54
|
+
sys.exit(1)
|
|
55
|
+
|
|
56
|
+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "bit-ttt")
|
|
57
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
api = HfApi()
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
# List GGUF files in the repo
|
|
63
|
+
files = api.list_repo_files(repo_id)
|
|
64
|
+
gguf_files = [f for f in files if f.endswith(".gguf")]
|
|
65
|
+
|
|
66
|
+
if not gguf_files:
|
|
67
|
+
print(f"❌ No GGUF files found in {repo_id}")
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
# Find best match for quantization hint
|
|
71
|
+
quant_lower = quant_hint.lower().replace("-", "_")
|
|
72
|
+
matched = [f for f in gguf_files if quant_lower in f.lower()]
|
|
73
|
+
|
|
74
|
+
if matched:
|
|
75
|
+
target = matched[0]
|
|
76
|
+
else:
|
|
77
|
+
# Prefer Q4_K_M > Q4_K_S > first available
|
|
78
|
+
for pref in ["Q4_K_M", "Q4_K_S", "Q4_0", "Q5_K_M", "Q8_0"]:
|
|
79
|
+
pref_match = [f for f in gguf_files if pref.lower() in f.lower()]
|
|
80
|
+
if pref_match:
|
|
81
|
+
target = pref_match[0]
|
|
82
|
+
break
|
|
83
|
+
else:
|
|
84
|
+
target = gguf_files[0]
|
|
85
|
+
|
|
86
|
+
print(f"📦 Selected: {target}")
|
|
87
|
+
print(f" Downloading to cache...")
|
|
88
|
+
|
|
89
|
+
local_path = hf_hub_download(
|
|
90
|
+
repo_id=repo_id,
|
|
91
|
+
filename=target,
|
|
92
|
+
cache_dir=cache_dir,
|
|
93
|
+
local_dir=os.path.join(cache_dir, repo_id.replace("/", "--")),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Also try to download tokenizer.json
|
|
97
|
+
if "tokenizer.json" in files:
|
|
98
|
+
try:
|
|
99
|
+
hf_hub_download(
|
|
100
|
+
repo_id=repo_id,
|
|
101
|
+
filename="tokenizer.json",
|
|
102
|
+
cache_dir=cache_dir,
|
|
103
|
+
local_dir=os.path.join(cache_dir, repo_id.replace("/", "--")),
|
|
104
|
+
)
|
|
105
|
+
except Exception:
|
|
106
|
+
pass # tokenizer is optional
|
|
107
|
+
|
|
108
|
+
print(f"✅ Downloaded: {local_path}")
|
|
109
|
+
return local_path
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
print(f"❌ Failed to download from {repo_id}: {e}")
|
|
113
|
+
sys.exit(1)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _load_model(model_path: str, device: str = "cuda", tokenizer: Optional[str] = None):
|
|
117
|
+
"""Load a GGUF model, auto-detecting tokenizer if needed.
|
|
118
|
+
|
|
119
|
+
model_path can be a local file or a HuggingFace repo ID.
|
|
120
|
+
"""
|
|
121
|
+
from cortex_rust import QGgufModel
|
|
122
|
+
|
|
123
|
+
# Resolve HF repo ID → local path
|
|
124
|
+
resolved_path = _resolve_model_path(model_path)
|
|
125
|
+
|
|
126
|
+
tok = tokenizer
|
|
127
|
+
if tok is None:
|
|
128
|
+
# Try to find tokenizer.json in same directory
|
|
129
|
+
model_dir = os.path.dirname(os.path.abspath(resolved_path))
|
|
130
|
+
candidate = os.path.join(model_dir, "tokenizer.json")
|
|
131
|
+
if os.path.exists(candidate):
|
|
132
|
+
tok = candidate
|
|
133
|
+
|
|
134
|
+
kwargs = {"device": device}
|
|
135
|
+
if tok:
|
|
136
|
+
kwargs["tokenizer"] = tok
|
|
137
|
+
|
|
138
|
+
return QGgufModel(resolved_path, **kwargs)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _apply_feature_flags(model, args):
|
|
142
|
+
"""Apply --lora, --q8-cache, --ttt flags to a loaded model."""
|
|
143
|
+
if getattr(args, 'q8_cache', False):
|
|
144
|
+
model.enable_q8_kv_cache(True)
|
|
145
|
+
print(" 🔧 Q8 KV cache enabled")
|
|
146
|
+
if getattr(args, 'ttt', False):
|
|
147
|
+
model.enable_ttt(True)
|
|
148
|
+
print(" 🔧 TTT (Test-Time Training) enabled")
|
|
149
|
+
if getattr(args, 'lora', None):
|
|
150
|
+
model.load_lora(args.lora)
|
|
151
|
+
model.enable_lora(True)
|
|
152
|
+
print(f" 🔧 LoRA loaded: {args.lora}")
|
|
153
|
+
return model
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def cmd_info(args):
|
|
157
|
+
"""Show model information."""
|
|
158
|
+
model = _load_model(args.model, device="cpu")
|
|
159
|
+
config = model.config
|
|
160
|
+
|
|
161
|
+
print(f"📦 Model: {os.path.basename(args.model)}")
|
|
162
|
+
print(f" Architecture: {getattr(config, 'arch', 'unknown')}")
|
|
163
|
+
print(f" Vocab size: {config.vocab_size:,}")
|
|
164
|
+
print(f" Hidden dim: {config.hidden_dim:,}")
|
|
165
|
+
print(f" Layers: {config.num_layers}")
|
|
166
|
+
print(f" Q heads: {config.n_heads}")
|
|
167
|
+
print(f" KV heads: {config.n_kv_heads}")
|
|
168
|
+
print(f" Head dim: {config.hidden_dim // config.n_heads}")
|
|
169
|
+
print(f" GQA ratio: {config.n_heads // config.n_kv_heads}")
|
|
170
|
+
|
|
171
|
+
from cortex_rust.chat import detect_template
|
|
172
|
+
template = detect_template(args.model)
|
|
173
|
+
print(f" Chat template: {template}")
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def cmd_generate(args):
|
|
177
|
+
"""Generate text from a prompt."""
|
|
178
|
+
from cortex_rust.chat import format_simple, detect_template
|
|
179
|
+
|
|
180
|
+
print(f"Loading {os.path.basename(args.model)}...", flush=True)
|
|
181
|
+
model = _load_model(args.model, device=args.device, tokenizer=args.tokenizer)
|
|
182
|
+
_apply_feature_flags(model, args)
|
|
183
|
+
|
|
184
|
+
template = args.template or detect_template(args.model)
|
|
185
|
+
|
|
186
|
+
if args.raw:
|
|
187
|
+
prompt = args.prompt
|
|
188
|
+
else:
|
|
189
|
+
prompt = format_simple(args.prompt, system_message=args.system, template=template)
|
|
190
|
+
|
|
191
|
+
print(f"Generating (template={template}, max_tokens={args.max_tokens})...\n")
|
|
192
|
+
|
|
193
|
+
start = time.time()
|
|
194
|
+
|
|
195
|
+
# Use streaming callback
|
|
196
|
+
token_count = [0]
|
|
197
|
+
def on_token(token_str):
|
|
198
|
+
print(token_str, end="", flush=True)
|
|
199
|
+
token_count[0] += 1
|
|
200
|
+
|
|
201
|
+
model.generate_with_callback(
|
|
202
|
+
prompt,
|
|
203
|
+
on_token,
|
|
204
|
+
max_tokens=args.max_tokens,
|
|
205
|
+
temperature=args.temperature,
|
|
206
|
+
top_k=args.top_k,
|
|
207
|
+
top_p=args.top_p,
|
|
208
|
+
repetition_penalty=args.repetition_penalty,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
elapsed = time.time() - start
|
|
212
|
+
speed = token_count[0] / elapsed if elapsed > 0 else 0
|
|
213
|
+
print(f"\n\n--- {token_count[0]} tokens in {elapsed:.1f}s ({speed:.1f} tok/s) ---")
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def cmd_chat(args):
|
|
217
|
+
"""Interactive chat mode."""
|
|
218
|
+
from cortex_rust.chat import format_chat, detect_template
|
|
219
|
+
|
|
220
|
+
print(f"Loading {os.path.basename(args.model)}...", flush=True)
|
|
221
|
+
model = _load_model(args.model, device=args.device, tokenizer=args.tokenizer)
|
|
222
|
+
_apply_feature_flags(model, args)
|
|
223
|
+
|
|
224
|
+
template = args.template or detect_template(args.model)
|
|
225
|
+
|
|
226
|
+
print(f"💬 Chat mode (template={template})")
|
|
227
|
+
print(f" Type 'quit' or Ctrl+C to exit")
|
|
228
|
+
print(f" Type '/reset' to clear conversation")
|
|
229
|
+
print(f" Type '/system <msg>' to change system prompt")
|
|
230
|
+
print()
|
|
231
|
+
|
|
232
|
+
system_msg = args.system
|
|
233
|
+
messages = []
|
|
234
|
+
if system_msg:
|
|
235
|
+
messages.append({"role": "system", "content": system_msg})
|
|
236
|
+
|
|
237
|
+
while True:
|
|
238
|
+
try:
|
|
239
|
+
user_input = input("You: ").strip()
|
|
240
|
+
except (KeyboardInterrupt, EOFError):
|
|
241
|
+
print("\nBye! 👋")
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
if not user_input:
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
if user_input.lower() == "quit":
|
|
248
|
+
print("Bye! 👋")
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
if user_input.lower() == "/reset":
|
|
252
|
+
messages = []
|
|
253
|
+
if system_msg:
|
|
254
|
+
messages.append({"role": "system", "content": system_msg})
|
|
255
|
+
model.reset_cache()
|
|
256
|
+
print("🔄 Conversation reset.\n")
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
if user_input.lower().startswith("/system "):
|
|
260
|
+
system_msg = user_input[8:].strip()
|
|
261
|
+
messages = [{"role": "system", "content": system_msg}]
|
|
262
|
+
model.reset_cache()
|
|
263
|
+
print(f"📝 System prompt updated: {system_msg}\n")
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
messages.append({"role": "user", "content": user_input})
|
|
267
|
+
prompt = format_chat(messages, template=template, add_generation_prompt=True)
|
|
268
|
+
|
|
269
|
+
# Reset cache and regenerate from full conversation
|
|
270
|
+
model.reset_cache()
|
|
271
|
+
|
|
272
|
+
print("Assistant: ", end="", flush=True)
|
|
273
|
+
|
|
274
|
+
collected = []
|
|
275
|
+
start = time.time()
|
|
276
|
+
|
|
277
|
+
def on_token(token_str):
|
|
278
|
+
print(token_str, end="", flush=True)
|
|
279
|
+
collected.append(token_str)
|
|
280
|
+
|
|
281
|
+
model.generate_with_callback(
|
|
282
|
+
prompt,
|
|
283
|
+
on_token,
|
|
284
|
+
max_tokens=args.max_tokens,
|
|
285
|
+
temperature=args.temperature,
|
|
286
|
+
top_k=args.top_k,
|
|
287
|
+
top_p=args.top_p,
|
|
288
|
+
repetition_penalty=args.repetition_penalty,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
elapsed = time.time() - start
|
|
292
|
+
speed = len(collected) / elapsed if elapsed > 0 else 0
|
|
293
|
+
print(f"\n [{len(collected)} tok, {speed:.1f} tok/s]\n")
|
|
294
|
+
|
|
295
|
+
# Add assistant response to history
|
|
296
|
+
assistant_response = "".join(collected)
|
|
297
|
+
messages.append({"role": "assistant", "content": assistant_response})
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def cmd_serve(args):
|
|
301
|
+
"""Start OpenAI-compatible API server."""
|
|
302
|
+
from cortex_rust.server import serve
|
|
303
|
+
serve(
|
|
304
|
+
model_path=args.model,
|
|
305
|
+
host=args.host,
|
|
306
|
+
port=args.port,
|
|
307
|
+
device=args.device,
|
|
308
|
+
tokenizer=args.tokenizer,
|
|
309
|
+
lora=getattr(args, 'lora', None),
|
|
310
|
+
q8_cache=getattr(args, 'q8_cache', False),
|
|
311
|
+
ttt=getattr(args, 'ttt', False),
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def main():
|
|
316
|
+
parser = argparse.ArgumentParser(
|
|
317
|
+
prog="bit-ttt",
|
|
318
|
+
description="Bit-TTT-Engine — Fast LLM inference with 1.58-bit quantization",
|
|
319
|
+
)
|
|
320
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
321
|
+
|
|
322
|
+
# Common arguments
|
|
323
|
+
def add_common_args(p):
|
|
324
|
+
p.add_argument("model", help="Path to GGUF model file")
|
|
325
|
+
p.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device (default: cuda)")
|
|
326
|
+
p.add_argument("--tokenizer", default=None, help="Path to tokenizer.json (auto-detected if not set)")
|
|
327
|
+
|
|
328
|
+
def add_generation_args(p):
|
|
329
|
+
p.add_argument("--max-tokens", type=int, default=512, help="Max tokens to generate (default: 512)")
|
|
330
|
+
p.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (default: 0.7)")
|
|
331
|
+
p.add_argument("--top-k", type=int, default=40, help="Top-K sampling (default: 40)")
|
|
332
|
+
p.add_argument("--top-p", type=float, default=0.9, help="Top-P (nucleus) sampling (default: 0.9)")
|
|
333
|
+
p.add_argument("--repetition-penalty", type=float, default=1.1, help="Repetition penalty (default: 1.1)")
|
|
334
|
+
p.add_argument("--template", default=None, help="Chat template (llama3/llama2/gemma2/chatml, auto-detected)")
|
|
335
|
+
p.add_argument("--system", default=None, help="System prompt")
|
|
336
|
+
|
|
337
|
+
def add_feature_args(p):
|
|
338
|
+
p.add_argument("--lora", default=None, metavar="PATH", help="Load LoRA adapter from file")
|
|
339
|
+
p.add_argument("--q8-cache", action="store_true", help="Enable Q8 KV cache (saves ~82%% VRAM)")
|
|
340
|
+
p.add_argument("--ttt", action="store_true", help="Enable Test-Time Training (online learning)")
|
|
341
|
+
|
|
342
|
+
# info
|
|
343
|
+
p_info = subparsers.add_parser("info", help="Show model information")
|
|
344
|
+
add_common_args(p_info)
|
|
345
|
+
p_info.set_defaults(func=cmd_info)
|
|
346
|
+
|
|
347
|
+
# generate
|
|
348
|
+
p_gen = subparsers.add_parser("generate", aliases=["gen"], help="Generate text from a prompt")
|
|
349
|
+
add_common_args(p_gen)
|
|
350
|
+
add_generation_args(p_gen)
|
|
351
|
+
add_feature_args(p_gen)
|
|
352
|
+
p_gen.add_argument("--prompt", "-p", required=True, help="Input prompt")
|
|
353
|
+
p_gen.add_argument("--raw", action="store_true", help="Don't apply chat template")
|
|
354
|
+
p_gen.set_defaults(func=cmd_generate)
|
|
355
|
+
|
|
356
|
+
# chat
|
|
357
|
+
p_chat = subparsers.add_parser("chat", help="Interactive chat mode")
|
|
358
|
+
add_common_args(p_chat)
|
|
359
|
+
add_generation_args(p_chat)
|
|
360
|
+
add_feature_args(p_chat)
|
|
361
|
+
p_chat.set_defaults(func=cmd_chat)
|
|
362
|
+
|
|
363
|
+
# serve
|
|
364
|
+
p_serve = subparsers.add_parser("serve", help="Start OpenAI-compatible API server")
|
|
365
|
+
add_common_args(p_serve)
|
|
366
|
+
add_feature_args(p_serve)
|
|
367
|
+
p_serve.add_argument("--host", default="0.0.0.0", help="Bind host (default: 0.0.0.0)")
|
|
368
|
+
p_serve.add_argument("--port", type=int, default=8000, help="Bind port (default: 8000)")
|
|
369
|
+
p_serve.set_defaults(func=cmd_serve)
|
|
370
|
+
|
|
371
|
+
args = parser.parse_args()
|
|
372
|
+
|
|
373
|
+
if not args.command:
|
|
374
|
+
parser.print_help()
|
|
375
|
+
sys.exit(1)
|
|
376
|
+
|
|
377
|
+
args.func(args)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
if __name__ == "__main__":
|
|
381
|
+
main()
|
|
Binary file
|
cortex_rust/engine.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""High-level Python SDK for Bit-TTT-Engine.
|
|
2
|
+
|
|
3
|
+
Simple, Pythonic interface for LLM inference.
|
|
4
|
+
|
|
5
|
+
Quick start:
|
|
6
|
+
from cortex_rust.engine import load
|
|
7
|
+
|
|
8
|
+
model = load("model.gguf") # local file
|
|
9
|
+
model = load("bartowski/Llama-3-8B-Instruct-GGUF") # HuggingFace auto-download
|
|
10
|
+
|
|
11
|
+
# Generate text
|
|
12
|
+
text = model.generate("Once upon a time")
|
|
13
|
+
|
|
14
|
+
# Chat (auto-applies chat template)
|
|
15
|
+
response = model.chat([
|
|
16
|
+
{"role": "user", "content": "What is 2+2?"}
|
|
17
|
+
])
|
|
18
|
+
|
|
19
|
+
# Stream tokens
|
|
20
|
+
for token in model.stream("Tell me a story"):
|
|
21
|
+
print(token, end="", flush=True)
|
|
22
|
+
|
|
23
|
+
# Chat stream
|
|
24
|
+
for token in model.chat_stream([
|
|
25
|
+
{"role": "system", "content": "You are a pirate."},
|
|
26
|
+
{"role": "user", "content": "Hello!"},
|
|
27
|
+
]):
|
|
28
|
+
print(token, end="", flush=True)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import os
|
|
32
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
33
|
+
|
|
34
|
+
Messages = List[Dict[str, str]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Model:
|
|
38
|
+
"""High-level model wrapper with chat, generate, and stream support."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, inner, template: str, model_path: str,
|
|
41
|
+
tokenizer=None):
|
|
42
|
+
self._inner = inner
|
|
43
|
+
self._template = template
|
|
44
|
+
self._model_path = model_path
|
|
45
|
+
self._tokenizer = tokenizer # tokenizers.Tokenizer instance
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def config(self):
|
|
49
|
+
"""Access model configuration."""
|
|
50
|
+
return self._inner.config
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def template(self) -> str:
|
|
54
|
+
"""Chat template name (llama3, llama2, gemma2, chatml)."""
|
|
55
|
+
return self._template
|
|
56
|
+
|
|
57
|
+
# ---- Generation ----
|
|
58
|
+
|
|
59
|
+
def generate(self, prompt: str, max_tokens: int = 256,
|
|
60
|
+
temperature: float = 0.7, top_p: float = 0.9,
|
|
61
|
+
repetition_penalty: float = 1.1, **kwargs) -> str:
|
|
62
|
+
"""Generate text from a prompt (no chat template applied).
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
prompt: Raw text prompt.
|
|
66
|
+
max_tokens: Maximum tokens to generate.
|
|
67
|
+
temperature: Sampling temperature (0 = greedy).
|
|
68
|
+
top_p: Nucleus sampling threshold.
|
|
69
|
+
repetition_penalty: Repetition penalty.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Generated text string.
|
|
73
|
+
"""
|
|
74
|
+
self._inner.reset_cache()
|
|
75
|
+
return self._inner.generate(
|
|
76
|
+
prompt,
|
|
77
|
+
max_tokens=max_tokens,
|
|
78
|
+
temperature=temperature,
|
|
79
|
+
top_p=top_p,
|
|
80
|
+
repetition_penalty=repetition_penalty,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def chat(self, messages: Messages, max_tokens: int = 256,
|
|
84
|
+
temperature: float = 0.7, top_p: float = 0.9,
|
|
85
|
+
repetition_penalty: float = 1.1, **kwargs) -> str:
|
|
86
|
+
"""Chat completion with automatic template formatting.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
messages: List of {"role": ..., "content": ...} dicts.
|
|
90
|
+
max_tokens: Maximum tokens to generate.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Assistant response text.
|
|
94
|
+
"""
|
|
95
|
+
from cortex_rust.chat import format_chat
|
|
96
|
+
prompt = format_chat(messages, template=self._template)
|
|
97
|
+
return self.generate(prompt, max_tokens=max_tokens,
|
|
98
|
+
temperature=temperature, top_p=top_p,
|
|
99
|
+
repetition_penalty=repetition_penalty)
|
|
100
|
+
|
|
101
|
+
# ---- Streaming ----
|
|
102
|
+
|
|
103
|
+
def stream(self, prompt: str, max_tokens: int = 256,
|
|
104
|
+
temperature: float = 0.7, top_p: float = 0.9,
|
|
105
|
+
repetition_penalty: float = 1.1) -> Iterator[str]:
|
|
106
|
+
"""Stream tokens from a prompt.
|
|
107
|
+
|
|
108
|
+
Uses generate_from_tokens internally and decodes token-by-token.
|
|
109
|
+
If tokenizer is not available, falls back to yielding the full output.
|
|
110
|
+
|
|
111
|
+
Yields:
|
|
112
|
+
Individual token strings as they are generated.
|
|
113
|
+
"""
|
|
114
|
+
self._inner.reset_cache()
|
|
115
|
+
|
|
116
|
+
# Try token-by-token via generate_from_tokens
|
|
117
|
+
if self._tokenizer is not None:
|
|
118
|
+
encoding = self._tokenizer.encode(prompt, add_special_tokens=True)
|
|
119
|
+
prompt_ids = [int(t) for t in encoding.ids]
|
|
120
|
+
if not prompt_ids:
|
|
121
|
+
prompt_ids = [1]
|
|
122
|
+
|
|
123
|
+
gen_ids = self._inner.generate_from_tokens(
|
|
124
|
+
prompt_ids,
|
|
125
|
+
max_tokens=max_tokens,
|
|
126
|
+
temperature=temperature,
|
|
127
|
+
top_p=top_p,
|
|
128
|
+
repetition_penalty=repetition_penalty,
|
|
129
|
+
)
|
|
130
|
+
# Decode incrementally to preserve whitespace
|
|
131
|
+
all_ids = []
|
|
132
|
+
prev_text = ""
|
|
133
|
+
for tid in gen_ids:
|
|
134
|
+
all_ids.append(int(tid))
|
|
135
|
+
full_text = self._tokenizer.decode(all_ids, skip_special_tokens=True)
|
|
136
|
+
delta = full_text[len(prev_text):]
|
|
137
|
+
prev_text = full_text
|
|
138
|
+
if delta:
|
|
139
|
+
yield delta
|
|
140
|
+
else:
|
|
141
|
+
# Fallback: yield full output at once
|
|
142
|
+
text = self._inner.generate(
|
|
143
|
+
prompt,
|
|
144
|
+
max_tokens=max_tokens,
|
|
145
|
+
temperature=temperature,
|
|
146
|
+
top_p=top_p,
|
|
147
|
+
repetition_penalty=repetition_penalty,
|
|
148
|
+
)
|
|
149
|
+
yield text
|
|
150
|
+
|
|
151
|
+
def chat_stream(self, messages: Messages, max_tokens: int = 256,
|
|
152
|
+
temperature: float = 0.7, top_p: float = 0.9,
|
|
153
|
+
repetition_penalty: float = 1.1) -> Iterator[str]:
|
|
154
|
+
"""Stream chat completion tokens.
|
|
155
|
+
|
|
156
|
+
Yields:
|
|
157
|
+
Individual token strings.
|
|
158
|
+
|
|
159
|
+
Example:
|
|
160
|
+
for token in model.chat_stream([{"role": "user", "content": "Hi!"}]):
|
|
161
|
+
print(token, end="", flush=True)
|
|
162
|
+
"""
|
|
163
|
+
from cortex_rust.chat import format_chat
|
|
164
|
+
prompt = format_chat(messages, template=self._template)
|
|
165
|
+
yield from self.stream(prompt, max_tokens=max_tokens,
|
|
166
|
+
temperature=temperature, top_p=top_p,
|
|
167
|
+
repetition_penalty=repetition_penalty)
|
|
168
|
+
|
|
169
|
+
# ---- Configuration ----
|
|
170
|
+
|
|
171
|
+
def enable_q8_cache(self, enabled: bool = True):
|
|
172
|
+
"""Enable/disable Q8 KV cache (saves ~82% VRAM)."""
|
|
173
|
+
self._inner.enable_q8_kv_cache(enabled)
|
|
174
|
+
return self
|
|
175
|
+
|
|
176
|
+
def enable_ttt(self, enabled: bool = True, lr: float = 0.01):
|
|
177
|
+
"""Enable/disable Test-Time Training."""
|
|
178
|
+
self._inner.enable_ttt(enabled)
|
|
179
|
+
return self
|
|
180
|
+
|
|
181
|
+
def create_lora(self, rank: int = 8, alpha: float = 16.0,
|
|
182
|
+
target_modules: Optional[List[str]] = None):
|
|
183
|
+
"""Create LoRA adapters for fine-tuning.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
rank: LoRA rank (lower = faster, less capacity).
|
|
187
|
+
alpha: LoRA alpha scaling factor.
|
|
188
|
+
target_modules: List of modules to apply LoRA to.
|
|
189
|
+
Default: ["q_proj", "v_proj"]
|
|
190
|
+
"""
|
|
191
|
+
modules = target_modules or ["q_proj", "v_proj"]
|
|
192
|
+
self._inner.create_lora(rank=rank, alpha=alpha, target_modules=modules)
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
def reset(self):
|
|
196
|
+
"""Reset KV cache and TTT state."""
|
|
197
|
+
self._inner.reset_cache()
|
|
198
|
+
return self
|
|
199
|
+
|
|
200
|
+
def __repr__(self):
|
|
201
|
+
name = os.path.basename(self._model_path)
|
|
202
|
+
return f"Model({name}, template={self._template})"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def load(model_path: str, device: str = "cuda",
|
|
206
|
+
tokenizer: Optional[str] = None) -> Model:
|
|
207
|
+
"""Load a model from a local file or HuggingFace repo.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
model_path: Local GGUF file path or HuggingFace repo ID.
|
|
211
|
+
Examples:
|
|
212
|
+
"model.gguf"
|
|
213
|
+
"bartowski/Llama-3-8B-Instruct-GGUF"
|
|
214
|
+
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF:Q4_K_M"
|
|
215
|
+
device: "cuda" or "cpu".
|
|
216
|
+
tokenizer: Optional path to tokenizer.json.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Model instance ready for generation.
|
|
220
|
+
"""
|
|
221
|
+
from cortex_rust import QGgufModel
|
|
222
|
+
from cortex_rust.chat import detect_template
|
|
223
|
+
from cortex_rust.cli import _resolve_model_path
|
|
224
|
+
|
|
225
|
+
# Resolve HF repo → local path
|
|
226
|
+
resolved = _resolve_model_path(model_path)
|
|
227
|
+
|
|
228
|
+
# Auto-detect tokenizer
|
|
229
|
+
tok = tokenizer
|
|
230
|
+
if tok is None:
|
|
231
|
+
model_dir = os.path.dirname(os.path.abspath(resolved))
|
|
232
|
+
candidate = os.path.join(model_dir, "tokenizer.json")
|
|
233
|
+
if os.path.exists(candidate):
|
|
234
|
+
tok = candidate
|
|
235
|
+
|
|
236
|
+
kwargs = {"device": device}
|
|
237
|
+
if tok:
|
|
238
|
+
kwargs["tokenizer"] = tok
|
|
239
|
+
|
|
240
|
+
inner = QGgufModel(resolved, **kwargs)
|
|
241
|
+
template = detect_template(resolved)
|
|
242
|
+
|
|
243
|
+
# Load tokenizer instance for streaming
|
|
244
|
+
tok_instance = None
|
|
245
|
+
if tok:
|
|
246
|
+
try:
|
|
247
|
+
from tokenizers import Tokenizer
|
|
248
|
+
tok_instance = Tokenizer.from_file(tok)
|
|
249
|
+
except Exception:
|
|
250
|
+
pass # streaming will fall back to non-token mode
|
|
251
|
+
|
|
252
|
+
return Model(inner, template=template, model_path=resolved,
|
|
253
|
+
tokenizer=tok_instance)
|