memwal 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
memwal/middleware.py ADDED
@@ -0,0 +1,472 @@
1
+ """
2
+ memwal — AI Middleware
3
+
4
+ Wraps LangChain and OpenAI SDK clients with automatic memory management.
5
+ Before each LLM call, relevant memories are recalled and injected.
6
+ After each call, the user message is analyzed for new facts (fire-and-forget).
7
+
8
+ Both integrations are optional: ``langchain-core`` and ``openai`` are only
9
+ imported when the corresponding wrapper is called.
10
+
11
+ Example (LangChain)::
12
+
13
+ from langchain_openai import ChatOpenAI
14
+ from memwal import with_memwal_langchain
15
+
16
+ llm = ChatOpenAI(model="gpt-4o")
17
+ smart_llm = with_memwal_langchain(
18
+ llm,
19
+ key="abcdef...",
20
+ account_id="0x...",
21
+ )
22
+ response = await smart_llm.ainvoke([HumanMessage("What are my allergies?")])
23
+
24
+ Example (OpenAI)::
25
+
26
+ from openai import AsyncOpenAI
27
+ from memwal import with_memwal_openai
28
+
29
+ client = AsyncOpenAI()
30
+ smart_client = with_memwal_openai(
31
+ client,
32
+ key="abcdef...",
33
+ account_id="0x...",
34
+ )
35
+ response = await smart_client.chat.completions.create(
36
+ model="gpt-4o",
37
+ messages=[{"role": "user", "content": "What are my allergies?"}],
38
+ )
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ import asyncio
44
+ import logging
45
+ import threading
46
+ from typing import (
47
+ TYPE_CHECKING,
48
+ Any,
49
+ Callable,
50
+ Dict,
51
+ List,
52
+ Optional,
53
+ )
54
+
55
+ from .client import MemWal
56
+ from .types import RecallMemory
57
+
58
+ if TYPE_CHECKING:
59
+ from langchain_core.language_models.chat_models import BaseChatModel
60
+ from langchain_core.messages import BaseMessage
61
+
62
+ logger = logging.getLogger("memwal")
63
+
64
+
65
+ def _find_last_user_message(messages: Any) -> Optional[str]:
66
+ """Extract the text of the last user message from a message list.
67
+
68
+ Supports both dict-style messages (OpenAI format) and LangChain
69
+ BaseMessage objects.
70
+ """
71
+ if not isinstance(messages, (list, tuple)):
72
+ return None
73
+
74
+ for msg in reversed(messages):
75
+ # Dict-style (OpenAI format)
76
+ if isinstance(msg, dict):
77
+ if msg.get("role") == "user":
78
+ content = msg.get("content", "")
79
+ if isinstance(content, str):
80
+ return content
81
+ # Multimodal content array
82
+ if isinstance(content, list):
83
+ texts = [
84
+ p.get("text", "") if isinstance(p, dict) else str(p)
85
+ for p in content
86
+ if isinstance(p, dict) and p.get("type") == "text"
87
+ or isinstance(p, str)
88
+ ]
89
+ return " ".join(texts) if texts else None
90
+ # LangChain BaseMessage
91
+ elif hasattr(msg, "type") and hasattr(msg, "content"):
92
+ if msg.type == "human":
93
+ return msg.content if isinstance(msg.content, str) else str(msg.content)
94
+
95
+ return None
96
+
97
+
98
+ def _format_memories(memories: List[RecallMemory]) -> str:
99
+ """Format recalled memories into an injection string."""
100
+ lines = [
101
+ f"- {m.text} (relevance: {1 - m.distance:.2f})"
102
+ for m in memories
103
+ ]
104
+ return (
105
+ "[Memory Context] The following are known facts about this user "
106
+ "from their personal memory store. Use these facts to answer the "
107
+ "user's question:\n" + "\n".join(lines)
108
+ )
109
+
110
+
111
+ def _fire_and_forget(coro: Any) -> None:
112
+ """Schedule an async coroutine as fire-and-forget.
113
+
114
+ Works whether or not an event loop is already running.
115
+ """
116
+ try:
117
+ loop = asyncio.get_running_loop()
118
+ loop.create_task(coro)
119
+ except RuntimeError:
120
+ # No running loop -- run in a background thread
121
+ def _run() -> None:
122
+ try:
123
+ asyncio.run(coro)
124
+ except Exception:
125
+ logger.debug("Fire-and-forget analyze() failed", exc_info=True)
126
+
127
+ thread = threading.Thread(target=_run, daemon=True)
128
+ thread.start()
129
+
130
+
131
+ # ============================================================
132
+ # LangChain Integration
133
+ # ============================================================
134
+
135
+
136
+ def with_memwal_langchain(
137
+ llm: "BaseChatModel",
138
+ key: str,
139
+ account_id: str,
140
+ server_url: str = "https://relayer.memwal.ai",
141
+ namespace: str = "default",
142
+ max_memories: int = 5,
143
+ auto_save: bool = True,
144
+ min_relevance: float = 0.3,
145
+ debug: bool = False,
146
+ ) -> "BaseChatModel":
147
+ """Wrap a LangChain ``BaseChatModel`` with MemWal memory management.
148
+
149
+ Before each call:
150
+ - Recall relevant memories for the last user message
151
+ - Inject them as a system message
152
+
153
+ After each call:
154
+ - Analyze the user message to extract and store new facts (fire-and-forget)
155
+
156
+ Args:
157
+ llm: A LangChain ``BaseChatModel`` instance.
158
+ key: Ed25519 delegate key (hex).
159
+ account_id: MemWalAccount object ID.
160
+ server_url: MemWal server URL.
161
+ namespace: Default namespace.
162
+ max_memories: Max memories to inject per request.
163
+ auto_save: Auto-save new facts from conversation.
164
+ min_relevance: Minimum similarity score (0-1) to include a memory.
165
+ debug: Enable debug logging.
166
+
167
+ Returns:
168
+ A wrapped ``BaseChatModel`` that automatically uses MemWal memory.
169
+ """
170
+ try:
171
+ from langchain_core.messages import HumanMessage, SystemMessage # noqa: F811
172
+ from langchain_core.outputs import ChatResult # noqa: F401
173
+ except ImportError as e:
174
+ raise ImportError(
175
+ "LangChain integration requires langchain-core. "
176
+ "Install with: pip install memwal[langchain]"
177
+ ) from e
178
+
179
+ memwal = MemWal.create(
180
+ key=key,
181
+ account_id=account_id,
182
+ server_url=server_url,
183
+ namespace=namespace,
184
+ )
185
+
186
+ log = logger.debug if not debug else logger.warning
187
+
188
+ original_agenerate = llm._agenerate
189
+ original_generate = llm._generate
190
+
191
+ async def _inject_memories(messages: List[BaseMessage]) -> List[BaseMessage]:
192
+ """Recall memories and inject as system message."""
193
+ user_text = _find_last_user_message(messages)
194
+ if not user_text:
195
+ return messages
196
+
197
+ try:
198
+ recall_result = await memwal.recall(user_text, max_memories, namespace)
199
+ relevant = [
200
+ m for m in recall_result.results
201
+ if (1 - m.distance) >= min_relevance
202
+ ]
203
+ if not relevant:
204
+ return messages
205
+
206
+ memory_context = _format_memories(relevant)
207
+ log(f"[MemWal] Found {len(relevant)} relevant memories")
208
+
209
+ # Insert memory system message before the last user message
210
+ result = list(messages)
211
+ last_human_idx = -1
212
+ for i in range(len(result) - 1, -1, -1):
213
+ if isinstance(result[i], HumanMessage):
214
+ last_human_idx = i
215
+ break
216
+
217
+ memory_msg = SystemMessage(content=memory_context)
218
+ if last_human_idx > 0:
219
+ result.insert(last_human_idx, memory_msg)
220
+ else:
221
+ result.insert(0, memory_msg)
222
+
223
+ return result
224
+ except Exception as e:
225
+ log(f"[MemWal] Memory search failed: {e}")
226
+ return messages
227
+
228
+ async def _post_analyze(messages: List[BaseMessage]) -> None:
229
+ """Analyze user message for new facts."""
230
+ if not auto_save:
231
+ return
232
+ user_text = _find_last_user_message(messages)
233
+ if user_text:
234
+ try:
235
+ await memwal.analyze(user_text, namespace)
236
+ except Exception as e:
237
+ log(f"[MemWal] Auto-save failed: {e}")
238
+
239
+ async def patched_agenerate(
240
+ messages: List[List[BaseMessage]], *args: Any, **kwargs: Any
241
+ ) -> ChatResult:
242
+ enriched = []
243
+ for msg_list in messages:
244
+ enriched.append(await _inject_memories(msg_list))
245
+
246
+ result = await original_agenerate(enriched, *args, **kwargs)
247
+
248
+ for msg_list in messages:
249
+ _fire_and_forget(_post_analyze(msg_list))
250
+
251
+ return result
252
+
253
+ def patched_generate(
254
+ messages: List[List[BaseMessage]], *args: Any, **kwargs: Any
255
+ ) -> ChatResult:
256
+ # For sync generate, we inject memories synchronously via asyncio.run
257
+ enriched = []
258
+ for msg_list in messages:
259
+ try:
260
+ loop = asyncio.get_running_loop()
261
+ except RuntimeError:
262
+ loop = None
263
+
264
+ if loop is not None and loop.is_running():
265
+ # Already in async context -- cannot use asyncio.run
266
+ enriched.append(msg_list)
267
+ else:
268
+ enriched.append(asyncio.run(_inject_memories(msg_list)))
269
+
270
+ result = original_generate(enriched, *args, **kwargs)
271
+
272
+ for msg_list in messages:
273
+ _fire_and_forget(_post_analyze(msg_list))
274
+
275
+ return result
276
+
277
+ # Monkey-patch the LLM instance
278
+ llm._agenerate = patched_agenerate # type: ignore[assignment]
279
+ llm._generate = patched_generate # type: ignore[assignment]
280
+
281
+ return llm
282
+
283
+
284
+ # ============================================================
285
+ # OpenAI SDK Integration
286
+ # ============================================================
287
+
288
+
289
+ def with_memwal_openai(
290
+ client: Any,
291
+ key: str,
292
+ account_id: str,
293
+ server_url: str = "https://relayer.memwal.ai",
294
+ namespace: str = "default",
295
+ max_memories: int = 5,
296
+ auto_save: bool = True,
297
+ min_relevance: float = 0.3,
298
+ debug: bool = False,
299
+ ) -> Any:
300
+ """Wrap an OpenAI client with MemWal memory management.
301
+
302
+ Works with both ``openai.OpenAI`` (sync) and ``openai.AsyncOpenAI`` (async).
303
+
304
+ Before each ``chat.completions.create`` call:
305
+ - Recall relevant memories for the last user message
306
+ - Inject them as a system message
307
+
308
+ After each call:
309
+ - Analyze the user message to extract and store new facts (fire-and-forget)
310
+
311
+ Args:
312
+ client: An ``openai.OpenAI`` or ``openai.AsyncOpenAI`` instance.
313
+ key: Ed25519 delegate key (hex).
314
+ account_id: MemWalAccount object ID.
315
+ server_url: MemWal server URL.
316
+ namespace: Default namespace.
317
+ max_memories: Max memories to inject per request.
318
+ auto_save: Auto-save new facts from conversation.
319
+ min_relevance: Minimum similarity score (0-1) to include a memory.
320
+ debug: Enable debug logging.
321
+
322
+ Returns:
323
+ The same client, with ``chat.completions.create`` wrapped to use MemWal.
324
+ """
325
+ memwal = MemWal.create(
326
+ key=key,
327
+ account_id=account_id,
328
+ server_url=server_url,
329
+ namespace=namespace,
330
+ )
331
+
332
+ log = logger.debug if not debug else logger.warning
333
+
334
+ is_async = hasattr(client, "_async_client") or type(client).__name__ == "AsyncOpenAI"
335
+
336
+ if is_async:
337
+ _wrap_async_openai(client, memwal, namespace, max_memories, auto_save, min_relevance, log)
338
+ else:
339
+ _wrap_sync_openai(client, memwal, namespace, max_memories, auto_save, min_relevance, log)
340
+
341
+ return client
342
+
343
+
344
+ def _wrap_async_openai(
345
+ client: Any,
346
+ memwal: MemWal,
347
+ namespace: str,
348
+ max_memories: int,
349
+ auto_save: bool,
350
+ min_relevance: float,
351
+ log: Callable[..., Any],
352
+ ) -> None:
353
+ """Wrap an async OpenAI client's chat.completions.create."""
354
+ original_create = client.chat.completions.create
355
+
356
+ async def patched_create(*args: Any, **kwargs: Any) -> Any:
357
+ messages = kwargs.get("messages") or (args[0] if args else None)
358
+ if messages is None:
359
+ return await original_create(*args, **kwargs)
360
+
361
+ # Inject memories
362
+ user_text = _find_last_user_message(messages)
363
+ if user_text:
364
+ try:
365
+ recall_result = await memwal.recall(user_text, max_memories, namespace)
366
+ relevant = [
367
+ m for m in recall_result.results
368
+ if (1 - m.distance) >= min_relevance
369
+ ]
370
+ if relevant:
371
+ memory_context = _format_memories(relevant)
372
+ log(f"[MemWal] Found {len(relevant)} relevant memories")
373
+ messages = _inject_openai_memory(list(messages), memory_context)
374
+ if "messages" in kwargs:
375
+ kwargs["messages"] = messages
376
+ elif args:
377
+ args = (messages,) + args[1:]
378
+ except Exception as e:
379
+ log(f"[MemWal] Memory search failed: {e}")
380
+
381
+ result = await original_create(*args, **kwargs)
382
+
383
+ # Fire-and-forget analyze
384
+ if auto_save and user_text:
385
+ async def _analyze() -> None:
386
+ try:
387
+ await memwal.analyze(user_text, namespace)
388
+ except Exception as e:
389
+ log(f"[MemWal] Auto-save failed: {e}")
390
+
391
+ _fire_and_forget(_analyze())
392
+
393
+ return result
394
+
395
+ client.chat.completions.create = patched_create
396
+
397
+
398
+ def _wrap_sync_openai(
399
+ client: Any,
400
+ memwal: MemWal,
401
+ namespace: str,
402
+ max_memories: int,
403
+ auto_save: bool,
404
+ min_relevance: float,
405
+ log: Callable[..., Any],
406
+ ) -> None:
407
+ """Wrap a sync OpenAI client's chat.completions.create."""
408
+ original_create = client.chat.completions.create
409
+
410
+ def patched_create(*args: Any, **kwargs: Any) -> Any:
411
+ messages = kwargs.get("messages") or (args[0] if args else None)
412
+ if messages is None:
413
+ return original_create(*args, **kwargs)
414
+
415
+ # Inject memories (sync)
416
+ user_text = _find_last_user_message(messages)
417
+ if user_text:
418
+ try:
419
+ recall_result = asyncio.run(
420
+ memwal.recall(user_text, max_memories, namespace)
421
+ )
422
+ relevant = [
423
+ m for m in recall_result.results
424
+ if (1 - m.distance) >= min_relevance
425
+ ]
426
+ if relevant:
427
+ memory_context = _format_memories(relevant)
428
+ log(f"[MemWal] Found {len(relevant)} relevant memories")
429
+ messages = _inject_openai_memory(list(messages), memory_context)
430
+ if "messages" in kwargs:
431
+ kwargs["messages"] = messages
432
+ elif args:
433
+ args = (messages,) + args[1:]
434
+ except Exception as e:
435
+ log(f"[MemWal] Memory search failed: {e}")
436
+
437
+ result = original_create(*args, **kwargs)
438
+
439
+ # Fire-and-forget analyze
440
+ if auto_save and user_text:
441
+ async def _analyze() -> None:
442
+ try:
443
+ await memwal.analyze(user_text, namespace)
444
+ except Exception as e:
445
+ log(f"[MemWal] Auto-save failed: {e}")
446
+
447
+ _fire_and_forget(_analyze())
448
+
449
+ return result
450
+
451
+ client.chat.completions.create = patched_create
452
+
453
+
454
+ def _inject_openai_memory(
455
+ messages: List[Dict[str, Any]],
456
+ memory_context: str,
457
+ ) -> List[Dict[str, Any]]:
458
+ """Insert a memory system message before the last user message."""
459
+ last_user_idx = -1
460
+ for i in range(len(messages) - 1, -1, -1):
461
+ if isinstance(messages[i], dict) and messages[i].get("role") == "user":
462
+ last_user_idx = i
463
+ break
464
+
465
+ memory_msg: Dict[str, Any] = {"role": "system", "content": memory_context}
466
+
467
+ if last_user_idx > 0:
468
+ messages.insert(last_user_idx, memory_msg)
469
+ else:
470
+ messages.insert(0, memory_msg)
471
+
472
+ return messages