memplex 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memnex/__init__.py +31 -0
- memnex/__main__.py +6 -0
- memnex/_plugin/.claude-plugin/plugin.json +24 -0
- memnex/_plugin/.mcp.json +9 -0
- memnex/_plugin/__init__.py +0 -0
- memnex/_plugin/hooks/hooks.json +43 -0
- memnex/_plugin/scripts/hook-runner.py +166 -0
- memnex/_plugin/skills/mem-explore/SKILL.md +83 -0
- memnex/_plugin/skills/mem-manage/SKILL.md +92 -0
- memnex/_plugin/skills/mem-search/SKILL.md +85 -0
- memnex/_plugin/skills/mem-write/SKILL.md +78 -0
- memnex/adapters/__init__.py +14 -0
- memnex/adapters/claude_skill.py +169 -0
- memnex/adapters/cli.py +525 -0
- memnex/adapters/http_api.py +314 -0
- memnex/adapters/mcp_server.py +448 -0
- memnex/compaction.py +563 -0
- memnex/config.py +366 -0
- memnex/core/__init__.py +13 -0
- memnex/core/associator/__init__.py +8 -0
- memnex/core/associator/domain_classifier.py +75 -0
- memnex/core/associator/entity_aligner.py +127 -0
- memnex/core/associator/ref_linker.py +197 -0
- memnex/core/associator/term_mapper.py +77 -0
- memnex/core/dictionaries/__init__.py +50 -0
- memnex/core/engine.py +667 -0
- memnex/core/extractors/__init__.py +15 -0
- memnex/core/extractors/docx.py +97 -0
- memnex/core/extractors/image.py +233 -0
- memnex/core/extractors/markdown.py +139 -0
- memnex/core/extractors/pdf.py +133 -0
- memnex/core/extractors/vision_mapper.py +131 -0
- memnex/core/handlers/__init__.py +7 -0
- memnex/core/handlers/clipboard.py +40 -0
- memnex/core/handlers/file_handler.py +62 -0
- memnex/core/handlers/url_handler.py +132 -0
- memnex/llm/__init__.py +25 -0
- memnex/llm/enhancer.py +226 -0
- memnex/llm/fallback_chain.py +87 -0
- memnex/llm/injection_guard.py +178 -0
- memnex/llm/provider.py +130 -0
- memnex/llm/providers/__init__.py +22 -0
- memnex/llm/providers/anthropic.py +135 -0
- memnex/llm/providers/local.py +135 -0
- memnex/llm/providers/rule_based.py +68 -0
- memnex/llm/sanitizer.py +67 -0
- memnex/models/__init__.py +68 -0
- memnex/models/feedback.py +42 -0
- memnex/models/graph.py +33 -0
- memnex/models/memory.py +102 -0
- memnex/models/misc.py +185 -0
- memnex/models/paragraph.py +45 -0
- memnex/models/search.py +51 -0
- memnex/models/source.py +23 -0
- memnex/models/task.py +62 -0
- memnex/processing/__init__.py +1 -0
- memnex/processing/graph_builder.py +278 -0
- memnex/processing/merger/__init__.py +6 -0
- memnex/processing/merger/confidence_calculator.py +127 -0
- memnex/processing/merger/conflict_resolver.py +116 -0
- memnex/retrieval/__init__.py +1 -0
- memnex/retrieval/dedup.py +386 -0
- memnex/retrieval/embedding.py +289 -0
- memnex/retrieval/reranker.py +299 -0
- memnex/service.py +902 -0
- memnex/storage/__init__.py +65 -0
- memnex/storage/base.py +132 -0
- memnex/storage/changelog.py +106 -0
- memnex/storage/feedback.py +486 -0
- memnex/storage/lite/__init__.py +5 -0
- memnex/storage/lite/store.py +606 -0
- memnex/storage/vector.py +265 -0
- memnex/wiki/__init__.py +11 -0
- memnex/wiki/community.py +221 -0
- memnex/wiki/compiler.py +545 -0
- memnex/wiki/generator.py +270 -0
- memnex/wiki/search.py +282 -0
- memnex/worker.py +412 -0
- memplex-3.2.0.dist-info/METADATA +37 -0
- memplex-3.2.0.dist-info/RECORD +83 -0
- memplex-3.2.0.dist-info/WHEEL +5 -0
- memplex-3.2.0.dist-info/entry_points.txt +2 -0
- memplex-3.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Chain-of-responsibility fallback for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from memnex.models import IntentType
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from memnex.llm.provider import LLMProvider
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FallbackChain:
|
|
17
|
+
"""Try providers in order; first success wins, final fallback to RuleBasedProvider.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
providers:
|
|
22
|
+
Ordered list of LLMProvider implementations. They are tried
|
|
23
|
+
sequentially; the first one that returns without raising wins.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, providers: list[LLMProvider] | None = None) -> None:
|
|
27
|
+
self._providers: list[LLMProvider] = providers or []
|
|
28
|
+
|
|
29
|
+
def _fallback(self) -> LLMProvider:
|
|
30
|
+
"""Lazily create a RuleBasedProvider as the ultimate fallback."""
|
|
31
|
+
from memnex.llm.providers.rule_based import RuleBasedProvider
|
|
32
|
+
|
|
33
|
+
return RuleBasedProvider()
|
|
34
|
+
|
|
35
|
+
# -- LLMProvider interface ------------------------------------------
|
|
36
|
+
|
|
37
|
+
async def classify_intent(
|
|
38
|
+
self, query: str, context: dict | None = None
|
|
39
|
+
) -> IntentType:
|
|
40
|
+
errors: list[str] = []
|
|
41
|
+
for p in self._providers:
|
|
42
|
+
try:
|
|
43
|
+
return await p.classify_intent(query, context)
|
|
44
|
+
except Exception as exc:
|
|
45
|
+
errors.append(f"{p.__class__.__name__}: {exc}")
|
|
46
|
+
logger.debug("classify_intent fallback: %s", errors[-1])
|
|
47
|
+
return await self._fallback().classify_intent(query, context)
|
|
48
|
+
|
|
49
|
+
async def summarize(self, content: str, max_tokens: int = 256) -> str:
|
|
50
|
+
for p in self._providers:
|
|
51
|
+
try:
|
|
52
|
+
return await p.summarize(content, max_tokens)
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
logger.debug("summarize fallback: %s: %s", p.__class__.__name__, exc)
|
|
55
|
+
return content[:max_tokens]
|
|
56
|
+
|
|
57
|
+
async def extract_structured(self, prompt: str, schema: dict) -> dict:
|
|
58
|
+
for p in self._providers:
|
|
59
|
+
try:
|
|
60
|
+
return await p.extract_structured(prompt, schema)
|
|
61
|
+
except Exception as exc:
|
|
62
|
+
logger.debug("extract_structured fallback: %s: %s", p.__class__.__name__, exc)
|
|
63
|
+
return {}
|
|
64
|
+
|
|
65
|
+
async def generate_hypothetical(self, query: str) -> str:
|
|
66
|
+
for p in self._providers:
|
|
67
|
+
try:
|
|
68
|
+
return await p.generate_hypothetical(query)
|
|
69
|
+
except Exception as exc:
|
|
70
|
+
logger.debug("generate_hypothetical fallback: %s: %s", p.__class__.__name__, exc)
|
|
71
|
+
return query
|
|
72
|
+
|
|
73
|
+
async def complete(self, prompt: str) -> str:
|
|
74
|
+
for p in self._providers:
|
|
75
|
+
try:
|
|
76
|
+
return await p.complete(prompt)
|
|
77
|
+
except Exception as exc:
|
|
78
|
+
logger.debug("complete fallback: %s: %s", p.__class__.__name__, exc)
|
|
79
|
+
return ""
|
|
80
|
+
|
|
81
|
+
async def complete_json(self, prompt: str) -> dict:
|
|
82
|
+
for p in self._providers:
|
|
83
|
+
try:
|
|
84
|
+
return await p.complete_json(prompt)
|
|
85
|
+
except Exception as exc:
|
|
86
|
+
logger.debug("complete_json fallback: %s: %s", p.__class__.__name__, exc)
|
|
87
|
+
return {}
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Indirect prompt injection guard for memory recall contexts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from typing import TYPE_CHECKING, List
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from memnex.models import SearchResult
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class IndirectInjectionGuard:
|
|
16
|
+
"""Detect and mitigate indirect prompt injection in recalled memories.
|
|
17
|
+
|
|
18
|
+
While ``LLMPromptSanitizer`` protects against *direct* injection (user
|
|
19
|
+
input), this class handles *indirect* injection: an attacker embeds
|
|
20
|
+
malicious instructions in a source document that survives extraction
|
|
21
|
+
and gets injected into the LLM context when recalled via RAG.
|
|
22
|
+
|
|
23
|
+
Defense layers:
|
|
24
|
+
1. Content scanning -- regex-based detection of system-role keywords.
|
|
25
|
+
2. Protective wrapping -- memory content is wrapped in ``[MEMORY ...]``
|
|
26
|
+
tags so the LLM treats them as data, not instructions.
|
|
27
|
+
3. Trust-level labelling -- each memory is annotated with a trust level
|
|
28
|
+
derived from its ``source_type``.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# Multi-language injection patterns (compiled once at class load)
|
|
32
|
+
INJECTION_PATTERNS: list[str] = [
|
|
33
|
+
r"ignore\s+(all\s+)?previous\s+instructions?",
|
|
34
|
+
r"disregard\s+(all\s+)?prior\s+instructions?",
|
|
35
|
+
r"system\s*:\s*you\s+are",
|
|
36
|
+
r"<\|endoftext\|>",
|
|
37
|
+
r"<\|im_start\|>",
|
|
38
|
+
r"忽略(之前|前面|上面)的(所有|全部)?指令",
|
|
39
|
+
r"忽略系统提示",
|
|
40
|
+
r"你现在是",
|
|
41
|
+
r"新的系统提示",
|
|
42
|
+
r"assistant\s*:\s*sure",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
_compiled: list[re.Pattern] = [re.compile(p, re.IGNORECASE) for p in INJECTION_PATTERNS]
|
|
46
|
+
|
|
47
|
+
# Trust level mapping: source_type value -> trust level label
|
|
48
|
+
TRUST_LEVELS: dict[str, str] = {
|
|
49
|
+
"requirement": "HIGH",
|
|
50
|
+
"meeting": "MEDIUM",
|
|
51
|
+
"code": "MEDIUM",
|
|
52
|
+
"wiki": "LOW",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def scan(cls, content: str) -> bool:
|
|
57
|
+
"""Scan content for suspected injection attacks.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
True if the content is suspected to contain an injection payload;
|
|
62
|
+
the caller should discard or isolate the memory entry.
|
|
63
|
+
"""
|
|
64
|
+
for pattern in cls._compiled:
|
|
65
|
+
if pattern.search(content):
|
|
66
|
+
return True
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def wrap_for_context(
|
|
71
|
+
cls,
|
|
72
|
+
memories: list[SearchResult],
|
|
73
|
+
store: object,
|
|
74
|
+
) -> str:
|
|
75
|
+
"""Wrap recalled memories in protective tags for LLM context injection.
|
|
76
|
+
|
|
77
|
+
Each memory is enclosed in ``[MEMORY START | trust=LEVEL | id=...]``
|
|
78
|
+
/ ``[MEMORY END]`` markers so the LLM treats the content as
|
|
79
|
+
data, not as system instructions.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
memories:
|
|
84
|
+
Search results to be injected into the LLM context.
|
|
85
|
+
store:
|
|
86
|
+
A MemoryStore-like object with a ``get(id)`` method that
|
|
87
|
+
returns a MemoryNode (or None).
|
|
88
|
+
"""
|
|
89
|
+
parts: list[str] = []
|
|
90
|
+
for r in memories:
|
|
91
|
+
func = store.get(r.func_id) if store else None
|
|
92
|
+
if not func:
|
|
93
|
+
continue
|
|
94
|
+
source_type_val = (
|
|
95
|
+
func.source_type.value
|
|
96
|
+
if hasattr(func.source_type, "value")
|
|
97
|
+
else (func.source_type or "wiki")
|
|
98
|
+
)
|
|
99
|
+
trust = cls.TRUST_LEVELS.get(source_type_val, "LOW")
|
|
100
|
+
summary = r.summary or func.name
|
|
101
|
+
parts.append(
|
|
102
|
+
f"[MEMORY START | trust={trust} | id={r.func_id}]\n"
|
|
103
|
+
f"{summary}\n"
|
|
104
|
+
f"[MEMORY END]"
|
|
105
|
+
)
|
|
106
|
+
return "\n\n".join(parts)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def filter_and_wrap(
|
|
110
|
+
cls,
|
|
111
|
+
memories: list[SearchResult],
|
|
112
|
+
store: object,
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Filter out injection-suspected memories, then wrap the rest.
|
|
115
|
+
|
|
116
|
+
Memories that trigger the injection scanner are logged as warnings
|
|
117
|
+
and excluded from the output.
|
|
118
|
+
"""
|
|
119
|
+
safe: list[SearchResult] = []
|
|
120
|
+
for r in memories:
|
|
121
|
+
func = store.get(r.func_id) if store else None
|
|
122
|
+
if func:
|
|
123
|
+
memory_type = getattr(func, "memory_type", "function")
|
|
124
|
+
text = cls._extract_scan_text(func, memory_type)
|
|
125
|
+
if cls.scan(text):
|
|
126
|
+
logger.warning(
|
|
127
|
+
"Indirect injection detected in memory %s (type=%s), skipped.",
|
|
128
|
+
r.func_id,
|
|
129
|
+
memory_type,
|
|
130
|
+
)
|
|
131
|
+
continue
|
|
132
|
+
safe.append(r)
|
|
133
|
+
return cls.wrap_for_context(safe, store)
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def _extract_scan_text(cls, func: object, memory_type: str) -> str:
|
|
137
|
+
"""Extract the relevant text fields for injection scanning by memory type."""
|
|
138
|
+
if memory_type == "function":
|
|
139
|
+
return " ".join(
|
|
140
|
+
fv.desc
|
|
141
|
+
for role in ("trigger", "condition", "action", "benefit")
|
|
142
|
+
for fv in getattr(func, role, [])
|
|
143
|
+
)
|
|
144
|
+
if memory_type == "fact":
|
|
145
|
+
return " ".join(
|
|
146
|
+
filter(
|
|
147
|
+
None,
|
|
148
|
+
[
|
|
149
|
+
getattr(func, "subject", ""),
|
|
150
|
+
getattr(func, "predicate", ""),
|
|
151
|
+
getattr(func, "object_", ""),
|
|
152
|
+
],
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
if memory_type == "preference":
|
|
156
|
+
return " ".join(
|
|
157
|
+
filter(
|
|
158
|
+
None,
|
|
159
|
+
[
|
|
160
|
+
getattr(func, "aspect", ""),
|
|
161
|
+
getattr(func, "preference", ""),
|
|
162
|
+
],
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
if memory_type == "observation":
|
|
166
|
+
return " ".join(
|
|
167
|
+
filter(
|
|
168
|
+
None,
|
|
169
|
+
[
|
|
170
|
+
getattr(func, "event", ""),
|
|
171
|
+
getattr(func, "context", ""),
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
# Unknown type: scan all string attributes
|
|
176
|
+
return " ".join(
|
|
177
|
+
str(v) for v in vars(func).values() if isinstance(v, str)
|
|
178
|
+
)
|
memnex/llm/provider.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""LLM Provider protocol definition."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from memnex.models import IntentType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class LLMProvider(Protocol):
|
|
10
|
+
"""LLM Provider standard protocol.
|
|
11
|
+
|
|
12
|
+
All LLM provider implementations must satisfy this protocol.
|
|
13
|
+
Used for intent classification, summarization, structured extraction,
|
|
14
|
+
HyDE generation, and general-purpose completion.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
async def classify_intent(
|
|
18
|
+
self, query: str, context: dict | None = None
|
|
19
|
+
) -> IntentType:
|
|
20
|
+
"""Classify user query intent.
|
|
21
|
+
|
|
22
|
+
Returns one of: IMMEDIATE, SYNTHESIS, RELATION, ALL.
|
|
23
|
+
"""
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
async def summarize(self, content: str, max_tokens: int = 256) -> str:
|
|
27
|
+
"""Summarize content into a concise representation."""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
async def extract_structured(self, prompt: str, schema: dict) -> dict:
|
|
31
|
+
"""Extract structured data according to the provided JSON schema."""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
async def generate_hypothetical(self, query: str) -> str:
|
|
35
|
+
"""Generate a hypothetical answer for HyDE (Hypothetical Document Embeddings)."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
async def complete(self, prompt: str) -> str:
|
|
39
|
+
"""General-purpose text completion."""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
async def complete_json(self, prompt: str) -> dict:
|
|
43
|
+
"""Complete and parse response as JSON."""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def create_provider(
|
|
48
|
+
provider: str = "auto",
|
|
49
|
+
*,
|
|
50
|
+
anthropic_api_key: str | None = None,
|
|
51
|
+
anthropic_model: str = "claude-sonnet-4-6",
|
|
52
|
+
local_endpoint: str = "http://localhost:11434/v1",
|
|
53
|
+
local_model: str = "qwen2.5",
|
|
54
|
+
fallback_chain: list[str] | None = None,
|
|
55
|
+
) -> LLMProvider:
|
|
56
|
+
"""Factory: create an LLM provider instance by name.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
provider:
|
|
61
|
+
One of "auto", "anthropic", "local", "rule-based".
|
|
62
|
+
anthropic_api_key:
|
|
63
|
+
Anthropic API key (required for anthropic provider).
|
|
64
|
+
anthropic_model:
|
|
65
|
+
Model name for Anthropic.
|
|
66
|
+
local_endpoint:
|
|
67
|
+
OpenAI-compatible endpoint URL (e.g., Ollama / LM Studio).
|
|
68
|
+
local_model:
|
|
69
|
+
Model name for the local provider.
|
|
70
|
+
fallback_chain:
|
|
71
|
+
Ordered list of provider names for FallbackChain.
|
|
72
|
+
Defaults to ["anthropic", "local", "rule-based"] when *provider*
|
|
73
|
+
is "auto".
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
An object satisfying the LLMProvider protocol.
|
|
78
|
+
"""
|
|
79
|
+
if provider == "auto":
|
|
80
|
+
chain = fallback_chain or ["anthropic", "local", "rule-based"]
|
|
81
|
+
return create_provider(
|
|
82
|
+
provider=None,
|
|
83
|
+
anthropic_api_key=anthropic_api_key,
|
|
84
|
+
anthropic_model=anthropic_model,
|
|
85
|
+
local_endpoint=local_endpoint,
|
|
86
|
+
local_model=local_model,
|
|
87
|
+
fallback_chain=chain,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# If provider is None, build a FallbackChain from fallback_chain list.
|
|
91
|
+
if provider is None:
|
|
92
|
+
from memnex.llm.fallback_chain import FallbackChain
|
|
93
|
+
|
|
94
|
+
chain_names = fallback_chain or ["anthropic", "local", "rule-based"]
|
|
95
|
+
providers = []
|
|
96
|
+
for name in chain_names:
|
|
97
|
+
try:
|
|
98
|
+
p = create_provider(
|
|
99
|
+
provider=name,
|
|
100
|
+
anthropic_api_key=anthropic_api_key,
|
|
101
|
+
anthropic_model=anthropic_model,
|
|
102
|
+
local_endpoint=local_endpoint,
|
|
103
|
+
local_model=local_model,
|
|
104
|
+
)
|
|
105
|
+
providers.append(p)
|
|
106
|
+
except Exception:
|
|
107
|
+
continue
|
|
108
|
+
return FallbackChain(providers)
|
|
109
|
+
|
|
110
|
+
if provider == "anthropic":
|
|
111
|
+
from memnex.llm.providers.anthropic import AnthropicProvider
|
|
112
|
+
|
|
113
|
+
if not anthropic_api_key:
|
|
114
|
+
raise ValueError("anthropic_api_key is required for the Anthropic provider")
|
|
115
|
+
return AnthropicProvider(api_key=anthropic_api_key, model=anthropic_model)
|
|
116
|
+
|
|
117
|
+
if provider == "local":
|
|
118
|
+
from memnex.llm.providers.local import LocalProvider
|
|
119
|
+
|
|
120
|
+
return LocalProvider(endpoint=local_endpoint, model=local_model)
|
|
121
|
+
|
|
122
|
+
if provider == "rule-based":
|
|
123
|
+
from memnex.llm.providers.rule_based import RuleBasedProvider
|
|
124
|
+
|
|
125
|
+
return RuleBasedProvider()
|
|
126
|
+
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Unknown provider: {provider!r}. "
|
|
129
|
+
f"Choose from: auto, anthropic, local, rule-based"
|
|
130
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""LLM provider implementations."""
|
|
2
|
+
|
|
3
|
+
from memnex.llm.providers.rule_based import RuleBasedProvider
|
|
4
|
+
|
|
5
|
+
# Anthropic and Local providers are optional -- they require external
|
|
6
|
+
# packages that may not be installed. Import them lazily or guard with
|
|
7
|
+
# try/except at the call site.
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"RuleBasedProvider",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def __getattr__(name: str):
|
|
15
|
+
"""Lazy-load optional providers that depend on external packages."""
|
|
16
|
+
if name == "AnthropicProvider":
|
|
17
|
+
from memnex.llm.providers.anthropic import AnthropicProvider
|
|
18
|
+
return AnthropicProvider
|
|
19
|
+
if name == "LocalProvider":
|
|
20
|
+
from memnex.llm.providers.local import LocalProvider
|
|
21
|
+
return LocalProvider
|
|
22
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Anthropic LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from memnex.models import IntentType
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import anthropic
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"The 'anthropic' package is required for AnthropicProvider. "
|
|
15
|
+
"Install it with: pip install anthropic"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AnthropicProvider:
|
|
20
|
+
"""LLM provider backed by the Anthropic SDK (Claude).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
api_key:
|
|
25
|
+
Anthropic API key.
|
|
26
|
+
model:
|
|
27
|
+
Model identifier (default: claude-sonnet-4-6).
|
|
28
|
+
max_tokens:
|
|
29
|
+
Default maximum tokens for completions.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
api_key: str,
|
|
35
|
+
model: str = "claude-sonnet-4-6",
|
|
36
|
+
max_tokens: int = 1024,
|
|
37
|
+
) -> None:
|
|
38
|
+
self._client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
39
|
+
self._model = model
|
|
40
|
+
self._max_tokens = max_tokens
|
|
41
|
+
|
|
42
|
+
# -- helpers --------------------------------------------------------
|
|
43
|
+
|
|
44
|
+
async def _raw_complete(self, prompt: str, max_tokens: int | None = None) -> str:
|
|
45
|
+
"""Send a single-turn user message and return the assistant text."""
|
|
46
|
+
resp = await self._client.messages.create(
|
|
47
|
+
model=self._model,
|
|
48
|
+
max_tokens=max_tokens or self._max_tokens,
|
|
49
|
+
messages=[{"role": "user", "content": prompt}],
|
|
50
|
+
)
|
|
51
|
+
return resp.content[0].text
|
|
52
|
+
|
|
53
|
+
async def _raw_complete_json(self, prompt: str) -> dict:
|
|
54
|
+
"""Complete with JSON response expectation and parse the result."""
|
|
55
|
+
text = await self._raw_complete(prompt)
|
|
56
|
+
return self._parse_json(text)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _parse_json(text: str) -> dict:
|
|
60
|
+
"""Best-effort JSON extraction from LLM output."""
|
|
61
|
+
text = text.strip()
|
|
62
|
+
# Try direct parse
|
|
63
|
+
try:
|
|
64
|
+
return json.loads(text)
|
|
65
|
+
except json.JSONDecodeError:
|
|
66
|
+
pass
|
|
67
|
+
# Try to extract JSON block from markdown code fence
|
|
68
|
+
import re
|
|
69
|
+
m = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
|
|
70
|
+
if m:
|
|
71
|
+
try:
|
|
72
|
+
return json.loads(m.group(1).strip())
|
|
73
|
+
except json.JSONDecodeError:
|
|
74
|
+
pass
|
|
75
|
+
# Try to find first { ... } block
|
|
76
|
+
start = text.find("{")
|
|
77
|
+
end = text.rfind("}")
|
|
78
|
+
if start != -1 and end > start:
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(text[start : end + 1])
|
|
81
|
+
except json.JSONDecodeError:
|
|
82
|
+
pass
|
|
83
|
+
logger.warning("Failed to parse JSON from LLM response, returning empty dict")
|
|
84
|
+
return {}
|
|
85
|
+
|
|
86
|
+
# -- LLMProvider interface ------------------------------------------
|
|
87
|
+
|
|
88
|
+
async def classify_intent(
|
|
89
|
+
self, query: str, context: dict | None = None
|
|
90
|
+
) -> IntentType:
|
|
91
|
+
"""Classify user query intent using Claude."""
|
|
92
|
+
result = await self.complete_json(
|
|
93
|
+
f"Classify the intent of the following query. "
|
|
94
|
+
f'Respond with a JSON object: {{"intent": "search|understand|compare|relation"}}\n\nQuery: {query}'
|
|
95
|
+
)
|
|
96
|
+
intent_str = result.get("intent", "search")
|
|
97
|
+
mapping = {
|
|
98
|
+
"search": IntentType.IMMEDIATE,
|
|
99
|
+
"understand": IntentType.SYNTHESIS,
|
|
100
|
+
"compare": IntentType.RELATION,
|
|
101
|
+
"relation": IntentType.RELATION,
|
|
102
|
+
}
|
|
103
|
+
return mapping.get(intent_str, IntentType.IMMEDIATE)
|
|
104
|
+
|
|
105
|
+
async def summarize(self, content: str, max_tokens: int = 256) -> str:
|
|
106
|
+
"""Summarize content using Claude."""
|
|
107
|
+
prompt = (
|
|
108
|
+
f"Summarize the following content concisely in at most {max_tokens} tokens:\n\n{content}"
|
|
109
|
+
)
|
|
110
|
+
return await self._raw_complete(prompt, max_tokens=max_tokens)
|
|
111
|
+
|
|
112
|
+
async def extract_structured(self, prompt: str, schema: dict) -> dict:
|
|
113
|
+
"""Extract structured data according to a JSON schema."""
|
|
114
|
+
full_prompt = (
|
|
115
|
+
f"{prompt}\n\n"
|
|
116
|
+
f"Respond with valid JSON matching this schema:\n"
|
|
117
|
+
f"{json.dumps(schema, ensure_ascii=False)}"
|
|
118
|
+
)
|
|
119
|
+
return await self._raw_complete_json(full_prompt)
|
|
120
|
+
|
|
121
|
+
async def generate_hypothetical(self, query: str) -> str:
|
|
122
|
+
"""Generate a hypothetical answer for HyDE."""
|
|
123
|
+
prompt = (
|
|
124
|
+
f"Given the query below, write a brief hypothetical answer (2-3 sentences) "
|
|
125
|
+
f"as if a comprehensive knowledge base entry existed for it.\n\nQuery: {query}"
|
|
126
|
+
)
|
|
127
|
+
return await self._raw_complete(prompt, max_tokens=256)
|
|
128
|
+
|
|
129
|
+
async def complete(self, prompt: str) -> str:
|
|
130
|
+
"""General-purpose text completion."""
|
|
131
|
+
return await self._raw_complete(prompt)
|
|
132
|
+
|
|
133
|
+
async def complete_json(self, prompt: str) -> dict:
|
|
134
|
+
"""Complete and parse response as JSON."""
|
|
135
|
+
return await self._raw_complete_json(prompt)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Local LLM provider using OpenAI-compatible API (Ollama / LM Studio)."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from memnex.models import IntentType
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"The 'openai' package is required for LocalProvider. "
|
|
15
|
+
"Install it with: pip install openai"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LocalProvider:
|
|
20
|
+
"""LLM provider backed by an OpenAI-compatible local API.
|
|
21
|
+
|
|
22
|
+
Works with Ollama, LM Studio, vLLM, and any server that exposes
|
|
23
|
+
the ``/v1/chat/completions`` endpoint.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
endpoint:
|
|
28
|
+
Base URL of the OpenAI-compatible API.
|
|
29
|
+
model:
|
|
30
|
+
Model identifier served by the local endpoint.
|
|
31
|
+
max_tokens:
|
|
32
|
+
Default maximum tokens for completions.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
endpoint: str = "http://localhost:11434/v1",
|
|
38
|
+
model: str = "qwen2.5",
|
|
39
|
+
max_tokens: int = 1024,
|
|
40
|
+
) -> None:
|
|
41
|
+
self._client = AsyncOpenAI(base_url=endpoint, api_key="not-needed")
|
|
42
|
+
self._model = model
|
|
43
|
+
self._max_tokens = max_tokens
|
|
44
|
+
|
|
45
|
+
# -- helpers --------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
async def _raw_complete(self, prompt: str, max_tokens: int | None = None) -> str:
|
|
48
|
+
"""Send a single-turn chat completion request."""
|
|
49
|
+
resp = await self._client.chat.completions.create(
|
|
50
|
+
model=self._model,
|
|
51
|
+
max_tokens=max_tokens or self._max_tokens,
|
|
52
|
+
messages=[{"role": "user", "content": prompt}],
|
|
53
|
+
)
|
|
54
|
+
return resp.choices[0].message.content or ""
|
|
55
|
+
|
|
56
|
+
async def _raw_complete_json(self, prompt: str) -> dict:
|
|
57
|
+
"""Complete with JSON response expectation and parse the result."""
|
|
58
|
+
text = await self._raw_complete(prompt)
|
|
59
|
+
return self._parse_json(text)
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _parse_json(text: str) -> dict:
|
|
63
|
+
"""Best-effort JSON extraction from LLM output."""
|
|
64
|
+
text = text.strip()
|
|
65
|
+
try:
|
|
66
|
+
return json.loads(text)
|
|
67
|
+
except json.JSONDecodeError:
|
|
68
|
+
pass
|
|
69
|
+
import re
|
|
70
|
+
m = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
|
|
71
|
+
if m:
|
|
72
|
+
try:
|
|
73
|
+
return json.loads(m.group(1).strip())
|
|
74
|
+
except json.JSONDecodeError:
|
|
75
|
+
pass
|
|
76
|
+
start = text.find("{")
|
|
77
|
+
end = text.rfind("}")
|
|
78
|
+
if start != -1 and end > start:
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(text[start : end + 1])
|
|
81
|
+
except json.JSONDecodeError:
|
|
82
|
+
pass
|
|
83
|
+
logger.warning("Failed to parse JSON from local LLM response, returning empty dict")
|
|
84
|
+
return {}
|
|
85
|
+
|
|
86
|
+
# -- LLMProvider interface ------------------------------------------
|
|
87
|
+
|
|
88
|
+
async def classify_intent(
|
|
89
|
+
self, query: str, context: dict | None = None
|
|
90
|
+
) -> IntentType:
|
|
91
|
+
"""Classify user query intent using local LLM."""
|
|
92
|
+
result = await self.complete_json(
|
|
93
|
+
f"Classify the intent of the following query. "
|
|
94
|
+
f'Respond with a JSON object: {{"intent": "search|understand|compare|relation"}}\n\nQuery: {query}'
|
|
95
|
+
)
|
|
96
|
+
intent_str = result.get("intent", "search")
|
|
97
|
+
mapping = {
|
|
98
|
+
"search": IntentType.IMMEDIATE,
|
|
99
|
+
"understand": IntentType.SYNTHESIS,
|
|
100
|
+
"compare": IntentType.RELATION,
|
|
101
|
+
"relation": IntentType.RELATION,
|
|
102
|
+
}
|
|
103
|
+
return mapping.get(intent_str, IntentType.IMMEDIATE)
|
|
104
|
+
|
|
105
|
+
async def summarize(self, content: str, max_tokens: int = 256) -> str:
|
|
106
|
+
"""Summarize content using local LLM."""
|
|
107
|
+
prompt = (
|
|
108
|
+
f"Summarize the following content concisely in at most {max_tokens} tokens:\n\n{content}"
|
|
109
|
+
)
|
|
110
|
+
return await self._raw_complete(prompt, max_tokens=max_tokens)
|
|
111
|
+
|
|
112
|
+
async def extract_structured(self, prompt: str, schema: dict) -> dict:
|
|
113
|
+
"""Extract structured data according to a JSON schema."""
|
|
114
|
+
full_prompt = (
|
|
115
|
+
f"{prompt}\n\n"
|
|
116
|
+
f"Respond with valid JSON matching this schema:\n"
|
|
117
|
+
f"{json.dumps(schema, ensure_ascii=False)}"
|
|
118
|
+
)
|
|
119
|
+
return await self._raw_complete_json(full_prompt)
|
|
120
|
+
|
|
121
|
+
async def generate_hypothetical(self, query: str) -> str:
|
|
122
|
+
"""Generate a hypothetical answer for HyDE."""
|
|
123
|
+
prompt = (
|
|
124
|
+
f"Given the query below, write a brief hypothetical answer (2-3 sentences) "
|
|
125
|
+
f"as if a comprehensive knowledge base entry existed for it.\n\nQuery: {query}"
|
|
126
|
+
)
|
|
127
|
+
return await self._raw_complete(prompt, max_tokens=256)
|
|
128
|
+
|
|
129
|
+
async def complete(self, prompt: str) -> str:
|
|
130
|
+
"""General-purpose text completion."""
|
|
131
|
+
return await self._raw_complete(prompt)
|
|
132
|
+
|
|
133
|
+
async def complete_json(self, prompt: str) -> dict:
|
|
134
|
+
"""Complete and parse response as JSON."""
|
|
135
|
+
return await self._raw_complete_json(prompt)
|