optichat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
app/__init__.py ADDED
@@ -0,0 +1 @@
1
+ # OptiChat App Package
app/connect_models.py ADDED
@@ -0,0 +1,340 @@
1
+ """OptiChat – AI model connection layer.
2
+
3
+ Responsibilities
4
+ ────────────────
5
+ • Validate API keys against each provider.
6
+ • List available models from cloud providers (OpenAI / Anthropic / Gemini).
7
+ • Detect locally-installed Ollama models.
8
+ • Instantiate LangChain chat model objects for actual inference.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ from langchain_core.language_models.chat_models import BaseChatModel
16
+
17
+
18
+ # ══════════════════════════════════════════════
19
+ # Provider registry
20
+ # ══════════════════════════════════════════════
21
+ PROVIDERS = ("openai", "anthropic", "gemini")
22
+
23
+
24
+ # ══════════════════════════════════════════════
25
+ # API key validation
26
+ # ══════════════════════════════════════════════
27
+ def validate_api_key(provider: str, api_key: str) -> bool:
28
+ """Return True if *api_key* is accepted by *provider*.
29
+
30
+ Each provider check is a lightweight call (list models or a tiny request)
31
+ wrapped in a try/except so a bad key returns False.
32
+ """
33
+ try:
34
+ if provider == "openai":
35
+ return _validate_openai(api_key)
36
+ elif provider == "anthropic":
37
+ return _validate_anthropic(api_key)
38
+ elif provider == "gemini":
39
+ return _validate_gemini(api_key)
40
+ else:
41
+ return False
42
+ except Exception:
43
+ return False
44
+
45
+
46
+ def _validate_openai(api_key: str) -> bool:
47
+ from openai import OpenAI
48
+
49
+ client = OpenAI(api_key=api_key)
50
+ # A successful models.list() call proves the key is valid
51
+ models = client.models.list()
52
+ # Consume at least one item to confirm
53
+ _ = next(iter(models))
54
+ return True
55
+
56
+
57
+ def _validate_anthropic(api_key: str) -> bool:
58
+ from langchain_anthropic import ChatAnthropic
59
+
60
+ client = ChatAnthropic(api_key=api_key)
61
+ models = client.models.list()
62
+ # Consume at least one item to confirm
63
+ _ = next(iter(models))
64
+ return True
65
+
66
+
67
+ def _validate_gemini(api_key: str) -> bool:
68
+ from google import genai
69
+
70
+ client = genai.Client(api_key=api_key)
71
+ models = list(client.models.list())
72
+ return len(models) > 0
73
+
74
+
75
+ # ══════════════════════════════════════════════
76
+ # List cloud models
77
+ # ══════════════════════════════════════════════
78
+ def list_cloud_models(provider: str, api_key: str) -> list[dict[str, str]]:
79
+ """Return a list of ``{id, name}`` dicts for available models.
80
+
81
+ Only returns chat/completion-capable models where possible.
82
+ """
83
+ try:
84
+ if provider == "openai":
85
+ return _list_openai(api_key)
86
+ elif provider == "anthropic":
87
+ return _list_anthropic(api_key)
88
+ elif provider == "gemini":
89
+ return _list_gemini(api_key)
90
+ except Exception:
91
+ pass
92
+ return []
93
+
94
+
95
+ def _list_openai(api_key: str) -> list[dict[str, str]]:
96
+ from openai import OpenAI
97
+
98
+ client = OpenAI(api_key=api_key)
99
+ models = client.models.list()
100
+ result: list[dict[str, str]] = []
101
+ for m in models:
102
+ mid = m.id
103
+ # Filter to chat-capable models (gpt- prefix)
104
+ if mid.startswith(("gpt-", "o", "chatgpt")):
105
+ result.append({"id": f"openai/{mid}", "name": mid})
106
+ result.sort(key=lambda x: x["name"])
107
+ return result
108
+
109
+
110
+ def _list_anthropic(api_key: str) -> list[dict[str, str]]:
111
+ from langchain_anthropic import ChatAnthropic
112
+
113
+ client = ChatAnthropic(api_key=api_key)
114
+ models = client.models.list()
115
+ result: list[dict[str, str]] = []
116
+ for m in models:
117
+ result.append({"id": f"anthropic/{m.id}", "name": m.display_name or m.id})
118
+ result.sort(key=lambda x: x["name"])
119
+ return result
120
+
121
+
122
+ def _list_gemini(api_key: str) -> list[dict[str, str]]:
123
+ from google import genai
124
+
125
+ client = genai.Client(api_key=api_key)
126
+ result: list[dict[str, str]] = []
127
+ for m in client.models.list():
128
+ name = getattr(m, "name", "")
129
+ display = getattr(m, "display_name", name)
130
+ # Only include generative models
131
+ if "gemini" in name.lower():
132
+ result.append({"id": f"gemini/{name}", "name": display})
133
+ result.sort(key=lambda x: x["name"])
134
+ return result
135
+
136
+
137
+ # ══════════════════════════════════════════════
138
+ # Ollama – local model detection
139
+ # ══════════════════════════════════════════════
140
+ def detect_ollama_models() -> list[dict[str, str]]:
141
+ """Detect locally installed Ollama models.
142
+
143
+ Returns a list of ``{id, name, size}`` dicts, or an empty list
144
+ if Ollama is not running / not installed.
145
+ """
146
+ try:
147
+ from ollama import Client
148
+ client = Client(host='http://127.0.0.1:11434')
149
+
150
+ response = client.list()
151
+ result: list[dict[str, str]] = []
152
+ for m in response.models:
153
+ model_name = m.model if hasattr(m, "model") else m.name
154
+ size_bytes = getattr(m, "size", 0)
155
+ size_gb = f"{size_bytes / (1024 ** 3):.1f} GB" if size_bytes else "?"
156
+ result.append({
157
+ "id": f"ollama/{model_name}",
158
+ "name": model_name,
159
+ "size": size_gb,
160
+ })
161
+ return result
162
+ except Exception:
163
+ return []
164
+
165
+
166
+ # ══════════════════════════════════════════════
167
+ # Create a LangChain chat model instance
168
+ # ══════════════════════════════════════════════
169
+ def get_chat_model(model_id: str) -> BaseChatModel:
170
+ """Instantiate and return a LangChain chat model for *model_id*.
171
+
172
+ *model_id* format: ``provider/model_name``
173
+ e.g. ``openai/gpt-4o``, ``anthropic/claude-sonnet-4-20250514``,
174
+ ``gemini/gemini-2.0-flash``, ``ollama/llama3``.
175
+ """
176
+ if "/" not in model_id:
177
+ raise ValueError(f"Invalid model_id format: {model_id!r}. Expected 'provider/model'.")
178
+
179
+ provider, model_name = model_id.split("/", 1)
180
+
181
+ if provider == "openai":
182
+ from langchain_openai import ChatOpenAI
183
+
184
+ return ChatOpenAI(model=model_name, streaming=True)
185
+
186
+ elif provider == "anthropic":
187
+ from langchain_anthropic import ChatAnthropic
188
+
189
+ return ChatAnthropic(model=model_name, streaming=True)
190
+
191
+ elif provider == "gemini":
192
+ from langchain_google_genai import ChatGoogleGenerativeAI
193
+
194
+ return ChatGoogleGenerativeAI(model=model_name, streaming=True)
195
+
196
+ elif provider == "ollama":
197
+ from langchain_community.chat_models import ChatOllama
198
+
199
+ return ChatOllama(model=model_name)
200
+
201
+ else:
202
+ raise ValueError(f"Unknown provider: {provider!r}")
203
+
204
+
205
+ # ══════════════════════════════════════════════
206
+ # Pipeline-aware message sending (Phase 4)
207
+ # ══════════════════════════════════════════════
208
+ async def send_message_via_pipeline(
209
+ model_id: str,
210
+ user_input: str,
211
+ chat_name: str,
212
+ chat_id: str,
213
+ *,
214
+ websearch_enabled: bool = False,
215
+ ) -> dict[str, str]:
216
+ """Run the user's message through the full prompt construction pipeline.
217
+
218
+ The pipeline handles classification, memory retrieval, prompt assembly,
219
+ LLM invocation, and post-processing (DB + memory storage).
220
+
221
+ Parameters
222
+ ----------
223
+ websearch_enabled:
224
+ When True the pipeline's classifier node queries DuckDuckGo for
225
+ the top-2 results and injects them into the final prompt
226
+ (Phase 5 feature).
227
+
228
+ Returns a dict with keys ``response`` (the assistant reply) and
229
+ ``trace_log`` (the chain-of-thought trace extracted from the model output).
230
+ """
231
+ from app.pipeline import run_pipeline
232
+
233
+ result = await run_pipeline(
234
+ user_input=user_input,
235
+ chat_name=chat_name,
236
+ chat_id=chat_id,
237
+ model_id=model_id,
238
+ websearch_enabled=websearch_enabled,
239
+ )
240
+
241
+ error = result.get("error")
242
+ if error:
243
+ return {"response": f"*{error}*", "trace_log": ""}
244
+
245
+ return {
246
+ "response": result.get("response", ""),
247
+ "trace_log": result.get("trace_log", ""),
248
+ }
249
+
250
+
251
+ # ══════════════════════════════════════════════
252
+ # Legacy: direct send (no pipeline, for fallback)
253
+ # ══════════════════════════════════════════════
254
+ async def send_message(
255
+ model_id: str,
256
+ messages: list[dict[str, str]],
257
+ chat_name: str | None = None,
258
+ chat_id: str | None = None,
259
+ ) -> str:
260
+ """Send a list of {role, content} dicts and return the assistant reply.
261
+
262
+ If *chat_name* and *chat_id* are provided, the user message and AI
263
+ response are automatically fed through the memory pipeline.
264
+
265
+ NOTE: For Phase 4+, prefer ``send_message_via_pipeline()`` which runs
266
+ the full prompt construction pipeline.
267
+ """
268
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
269
+
270
+ _type_map = {
271
+ "system": SystemMessage,
272
+ "user": HumanMessage,
273
+ "assistant": AIMessage,
274
+ }
275
+ lc_messages = [_type_map[m["role"]](content=m["content"]) for m in messages]
276
+
277
+ model = get_chat_model(model_id)
278
+ response = await model.ainvoke(lc_messages)
279
+ reply = str(response.content)
280
+
281
+ # ── Memory integration (Phase 3) ────────
282
+ if chat_name and chat_id:
283
+ try:
284
+ from app.memory import process_message
285
+
286
+ # Store the last user message in memory
287
+ user_msgs = [m for m in messages if m["role"] == "user"]
288
+ if user_msgs:
289
+ await process_message(chat_name, chat_id, "user", user_msgs[-1]["content"])
290
+ # Store the assistant reply in memory
291
+ await process_message(chat_name, chat_id, "assistant", reply)
292
+ except Exception:
293
+ pass # Memory errors must not block the response
294
+
295
+ return reply
296
+
297
+
298
+ async def stream_message(
299
+ model_id: str,
300
+ messages: list[dict[str, str]],
301
+ chat_name: str | None = None,
302
+ chat_id: str | None = None,
303
+ ):
304
+ """Yield token chunks as an async generator.
305
+
306
+ After streaming completes, the accumulated response is fed through
307
+ the memory pipeline if *chat_name* and *chat_id* are provided.
308
+ """
309
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
310
+
311
+ _type_map = {
312
+ "system": SystemMessage,
313
+ "user": HumanMessage,
314
+ "assistant": AIMessage,
315
+ }
316
+ lc_messages = [_type_map[m["role"]](content=m["content"]) for m in messages]
317
+
318
+ model = get_chat_model(model_id)
319
+ full_response: list[str] = []
320
+ async for chunk in model.astream(lc_messages):
321
+ text = chunk.content if hasattr(chunk, "content") else str(chunk)
322
+ if text:
323
+ full_response.append(text)
324
+ yield text
325
+
326
+ # ── Memory integration (Phase 3) ────────
327
+ if chat_name and chat_id:
328
+ try:
329
+ from app.memory import process_message
330
+
331
+ # Store the last user message in memory
332
+ user_msgs = [m for m in messages if m["role"] == "user"]
333
+ if user_msgs:
334
+ await process_message(chat_name, chat_id, "user", user_msgs[-1]["content"])
335
+ # Store the full accumulated assistant response in memory
336
+ accumulated = "".join(full_response)
337
+ if accumulated:
338
+ await process_message(chat_name, chat_id, "assistant", accumulated)
339
+ except Exception:
340
+ pass # Memory errors must not block the response