memu-py 1.2.0__cp313-abi3-win_amd64.whl → 1.3.0__cp313-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memu/_core.pyd +0 -0
- memu/app/service.py +13 -0
- memu/app/settings.py +24 -1
- memu/database/models.py +4 -2
- memu/database/postgres/models.py +3 -0
- memu/database/sqlite/models.py +3 -1
- memu/integrations/__init__.py +3 -0
- memu/integrations/langgraph.py +163 -0
- memu/llm/backends/__init__.py +3 -1
- memu/llm/backends/grok.py +11 -0
- memu/llm/backends/openrouter.py +70 -0
- memu/llm/http_client.py +19 -0
- memu/llm/lazyllm_client.py +134 -0
- memu_py-1.3.0.dist-info/METADATA +634 -0
- {memu_py-1.2.0.dist-info → memu_py-1.3.0.dist-info}/RECORD +18 -13
- memu_py-1.2.0.dist-info/METADATA +0 -476
- {memu_py-1.2.0.dist-info → memu_py-1.3.0.dist-info}/WHEEL +0 -0
- {memu_py-1.2.0.dist-info → memu_py-1.3.0.dist-info}/entry_points.txt +0 -0
- {memu_py-1.2.0.dist-info → memu_py-1.3.0.dist-info}/licenses/LICENSE.txt +0 -0
memu/_core.pyd
CHANGED
|
Binary file
|
memu/app/service.py
CHANGED
|
@@ -117,6 +117,19 @@ class MemoryService(MemorizeMixin, RetrieveMixin, CRUDMixin):
|
|
|
117
117
|
endpoint_overrides=cfg.endpoint_overrides,
|
|
118
118
|
embed_model=cfg.embed_model,
|
|
119
119
|
)
|
|
120
|
+
elif backend == "lazyllm_backend":
|
|
121
|
+
from memu.llm.lazyllm_client import LazyLLMClient
|
|
122
|
+
|
|
123
|
+
return LazyLLMClient(
|
|
124
|
+
llm_source=cfg.lazyllm_source.llm_source or cfg.lazyllm_source.source,
|
|
125
|
+
vlm_source=cfg.lazyllm_source.vlm_source or cfg.lazyllm_source.source,
|
|
126
|
+
embed_source=cfg.lazyllm_source.embed_source or cfg.lazyllm_source.source,
|
|
127
|
+
stt_source=cfg.lazyllm_source.stt_source or cfg.lazyllm_source.source,
|
|
128
|
+
chat_model=cfg.chat_model,
|
|
129
|
+
embed_model=cfg.embed_model,
|
|
130
|
+
vlm_model=cfg.lazyllm_source.vlm_model,
|
|
131
|
+
stt_model=cfg.lazyllm_source.stt_model,
|
|
132
|
+
)
|
|
120
133
|
else:
|
|
121
134
|
msg = f"Unknown llm_client_backend '{cfg.client_backend}'"
|
|
122
135
|
raise ValueError(msg)
|
memu/app/settings.py
CHANGED
|
@@ -89,6 +89,16 @@ def _default_memory_categories() -> list[CategoryConfig]:
|
|
|
89
89
|
]
|
|
90
90
|
|
|
91
91
|
|
|
92
|
+
class LazyLLMSource(BaseModel):
|
|
93
|
+
source: str | None = Field(default=None, description="default source for lazyllm client backend")
|
|
94
|
+
llm_source: str | None = Field(default=None, description="LLM source for lazyllm client backend")
|
|
95
|
+
embed_source: str | None = Field(default=None, description="Embedding source for lazyllm client backend")
|
|
96
|
+
vlm_source: str | None = Field(default=None, description="VLM source for lazyllm client backend")
|
|
97
|
+
stt_source: str | None = Field(default=None, description="STT source for lazyllm client backend")
|
|
98
|
+
vlm_model: str = Field(default="qwen-vl-plus", description="Vision language model for lazyllm client backend")
|
|
99
|
+
stt_model: str = Field(default="qwen-audio-turbo", description="Speech-to-text model for lazyllm client backend")
|
|
100
|
+
|
|
101
|
+
|
|
92
102
|
class LLMConfig(BaseModel):
|
|
93
103
|
provider: str = Field(
|
|
94
104
|
default="openai",
|
|
@@ -99,8 +109,9 @@ class LLMConfig(BaseModel):
|
|
|
99
109
|
chat_model: str = Field(default="gpt-4o-mini")
|
|
100
110
|
client_backend: str = Field(
|
|
101
111
|
default="sdk",
|
|
102
|
-
description="Which LLM client backend to use: 'httpx' (httpx)
|
|
112
|
+
description="Which LLM client backend to use: 'httpx' (httpx), 'sdk' (official OpenAI), or 'lazyllm_backend' (for more LLM source like Qwen, Doubao, SIliconflow, etc.)",
|
|
103
113
|
)
|
|
114
|
+
lazyllm_source: LazyLLMSource = Field(default=LazyLLMSource())
|
|
104
115
|
endpoint_overrides: dict[str, str] = Field(
|
|
105
116
|
default_factory=dict,
|
|
106
117
|
description="Optional overrides for HTTP endpoints (keys: 'chat'/'summary').",
|
|
@@ -114,6 +125,18 @@ class LLMConfig(BaseModel):
|
|
|
114
125
|
description="Maximum batch size for embedding API calls (used by SDK client backends).",
|
|
115
126
|
)
|
|
116
127
|
|
|
128
|
+
@model_validator(mode="after")
|
|
129
|
+
def set_provider_defaults(self) -> "LLMConfig":
|
|
130
|
+
if self.provider == "grok":
|
|
131
|
+
# If values match the OpenAI defaults, switch them to Grok defaults
|
|
132
|
+
if self.base_url == "https://api.openai.com/v1":
|
|
133
|
+
self.base_url = "https://api.x.ai/v1"
|
|
134
|
+
if self.api_key == "OPENAI_API_KEY":
|
|
135
|
+
self.api_key = "XAI_API_KEY"
|
|
136
|
+
if self.chat_model == "gpt-4o-mini":
|
|
137
|
+
self.chat_model = "grok-2-latest"
|
|
138
|
+
return self
|
|
139
|
+
|
|
117
140
|
|
|
118
141
|
class BlobConfig(BaseModel):
|
|
119
142
|
provider: str = Field(default="local")
|
memu/database/models.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import uuid
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Any, Literal
|
|
6
6
|
|
|
7
7
|
import pendulum
|
|
8
8
|
from pydantic import BaseModel, ConfigDict, Field
|
|
@@ -28,9 +28,11 @@ class Resource(BaseRecord):
|
|
|
28
28
|
|
|
29
29
|
class MemoryItem(BaseRecord):
|
|
30
30
|
resource_id: str | None
|
|
31
|
-
memory_type:
|
|
31
|
+
memory_type: str
|
|
32
32
|
summary: str
|
|
33
33
|
embedding: list[float] | None = None
|
|
34
|
+
happened_at: datetime | None = None
|
|
35
|
+
extra: dict[str, Any] = {}
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
class MemoryCategory(BaseRecord):
|
memu/database/postgres/models.py
CHANGED
|
@@ -14,6 +14,7 @@ except ImportError as exc:
|
|
|
14
14
|
|
|
15
15
|
from pydantic import BaseModel
|
|
16
16
|
from sqlalchemy import ForeignKey, MetaData, String, Text
|
|
17
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
17
18
|
from sqlmodel import Column, DateTime, Field, Index, SQLModel, func
|
|
18
19
|
|
|
19
20
|
from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource
|
|
@@ -55,6 +56,8 @@ class MemoryItemModel(BaseModelMixin, MemoryItem):
|
|
|
55
56
|
memory_type: MemoryType = Field(sa_column=Column(String, nullable=False))
|
|
56
57
|
summary: str = Field(sa_column=Column(Text, nullable=False))
|
|
57
58
|
embedding: list[float] | None = Field(default=None, sa_column=Column(Vector(), nullable=True))
|
|
59
|
+
happened_at: datetime | None = Field(default=None, sa_column=Column(DateTime, nullable=True))
|
|
60
|
+
extra: dict[str, Any] = Field(default={}, sa_column=Column(JSONB, nullable=True))
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
class MemoryCategoryModel(BaseModelMixin, MemoryCategory):
|
memu/database/sqlite/models.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Any
|
|
|
10
10
|
|
|
11
11
|
import pendulum
|
|
12
12
|
from pydantic import BaseModel
|
|
13
|
-
from sqlalchemy import MetaData, String, Text
|
|
13
|
+
from sqlalchemy import JSON, MetaData, String, Text
|
|
14
14
|
from sqlmodel import Column, DateTime, Field, Index, SQLModel, func
|
|
15
15
|
|
|
16
16
|
from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource
|
|
@@ -83,6 +83,8 @@ class SQLiteMemoryItemModel(SQLiteBaseModelMixin, MemoryItem):
|
|
|
83
83
|
summary: str = Field(sa_column=Column(Text, nullable=False))
|
|
84
84
|
# Store embedding as JSON string since SQLite doesn't have native vector type
|
|
85
85
|
embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
|
86
|
+
happened_at: datetime | None = Field(default=None, sa_column=Column(DateTime, nullable=True))
|
|
87
|
+
extra: dict[str, Any] = Field(default={}, sa_column=Column(JSON, nullable=True))
|
|
86
88
|
|
|
87
89
|
@property
|
|
88
90
|
def embedding(self) -> list[float] | None:
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""LangGraph integration for MemU."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import tempfile
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
# MUST explicitly import langgraph to satisfy DEP002
|
|
13
|
+
import langgraph
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
from memu.app.service import MemoryService
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from langchain_core.tools import BaseTool, StructuredTool
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
msg = "Please install 'langchain-core' (and 'langgraph') to use the LangGraph integration."
|
|
22
|
+
raise ImportError(msg) from e
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Setup logger
|
|
26
|
+
logger = logging.getLogger("memu.integrations.langgraph")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemUIntegrationError(Exception):
|
|
30
|
+
"""Base exception for MemU integration issues."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SaveRecallInput(BaseModel):
|
|
34
|
+
"""Input schema for the save_memory tool."""
|
|
35
|
+
|
|
36
|
+
content: str = Field(description="The text content or information to save/remember.")
|
|
37
|
+
user_id: str = Field(description="The unique identifier of the user.")
|
|
38
|
+
metadata: dict[str, Any] | None = Field(default=None, description="Additional metadata related to the memory.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SearchRecallInput(BaseModel):
|
|
42
|
+
"""Input schema for the search_memory tool."""
|
|
43
|
+
|
|
44
|
+
query: str = Field(description="The search query to retrieve relevant memories.")
|
|
45
|
+
user_id: str = Field(description="The unique identifier of the user.")
|
|
46
|
+
limit: int = Field(default=5, description="Number of memories to retrieve.")
|
|
47
|
+
metadata_filter: dict[str, Any] | None = Field(
|
|
48
|
+
default=None, description="Optional filter for memory metadata (e.g., {'category': 'work'})."
|
|
49
|
+
)
|
|
50
|
+
min_relevance_score: float = Field(default=0.0, description="Minimum relevance score (0.0 to 1.0) for results.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MemULangGraphTools:
|
|
54
|
+
"""Adapter to expose MemU as a set of Tools for LangGraph/LangChain agents.
|
|
55
|
+
|
|
56
|
+
This class provides a bridge between the MemU MemoryService and LangChain's
|
|
57
|
+
tooling ecosystem.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, memory_service: MemoryService):
|
|
61
|
+
"""Initializes the MemULangGraphTools with a memory service."""
|
|
62
|
+
self.memory_service = memory_service
|
|
63
|
+
# Expose the langgraph module to ensure it's "used" even if just by reference in this class
|
|
64
|
+
self._graph_backend = langgraph
|
|
65
|
+
|
|
66
|
+
def tools(self) -> list[BaseTool]:
|
|
67
|
+
"""Return a list of tools compatible with LangGraph."""
|
|
68
|
+
return [
|
|
69
|
+
self.save_memory_tool(),
|
|
70
|
+
self.search_memory_tool(),
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
def save_memory_tool(self) -> StructuredTool:
|
|
74
|
+
"""Creates a tool to save information into MemU."""
|
|
75
|
+
|
|
76
|
+
async def _save(content: str, user_id: str, metadata: dict | None = None) -> str:
|
|
77
|
+
logger.info("Entering save_memory_tool for user_id: %s", user_id)
|
|
78
|
+
filename = f"memu_input_{uuid.uuid4()}.txt"
|
|
79
|
+
temp_dir = tempfile.gettempdir()
|
|
80
|
+
file_path = os.path.join(temp_dir, filename)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
84
|
+
f.write(content)
|
|
85
|
+
|
|
86
|
+
logger.debug("Calling memory_service.memorize with temporary file: %s", file_path)
|
|
87
|
+
await self.memory_service.memorize(
|
|
88
|
+
resource_url=file_path,
|
|
89
|
+
modality="conversation",
|
|
90
|
+
user={"user_id": user_id, **(metadata or {})},
|
|
91
|
+
)
|
|
92
|
+
logger.info("Successfully saved memory for user_id: %s", user_id)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
error_msg = f"Failed to save memory for user {user_id}: {e!s}"
|
|
95
|
+
logger.exception(error_msg)
|
|
96
|
+
return str(MemUIntegrationError(error_msg))
|
|
97
|
+
finally:
|
|
98
|
+
if os.path.exists(file_path):
|
|
99
|
+
with contextlib.suppress(OSError):
|
|
100
|
+
os.remove(file_path)
|
|
101
|
+
logger.debug("Cleaned up temporary file: %s", file_path)
|
|
102
|
+
|
|
103
|
+
return "Memory saved successfully."
|
|
104
|
+
|
|
105
|
+
return StructuredTool.from_function(
|
|
106
|
+
func=None,
|
|
107
|
+
coroutine=_save,
|
|
108
|
+
name="save_memory",
|
|
109
|
+
description="Save a piece of information, conversation snippet, or memory for a user.",
|
|
110
|
+
args_schema=SaveRecallInput,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def search_memory_tool(self) -> StructuredTool:
|
|
114
|
+
"""Creates a tool to search for information in MemU."""
|
|
115
|
+
|
|
116
|
+
async def _search(
|
|
117
|
+
query: str,
|
|
118
|
+
user_id: str,
|
|
119
|
+
limit: int = 5,
|
|
120
|
+
metadata_filter: dict | None = None,
|
|
121
|
+
min_relevance_score: float = 0.0,
|
|
122
|
+
) -> str:
|
|
123
|
+
logger.info("Entering search_memory_tool for user_id: %s, query: '%s'", user_id, query)
|
|
124
|
+
try:
|
|
125
|
+
queries = [{"role": "user", "content": query}]
|
|
126
|
+
where_filter = {"user_id": user_id}
|
|
127
|
+
if metadata_filter:
|
|
128
|
+
where_filter.update(metadata_filter)
|
|
129
|
+
|
|
130
|
+
logger.debug("Calling memory_service.retrieve with where_filter: %s", where_filter)
|
|
131
|
+
result = await self.memory_service.retrieve(
|
|
132
|
+
queries=queries,
|
|
133
|
+
where=where_filter,
|
|
134
|
+
)
|
|
135
|
+
logger.info("Successfully retrieved memories for user_id: %s", user_id)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
error_msg = f"Failed to search memory for user {user_id}: {e!s}"
|
|
138
|
+
logger.exception(error_msg)
|
|
139
|
+
return str(MemUIntegrationError(error_msg))
|
|
140
|
+
|
|
141
|
+
items = result.get("items", [])
|
|
142
|
+
if min_relevance_score > 0:
|
|
143
|
+
items = [item for item in items if item.get("score", 1.0) >= min_relevance_score]
|
|
144
|
+
|
|
145
|
+
if not items:
|
|
146
|
+
logger.info("No memories found for user_id: %s", user_id)
|
|
147
|
+
return "No relevant memories found."
|
|
148
|
+
|
|
149
|
+
response_text = "Retrieved Memories:\n"
|
|
150
|
+
for idx, item in enumerate(items[:limit]):
|
|
151
|
+
summary = item.get("summary", "")
|
|
152
|
+
score = item.get("score", "N/A")
|
|
153
|
+
response_text += f"{idx + 1}. [Score: {score}] {summary}\n"
|
|
154
|
+
|
|
155
|
+
return response_text
|
|
156
|
+
|
|
157
|
+
return StructuredTool.from_function(
|
|
158
|
+
func=None,
|
|
159
|
+
coroutine=_search,
|
|
160
|
+
name="search_memory",
|
|
161
|
+
description="Search for relevant memories or information for a user based on a query.",
|
|
162
|
+
args_schema=SearchRecallInput,
|
|
163
|
+
)
|
memu/llm/backends/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from memu.llm.backends.base import LLMBackend
|
|
2
2
|
from memu.llm.backends.doubao import DoubaoLLMBackend
|
|
3
|
+
from memu.llm.backends.grok import GrokBackend
|
|
3
4
|
from memu.llm.backends.openai import OpenAILLMBackend
|
|
5
|
+
from memu.llm.backends.openrouter import OpenRouterLLMBackend
|
|
4
6
|
|
|
5
|
-
__all__ = ["DoubaoLLMBackend", "LLMBackend", "OpenAILLMBackend"]
|
|
7
|
+
__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from memu.llm.backends.openai import OpenAILLMBackend
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GrokBackend(OpenAILLMBackend):
|
|
7
|
+
"""Backend for Grok (xAI) LLM API."""
|
|
8
|
+
|
|
9
|
+
name = "grok"
|
|
10
|
+
# Grok uses the same payload structure as OpenAI
|
|
11
|
+
# We inherits build_summary_payload, parse_summary_response, etc.
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
from memu.llm.backends.base import LLMBackend
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OpenRouterLLMBackend(LLMBackend):
|
|
9
|
+
"""Backend for OpenRouter LLM API (OpenAI-compatible)."""
|
|
10
|
+
|
|
11
|
+
name = "openrouter"
|
|
12
|
+
summary_endpoint = "/api/v1/chat/completions"
|
|
13
|
+
|
|
14
|
+
def build_summary_payload(
|
|
15
|
+
self, *, text: str, system_prompt: str | None, chat_model: str, max_tokens: int | None
|
|
16
|
+
) -> dict[str, Any]:
|
|
17
|
+
"""Build payload for OpenRouter chat completions (OpenAI-compatible)."""
|
|
18
|
+
prompt = system_prompt or "Summarize the text in one short paragraph."
|
|
19
|
+
payload: dict[str, Any] = {
|
|
20
|
+
"model": chat_model,
|
|
21
|
+
"messages": [
|
|
22
|
+
{"role": "system", "content": prompt},
|
|
23
|
+
{"role": "user", "content": text},
|
|
24
|
+
],
|
|
25
|
+
"temperature": 0.2,
|
|
26
|
+
}
|
|
27
|
+
if max_tokens is not None:
|
|
28
|
+
payload["max_tokens"] = max_tokens
|
|
29
|
+
return payload
|
|
30
|
+
|
|
31
|
+
def parse_summary_response(self, data: dict[str, Any]) -> str:
|
|
32
|
+
"""Parse OpenRouter response (OpenAI-compatible format)."""
|
|
33
|
+
return cast(str, data["choices"][0]["message"]["content"])
|
|
34
|
+
|
|
35
|
+
def build_vision_payload(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
prompt: str,
|
|
39
|
+
base64_image: str,
|
|
40
|
+
mime_type: str,
|
|
41
|
+
system_prompt: str | None,
|
|
42
|
+
chat_model: str,
|
|
43
|
+
max_tokens: int | None,
|
|
44
|
+
) -> dict[str, Any]:
|
|
45
|
+
"""Build payload for OpenRouter Vision API (OpenAI-compatible)."""
|
|
46
|
+
messages: list[dict[str, Any]] = []
|
|
47
|
+
if system_prompt:
|
|
48
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
49
|
+
|
|
50
|
+
messages.append({
|
|
51
|
+
"role": "user",
|
|
52
|
+
"content": [
|
|
53
|
+
{"type": "text", "text": prompt},
|
|
54
|
+
{
|
|
55
|
+
"type": "image_url",
|
|
56
|
+
"image_url": {
|
|
57
|
+
"url": f"data:{mime_type};base64,{base64_image}",
|
|
58
|
+
},
|
|
59
|
+
},
|
|
60
|
+
],
|
|
61
|
+
})
|
|
62
|
+
|
|
63
|
+
payload: dict[str, Any] = {
|
|
64
|
+
"model": chat_model,
|
|
65
|
+
"messages": messages,
|
|
66
|
+
"temperature": 0.2,
|
|
67
|
+
}
|
|
68
|
+
if max_tokens is not None:
|
|
69
|
+
payload["max_tokens"] = max_tokens
|
|
70
|
+
return payload
|
memu/llm/http_client.py
CHANGED
|
@@ -10,7 +10,9 @@ import httpx
|
|
|
10
10
|
|
|
11
11
|
from memu.llm.backends.base import LLMBackend
|
|
12
12
|
from memu.llm.backends.doubao import DoubaoLLMBackend
|
|
13
|
+
from memu.llm.backends.grok import GrokBackend
|
|
13
14
|
from memu.llm.backends.openai import OpenAILLMBackend
|
|
15
|
+
from memu.llm.backends.openrouter import OpenRouterLLMBackend
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
# Minimal embedding backend support (moved from embedding module)
|
|
@@ -47,11 +49,26 @@ class _DoubaoEmbeddingBackend(_EmbeddingBackend):
|
|
|
47
49
|
return [cast(list[float], d["embedding"]) for d in data["data"]]
|
|
48
50
|
|
|
49
51
|
|
|
52
|
+
class _OpenRouterEmbeddingBackend(_EmbeddingBackend):
|
|
53
|
+
"""OpenRouter uses OpenAI-compatible embedding API."""
|
|
54
|
+
|
|
55
|
+
name = "openrouter"
|
|
56
|
+
embedding_endpoint = "/api/v1/embeddings"
|
|
57
|
+
|
|
58
|
+
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
|
|
59
|
+
return {"model": embed_model, "input": inputs}
|
|
60
|
+
|
|
61
|
+
def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
|
|
62
|
+
return [cast(list[float], d["embedding"]) for d in data["data"]]
|
|
63
|
+
|
|
64
|
+
|
|
50
65
|
logger = logging.getLogger(__name__)
|
|
51
66
|
|
|
52
67
|
LLM_BACKENDS: dict[str, Callable[[], LLMBackend]] = {
|
|
53
68
|
OpenAILLMBackend.name: OpenAILLMBackend,
|
|
54
69
|
DoubaoLLMBackend.name: DoubaoLLMBackend,
|
|
70
|
+
GrokBackend.name: GrokBackend,
|
|
71
|
+
OpenRouterLLMBackend.name: OpenRouterLLMBackend,
|
|
55
72
|
}
|
|
56
73
|
|
|
57
74
|
|
|
@@ -229,6 +246,8 @@ class HTTPLLMClient:
|
|
|
229
246
|
backends: dict[str, type[_EmbeddingBackend]] = {
|
|
230
247
|
_OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend,
|
|
231
248
|
_DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend,
|
|
249
|
+
"grok": _OpenAIEmbeddingBackend,
|
|
250
|
+
_OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend,
|
|
232
251
|
}
|
|
233
252
|
factory = backends.get(provider)
|
|
234
253
|
if not factory:
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import lazyllm # type: ignore[import-untyped]
|
|
6
|
+
from lazyllm import LOG
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LazyLLMClient:
|
|
10
|
+
"""LAZYLLM client that relies on the LazyLLM framework."""
|
|
11
|
+
|
|
12
|
+
DEFAULT_SOURCE = "qwen"
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
*,
|
|
17
|
+
llm_source: str | None = None,
|
|
18
|
+
vlm_source: str | None = None,
|
|
19
|
+
embed_source: str | None = None,
|
|
20
|
+
stt_source: str | None = None,
|
|
21
|
+
chat_model: str | None = None,
|
|
22
|
+
vlm_model: str | None = None,
|
|
23
|
+
embed_model: str | None = None,
|
|
24
|
+
stt_model: str | None = None,
|
|
25
|
+
):
|
|
26
|
+
self.llm_source = llm_source or self.DEFAULT_SOURCE
|
|
27
|
+
self.vlm_source = vlm_source or self.DEFAULT_SOURCE
|
|
28
|
+
self.embed_source = embed_source or self.DEFAULT_SOURCE
|
|
29
|
+
self.stt_source = stt_source or self.DEFAULT_SOURCE
|
|
30
|
+
self.chat_model = chat_model
|
|
31
|
+
self.vlm_model = vlm_model
|
|
32
|
+
self.embed_model = embed_model
|
|
33
|
+
self.stt_model = stt_model
|
|
34
|
+
|
|
35
|
+
async def _call_async(self, client: Any, *args: Any, **kwargs: Any) -> Any:
|
|
36
|
+
"""
|
|
37
|
+
Asynchronously call a LazyLLM client with given arguments and keyword arguments.
|
|
38
|
+
"""
|
|
39
|
+
if kwargs:
|
|
40
|
+
return await asyncio.to_thread(functools.partial(client, *args, **kwargs))
|
|
41
|
+
else:
|
|
42
|
+
return await asyncio.to_thread(client, *args)
|
|
43
|
+
|
|
44
|
+
async def summarize(
|
|
45
|
+
self,
|
|
46
|
+
text: str,
|
|
47
|
+
*,
|
|
48
|
+
max_tokens: int | None = None,
|
|
49
|
+
system_prompt: str | None = None,
|
|
50
|
+
) -> str:
|
|
51
|
+
"""
|
|
52
|
+
Generate a summary or response for the input text using the configured LLM backend.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
text: The input text to summarize or process.
|
|
56
|
+
max_tokens: (Optional) Maximum number of tokens to generate.
|
|
57
|
+
system_prompt: (Optional) System instruction to guide the LLM behavior.
|
|
58
|
+
Return:
|
|
59
|
+
The generated summary text as a string.
|
|
60
|
+
"""
|
|
61
|
+
client = lazyllm.namespace("MEMU").OnlineModule(source=self.llm_source, model=self.chat_model, type="llm")
|
|
62
|
+
prompt = system_prompt or "Summarize the text in one short paragraph."
|
|
63
|
+
full_prompt = f"{prompt}\n\ntext:\n{text}"
|
|
64
|
+
LOG.debug(f"Summarizing text with {self.llm_source}/{self.chat_model}")
|
|
65
|
+
response = await self._call_async(client, full_prompt)
|
|
66
|
+
return cast(str, response)
|
|
67
|
+
|
|
68
|
+
async def vision(
|
|
69
|
+
self,
|
|
70
|
+
prompt: str,
|
|
71
|
+
image_path: str,
|
|
72
|
+
*,
|
|
73
|
+
max_tokens: int | None = None,
|
|
74
|
+
system_prompt: str | None = None,
|
|
75
|
+
) -> tuple[str, Any]:
|
|
76
|
+
"""
|
|
77
|
+
Process an image with a text prompt using the configured VLM (Vision-Language Model).
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
prompt: Text prompt describing the request or question about the image.
|
|
82
|
+
image_path: Path to the image file to be analyzed.
|
|
83
|
+
max_tokens: (Optional) Maximum number of tokens to generate.
|
|
84
|
+
system_prompt: (Optional) System instruction to guide the VLM behavior.
|
|
85
|
+
Return:
|
|
86
|
+
A tuple containing the generated text response and None (reserved for metadata).
|
|
87
|
+
"""
|
|
88
|
+
client = lazyllm.namespace("MEMU").OnlineModule(source=self.vlm_source, model=self.vlm_model, type="vlm")
|
|
89
|
+
LOG.debug(f"Processing image with {self.vlm_source}/{self.vlm_model}: {image_path}")
|
|
90
|
+
# LazyLLM VLM accepts prompt as first positional argument and image_path as keyword argument
|
|
91
|
+
response = await self._call_async(client, prompt, lazyllm_files=image_path)
|
|
92
|
+
return response, None
|
|
93
|
+
|
|
94
|
+
async def embed(
|
|
95
|
+
self,
|
|
96
|
+
texts: list[str],
|
|
97
|
+
batch_size: int = 10,
|
|
98
|
+
) -> list[list[float]]:
|
|
99
|
+
"""
|
|
100
|
+
Generate vector embeddings for a list of text strings.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
texts: List of text strings to embed.
|
|
104
|
+
batch_size: (Optional) Batch size for processing embeddings (default: 10).
|
|
105
|
+
Return:
|
|
106
|
+
A list of embedding vectors (list of floats), one for each input text.
|
|
107
|
+
"""
|
|
108
|
+
client = lazyllm.namespace("MEMU").OnlineModule(
|
|
109
|
+
source=self.embed_source, model=self.embed_model, type="embed", batch_size=batch_size
|
|
110
|
+
)
|
|
111
|
+
LOG.debug(f"embed {len(texts)} texts with {self.embed_source}/{self.embed_model}")
|
|
112
|
+
response = await self._call_async(client, texts)
|
|
113
|
+
return cast(list[list[float]], response)
|
|
114
|
+
|
|
115
|
+
async def transcribe(
|
|
116
|
+
self,
|
|
117
|
+
audio_path: str,
|
|
118
|
+
language: str | None = None,
|
|
119
|
+
prompt: str | None = None,
|
|
120
|
+
) -> str:
|
|
121
|
+
"""
|
|
122
|
+
Transcribe audio content to text using the configured STT (Speech-to-Text) backend.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
audio_path: Path to the audio file to transcribe.
|
|
126
|
+
language: (Optional) Language code of the audio content.
|
|
127
|
+
prompt: (Optional) Text prompt to guide the transcription or translation.
|
|
128
|
+
Return:
|
|
129
|
+
The transcribed text as a string.
|
|
130
|
+
"""
|
|
131
|
+
client = lazyllm.namespace("MEMU").OnlineModule(source=self.stt_source, model=self.stt_model, type="stt")
|
|
132
|
+
LOG.debug(f"Transcribing audio with {self.stt_source}/{self.stt_model}: {audio_path}")
|
|
133
|
+
response = await self._call_async(client, audio_path)
|
|
134
|
+
return cast(str, response)
|