memplex 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. memnex/__init__.py +31 -0
  2. memnex/__main__.py +6 -0
  3. memnex/_plugin/.claude-plugin/plugin.json +24 -0
  4. memnex/_plugin/.mcp.json +9 -0
  5. memnex/_plugin/__init__.py +0 -0
  6. memnex/_plugin/hooks/hooks.json +43 -0
  7. memnex/_plugin/scripts/hook-runner.py +166 -0
  8. memnex/_plugin/skills/mem-explore/SKILL.md +83 -0
  9. memnex/_plugin/skills/mem-manage/SKILL.md +92 -0
  10. memnex/_plugin/skills/mem-search/SKILL.md +85 -0
  11. memnex/_plugin/skills/mem-write/SKILL.md +78 -0
  12. memnex/adapters/__init__.py +14 -0
  13. memnex/adapters/claude_skill.py +169 -0
  14. memnex/adapters/cli.py +525 -0
  15. memnex/adapters/http_api.py +314 -0
  16. memnex/adapters/mcp_server.py +448 -0
  17. memnex/compaction.py +563 -0
  18. memnex/config.py +366 -0
  19. memnex/core/__init__.py +13 -0
  20. memnex/core/associator/__init__.py +8 -0
  21. memnex/core/associator/domain_classifier.py +75 -0
  22. memnex/core/associator/entity_aligner.py +127 -0
  23. memnex/core/associator/ref_linker.py +197 -0
  24. memnex/core/associator/term_mapper.py +77 -0
  25. memnex/core/dictionaries/__init__.py +50 -0
  26. memnex/core/engine.py +667 -0
  27. memnex/core/extractors/__init__.py +15 -0
  28. memnex/core/extractors/docx.py +97 -0
  29. memnex/core/extractors/image.py +233 -0
  30. memnex/core/extractors/markdown.py +139 -0
  31. memnex/core/extractors/pdf.py +133 -0
  32. memnex/core/extractors/vision_mapper.py +131 -0
  33. memnex/core/handlers/__init__.py +7 -0
  34. memnex/core/handlers/clipboard.py +40 -0
  35. memnex/core/handlers/file_handler.py +62 -0
  36. memnex/core/handlers/url_handler.py +132 -0
  37. memnex/llm/__init__.py +25 -0
  38. memnex/llm/enhancer.py +226 -0
  39. memnex/llm/fallback_chain.py +87 -0
  40. memnex/llm/injection_guard.py +178 -0
  41. memnex/llm/provider.py +130 -0
  42. memnex/llm/providers/__init__.py +22 -0
  43. memnex/llm/providers/anthropic.py +135 -0
  44. memnex/llm/providers/local.py +135 -0
  45. memnex/llm/providers/rule_based.py +68 -0
  46. memnex/llm/sanitizer.py +67 -0
  47. memnex/models/__init__.py +68 -0
  48. memnex/models/feedback.py +42 -0
  49. memnex/models/graph.py +33 -0
  50. memnex/models/memory.py +102 -0
  51. memnex/models/misc.py +185 -0
  52. memnex/models/paragraph.py +45 -0
  53. memnex/models/search.py +51 -0
  54. memnex/models/source.py +23 -0
  55. memnex/models/task.py +62 -0
  56. memnex/processing/__init__.py +1 -0
  57. memnex/processing/graph_builder.py +278 -0
  58. memnex/processing/merger/__init__.py +6 -0
  59. memnex/processing/merger/confidence_calculator.py +127 -0
  60. memnex/processing/merger/conflict_resolver.py +116 -0
  61. memnex/retrieval/__init__.py +1 -0
  62. memnex/retrieval/dedup.py +386 -0
  63. memnex/retrieval/embedding.py +289 -0
  64. memnex/retrieval/reranker.py +299 -0
  65. memnex/service.py +902 -0
  66. memnex/storage/__init__.py +65 -0
  67. memnex/storage/base.py +132 -0
  68. memnex/storage/changelog.py +106 -0
  69. memnex/storage/feedback.py +486 -0
  70. memnex/storage/lite/__init__.py +5 -0
  71. memnex/storage/lite/store.py +606 -0
  72. memnex/storage/vector.py +265 -0
  73. memnex/wiki/__init__.py +11 -0
  74. memnex/wiki/community.py +221 -0
  75. memnex/wiki/compiler.py +545 -0
  76. memnex/wiki/generator.py +270 -0
  77. memnex/wiki/search.py +282 -0
  78. memnex/worker.py +412 -0
  79. memplex-3.2.0.dist-info/METADATA +37 -0
  80. memplex-3.2.0.dist-info/RECORD +83 -0
  81. memplex-3.2.0.dist-info/WHEEL +5 -0
  82. memplex-3.2.0.dist-info/entry_points.txt +2 -0
  83. memplex-3.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,87 @@
