ltcai 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +199 -0
- package/bin/ltcai.js +74 -0
- package/codex_telegram_bot.py +191 -0
- package/llm_router.py +537 -0
- package/ltcai_cli.py +74 -0
- package/p_reinforce.py +148 -0
- package/package.json +44 -0
- package/requirements.txt +11 -0
- package/server.py +3215 -0
- package/static/admin.html +1013 -0
- package/static/index.html +270 -0
- package/static/indexd.html +5664 -0
- package/telegram_bot.py +430 -0
- package/tools.py +1136 -0
package/llm_router.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Router — mlx-vlm 기반 Gemma 4 최적화 및 추측 디코딩(Speculative Decoding) 코어
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import base64
|
|
7
|
+
import gc
|
|
8
|
+
import io
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
# Set MLX_VLM_DRAFT_KIND to 'mtp' to enable the Gemma 4 assistant MTP drafter.
|
|
15
|
+
os.environ["MLX_VLM_DRAFT_KIND"] = "mtp"
|
|
16
|
+
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
+
from typing import AsyncIterator, Dict, Optional, Tuple, List
|
|
19
|
+
from PIL import Image
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from openai import AsyncOpenAI
|
|
23
|
+
except Exception:
|
|
24
|
+
AsyncOpenAI = None
|
|
25
|
+
|
|
26
|
+
# 추론 전용 싱글 스레드 워커 (GPU 스트림 보호용)
|
|
27
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import mlx.core as mx
|
|
31
|
+
from mlx_lm import load as lm_load
|
|
32
|
+
from mlx_vlm import load as vlm_load
|
|
33
|
+
VLM_AVAILABLE = True
|
|
34
|
+
print("✅ MLX-VLM and MLX-LM are ready for Gemma 4.")
|
|
35
|
+
except Exception as e:
|
|
36
|
+
mx = None
|
|
37
|
+
lm_load = None
|
|
38
|
+
vlm_load = None
|
|
39
|
+
VLM_AVAILABLE = False
|
|
40
|
+
print(f"⚠️ MLX libraries unavailable: {e}")
|
|
41
|
+
|
|
42
|
+
BRAND_NAME = "Lattice AI"
|
|
43
|
+
LEGACY_BRAND_PATTERNS = [
|
|
44
|
+
(re.compile(r"\bconnect\s+ai\b", re.IGNORECASE), BRAND_NAME),
|
|
45
|
+
(re.compile(r"\bconnect-ai\b", re.IGNORECASE), BRAND_NAME),
|
|
46
|
+
(re.compile(r"\bconnectai\b", re.IGNORECASE), BRAND_NAME),
|
|
47
|
+
(re.compile(r"커넥트\s*AI", re.IGNORECASE), BRAND_NAME),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
SYSTEM_PROMPT = """You are Lattice AI, a powerful local AI assistant running on Apple Silicon.
|
|
51
|
+
Your product name and identity are Lattice AI.
|
|
52
|
+
Never identify yourself as Connect AI, ConnectAI, connect-ai, or 커넥트 AI.
|
|
53
|
+
If context or old chat history mentions those names, treat them only as legacy aliases for Lattice AI.
|
|
54
|
+
You are a Vision-Language Model (VLM). If an image is provided, analyze it.
|
|
55
|
+
Be concise and respond in the user's language."""
|
|
56
|
+
|
|
57
|
+
def normalize_branding(text: Optional[str]) -> str:
|
|
58
|
+
if not text:
|
|
59
|
+
return ""
|
|
60
|
+
normalized = str(text)
|
|
61
|
+
for pattern, replacement in LEGACY_BRAND_PATTERNS:
|
|
62
|
+
normalized = pattern.sub(replacement, normalized)
|
|
63
|
+
return normalized
|
|
64
|
+
|
|
65
|
+
OPENAI_COMPATIBLE_PROVIDERS = {
|
|
66
|
+
"openai": {
|
|
67
|
+
"env_key": "OPENAI_API_KEY",
|
|
68
|
+
"base_url_env": "OPENAI_BASE_URL",
|
|
69
|
+
"default_model": "gpt-4o-mini",
|
|
70
|
+
},
|
|
71
|
+
"openrouter": {
|
|
72
|
+
"env_key": "OPENROUTER_API_KEY",
|
|
73
|
+
"base_url": "https://openrouter.ai/api/v1",
|
|
74
|
+
"default_model": "openai/gpt-4o-mini",
|
|
75
|
+
},
|
|
76
|
+
"groq": {
|
|
77
|
+
"env_key": "GROQ_API_KEY",
|
|
78
|
+
"base_url": "https://api.groq.com/openai/v1",
|
|
79
|
+
"default_model": "llama-3.1-8b-instant",
|
|
80
|
+
},
|
|
81
|
+
"together": {
|
|
82
|
+
"env_key": "TOGETHER_API_KEY",
|
|
83
|
+
"base_url": "https://api.together.xyz/v1",
|
|
84
|
+
"default_model": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
85
|
+
},
|
|
86
|
+
"xai": {
|
|
87
|
+
"env_key": "XAI_API_KEY",
|
|
88
|
+
"base_url": "https://api.x.ai/v1",
|
|
89
|
+
"default_model": "grok-beta",
|
|
90
|
+
},
|
|
91
|
+
"ollama": {
|
|
92
|
+
"env_key": "OLLAMA_API_KEY",
|
|
93
|
+
"base_url_env": "OLLAMA_BASE_URL",
|
|
94
|
+
"base_url": "http://localhost:11434/v1",
|
|
95
|
+
"default_model": "llama3.1",
|
|
96
|
+
"api_key_fallback": "ollama",
|
|
97
|
+
},
|
|
98
|
+
"vllm": {
|
|
99
|
+
"env_key": "VLLM_API_KEY",
|
|
100
|
+
"base_url_env": "VLLM_BASE_URL",
|
|
101
|
+
"base_url": "http://localhost:8000/v1",
|
|
102
|
+
"default_model": "Qwen/Qwen2.5-7B-Instruct",
|
|
103
|
+
"api_key_fallback": "vllm",
|
|
104
|
+
},
|
|
105
|
+
"lmstudio": {
|
|
106
|
+
"env_key": "LMSTUDIO_API_KEY",
|
|
107
|
+
"base_url_env": "LMSTUDIO_BASE_URL",
|
|
108
|
+
"base_url": "http://localhost:1234/v1",
|
|
109
|
+
"default_model": "local-model",
|
|
110
|
+
"api_key_fallback": "lmstudio",
|
|
111
|
+
},
|
|
112
|
+
"llamacpp": {
|
|
113
|
+
"env_key": "LLAMACPP_API_KEY",
|
|
114
|
+
"base_url_env": "LLAMACPP_BASE_URL",
|
|
115
|
+
"base_url": "http://localhost:8080/v1",
|
|
116
|
+
"default_model": "llama.cpp-model",
|
|
117
|
+
"api_key_fallback": "llamacpp",
|
|
118
|
+
},
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
PROVIDER_MODEL_CATALOG = {
|
|
122
|
+
"openai": [
|
|
123
|
+
{"id": "gpt-4o-mini", "name": "GPT-4o Mini", "family": "GPT"},
|
|
124
|
+
{"id": "gpt-4o", "name": "GPT-4o", "family": "GPT"},
|
|
125
|
+
{"id": "gpt-4.1-mini", "name": "GPT-4.1 Mini", "family": "GPT"},
|
|
126
|
+
{"id": "gpt-4.1", "name": "GPT-4.1", "family": "GPT"},
|
|
127
|
+
],
|
|
128
|
+
"openrouter": [
|
|
129
|
+
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini via OpenRouter", "family": "GPT"},
|
|
130
|
+
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet via OpenRouter", "family": "Claude"},
|
|
131
|
+
{"id": "anthropic/claude-3.5-haiku", "name": "Claude 3.5 Haiku via OpenRouter", "family": "Claude"},
|
|
132
|
+
{"id": "x-ai/grok-2", "name": "Grok 2 via OpenRouter", "family": "Grok"},
|
|
133
|
+
{"id": "meta-llama/llama-3.3-70b-instruct", "name": "Llama 3.3 70B via OpenRouter", "family": "Llama"},
|
|
134
|
+
{"id": "qwen/qwen-2.5-72b-instruct", "name": "Qwen 2.5 72B via OpenRouter", "family": "Qwen"},
|
|
135
|
+
{"id": "google/gemini-2.0-flash-exp", "name": "Gemini 2 Flash via OpenRouter", "family": "Gemini"},
|
|
136
|
+
],
|
|
137
|
+
"groq": [
|
|
138
|
+
{"id": "llama-3.1-8b-instant", "name": "Llama 3.1 8B Instant", "family": "Llama"},
|
|
139
|
+
{"id": "llama-3.3-70b-versatile", "name": "Llama 3.3 70B Versatile", "family": "Llama"},
|
|
140
|
+
{"id": "qwen-qwq-32b", "name": "Qwen QwQ 32B", "family": "Qwen"},
|
|
141
|
+
],
|
|
142
|
+
"together": [
|
|
143
|
+
{"id": "meta-llama/Llama-3.3-70B-Instruct-Turbo", "name": "Llama 3.3 70B Turbo", "family": "Llama"},
|
|
144
|
+
{"id": "Qwen/Qwen2.5-72B-Instruct-Turbo", "name": "Qwen 2.5 72B Turbo", "family": "Qwen"},
|
|
145
|
+
{"id": "deepseek-ai/DeepSeek-R1", "name": "DeepSeek R1", "family": "DeepSeek"},
|
|
146
|
+
{"id": "mistralai/Mixtral-8x22B-Instruct-v0.1", "name": "Mixtral 8x22B", "family": "Mistral"},
|
|
147
|
+
],
|
|
148
|
+
"xai": [
|
|
149
|
+
{"id": "grok-beta", "name": "Grok Beta", "family": "Grok"},
|
|
150
|
+
{"id": "grok-vision-beta", "name": "Grok Vision Beta", "family": "Grok"},
|
|
151
|
+
],
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class CloudModel:
|
|
156
|
+
provider: str
|
|
157
|
+
model: str
|
|
158
|
+
client: object
|
|
159
|
+
cache_key: str
|
|
160
|
+
|
|
161
|
+
def parse_model_ref(model_id: str) -> tuple[str, str]:
|
|
162
|
+
"""Return (provider, model). Unprefixed refs stay local MLX."""
|
|
163
|
+
if model_id.startswith("cloud:"):
|
|
164
|
+
_, provider, model = model_id.split(":", 2)
|
|
165
|
+
return provider, model
|
|
166
|
+
if ":" in model_id:
|
|
167
|
+
provider, model = model_id.split(":", 1)
|
|
168
|
+
if provider in OPENAI_COMPATIBLE_PROVIDERS:
|
|
169
|
+
return provider, model
|
|
170
|
+
if model_id.startswith("local_mlx:"):
|
|
171
|
+
return "local_mlx", model_id.split(":", 1)[1]
|
|
172
|
+
return "local_mlx", model_id
|
|
173
|
+
|
|
174
|
+
class LLMRouter:
|
|
175
|
+
def __init__(self):
|
|
176
|
+
self._cache: Dict[str, Tuple] = {}
|
|
177
|
+
self._current: Optional[str] = None
|
|
178
|
+
self._last_used: Dict[str, float] = {}
|
|
179
|
+
self._max_local_models = max(1, int(os.getenv("LATTICEAI_MAX_LOCAL_MODELS", "1")))
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def current_model_id(self) -> Optional[str]:
|
|
183
|
+
return self._current
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def loaded_model_ids(self) -> List[str]:
|
|
187
|
+
return list(self._cache.keys())
|
|
188
|
+
|
|
189
|
+
def switch_model(self, model_id: str) -> None:
|
|
190
|
+
if model_id not in self._cache:
|
|
191
|
+
raise KeyError(model_id)
|
|
192
|
+
self._current = model_id
|
|
193
|
+
self._touch(model_id)
|
|
194
|
+
|
|
195
|
+
def unload_model(self, model_id: str) -> None:
|
|
196
|
+
self._cache.pop(model_id, None)
|
|
197
|
+
self._last_used.pop(model_id, None)
|
|
198
|
+
if self._current == model_id:
|
|
199
|
+
self._current = next(iter(self._cache), None)
|
|
200
|
+
self._release_memory()
|
|
201
|
+
|
|
202
|
+
def unload_all(self) -> None:
|
|
203
|
+
self._cache.clear()
|
|
204
|
+
self._last_used.clear()
|
|
205
|
+
self._current = None
|
|
206
|
+
self._release_memory()
|
|
207
|
+
|
|
208
|
+
def unload_idle_models(self, idle_seconds: int) -> List[str]:
|
|
209
|
+
if idle_seconds <= 0:
|
|
210
|
+
return []
|
|
211
|
+
now = time.monotonic()
|
|
212
|
+
unloaded = []
|
|
213
|
+
for model_id, last_used in list(self._last_used.items()):
|
|
214
|
+
if now - last_used >= idle_seconds:
|
|
215
|
+
self.unload_model(model_id)
|
|
216
|
+
unloaded.append(model_id)
|
|
217
|
+
return unloaded
|
|
218
|
+
|
|
219
|
+
def model_memory_policy(self) -> Dict[str, object]:
|
|
220
|
+
return {
|
|
221
|
+
"max_local_models": self._max_local_models,
|
|
222
|
+
"loaded_count": len(self._cache),
|
|
223
|
+
"last_used": dict(self._last_used),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
def _touch(self, model_id: Optional[str] = None) -> None:
|
|
227
|
+
model_id = model_id or self._current
|
|
228
|
+
if model_id:
|
|
229
|
+
self._last_used[model_id] = time.monotonic()
|
|
230
|
+
|
|
231
|
+
def _is_local_model(self, model_id: str) -> bool:
|
|
232
|
+
cached = self._cache.get(model_id)
|
|
233
|
+
return cached is not None and not isinstance(cached, CloudModel)
|
|
234
|
+
|
|
235
|
+
def _enforce_local_model_limit(self, incoming_key: str) -> None:
|
|
236
|
+
local_ids = [model_id for model_id in self._cache if self._is_local_model(model_id)]
|
|
237
|
+
while len(local_ids) >= self._max_local_models:
|
|
238
|
+
victim = min(local_ids, key=lambda model_id: self._last_used.get(model_id, 0))
|
|
239
|
+
if victim == incoming_key:
|
|
240
|
+
break
|
|
241
|
+
print(f"🧹 Unloading local model to stay within memory policy: {victim}")
|
|
242
|
+
self.unload_model(victim)
|
|
243
|
+
local_ids = [model_id for model_id in self._cache if self._is_local_model(model_id)]
|
|
244
|
+
|
|
245
|
+
def _release_memory(self) -> None:
|
|
246
|
+
gc.collect()
|
|
247
|
+
if mx is not None and hasattr(mx, "clear_cache"):
|
|
248
|
+
try:
|
|
249
|
+
mx.clear_cache()
|
|
250
|
+
except Exception as e:
|
|
251
|
+
print(f"⚠️ MLX cache clear skipped: {e}")
|
|
252
|
+
|
|
253
|
+
async def load_model(
|
|
254
|
+
self,
|
|
255
|
+
model_id: str,
|
|
256
|
+
adapter_path: str = None,
|
|
257
|
+
draft_model_id: str = None,
|
|
258
|
+
api_key_override: Optional[str] = None,
|
|
259
|
+
owner: Optional[str] = None,
|
|
260
|
+
) -> str:
|
|
261
|
+
provider, provider_model = parse_model_ref(model_id)
|
|
262
|
+
if provider != "local_mlx":
|
|
263
|
+
return self._load_cloud_model(provider, provider_model, api_key_override=api_key_override, owner=owner)
|
|
264
|
+
|
|
265
|
+
if mx is None or lm_load is None:
|
|
266
|
+
raise RuntimeError("MLX is not available in this process. Run on Apple Silicon with Metal access.")
|
|
267
|
+
|
|
268
|
+
cache_key = f"{model_id}_{draft_model_id}" if draft_model_id else model_id
|
|
269
|
+
if cache_key in self._cache:
|
|
270
|
+
self._current = cache_key
|
|
271
|
+
self._touch(cache_key)
|
|
272
|
+
return f"Cached: {cache_key}"
|
|
273
|
+
|
|
274
|
+
self._enforce_local_model_limit(cache_key)
|
|
275
|
+
print(f"⏳ Loading Gemma 4 Stack: {cache_key}...")
|
|
276
|
+
loop = asyncio.get_event_loop()
|
|
277
|
+
|
|
278
|
+
def _load():
|
|
279
|
+
mx.set_default_device(mx.gpu)
|
|
280
|
+
is_gemma4 = "gemma-4" in model_id.lower() or "gemma4" in model_id.lower()
|
|
281
|
+
|
|
282
|
+
# 1. Target 로드 (Gemma 4는 항상 vlm_load 사용)
|
|
283
|
+
if is_gemma4 and VLM_AVAILABLE:
|
|
284
|
+
print(f"🔄 Loading Target (VLM Mode): {model_id}...")
|
|
285
|
+
model, tokenizer = vlm_load(model_id)
|
|
286
|
+
else:
|
|
287
|
+
print(f"🔄 Loading Target (LM Mode): {model_id}...")
|
|
288
|
+
model, tokenizer = lm_load(model_id)
|
|
289
|
+
|
|
290
|
+
# 2. Draft 로드 (Gemma 4는 항상 vlm_load 사용)
|
|
291
|
+
draft_model = None
|
|
292
|
+
if draft_model_id:
|
|
293
|
+
print(f"🔄 Loading Assistant (VLM Mode): {draft_model_id}...")
|
|
294
|
+
if is_gemma4 and VLM_AVAILABLE:
|
|
295
|
+
draft_model, _ = vlm_load(draft_model_id)
|
|
296
|
+
else:
|
|
297
|
+
draft_model, _ = lm_load(draft_model_id)
|
|
298
|
+
print(f"✅ Assistant Ready.")
|
|
299
|
+
|
|
300
|
+
return model, tokenizer, draft_model
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
# Use the dedicated single-thread executor to ensure MLX GPU streams match during inference
|
|
304
|
+
model, tokenizer, draft_model = await loop.run_in_executor(executor, _load)
|
|
305
|
+
self._cache[cache_key] = (model, tokenizer, draft_model)
|
|
306
|
+
self._current = cache_key
|
|
307
|
+
self._touch(cache_key)
|
|
308
|
+
print(f"✅ Fully Loaded: {cache_key}")
|
|
309
|
+
return f"Success: {cache_key}"
|
|
310
|
+
except Exception as e:
|
|
311
|
+
print(f"❌ Load Error: {e}")
|
|
312
|
+
raise e
|
|
313
|
+
|
|
314
|
+
def _load_cloud_model(self, provider: str, model: str, api_key_override: Optional[str] = None, owner: Optional[str] = None) -> str:
|
|
315
|
+
if AsyncOpenAI is None:
|
|
316
|
+
raise RuntimeError("openai package is not installed. Add it to requirements.txt and install dependencies.")
|
|
317
|
+
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
|
318
|
+
if not config:
|
|
319
|
+
raise RuntimeError(f"Unsupported cloud provider: {provider}")
|
|
320
|
+
|
|
321
|
+
api_key = api_key_override or os.getenv(config["env_key"]) or config.get("api_key_fallback")
|
|
322
|
+
if not api_key:
|
|
323
|
+
raise RuntimeError(f"Missing API key env var: {config['env_key']}")
|
|
324
|
+
|
|
325
|
+
base_url = os.getenv(config.get("base_url_env", "")) if config.get("base_url_env") else None
|
|
326
|
+
base_url = base_url or config.get("base_url")
|
|
327
|
+
client_kwargs = {"api_key": api_key}
|
|
328
|
+
if base_url:
|
|
329
|
+
client_kwargs["base_url"] = base_url
|
|
330
|
+
|
|
331
|
+
cache_owner = owner or "global"
|
|
332
|
+
cache_key = f"{provider}:{model}::{cache_owner}"
|
|
333
|
+
self._cache[cache_key] = CloudModel(provider=provider, model=model, client=AsyncOpenAI(**client_kwargs), cache_key=cache_key)
|
|
334
|
+
self._current = cache_key
|
|
335
|
+
self._touch(cache_key)
|
|
336
|
+
return f"Cloud provider ready: {cache_key}"
|
|
337
|
+
|
|
338
|
+
def detected_cloud_models(self) -> List[Dict[str, str]]:
|
|
339
|
+
local_server_providers = {"ollama", "vllm", "lmstudio", "llamacpp"}
|
|
340
|
+
items = []
|
|
341
|
+
for provider, config in OPENAI_COMPATIBLE_PROVIDERS.items():
|
|
342
|
+
has_key = bool(os.getenv(config["env_key"]) or config.get("api_key_fallback"))
|
|
343
|
+
provider_models = PROVIDER_MODEL_CATALOG.get(provider) or [{
|
|
344
|
+
"id": config["default_model"],
|
|
345
|
+
"name": f"{provider.title()} · {config['default_model']}",
|
|
346
|
+
"family": provider.title(),
|
|
347
|
+
}]
|
|
348
|
+
for model in provider_models:
|
|
349
|
+
model_id = model["id"]
|
|
350
|
+
items.append({
|
|
351
|
+
"id": f"{provider}:{model_id}",
|
|
352
|
+
"name": model.get("name") or f"{provider.title()} · {model_id}",
|
|
353
|
+
"provider": provider,
|
|
354
|
+
"family": model.get("family"),
|
|
355
|
+
"tag": "local-server" if provider in local_server_providers else "cloud",
|
|
356
|
+
"available": has_key,
|
|
357
|
+
"requires": config["env_key"] if not has_key else None,
|
|
358
|
+
})
|
|
359
|
+
custom = os.getenv("LATTICEAI_CLOUD_MODELS") or ""
|
|
360
|
+
for raw in [item.strip() for item in custom.split(",") if item.strip()]:
|
|
361
|
+
provider, model = parse_model_ref(raw)
|
|
362
|
+
if provider != "local_mlx" and provider in OPENAI_COMPATIBLE_PROVIDERS:
|
|
363
|
+
config = OPENAI_COMPATIBLE_PROVIDERS[provider]
|
|
364
|
+
items.append({
|
|
365
|
+
"id": f"{provider}:{model}",
|
|
366
|
+
"name": f"{provider.title()} · {model}",
|
|
367
|
+
"provider": provider,
|
|
368
|
+
"tag": "cloud",
|
|
369
|
+
"available": bool(os.getenv(config["env_key"]) or config.get("api_key_fallback")),
|
|
370
|
+
"requires": None,
|
|
371
|
+
})
|
|
372
|
+
return items
|
|
373
|
+
|
|
374
|
+
def _is_cloud_current(self) -> bool:
|
|
375
|
+
return bool(self._current and isinstance(self._cache.get(self._current), CloudModel))
|
|
376
|
+
|
|
377
|
+
def _build_prompt(self, message: str, context: Optional[str], tokenizer) -> str:
|
|
378
|
+
system = SYSTEM_PROMPT
|
|
379
|
+
context = normalize_branding(context)
|
|
380
|
+
if context: system += f"\n\nContext:\n{context}"
|
|
381
|
+
if hasattr(tokenizer, "apply_chat_template"):
|
|
382
|
+
try:
|
|
383
|
+
msgs = [{"role": "system", "content": system}, {"role": "user", "content": message}]
|
|
384
|
+
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
|
385
|
+
except: pass
|
|
386
|
+
return f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
|
387
|
+
|
|
388
|
+
def _build_vlm_prompt(self, model, processor, message: str, context: Optional[str], num_images: int) -> str:
|
|
389
|
+
system = SYSTEM_PROMPT
|
|
390
|
+
context = normalize_branding(context)
|
|
391
|
+
if context:
|
|
392
|
+
system += f"\n\nContext:\n{context}"
|
|
393
|
+
try:
|
|
394
|
+
from mlx_vlm import apply_chat_template
|
|
395
|
+
|
|
396
|
+
return apply_chat_template(
|
|
397
|
+
processor,
|
|
398
|
+
model.config,
|
|
399
|
+
[
|
|
400
|
+
{"role": "system", "content": system},
|
|
401
|
+
{"role": "user", "content": message},
|
|
402
|
+
],
|
|
403
|
+
add_generation_prompt=True,
|
|
404
|
+
num_images=num_images,
|
|
405
|
+
)
|
|
406
|
+
except Exception as e:
|
|
407
|
+
print(f"⚠️ VLM chat template fallback: {e}")
|
|
408
|
+
return self._build_prompt(message, context, processor)
|
|
409
|
+
|
|
410
|
+
async def generate(self, message: str, context: Optional[str] = None, max_tokens: int = 4096, temperature: float = 0.2, image_data: Optional[str] = None) -> str:
|
|
411
|
+
if not self._current: return "No model."
|
|
412
|
+
self._touch()
|
|
413
|
+
cached = self._cache[self._current]
|
|
414
|
+
if isinstance(cached, CloudModel):
|
|
415
|
+
return await self._cloud_generate(cached, message, context, max_tokens, temperature)
|
|
416
|
+
|
|
417
|
+
model, tokenizer, draft_model = self._cache[self._current]
|
|
418
|
+
is_gemma4 = "gemma-4" in self._current.lower() or "gemma4" in self._current.lower()
|
|
419
|
+
prompt = (
|
|
420
|
+
self._build_vlm_prompt(model, tokenizer, message, context, 1)
|
|
421
|
+
if image_data and is_gemma4 and VLM_AVAILABLE
|
|
422
|
+
else self._build_prompt(message, context, tokenizer)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
loop = asyncio.get_event_loop()
|
|
426
|
+
|
|
427
|
+
def _gen():
|
|
428
|
+
import mlx.core as mx
|
|
429
|
+
mx.set_default_device(mx.gpu)
|
|
430
|
+
is_gemma4 = "gemma-4" in self._current.lower() or "gemma4" in self._current.lower()
|
|
431
|
+
if is_gemma4 and VLM_AVAILABLE:
|
|
432
|
+
from mlx_vlm import generate as vlm_gen
|
|
433
|
+
return vlm_gen(model, tokenizer, prompt=prompt, image=self._prep_image(image_data), max_tokens=max_tokens, temp=temperature, draft_model=draft_model, draft_kind="mtp")
|
|
434
|
+
else:
|
|
435
|
+
from mlx_lm import generate as lm_gen
|
|
436
|
+
return lm_gen(model, tokenizer, prompt=prompt, max_tokens=max_tokens, temp=temperature, draft_model=draft_model)
|
|
437
|
+
result = await loop.run_in_executor(executor, _gen)
|
|
438
|
+
# mlx-vlm might return a GenerationResult object; extract the text
|
|
439
|
+
if hasattr(result, "text"):
|
|
440
|
+
return normalize_branding(result.text)
|
|
441
|
+
return normalize_branding(str(result))
|
|
442
|
+
|
|
443
|
+
async def _cloud_generate(self, cloud: CloudModel, message: str, context: Optional[str], max_tokens: int, temperature: float) -> str:
|
|
444
|
+
system = SYSTEM_PROMPT
|
|
445
|
+
context = normalize_branding(context)
|
|
446
|
+
if context:
|
|
447
|
+
system += f"\n\nContext:\n{context}"
|
|
448
|
+
response = await cloud.client.chat.completions.create(
|
|
449
|
+
model=cloud.model,
|
|
450
|
+
messages=[
|
|
451
|
+
{"role": "system", "content": system},
|
|
452
|
+
{"role": "user", "content": message},
|
|
453
|
+
],
|
|
454
|
+
max_tokens=max_tokens,
|
|
455
|
+
temperature=temperature,
|
|
456
|
+
)
|
|
457
|
+
return normalize_branding(response.choices[0].message.content or "")
|
|
458
|
+
|
|
459
|
+
async def stream_generate(self, message: str, context: Optional[str] = None, max_tokens: int = 4096, temperature: float = 0.2, image_data: Optional[str] = None) -> AsyncIterator[str]:
|
|
460
|
+
if not self._current:
|
|
461
|
+
yield "No model."
|
|
462
|
+
return
|
|
463
|
+
self._touch()
|
|
464
|
+
cached = self._cache[self._current]
|
|
465
|
+
if isinstance(cached, CloudModel):
|
|
466
|
+
async for chunk in self._cloud_stream_generate(cached, message, context, max_tokens, temperature):
|
|
467
|
+
yield chunk
|
|
468
|
+
return
|
|
469
|
+
|
|
470
|
+
model, tokenizer, draft_model = self._cache[self._current]
|
|
471
|
+
is_gemma4 = "gemma-4" in self._current.lower() or "gemma4" in self._current.lower()
|
|
472
|
+
prompt = (
|
|
473
|
+
self._build_vlm_prompt(model, tokenizer, message, context, 1)
|
|
474
|
+
if image_data and is_gemma4 and VLM_AVAILABLE
|
|
475
|
+
else self._build_prompt(message, context, tokenizer)
|
|
476
|
+
)
|
|
477
|
+
loop = asyncio.get_event_loop()
|
|
478
|
+
queue = asyncio.Queue()
|
|
479
|
+
|
|
480
|
+
def _stream():
|
|
481
|
+
import mlx.core as mx
|
|
482
|
+
mx.set_default_device(mx.gpu)
|
|
483
|
+
try:
|
|
484
|
+
is_gemma4 = "gemma-4" in self._current.lower() or "gemma4" in self._current.lower()
|
|
485
|
+
if is_gemma4 and VLM_AVAILABLE:
|
|
486
|
+
from mlx_vlm import stream_generate as vlm_stream
|
|
487
|
+
gen = vlm_stream(model, tokenizer, prompt=prompt, image=self._prep_image(image_data), max_tokens=max_tokens, temp=temperature, draft_model=draft_model, draft_kind="mtp")
|
|
488
|
+
else:
|
|
489
|
+
from mlx_lm import stream_generate as lm_stream
|
|
490
|
+
gen = lm_stream(model, tokenizer, prompt=prompt, max_tokens=max_tokens, temp=temperature, draft_model=draft_model)
|
|
491
|
+
|
|
492
|
+
for chunk in gen:
|
|
493
|
+
text = chunk.text if hasattr(chunk, "text") else (chunk[0] if isinstance(chunk, tuple) else str(chunk))
|
|
494
|
+
loop.call_soon_threadsafe(queue.put_nowait, text)
|
|
495
|
+
except Exception as e:
|
|
496
|
+
loop.call_soon_threadsafe(queue.put_nowait, f"⚠️ Error: {e}")
|
|
497
|
+
finally:
|
|
498
|
+
loop.call_soon_threadsafe(queue.put_nowait, None)
|
|
499
|
+
|
|
500
|
+
loop.run_in_executor(executor, _stream)
|
|
501
|
+
while True:
|
|
502
|
+
chunk = await queue.get()
|
|
503
|
+
if chunk is None: break
|
|
504
|
+
yield normalize_branding(chunk)
|
|
505
|
+
|
|
506
|
+
async def _cloud_stream_generate(self, cloud: CloudModel, message: str, context: Optional[str], max_tokens: int, temperature: float) -> AsyncIterator[str]:
|
|
507
|
+
system = SYSTEM_PROMPT
|
|
508
|
+
context = normalize_branding(context)
|
|
509
|
+
if context:
|
|
510
|
+
system += f"\n\nContext:\n{context}"
|
|
511
|
+
stream = await cloud.client.chat.completions.create(
|
|
512
|
+
model=cloud.model,
|
|
513
|
+
messages=[
|
|
514
|
+
{"role": "system", "content": system},
|
|
515
|
+
{"role": "user", "content": message},
|
|
516
|
+
],
|
|
517
|
+
max_tokens=max_tokens,
|
|
518
|
+
temperature=temperature,
|
|
519
|
+
stream=True,
|
|
520
|
+
)
|
|
521
|
+
async for event in stream:
|
|
522
|
+
if not event.choices:
|
|
523
|
+
continue
|
|
524
|
+
delta = event.choices[0].delta.content
|
|
525
|
+
if delta:
|
|
526
|
+
yield normalize_branding(delta)
|
|
527
|
+
|
|
528
|
+
def _prep_image(self, image_data: Optional[str]) -> Optional[Image.Image]:
|
|
529
|
+
if not image_data:
|
|
530
|
+
return None
|
|
531
|
+
try:
|
|
532
|
+
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
|
|
533
|
+
print(f"🖼️ VLM image decoded: {image.width}x{image.height}")
|
|
534
|
+
return image
|
|
535
|
+
except Exception as e:
|
|
536
|
+
print(f"⚠️ VLM image decode failed: {e}")
|
|
537
|
+
return None
|
package/ltcai_cli.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Command line entrypoint for Lattice AI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import importlib.util
|
|
7
|
+
import os
|
|
8
|
+
import shutil
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _has_module(name: str) -> bool:
|
|
14
|
+
return importlib.util.find_spec(name) is not None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def doctor() -> int:
|
|
18
|
+
checks = [
|
|
19
|
+
("Python 3.11+", sys.version_info >= (3, 11), sys.version.split()[0], True),
|
|
20
|
+
("FastAPI", _has_module("fastapi"), "required server dependency", True),
|
|
21
|
+
("Uvicorn", _has_module("uvicorn"), "required server dependency", True),
|
|
22
|
+
("OpenAI SDK", _has_module("openai"), "required for cloud providers", False),
|
|
23
|
+
("MLX", _has_module("mlx"), "required for Apple Silicon local models", False),
|
|
24
|
+
("MLX-LM", _has_module("mlx_lm"), "required for local text models", False),
|
|
25
|
+
("MLX-VLM", _has_module("mlx_vlm"), "required for Gemma/VLM models", False),
|
|
26
|
+
("Ollama binary", shutil.which("ollama") is not None, "optional local-server engine", False),
|
|
27
|
+
]
|
|
28
|
+
data_dir = Path(os.getenv("LATTICEAI_DATA_DIR") or Path.home() / ".ltcai")
|
|
29
|
+
static_dir = Path(os.getenv("LATTICEAI_STATIC_DIR") or Path(__file__).resolve().parent / "static")
|
|
30
|
+
checks.extend([
|
|
31
|
+
("Data dir", data_dir.exists() or data_dir.parent.exists(), str(data_dir), True),
|
|
32
|
+
("Static UI", static_dir.exists(), str(static_dir), True),
|
|
33
|
+
])
|
|
34
|
+
|
|
35
|
+
ok = True
|
|
36
|
+
for label, passed, detail, required in checks:
|
|
37
|
+
icon = "OK" if passed else ("MISS" if required else "OPTIONAL")
|
|
38
|
+
print(f"[{icon}] {label}: {detail}")
|
|
39
|
+
ok = ok and (passed or not required)
|
|
40
|
+
|
|
41
|
+
cloud_keys = ["OPENAI_API_KEY", "OPENROUTER_API_KEY", "GROQ_API_KEY", "TOGETHER_API_KEY"]
|
|
42
|
+
configured = [key for key in cloud_keys if os.getenv(key)]
|
|
43
|
+
print(f"[INFO] Cloud keys configured: {', '.join(configured) if configured else 'none'}")
|
|
44
|
+
return 0 if ok else 1
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def main() -> None:
|
|
48
|
+
parser = argparse.ArgumentParser(prog="LTCAI", description="Run the Lattice AI local server.")
|
|
49
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
50
|
+
subparsers.add_parser("doctor", help="Check local runtime dependencies and configuration.")
|
|
51
|
+
parser.add_argument("--host", default=os.getenv("LATTICEAI_HOST") or "127.0.0.1")
|
|
52
|
+
parser.add_argument("--port", type=int, default=int(os.getenv("LATTICEAI_PORT") or "4825"))
|
|
53
|
+
parser.add_argument("--reload", action="store_true", help="Enable uvicorn reload for local development.")
|
|
54
|
+
args = parser.parse_args()
|
|
55
|
+
|
|
56
|
+
if args.command == "doctor":
|
|
57
|
+
raise SystemExit(doctor())
|
|
58
|
+
|
|
59
|
+
app_dir = Path(__file__).resolve().parent
|
|
60
|
+
os.chdir(app_dir)
|
|
61
|
+
|
|
62
|
+
import uvicorn
|
|
63
|
+
|
|
64
|
+
uvicorn.run(
|
|
65
|
+
"server:app",
|
|
66
|
+
host=args.host,
|
|
67
|
+
port=args.port,
|
|
68
|
+
reload=args.reload,
|
|
69
|
+
log_level="info",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
if __name__ == "__main__":
|
|
74
|
+
main()
|