h-ai-brain 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- h_ai/__init__.py +11 -2
- h_ai/application/__init__.py +12 -0
- h_ai/application/hai_service.py +12 -18
- h_ai/application/services/__init__.py +11 -0
- h_ai/application/services/base_model_service.py +69 -0
- h_ai/application/services/granite_service.py +139 -0
- h_ai/application/services/nomic_service.py +117 -0
- h_ai/domain/llm_config.py +14 -1
- h_ai/domain/model_factory.py +44 -0
- h_ai/domain/reasoning/llm_chat_repository.py +39 -2
- h_ai/domain/reasoning/llm_embedding_repository.py +20 -0
- h_ai/domain/reasoning/llm_generate_respository.py +21 -4
- h_ai/domain/reasoning/llm_tool_repository.py +24 -1
- h_ai/infrastructure/llm/json_resource_loader.py +97 -0
- h_ai/infrastructure/llm/ollama/factories/__init__.py +1 -0
- h_ai/infrastructure/llm/ollama/factories/granite_factory.py +91 -0
- h_ai/infrastructure/llm/ollama/factories/nomic_factory.py +58 -0
- h_ai/infrastructure/llm/ollama/ollama_chat_repository.py +165 -26
- h_ai/infrastructure/llm/ollama/ollama_embed_repository.py +43 -0
- h_ai/infrastructure/llm/ollama/ollama_generate_repository.py +88 -32
- h_ai/infrastructure/llm/ollama/ollama_http_client.py +54 -0
- h_ai/infrastructure/llm/prompt_loader.py +42 -7
- h_ai/infrastructure/llm/template_loader.py +146 -0
- {h_ai_brain-0.0.23.dist-info → h_ai_brain-0.0.25.dist-info}/METADATA +2 -1
- h_ai_brain-0.0.25.dist-info/RECORD +43 -0
- h_ai_brain-0.0.23.dist-info/RECORD +0 -30
- {h_ai_brain-0.0.23.dist-info → h_ai_brain-0.0.25.dist-info}/WHEEL +0 -0
- {h_ai_brain-0.0.23.dist-info → h_ai_brain-0.0.25.dist-info}/licenses/LICENSE +0 -0
- {h_ai_brain-0.0.23.dist-info → h_ai_brain-0.0.25.dist-info}/licenses/NOTICE.txt +0 -0
- {h_ai_brain-0.0.23.dist-info → h_ai_brain-0.0.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import importlib.util
|
4
|
+
import importlib.resources
|
5
|
+
from typing import Dict, Any
|
6
|
+
|
7
|
+
|
8
|
+
class JsonResourceLoader:
|
9
|
+
"""
|
10
|
+
Loader for JSON configuration files.
|
11
|
+
|
12
|
+
This class provides functionality to load JSON configuration files from the resources directory.
|
13
|
+
It can load files from either a package or the file system, with fallback behavior.
|
14
|
+
|
15
|
+
Usage:
|
16
|
+
# Create a JSON resource loader
|
17
|
+
loader = JsonResourceLoader()
|
18
|
+
|
19
|
+
# Load a JSON resource
|
20
|
+
config = loader.load_json_resource("autonomous_agent.json")
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, resources_dir: str = None, package_name: str = None, resources_path: str = None):
|
24
|
+
"""
|
25
|
+
Initialize a new JsonResourceLoader.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
resources_dir: Directory containing JSON resources. If None, defaults to 'resources'.
|
29
|
+
package_name: Python package name containing resources. If None, defaults to 'h_ai'.
|
30
|
+
resources_path: Path within the package to resources. If None, defaults to 'resources'.
|
31
|
+
"""
|
32
|
+
# Get the base directory of the package
|
33
|
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
34
|
+
|
35
|
+
# Set default directories if not provided
|
36
|
+
if resources_dir is None:
|
37
|
+
resources_dir = os.path.join(base_dir, 'resources')
|
38
|
+
if package_name is None:
|
39
|
+
package_name = 'h_ai'
|
40
|
+
if resources_path is None:
|
41
|
+
resources_path = 'resources'
|
42
|
+
|
43
|
+
self.resources_dir = resources_dir
|
44
|
+
self.package_name = package_name
|
45
|
+
self.resources_path = resources_path
|
46
|
+
|
47
|
+
def load_json_resource(self, filename: str) -> Dict[str, Any]:
|
48
|
+
"""
|
49
|
+
Load a JSON resource file from package or file system.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
filename: Name of the JSON file (with or without .json extension).
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Dict[str, Any]: The loaded JSON data.
|
56
|
+
|
57
|
+
Raises:
|
58
|
+
ValueError: If the resource file cannot be found in either the package or file system.
|
59
|
+
"""
|
60
|
+
# Add .json extension if not present
|
61
|
+
if not filename.endswith('.json'):
|
62
|
+
filename = f"{filename}.json"
|
63
|
+
|
64
|
+
# First try to load from package
|
65
|
+
try:
|
66
|
+
# Construct the resource path within the package
|
67
|
+
resource_path = os.path.join(self.resources_path, filename)
|
68
|
+
resource_path = resource_path.replace('\\', '/') # Ensure forward slashes for package paths
|
69
|
+
|
70
|
+
# Try to get the resource from the package
|
71
|
+
package_spec = importlib.util.find_spec(self.package_name)
|
72
|
+
if package_spec is not None:
|
73
|
+
# Use importlib.resources to get the resource content
|
74
|
+
resource_package = f"{self.package_name}.{os.path.dirname(resource_path)}"
|
75
|
+
resource_name = os.path.basename(resource_path)
|
76
|
+
|
77
|
+
# Handle different importlib.resources APIs based on Python version
|
78
|
+
try:
|
79
|
+
# Python 3.9+
|
80
|
+
with importlib.resources.files(resource_package).joinpath(resource_name).open('r') as f:
|
81
|
+
return json.load(f)
|
82
|
+
except (AttributeError, ImportError):
|
83
|
+
# Fallback for older Python versions
|
84
|
+
resource_text = importlib.resources.read_text(resource_package, resource_name)
|
85
|
+
return json.loads(resource_text)
|
86
|
+
except (ImportError, ModuleNotFoundError, FileNotFoundError, ValueError):
|
87
|
+
# If package loading fails, fall back to file system
|
88
|
+
pass
|
89
|
+
|
90
|
+
# Fall back to file system
|
91
|
+
file_path = os.path.join(self.resources_dir, filename)
|
92
|
+
if os.path.exists(file_path):
|
93
|
+
with open(file_path, 'r') as f:
|
94
|
+
return json.load(f)
|
95
|
+
|
96
|
+
# If we get here, the resource wasn't found in either location
|
97
|
+
raise ValueError(f"Resource file not found: {filename} (tried package '{self.package_name}' and directory '{self.resources_dir}')")
|
@@ -0,0 +1 @@
|
|
1
|
+
# This file is intentionally left empty to mark the directory as a Python package.
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from .....domain.model_factory import ModelFactory
|
4
|
+
from .....domain.reasoning.llm_chat_repository import LlmChatRepository
|
5
|
+
from .....domain.reasoning.llm_embedding_repository import LlmEmbeddingRepository
|
6
|
+
from .....domain.reasoning.llm_generate_respository import LlmGenerateRepository
|
7
|
+
from ....llm.template_loader import TemplateLoader
|
8
|
+
from ..ollama_chat_repository import OllamaChatRepository
|
9
|
+
from ..ollama_generate_repository import OllamaGenerateRepository
|
10
|
+
|
11
|
+
|
12
|
+
class GraniteModelFactory(ModelFactory):
|
13
|
+
"""
|
14
|
+
Factory for creating repositories for the Granite 3.3:8b model.
|
15
|
+
"""
|
16
|
+
|
17
|
+
MODEL_NAME = "granite3.3:8b"
|
18
|
+
|
19
|
+
def __init__(self, api_url: str, temperature: float = 0.6, max_tokens: int = 2500,
|
20
|
+
api_token: str = None, use_templating: bool = True):
|
21
|
+
"""
|
22
|
+
Initialize a new GraniteModelFactory.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
api_url: The base URL of the Ollama API.
|
26
|
+
temperature: The temperature to use for generation.
|
27
|
+
max_tokens: The maximum number of tokens to generate.
|
28
|
+
api_token: Optional API token for authentication.
|
29
|
+
use_templating: Whether to use Jinja2 templating for prompt formatting.
|
30
|
+
"""
|
31
|
+
self.api_url = api_url
|
32
|
+
self.temperature = temperature
|
33
|
+
self.max_tokens = max_tokens
|
34
|
+
self.api_token = api_token
|
35
|
+
self.use_templating = use_templating
|
36
|
+
|
37
|
+
# Create a template loader if templating is enabled
|
38
|
+
if use_templating:
|
39
|
+
self.template_loader = TemplateLoader()
|
40
|
+
|
41
|
+
def create_chat_repository(self) -> LlmChatRepository:
|
42
|
+
"""
|
43
|
+
Creates a chat repository for the Granite model.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
LlmChatRepository: A repository for chat interactions with the Granite model.
|
47
|
+
"""
|
48
|
+
kwargs = {
|
49
|
+
"api_url": self.api_url,
|
50
|
+
"model_name": self.MODEL_NAME,
|
51
|
+
"temperature": self.temperature,
|
52
|
+
"api_token": self.api_token,
|
53
|
+
"use_templating": self.use_templating
|
54
|
+
}
|
55
|
+
|
56
|
+
# Add template loader if templating is enabled
|
57
|
+
if self.use_templating and hasattr(self, 'template_loader'):
|
58
|
+
kwargs["template_loader"] = self.template_loader
|
59
|
+
|
60
|
+
return OllamaChatRepository(**kwargs)
|
61
|
+
|
62
|
+
def create_generate_repository(self) -> LlmGenerateRepository:
|
63
|
+
"""
|
64
|
+
Creates a generate repository for the Granite model.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
LlmGenerateRepository: A repository for text generation with the Granite model.
|
68
|
+
"""
|
69
|
+
kwargs = {
|
70
|
+
"api_url": self.api_url,
|
71
|
+
"model_name": self.MODEL_NAME,
|
72
|
+
"temperature": self.temperature,
|
73
|
+
"max_tokens": self.max_tokens,
|
74
|
+
"api_token": self.api_token,
|
75
|
+
"use_templating": self.use_templating
|
76
|
+
}
|
77
|
+
|
78
|
+
# Add template loader if templating is enabled
|
79
|
+
if self.use_templating and hasattr(self, 'template_loader'):
|
80
|
+
kwargs["template_loader"] = self.template_loader
|
81
|
+
|
82
|
+
return OllamaGenerateRepository(**kwargs)
|
83
|
+
|
84
|
+
def create_embedding_repository(self) -> Optional[LlmEmbeddingRepository]:
|
85
|
+
"""
|
86
|
+
Granite model doesn't support embeddings, so this returns None.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Optional[LlmEmbeddingRepository]: None, as Granite doesn't support embeddings.
|
90
|
+
"""
|
91
|
+
return None
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from .....domain.model_factory import ModelFactory
|
4
|
+
from .....domain.reasoning.llm_chat_repository import LlmChatRepository
|
5
|
+
from .....domain.reasoning.llm_embedding_repository import LlmEmbeddingRepository
|
6
|
+
from .....domain.reasoning.llm_generate_respository import LlmGenerateRepository
|
7
|
+
from ..ollama_embed_repository import OllamaEmbeddingRepository
|
8
|
+
|
9
|
+
|
10
|
+
class NomicModelFactory(ModelFactory):
|
11
|
+
"""
|
12
|
+
Factory for creating repositories for the Nomic Embed Text model.
|
13
|
+
This model is specifically for embeddings and doesn't support chat or generation.
|
14
|
+
"""
|
15
|
+
|
16
|
+
MODEL_NAME = "nomic-embed-text:137m-v1.5-fp16"
|
17
|
+
|
18
|
+
def __init__(self, api_url: str, api_token: str = None):
|
19
|
+
"""
|
20
|
+
Initialize a new NomicModelFactory.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
api_url: The base URL of the Ollama API.
|
24
|
+
api_token: Optional API token for authentication.
|
25
|
+
"""
|
26
|
+
self.api_url = api_url
|
27
|
+
self.api_token = api_token
|
28
|
+
|
29
|
+
def create_chat_repository(self) -> Optional[LlmChatRepository]:
|
30
|
+
"""
|
31
|
+
Nomic model doesn't support chat, so this returns None.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
Optional[LlmChatRepository]: None, as Nomic doesn't support chat.
|
35
|
+
"""
|
36
|
+
return None
|
37
|
+
|
38
|
+
def create_generate_repository(self) -> Optional[LlmGenerateRepository]:
|
39
|
+
"""
|
40
|
+
Nomic model doesn't support generation, so this returns None.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Optional[LlmGenerateRepository]: None, as Nomic doesn't support generation.
|
44
|
+
"""
|
45
|
+
return None
|
46
|
+
|
47
|
+
def create_embedding_repository(self) -> LlmEmbeddingRepository:
|
48
|
+
"""
|
49
|
+
Creates an embedding repository for the Nomic model.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
LlmEmbeddingRepository: A repository for creating embeddings with the Nomic model.
|
53
|
+
"""
|
54
|
+
return OllamaEmbeddingRepository(
|
55
|
+
api_url=self.api_url,
|
56
|
+
model_name=self.MODEL_NAME,
|
57
|
+
api_token=self.api_token
|
58
|
+
)
|
@@ -1,56 +1,195 @@
|
|
1
|
-
from typing import Optional, List
|
2
|
-
|
3
|
-
import requests
|
1
|
+
from typing import Optional, List, Dict, Any
|
4
2
|
|
5
3
|
from ....domain.reasoning.llm_chat_repository import LlmChatRepository
|
6
4
|
from ....infrastructure.llm.llm_response_cleaner import clean_llm_response
|
7
5
|
from ....infrastructure.llm.ollama.models.ollama_chat_message import OllamaChatMessage
|
8
6
|
from ....infrastructure.llm.ollama.models.ollama_chat_session import OllamaChatSession
|
7
|
+
from ....infrastructure.llm.template_loader import TemplateLoader
|
8
|
+
from .ollama_http_client import OllamaHttpClient
|
9
9
|
|
10
10
|
|
11
11
|
class OllamaChatRepository(LlmChatRepository):
|
12
|
+
"""
|
13
|
+
Repository for chat interactions using the Ollama API.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, api_url: str, model_name: str, system_prompts: list[str] = None,
|
17
|
+
temperature: float = None, seed: int = None, api_token: str = None,
|
18
|
+
template_loader: TemplateLoader = None, use_templating: bool = True):
|
19
|
+
"""
|
20
|
+
Initialize a new OllamaChatRepository.
|
12
21
|
|
13
|
-
|
14
|
-
|
22
|
+
Args:
|
23
|
+
api_url: The base URL of the Ollama API.
|
24
|
+
model_name: The name of the model to use.
|
25
|
+
system_prompts: Optional list of system prompts to use for the chat.
|
26
|
+
temperature: The temperature to use for generation.
|
27
|
+
seed: Optional seed for reproducible generation.
|
28
|
+
api_token: Optional API token for authentication.
|
29
|
+
template_loader: Optional template loader to use for formatting prompts.
|
30
|
+
use_templating: Whether to use Jinja2 templating for prompt formatting.
|
31
|
+
"""
|
15
32
|
self.model_name = model_name
|
16
33
|
self.temperature = temperature
|
17
34
|
self.seed = seed
|
18
|
-
self.system_prompts = system_prompts
|
35
|
+
self.system_prompts = system_prompts or []
|
36
|
+
self.http_client = OllamaHttpClient(api_url, api_token)
|
37
|
+
self.use_templating = use_templating
|
38
|
+
|
39
|
+
# Initialize template loader if not provided and templating is enabled
|
40
|
+
if use_templating:
|
41
|
+
self.template_loader = template_loader or TemplateLoader()
|
42
|
+
|
43
|
+
def chat(self, user_message: str, session_id: str, chat_history: List[dict] = None) -> Optional[str]:
|
44
|
+
"""
|
45
|
+
Chat with the model using the Ollama API.
|
19
46
|
|
20
|
-
|
47
|
+
Args:
|
48
|
+
user_message: The user's message.
|
49
|
+
session_id: The ID of the chat session.
|
50
|
+
chat_history: Optional chat history to include in the conversation.
|
21
51
|
|
22
|
-
|
52
|
+
Returns:
|
53
|
+
Optional[str]: The model's response, or None if the request failed.
|
54
|
+
"""
|
55
|
+
# Start with system prompts and user message
|
56
|
+
messages = []
|
23
57
|
for system_prompt in self.system_prompts:
|
24
58
|
messages.append(OllamaChatMessage("system", system_prompt))
|
59
|
+
|
60
|
+
# Add chat history if provided
|
61
|
+
if chat_history:
|
62
|
+
for message in chat_history:
|
63
|
+
messages.append(OllamaChatMessage(message.get("role", "user"), message.get("content", "")))
|
64
|
+
|
65
|
+
# Add the current user message
|
66
|
+
messages.append(OllamaChatMessage("user", user_message))
|
67
|
+
|
68
|
+
# Create a session
|
25
69
|
session = OllamaChatSession(session_id, messages)
|
26
70
|
|
27
|
-
return self.
|
71
|
+
return self.chat_with_messages(session.messages)
|
72
|
+
|
73
|
+
def chat_with_messages(self, messages: List[Any]) -> Optional[str]:
|
74
|
+
"""
|
75
|
+
Chat with the model using a custom list of messages.
|
76
|
+
|
77
|
+
This method allows for more flexibility in message formatting, such as including
|
78
|
+
document messages or other special message types.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
messages: The list of messages to send to the model. Each message should have
|
82
|
+
'role' and 'content' attributes or keys.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Optional[str]: The model's response, or None if the request failed.
|
86
|
+
"""
|
87
|
+
return self._call_ollama_api(messages)
|
88
|
+
|
89
|
+
def _format_with_template(self, messages: List[Any]) -> str:
|
90
|
+
"""
|
91
|
+
Format messages using a Jinja2 template appropriate for the model.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
messages: The messages to format. Each message should have 'role' and 'content' attributes.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
str: The formatted prompt string.
|
98
|
+
"""
|
99
|
+
# Get the appropriate template for this model
|
100
|
+
template_name = self.template_loader.get_model_template(self.model_name)
|
101
|
+
if not template_name:
|
102
|
+
# Fall back to standard formatting if no template is available
|
103
|
+
return self._format_standard(messages)
|
28
104
|
|
29
|
-
|
30
|
-
|
31
|
-
|
105
|
+
# Extract system messages
|
106
|
+
system_content = ""
|
107
|
+
for msg in messages:
|
108
|
+
if msg.role == "system":
|
109
|
+
if system_content:
|
110
|
+
system_content += "\n\n"
|
111
|
+
system_content += msg.content
|
112
|
+
|
113
|
+
# Prepare context for the template
|
114
|
+
context = {
|
115
|
+
"System": system_content,
|
116
|
+
"Messages": [{"Role": msg.role, "Content": msg.content} for msg in messages],
|
117
|
+
"Tools": [] # Can be extended later to support tools
|
118
|
+
}
|
119
|
+
|
120
|
+
# Render the template
|
121
|
+
return self.template_loader.render_template(template_name, context)
|
122
|
+
|
123
|
+
def _format_standard(self, messages: List[Any]) -> List[Dict[str, str]]:
|
124
|
+
"""
|
125
|
+
Format messages in the standard Ollama API format.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
messages: The messages to format. Each message should have 'role' and 'content' attributes.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
List[Dict[str, str]]: The formatted messages.
|
132
|
+
"""
|
133
|
+
result = []
|
134
|
+
for message in messages:
|
135
|
+
# Handle OllamaChatMessage objects
|
136
|
+
if hasattr(message, 'to_dict'):
|
137
|
+
result.append(message.to_dict())
|
138
|
+
# Handle dict-like objects
|
139
|
+
elif hasattr(message, 'get'):
|
140
|
+
result.append({
|
141
|
+
"role": message.get("role", "user"),
|
142
|
+
"content": message.get("content", "")
|
143
|
+
})
|
144
|
+
# Handle objects with role and content attributes
|
145
|
+
else:
|
146
|
+
result.append({
|
147
|
+
"role": getattr(message, "role", "user"),
|
148
|
+
"content": getattr(message, "content", "")
|
149
|
+
})
|
150
|
+
return result
|
151
|
+
|
152
|
+
def _call_ollama_api(self, messages: List[Any]) -> Optional[str]:
|
153
|
+
"""
|
154
|
+
Call the Ollama API with the given messages.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
messages: The messages to send to the API. Each message should have 'role' and 'content' attributes.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Optional[str]: The model's response, or None if the request failed.
|
161
|
+
"""
|
32
162
|
payload = {
|
33
163
|
"model": self.model_name,
|
34
|
-
"
|
35
|
-
"stream": False,
|
36
|
-
"temperature": "0.6"
|
164
|
+
"stream": False
|
37
165
|
}
|
166
|
+
|
167
|
+
# Add temperature and seed if provided
|
38
168
|
if self.temperature:
|
39
169
|
payload["temperature"] = self.temperature
|
40
170
|
if self.seed:
|
41
171
|
payload["seed"] = self.seed
|
42
172
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
return None
|
173
|
+
# Use templating if enabled and available for this model
|
174
|
+
if self.use_templating and hasattr(self, 'template_loader'):
|
175
|
+
# Format with template and use the prompt API
|
176
|
+
formatted_prompt = self._format_with_template(messages)
|
177
|
+
payload["prompt"] = formatted_prompt
|
178
|
+
endpoint = "generate"
|
179
|
+
else:
|
180
|
+
# Use standard message formatting and the chat API
|
181
|
+
payload["messages"] = self._format_standard(messages)
|
182
|
+
endpoint = "chat"
|
54
183
|
|
184
|
+
# Call the appropriate API endpoint
|
185
|
+
response_data = self.http_client.post(endpoint, payload)
|
186
|
+
if response_data:
|
187
|
+
# Extract response based on the endpoint used
|
188
|
+
if endpoint == "generate":
|
189
|
+
full_response = response_data.get("response")
|
190
|
+
else:
|
191
|
+
full_response = response_data.get("message", {}).get("content")
|
55
192
|
|
193
|
+
return clean_llm_response(full_response)
|
56
194
|
|
195
|
+
return None
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
|
3
|
+
from ....domain.reasoning.llm_embedding_repository import LlmEmbeddingRepository
|
4
|
+
from .ollama_http_client import OllamaHttpClient
|
5
|
+
|
6
|
+
|
7
|
+
class OllamaEmbeddingRepository(LlmEmbeddingRepository):
|
8
|
+
"""
|
9
|
+
Repository for creating embeddings using the Ollama API.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self, api_url: str, model_name: str, api_token: str = None):
|
13
|
+
"""
|
14
|
+
Initialize a new OllamaEmbeddingRepository.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
api_url: The base URL of the Ollama API.
|
18
|
+
model_name: The name of the model to use.
|
19
|
+
api_token: Optional API token for authentication.
|
20
|
+
"""
|
21
|
+
self.model_name = model_name
|
22
|
+
self.http_client = OllamaHttpClient(api_url, api_token)
|
23
|
+
|
24
|
+
def embed(self, text: str) -> Optional[List[float]]:
|
25
|
+
"""
|
26
|
+
Create an embedding for the given text using the Ollama API.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
text: The text to create an embedding for.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Optional[List[float]]: The embedding vector, or None if the request failed.
|
33
|
+
"""
|
34
|
+
payload = {
|
35
|
+
"model": self.model_name,
|
36
|
+
"prompt": text
|
37
|
+
}
|
38
|
+
|
39
|
+
response_data = self.http_client.post("embeddings", payload)
|
40
|
+
if response_data:
|
41
|
+
return response_data.get("embedding")
|
42
|
+
|
43
|
+
return None
|
@@ -1,40 +1,98 @@
|
|
1
1
|
import uuid
|
2
|
-
|
3
|
-
import requests
|
2
|
+
from typing import Optional, Dict, Any, List
|
4
3
|
|
5
4
|
from ..llm_response_cleaner import clean_llm_response
|
5
|
+
from ..template_loader import TemplateLoader
|
6
6
|
from ....domain.reasoning.llm_generate_respository import LlmGenerateRepository
|
7
|
+
from .ollama_http_client import OllamaHttpClient
|
7
8
|
|
8
9
|
|
9
10
|
class OllamaGenerateRepository(LlmGenerateRepository):
|
11
|
+
"""
|
12
|
+
Repository for generating text using the Ollama API.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, api_url: str, model_name: str, system_prompt: str = None,
|
16
|
+
temperature: float = None, seed: int = None, max_tokens: int = 5000,
|
17
|
+
api_token: str = None, template_loader: TemplateLoader = None,
|
18
|
+
use_templating: bool = True):
|
19
|
+
"""
|
20
|
+
Initialize a new OllamaGenerateRepository.
|
10
21
|
|
11
|
-
|
22
|
+
Args:
|
23
|
+
api_url: The base URL of the Ollama API.
|
24
|
+
model_name: The name of the model to use.
|
25
|
+
system_prompt: Optional system prompt to use for generation.
|
26
|
+
temperature: The temperature to use for generation.
|
27
|
+
seed: Optional seed for reproducible generation.
|
28
|
+
max_tokens: The maximum number of tokens to generate.
|
29
|
+
api_token: Optional API token for authentication.
|
30
|
+
template_loader: Optional template loader to use for formatting prompts.
|
31
|
+
use_templating: Whether to use Jinja2 templating for prompt formatting.
|
32
|
+
"""
|
12
33
|
self.model_name = model_name
|
13
34
|
self.system_prompt = system_prompt
|
14
|
-
self.api_url = api_url
|
15
35
|
self.temperature = temperature
|
16
36
|
self.seed = seed
|
17
37
|
self.max_tokens = max_tokens
|
18
|
-
self.
|
38
|
+
self.http_client = OllamaHttpClient(api_url, api_token)
|
39
|
+
self.use_templating = use_templating
|
40
|
+
|
41
|
+
# Initialize template loader if not provided and templating is enabled
|
42
|
+
if use_templating:
|
43
|
+
self.template_loader = template_loader or TemplateLoader()
|
44
|
+
|
45
|
+
def _format_with_template(self, user_prompt: str, system_prompt: str = None) -> str:
|
46
|
+
"""
|
47
|
+
Format the prompt using a Jinja2 template appropriate for the model.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
user_prompt: The user's prompt.
|
51
|
+
system_prompt: Optional system prompt.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
str: The formatted prompt string.
|
55
|
+
"""
|
56
|
+
# Get the appropriate template for this model
|
57
|
+
template_name = self.template_loader.get_model_template(self.model_name)
|
58
|
+
if not template_name:
|
59
|
+
# If no template is available, return the user prompt with system prompt as prefix
|
60
|
+
if system_prompt:
|
61
|
+
return f"{system_prompt}\n\n{user_prompt}"
|
62
|
+
return user_prompt
|
63
|
+
|
64
|
+
# Prepare context for the template
|
65
|
+
context = {
|
66
|
+
"System": system_prompt or "",
|
67
|
+
"Messages": [
|
68
|
+
{"Role": "user", "Content": user_prompt}
|
69
|
+
],
|
70
|
+
"Tools": [] # Can be extended later to support tools
|
71
|
+
}
|
72
|
+
|
73
|
+
# Render the template
|
74
|
+
return self.template_loader.render_template(template_name, context)
|
19
75
|
|
76
|
+
def generate(self, user_prompt: str, system_prompt: str = None, max_tokens: int = None) -> Optional[str]:
|
77
|
+
"""
|
78
|
+
Generate text using the Ollama API.
|
20
79
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
80
|
+
Args:
|
81
|
+
user_prompt: The prompt to generate text from.
|
82
|
+
system_prompt: Optional system prompt to override the default.
|
83
|
+
max_tokens: Optional maximum number of tokens to generate.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Optional[str]: The generated text, or None if the request failed.
|
87
|
+
"""
|
25
88
|
system_prompt = system_prompt or self.system_prompt
|
26
89
|
payload = {
|
27
90
|
"model": self.model_name,
|
28
|
-
"prompt": user_prompt,
|
29
|
-
"system": system_prompt,
|
30
91
|
"stream": False,
|
31
|
-
"
|
32
|
-
"num_ctx": f"{self.max_tokens}",
|
33
|
-
"temperature": f"{self.temperature}"
|
92
|
+
"num_ctx": f"{self.max_tokens}"
|
34
93
|
}
|
35
94
|
|
36
|
-
|
37
|
-
payload["session"] = session_id
|
95
|
+
# Add common parameters
|
38
96
|
if self.seed:
|
39
97
|
payload["seed"] = f"{self.seed}"
|
40
98
|
if self.temperature:
|
@@ -42,22 +100,20 @@ class OllamaGenerateRepository(LlmGenerateRepository):
|
|
42
100
|
if max_tokens:
|
43
101
|
payload["num_ctx"] = f"{max_tokens}"
|
44
102
|
|
45
|
-
|
46
|
-
if self.
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
103
|
+
# Use templating if enabled and available for this model
|
104
|
+
if self.use_templating and hasattr(self, 'template_loader'):
|
105
|
+
# Format with template
|
106
|
+
formatted_prompt = self._format_with_template(user_prompt, system_prompt)
|
107
|
+
payload["prompt"] = formatted_prompt
|
108
|
+
else:
|
109
|
+
# Use standard formatting
|
110
|
+
payload["prompt"] = user_prompt
|
111
|
+
if system_prompt:
|
112
|
+
payload["system"] = system_prompt
|
52
113
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
response_content = response.json()["response"]
|
114
|
+
response_data = self.http_client.post("generate", payload)
|
115
|
+
if response_data:
|
116
|
+
response_content = response_data.get("response")
|
58
117
|
return clean_llm_response(response_content)
|
59
118
|
|
60
|
-
|
61
|
-
print(f"Error occurred during API call: {e}")
|
62
|
-
return None
|
63
|
-
|
119
|
+
return None
|