1
+ """Chain-of-responsibility fallback for LLM providers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING
7
+
8
+ from memnex.models import IntentType
9
+
10
+ if TYPE_CHECKING:
11
+ from memnex.llm.provider import LLMProvider
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class FallbackChain:
17
+ """Try providers in order; first success wins, final fallback to RuleBasedProvider.
18
+
19
+ Parameters
20
+ ----------
21
+ providers:
22
+ Ordered list of LLMProvider implementations. They are tried
23
+ sequentially; the first one that returns without raising wins.
24
+ """
25
+
26
+ def __init__(self, providers: list[LLMProvider] | None = None) -> None:
27
+ self._providers: list[LLMProvider] = providers or []
28
+
29
+ def _fallback(self) -> LLMProvider:
30
+ """Lazily create a RuleBasedProvider as the ultimate fallback."""
31
+ from memnex.llm.providers.rule_based import RuleBasedProvider
32
+
33
+ return RuleBasedProvider()
34
+
35
+ # -- LLMProvider interface ------------------------------------------
36
+
37
+ async def classify_intent(
38
+ self, query: str, context: dict | None = None
39
+ ) -> IntentType:
40
+ errors: list[str] = []
41
+ for p in self._providers:
42
+ try:
43
+ return await p.classify_intent(query, context)
44
+ except Exception as exc:
45
+ errors.append(f"{p.__class__.__name__}: {exc}")
46
+ logger.debug("classify_intent fallback: %s", errors[-1])
47
+ return await self._fallback().classify_intent(query, context)
48
+
49
+ async def summarize(self, content: str, max_tokens: int = 256) -> str:
50
+ for p in self._providers:
51
+ try:
52
+ return await p.summarize(content, max_tokens)
53
+ except Exception as exc:
54
+ logger.debug("summarize fallback: %s: %s", p.__class__.__name__, exc)
55
+ return content[:max_tokens]
56
+
57
+ async def extract_structured(self, prompt: str, schema: dict) -> dict:
58
+ for p in self._providers:
59
+ try:
60
+ return await p.extract_structured(prompt, schema)
61
+ except Exception as exc:
62
+ logger.debug("extract_structured fallback: %s: %s", p.__class__.__name__, exc)
63
+ return {}
64
+
65
+ async def generate_hypothetical(self, query: str) -> str:
66
+ for p in self._providers:
67
+ try:
68
+ return await p.generate_hypothetical(query)
69
+ except Exception as exc:
70
+ logger.debug("generate_hypothetical fallback: %s: %s", p.__class__.__name__, exc)
71
+ return query
72
+
73
+ async def complete(self, prompt: str) -> str:
74
+ for p in self._providers:
75
+ try:
76
+ return await p.complete(prompt)
77
+ except Exception as exc:
78
+ logger.debug("complete fallback: %s: %s", p.__class__.__name__, exc)
79
+ return ""
80
+
81
+ async def complete_json(self, prompt: str) -> dict:
82
+ for p in self._providers:
83
+ try:
84
+ return await p.complete_json(prompt)
85
+ except Exception as exc:
86
+ logger.debug("complete_json fallback: %s: %s", p.__class__.__name__, exc)
87
+ return {}
@@ -0,0 +1,178 @@
1
+ """Indirect prompt injection guard for memory recall contexts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ from typing import TYPE_CHECKING, List
8
+
9
+ if TYPE_CHECKING:
10
+ from memnex.models import SearchResult
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class IndirectInjectionGuard:
16
+ """Detect and mitigate indirect prompt injection in recalled memories.
17
+
18
+ While ``LLMPromptSanitizer`` protects against *direct* injection (user
19
+ input), this class handles *indirect* injection: an attacker embeds
20
+ malicious instructions in a source document that survives extraction
21
+ and gets injected into the LLM context when recalled via RAG.
22
+
23
+ Defense layers:
24
+ 1. Content scanning -- regex-based detection of system-role keywords.
25
+ 2. Protective wrapping -- memory content is wrapped in ``[MEMORY ...]``
26
+ tags so the LLM treats them as data, not instructions.
27
+ 3. Trust-level labelling -- each memory is annotated with a trust level
28
+ derived from its ``source_type``.
29
+ """
30
+
31
+ # Multi-language injection patterns (compiled once at class load)
32
+ INJECTION_PATTERNS: list[str] = [
33
+ r"ignore\s+(all\s+)?previous\s+instructions?",
34
+ r"disregard\s+(all\s+)?prior\s+instructions?",
35
+ r"system\s*:\s*you\s+are",
36
+ r"<\|endoftext\|>",
37
+ r"<\|im_start\|>",
38
+ r"忽略(之前|前面|上面)的(所有|全部)?指令",
39
+ r"忽略系统提示",
40
+ r"你现在是",
41
+ r"新的系统提示",
42
+ r"assistant\s*:\s*sure",
43
+ ]
44
+
45
+ _compiled: list[re.Pattern] = [re.compile(p, re.IGNORECASE) for p in INJECTION_PATTERNS]
46
+
47
+ # Trust level mapping: source_type value -> trust level label
48
+ TRUST_LEVELS: dict[str, str] = {
49
+ "requirement": "HIGH",
50
+ "meeting": "MEDIUM",
51
+ "code": "MEDIUM",
52
+ "wiki": "LOW",
53
+ }
54
+
55
+ @classmethod
56
+ def scan(cls, content: str) -> bool:
57
+ """Scan content for suspected injection attacks.
58
+
59
+ Returns
60
+ -------
61
+ True if the content is suspected to contain an injection payload;
62
+ the caller should discard or isolate the memory entry.
63
+ """
64
+ for pattern in cls._compiled:
65
+ if pattern.search(content):
66
+ return True
67
+ return False
68
+
69
+ @classmethod
70
+ def wrap_for_context(
71
+ cls,
72
+ memories: list[SearchResult],
73
+ store: object,
74
+ ) -> str:
75
+ """Wrap recalled memories in protective tags for LLM context injection.
76
+
77
+ Each memory is enclosed in ``[MEMORY START | trust=LEVEL | id=...]``
78
+ / ``[MEMORY END]`` markers so the LLM treats the content as
79
+ data, not as system instructions.
80
+
81
+ Parameters
82
+ ----------
83
+ memories:
84
+ Search results to be injected into the LLM context.
85
+ store:
86
+ A MemoryStore-like object with a ``get(id)`` method that
87
+ returns a MemoryNode (or None).
88
+ """
89
+ parts: list[str] = []
90
+ for r in memories:
91
+ func = store.get(r.func_id) if store else None
92
+ if not func:
93
+ continue
94
+ source_type_val = (
95
+ func.source_type.value
96
+ if hasattr(func.source_type, "value")
97
+ else (func.source_type or "wiki")
98
+ )
99
+ trust = cls.TRUST_LEVELS.get(source_type_val, "LOW")
100
+ summary = r.summary or func.name
101
+ parts.append(
102
+ f"[MEMORY START | trust={trust} | id={r.func_id}]\n"
103
+ f"{summary}\n"
104
+ f"[MEMORY END]"
105
+ )
106
+ return "\n\n".join(parts)
107
+
108
+ @classmethod
109
+ def filter_and_wrap(
110
+ cls,
111
+ memories: list[SearchResult],
112
+ store: object,
113
+ ) -> str:
114
+ """Filter out injection-suspected memories, then wrap the rest.
115
+
116
+ Memories that trigger the injection scanner are logged as warnings
117
+ and excluded from the output.
118
+ """
119
+ safe: list[SearchResult] = []
120
+ for r in memories:
121
+ func = store.get(r.func_id) if store else None
122
+ if func:
123
+ memory_type = getattr(func, "memory_type", "function")
124
+ text = cls._extract_scan_text(func, memory_type)
125
+ if cls.scan(text):
126
+ logger.warning(
127
+ "Indirect injection detected in memory %s (type=%s), skipped.",
128
+ r.func_id,
129
+ memory_type,
130
+ )
131
+ continue
132
+ safe.append(r)
133
+ return cls.wrap_for_context(safe, store)
134
+
135
+ @classmethod
136
+ def _extract_scan_text(cls, func: object, memory_type: str) -> str:
137
+ """Extract the relevant text fields for injection scanning by memory type."""
138
+ if memory_type == "function":
139
+ return " ".join(
140
+ fv.desc
141
+ for role in ("trigger", "condition", "action", "benefit")
142
+ for fv in getattr(func, role, [])
143
+ )
144
+ if memory_type == "fact":
145
+ return " ".join(
146
+ filter(
147
+ None,
148
+ [
149
+ getattr(func, "subject", ""),
150
+ getattr(func, "predicate", ""),
151
+ getattr(func, "object_", ""),
152
+ ],
153
+ )
154
+ )
155
+ if memory_type == "preference":
156
+ return " ".join(
157
+ filter(
158
+ None,
159
+ [
160
+ getattr(func, "aspect", ""),
161
+ getattr(func, "preference", ""),
162
+ ],
163
+ )
164
+ )
165
+ if memory_type == "observation":
166
+ return " ".join(
167
+ filter(
168
+ None,
169
+ [
170
+ getattr(func, "event", ""),
171
+ getattr(func, "context", ""),
172
+ ],
173
+ )
174
+ )
175
+ # Unknown type: scan all string attributes
176
+ return " ".join(
177
+ str(v) for v in vars(func).values() if isinstance(v, str)
178
+ )
memnex/llm/provider.py ADDED
@@ -0,0 +1,130 @@
1
+ """LLM Provider protocol definition."""
2
+
3
+ from typing import Protocol, runtime_checkable
4
+
5
+ from memnex.models import IntentType
6
+
7
+
8
+ @runtime_checkable
9
+ class LLMProvider(Protocol):
10
+ """LLM Provider standard protocol.
11
+
12
+ All LLM provider implementations must satisfy this protocol.
13
+ Used for intent classification, summarization, structured extraction,
14
+ HyDE generation, and general-purpose completion.
15
+ """
16
+
17
+ async def classify_intent(
18
+ self, query: str, context: dict | None = None
19
+ ) -> IntentType:
20
+ """Classify user query intent.
21
+
22
+ Returns one of: IMMEDIATE, SYNTHESIS, RELATION, ALL.
23
+ """
24
+ ...
25
+
26
+ async def summarize(self, content: str, max_tokens: int = 256) -> str:
27
+ """Summarize content into a concise representation."""
28
+ ...
29
+
30
+ async def extract_structured(self, prompt: str, schema: dict) -> dict:
31
+ """Extract structured data according to the provided JSON schema."""
32
+ ...
33
+
34
+ async def generate_hypothetical(self, query: str) -> str:
35
+ """Generate a hypothetical answer for HyDE (Hypothetical Document Embeddings)."""
36
+ ...
37
+
38
+ async def complete(self, prompt: str) -> str:
39
+ """General-purpose text completion."""
40
+ ...
41
+
42
+ async def complete_json(self, prompt: str) -> dict:
43
+ """Complete and parse response as JSON."""
44
+ ...
45
+
46
+
47
+ def create_provider(
48
+ provider: str = "auto",
49
+ *,
50
+ anthropic_api_key: str | None = None,
51
+ anthropic_model: str = "claude-sonnet-4-6",
52
+ local_endpoint: str = "http://localhost:11434/v1",
53
+ local_model: str = "qwen2.5",
54
+ fallback_chain: list[str] | None = None,
55
+ ) -> LLMProvider:
56
+ """Factory: create an LLM provider instance by name.
57
+
58
+ Parameters
59
+ ----------
60
+ provider:
61
+ One of "auto", "anthropic", "local", "rule-based".
62
+ anthropic_api_key:
63
+ Anthropic API key (required for anthropic provider).
64
+ anthropic_model:
65
+ Model name for Anthropic.
66
+ local_endpoint:
67
+ OpenAI-compatible endpoint URL (e.g., Ollama / LM Studio).
68
+ local_model:
69
+ Model name for the local provider.
70
+ fallback_chain:
71
+ Ordered list of provider names for FallbackChain.
72
+ Defaults to ["anthropic", "local", "rule-based"] when *provider*
73
+ is "auto".
74
+
75
+ Returns
76
+ -------
77
+ An object satisfying the LLMProvider protocol.
78
+ """
79
+ if provider == "auto":
80
+ chain = fallback_chain or ["anthropic", "local", "rule-based"]
81
+ return create_provider(
82
+ provider=None,
83
+ anthropic_api_key=anthropic_api_key,
84
+ anthropic_model=anthropic_model,
85
+ local_endpoint=local_endpoint,
86
+ local_model=local_model,
87
+ fallback_chain=chain,
88
+ )
89
+
90
+ # If provider is None, build a FallbackChain from fallback_chain list.
91
+ if provider is None:
92
+ from memnex.llm.fallback_chain import FallbackChain
93
+
94
+ chain_names = fallback_chain or ["anthropic", "local", "rule-based"]
95
+ providers = []
96
+ for name in chain_names:
97
+ try:
98
+ p = create_provider(
99
+ provider=name,
100
+ anthropic_api_key=anthropic_api_key,
101
+ anthropic_model=anthropic_model,
102
+ local_endpoint=local_endpoint,
103
+ local_model=local_model,
104
+ )
105
+ providers.append(p)
106
+ except Exception:
107
+ continue
108
+ return FallbackChain(providers)
109
+
110
+ if provider == "anthropic":
111
+ from memnex.llm.providers.anthropic import AnthropicProvider
112
+
113
+ if not anthropic_api_key:
114
+ raise ValueError("anthropic_api_key is required for the Anthropic provider")
115
+ return AnthropicProvider(api_key=anthropic_api_key, model=anthropic_model)
116
+
117
+ if provider == "local":
118
+ from memnex.llm.providers.local import LocalProvider
119
+
120
+ return LocalProvider(endpoint=local_endpoint, model=local_model)
121
+
122
+ if provider == "rule-based":
123
+ from memnex.llm.providers.rule_based import RuleBasedProvider
124
+
125
+ return RuleBasedProvider()
126
+
127
+ raise ValueError(
128
+ f"Unknown provider: {provider!r}. "
129
+ f"Choose from: auto, anthropic, local, rule-based"
130
+ )
@@ -0,0 +1,22 @@
1
+ """LLM provider implementations."""
2
+
3
+ from memnex.llm.providers.rule_based import RuleBasedProvider
4
+
5
+ # Anthropic and Local providers are optional -- they require external
6
+ # packages that may not be installed. Import them lazily or guard with
7
+ # try/except at the call site.
8
+
9
+ __all__ = [
10
+ "RuleBasedProvider",
11
+ ]
12
+
13
+
14
+ def __getattr__(name: str):
15
+ """Lazy-load optional providers that depend on external packages."""
16
+ if name == "AnthropicProvider":
17
+ from memnex.llm.providers.anthropic import AnthropicProvider
18
+ return AnthropicProvider
19
+ if name == "LocalProvider":
20
+ from memnex.llm.providers.local import LocalProvider
21
+ return LocalProvider
22
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,135 @@
1
+ """Anthropic LLM provider implementation."""
2
+
3
+ import json
4
+ import logging
5
+
6
+ from memnex.models import IntentType
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ try:
11
+ import anthropic
12
+ except ImportError:
13
+ raise ImportError(
14
+ "The 'anthropic' package is required for AnthropicProvider. "
15
+ "Install it with: pip install anthropic"
16
+ )
17
+
18
+
19
+ class AnthropicProvider:
20
+ """LLM provider backed by the Anthropic SDK (Claude).
21
+
22
+ Parameters
23
+ ----------
24
+ api_key:
25
+ Anthropic API key.
26
+ model:
27
+ Model identifier (default: claude-sonnet-4-6).
28
+ max_tokens:
29
+ Default maximum tokens for completions.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ api_key: str,
35
+ model: str = "claude-sonnet-4-6",
36
+ max_tokens: int = 1024,
37
+ ) -> None:
38
+ self._client = anthropic.AsyncAnthropic(api_key=api_key)
39
+ self._model = model
40
+ self._max_tokens = max_tokens
41
+
42
+ # -- helpers --------------------------------------------------------
43
+
44
+ async def _raw_complete(self, prompt: str, max_tokens: int | None = None) -> str:
45
+ """Send a single-turn user message and return the assistant text."""
46
+ resp = await self._client.messages.create(
47
+ model=self._model,
48
+ max_tokens=max_tokens or self._max_tokens,
49
+ messages=[{"role": "user", "content": prompt}],
50
+ )
51
+ return resp.content[0].text
52
+
53
+ async def _raw_complete_json(self, prompt: str) -> dict:
54
+ """Complete with JSON response expectation and parse the result."""
55
+ text = await self._raw_complete(prompt)
56
+ return self._parse_json(text)
57
+
58
+ @staticmethod
59
+ def _parse_json(text: str) -> dict:
60
+ """Best-effort JSON extraction from LLM output."""
61
+ text = text.strip()
62
+ # Try direct parse
63
+ try:
64
+ return json.loads(text)
65
+ except json.JSONDecodeError:
66
+ pass
67
+ # Try to extract JSON block from markdown code fence
68
+ import re
69
+ m = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
70
+ if m:
71
+ try:
72
+ return json.loads(m.group(1).strip())
73
+ except json.JSONDecodeError:
74
+ pass
75
+ # Try to find first { ... } block
76
+ start = text.find("{")
77
+ end = text.rfind("}")
78
+ if start != -1 and end > start:
79
+ try:
80
+ return json.loads(text[start : end + 1])
81
+ except json.JSONDecodeError:
82
+ pass
83
+ logger.warning("Failed to parse JSON from LLM response, returning empty dict")
84
+ return {}
85
+
86
+ # -- LLMProvider interface ------------------------------------------
87
+
88
+ async def classify_intent(
89
+ self, query: str, context: dict | None = None
90
+ ) -> IntentType:
91
+ """Classify user query intent using Claude."""
92
+ result = await self.complete_json(
93
+ f"Classify the intent of the following query. "
94
+ f'Respond with a JSON object: {{"intent": "search|understand|compare|relation"}}\n\nQuery: {query}'
95
+ )
96
+ intent_str = result.get("intent", "search")
97
+ mapping = {
98
+ "search": IntentType.IMMEDIATE,
99
+ "understand": IntentType.SYNTHESIS,
100
+ "compare": IntentType.RELATION,
101
+ "relation": IntentType.RELATION,
102
+ }
103
+ return mapping.get(intent_str, IntentType.IMMEDIATE)
104
+
105
+ async def summarize(self, content: str, max_tokens: int = 256) -> str:
106
+ """Summarize content using Claude."""
107
+ prompt = (
108
+ f"Summarize the following content concisely in at most {max_tokens} tokens:\n\n{content}"
109
+ )
110
+ return await self._raw_complete(prompt, max_tokens=max_tokens)
111
+
112
+ async def extract_structured(self, prompt: str, schema: dict) -> dict:
113
+ """Extract structured data according to a JSON schema."""
114
+ full_prompt = (
115
+ f"{prompt}\n\n"
116
+ f"Respond with valid JSON matching this schema:\n"
117
+ f"{json.dumps(schema, ensure_ascii=False)}"
118
+ )
119
+ return await self._raw_complete_json(full_prompt)
120
+
121
+ async def generate_hypothetical(self, query: str) -> str:
122
+ """Generate a hypothetical answer for HyDE."""
123
+ prompt = (
124
+ f"Given the query below, write a brief hypothetical answer (2-3 sentences) "
125
+ f"as if a comprehensive knowledge base entry existed for it.\n\nQuery: {query}"
126
+ )
127
+ return await self._raw_complete(prompt, max_tokens=256)
128
+
129
+ async def complete(self, prompt: str) -> str:
130
+ """General-purpose text completion."""
131
+ return await self._raw_complete(prompt)
132
+
133
+ async def complete_json(self, prompt: str) -> dict:
134
+ """Complete and parse response as JSON."""
135
+ return await self._raw_complete_json(prompt)
@@ -0,0 +1,135 @@
1
+ """Local LLM provider using OpenAI-compatible API (Ollama / LM Studio)."""
2
+
3
+ import json
4
+ import logging
5
+
6
+ from memnex.models import IntentType
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ try:
11
+ from openai import AsyncOpenAI
12
+ except ImportError:
13
+ raise ImportError(
14
+ "The 'openai' package is required for LocalProvider. "
15
+ "Install it with: pip install openai"
16
+ )
17
+
18
+
19
+ class LocalProvider:
20
+ """LLM provider backed by an OpenAI-compatible local API.
21
+
22
+ Works with Ollama, LM Studio, vLLM, and any server that exposes
23
+ the ``/v1/chat/completions`` endpoint.
24
+
25
+ Parameters
26
+ ----------
27
+ endpoint:
28
+ Base URL of the OpenAI-compatible API.
29
+ model:
30
+ Model identifier served by the local endpoint.
31
+ max_tokens:
32
+ Default maximum tokens for completions.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ endpoint: str = "http://localhost:11434/v1",
38
+ model: str = "qwen2.5",
39
+ max_tokens: int = 1024,
40
+ ) -> None:
41
+ self._client = AsyncOpenAI(base_url=endpoint, api_key="not-needed")
42
+ self._model = model
43
+ self._max_tokens = max_tokens
44
+
45
+ # -- helpers --------------------------------------------------------
46
+
47
+ async def _raw_complete(self, prompt: str, max_tokens: int | None = None) -> str:
48
+ """Send a single-turn chat completion request."""
49
+ resp = await self._client.chat.completions.create(
50
+ model=self._model,
51
+ max_tokens=max_tokens or self._max_tokens,
52
+ messages=[{"role": "user", "content": prompt}],
53
+ )
54
+ return resp.choices[0].message.content or ""
55
+
56
+ async def _raw_complete_json(self, prompt: str) -> dict:
57
+ """Complete with JSON response expectation and parse the result."""
58
+ text = await self._raw_complete(prompt)
59
+ return self._parse_json(text)
60
+
61
+ @staticmethod
62
+ def _parse_json(text: str) -> dict:
63
+ """Best-effort JSON extraction from LLM output."""
64
+ text = text.strip()
65
+ try:
66
+ return json.loads(text)
67
+ except json.JSONDecodeError:
68
+ pass
69
+ import re
70
+ m = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
71
+ if m:
72
+ try:
73
+ return json.loads(m.group(1).strip())
74
+ except json.JSONDecodeError:
75
+ pass
76
+ start = text.find("{")
77
+ end = text.rfind("}")
78
+ if start != -1 and end > start:
79
+ try:
80
+ return json.loads(text[start : end + 1])
81
+ except json.JSONDecodeError:
82
+ pass
83
+ logger.warning("Failed to parse JSON from local LLM response, returning empty dict")
84
+ return {}
85
+
86
+ # -- LLMProvider interface ------------------------------------------
87
+
88
+ async def classify_intent(
89
+ self, query: str, context: dict | None = None
90
+ ) -> IntentType:
91
+ """Classify user query intent using local LLM."""
92
+ result = await self.complete_json(
93
+ f"Classify the intent of the following query. "
94
+ f'Respond with a JSON object: {{"intent": "search|understand|compare|relation"}}\n\nQuery: {query}'
95
+ )
96
+ intent_str = result.get("intent", "search")
97
+ mapping = {
98
+ "search": IntentType.IMMEDIATE,
99
+ "understand": IntentType.SYNTHESIS,
100
+ "compare": IntentType.RELATION,
101
+ "relation": IntentType.RELATION,
102
+ }
103
+ return mapping.get(intent_str, IntentType.IMMEDIATE)
104
+
105
+ async def summarize(self, content: str, max_tokens: int = 256) -> str:
106
+ """Summarize content using local LLM."""
107
+ prompt = (
108
+ f"Summarize the following content concisely in at most {max_tokens} tokens:\n\n{content}"
109
+ )
110
+ return await self._raw_complete(prompt, max_tokens=max_tokens)
111
+
112
+ async def extract_structured(self, prompt: str, schema: dict) -> dict:
113
+ """Extract structured data according to a JSON schema."""
114
+ full_prompt = (
115
+ f"{prompt}\n\n"
116
+ f"Respond with valid JSON matching this schema:\n"
117
+ f"{json.dumps(schema, ensure_ascii=False)}"
118
+ )
119
+ return await self._raw_complete_json(full_prompt)
120
+
121
+ async def generate_hypothetical(self, query: str) -> str:
122
+ """Generate a hypothetical answer for HyDE."""
123
+ prompt = (
124
+ f"Given the query below, write a brief hypothetical answer (2-3 sentences) "
125
+ f"as if a comprehensive knowledge base entry existed for it.\n\nQuery: {query}"
126
+ )
127
+ return await self._raw_complete(prompt, max_tokens=256)
128
+
129
+ async def complete(self, prompt: str) -> str:
130
+ """General-purpose text completion."""
131
+ return await self._raw_complete(prompt)
132
+
133
+ async def complete_json(self, prompt: str) -> dict:
134
+ """Complete and parse response as JSON."""
135
+ return await self._raw_complete_json(prompt)