patera-aiinterface 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- patera/aiinterface/__init__.py +19 -0
- patera/aiinterface/ai_interface.py +381 -0
- patera/aiinterface/database_types/__init__.py +5 -0
- patera/aiinterface/database_types/vector_column.py +80 -0
- patera/aiinterface/embeddings.py +107 -0
- patera_aiinterface-0.1.0.dist-info/METADATA +11 -0
- patera_aiinterface-0.1.0.dist-info/RECORD +8 -0
- patera_aiinterface-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ai interface module
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .ai_interface import (
|
|
6
|
+
AiInterface,
|
|
7
|
+
ChatContextNotFound,
|
|
8
|
+
tool,
|
|
9
|
+
FailedToRunAiToolMethod,
|
|
10
|
+
AiConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"AiInterface",
|
|
15
|
+
"ChatContextNotFound",
|
|
16
|
+
"tool",
|
|
17
|
+
"FailedToRunAiToolMethod",
|
|
18
|
+
"AiConfig",
|
|
19
|
+
]
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AI interface for Patera app
|
|
3
|
+
Makes connecting to LLM's easy
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
import inspect
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from typing import (
|
|
10
|
+
List,
|
|
11
|
+
Dict,
|
|
12
|
+
Any,
|
|
13
|
+
Optional,
|
|
14
|
+
get_type_hints,
|
|
15
|
+
Callable,
|
|
16
|
+
cast,
|
|
17
|
+
TypedDict,
|
|
18
|
+
NotRequired,
|
|
19
|
+
)
|
|
20
|
+
import docstring_parser
|
|
21
|
+
from openai import AsyncOpenAI
|
|
22
|
+
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
|
|
23
|
+
from pydantic import BaseModel, Field
|
|
24
|
+
|
|
25
|
+
from patera import Patera, Request, HttpStatus, Response
|
|
26
|
+
from patera.utilities import run_sync_or_async
|
|
27
|
+
from patera.exceptions import BaseHttpException
|
|
28
|
+
from patera.base_extension import BaseExtension
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ChatContextNotFound(BaseHttpException):
|
|
32
|
+
def __init__(self, msg: str, status_code: int | HttpStatus = HttpStatus.NOT_FOUND):
|
|
33
|
+
super().__init__(msg, status_code=status_code)
|
|
34
|
+
if isinstance(status_code, HttpStatus):
|
|
35
|
+
status_code = status_code.value
|
|
36
|
+
self.status_code = status_code
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class FailedToRunAiToolMethod(BaseHttpException):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
msg: str,
|
|
43
|
+
method_name: str,
|
|
44
|
+
*args,
|
|
45
|
+
status_code: int | HttpStatus = HttpStatus.UNPROCESSABLE_ENTITY,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
super().__init__(msg, status_code=status_code)
|
|
49
|
+
if isinstance(status_code, HttpStatus):
|
|
50
|
+
status_code = status_code.value
|
|
51
|
+
self.status_code = status_code
|
|
52
|
+
self.method_name = method_name
|
|
53
|
+
self.args = args
|
|
54
|
+
self.kwargs = kwargs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _AiInterfaceConfigs(BaseModel):
|
|
58
|
+
"""
|
|
59
|
+
AI interface configuration model
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
API_KEY: str = Field(description="API key for the AI provider")
|
|
63
|
+
API_BASE_URL: Optional[str] = Field(
|
|
64
|
+
"https://api.openai.com/v1", description="Base URL for the AI provider API"
|
|
65
|
+
)
|
|
66
|
+
ORGANIZATION_ID: Optional[str] = Field(
|
|
67
|
+
None, description="Organization ID for the AI provider"
|
|
68
|
+
)
|
|
69
|
+
PROJECT_ID: Optional[str] = Field(
|
|
70
|
+
None, description="Project ID for the AI provider"
|
|
71
|
+
)
|
|
72
|
+
TIMEOUT: Optional[int] = Field(
|
|
73
|
+
30, description="Timeout for AI provider requests in seconds"
|
|
74
|
+
)
|
|
75
|
+
MODEL: str = Field(description="Model name to use for AI requests")
|
|
76
|
+
TEMPERATURE: Optional[float] = Field(
|
|
77
|
+
1.0, description="Temperature for AI model responses"
|
|
78
|
+
)
|
|
79
|
+
RESPONSE_FORMAT: Optional[dict[str, str]] = Field(
|
|
80
|
+
{"type": "json_object"}, description="Desired response format from the AI model"
|
|
81
|
+
)
|
|
82
|
+
TOOL_CHOICE: Optional[bool] = Field(
|
|
83
|
+
False, description="Whether to enable tool choice for the AI model"
|
|
84
|
+
)
|
|
85
|
+
MAX_RETRIES: Optional[int] = Field(
|
|
86
|
+
0, description="Maximum number of retries for AI provider requests"
|
|
87
|
+
)
|
|
88
|
+
CHAT_CONTEXT_NAME: Optional[str] = Field(
|
|
89
|
+
"chat_context", description="Name of the chat context model for injection"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AiConfig(TypedDict):
|
|
94
|
+
"""Admin configurations typed dictionary"""
|
|
95
|
+
|
|
96
|
+
API_KEY: str
|
|
97
|
+
API_BASE_URL: NotRequired[str]
|
|
98
|
+
ORGANIZATION_ID: NotRequired[str]
|
|
99
|
+
PROJECT_ID: NotRequired[str]
|
|
100
|
+
TIMEOUT: NotRequired[int]
|
|
101
|
+
MODEL: str
|
|
102
|
+
TEMPERATURE: NotRequired[float]
|
|
103
|
+
RESPONSE_FORMAT: NotRequired[dict[str, str]]
|
|
104
|
+
TOOL_CHOICE: NotRequired[bool]
|
|
105
|
+
MAX_RETRIES: NotRequired[int]
|
|
106
|
+
CHAT_CONTEXT_NAME: NotRequired[str]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AiInterface(BaseExtension, ABC):
|
|
110
|
+
"""
|
|
111
|
+
Main AI interface
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(self, configs_name: Optional[str] = "AI_INTERFACE"):
|
|
115
|
+
"""
|
|
116
|
+
Extension init method
|
|
117
|
+
"""
|
|
118
|
+
self._app: Patera
|
|
119
|
+
self._configs_name: str = cast(str, configs_name)
|
|
120
|
+
self._configs: dict[str, Any] = {}
|
|
121
|
+
self._api_key: str
|
|
122
|
+
self._api_base_url: str
|
|
123
|
+
self._organization_id: str
|
|
124
|
+
self._project_id: str
|
|
125
|
+
self._timeout: int
|
|
126
|
+
self._model: str
|
|
127
|
+
self._temperature: float
|
|
128
|
+
self._response_format: dict[str, str]
|
|
129
|
+
self._tool_choice: bool
|
|
130
|
+
self._max_retries: int
|
|
131
|
+
self._tools: list = []
|
|
132
|
+
self._tools_mapping: dict[str, Callable] = {}
|
|
133
|
+
self._chat_context_name: str = "chat_context"
|
|
134
|
+
|
|
135
|
+
def init_app(self, app: Patera):
|
|
136
|
+
"""
|
|
137
|
+
Initilizer method for extension
|
|
138
|
+
"""
|
|
139
|
+
self._app = app # type: ignore
|
|
140
|
+
self._configs = cast(dict[str, Any], app.get_conf(self._configs_name, None))
|
|
141
|
+
if self._configs is None:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Configurations for {self._configs_name} not found in app configurations."
|
|
144
|
+
)
|
|
145
|
+
self._configs = self.validate_configs(self._configs, _AiInterfaceConfigs)
|
|
146
|
+
self._get_tools()
|
|
147
|
+
|
|
148
|
+
self._api_key = self._configs["API_KEY"]
|
|
149
|
+
self._api_base_url = self._configs["API_BASE_URL"]
|
|
150
|
+
self._organization_id = self._configs["ORGANIZATION_ID"]
|
|
151
|
+
self._project_id = self._configs["PROJECT_ID"]
|
|
152
|
+
self._timeout = self._configs["TIMEOUT"]
|
|
153
|
+
self._model = self._configs["MODEL"]
|
|
154
|
+
self._temperature = self._configs["TEMPERATURE"]
|
|
155
|
+
self._response_format = self._configs["RESPONSE_FORMAT"]
|
|
156
|
+
self._tool_choice = self._configs["TOOL_CHOICE"]
|
|
157
|
+
self._max_retries = self._configs["MAX_RETRIES"]
|
|
158
|
+
self._chat_context_name = self._configs["CHAT_CONTEXT_NAME"]
|
|
159
|
+
self._app.add_extension(self)
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def configs(self) -> dict[str, Any]:
|
|
163
|
+
"""
|
|
164
|
+
Returns default configs object with env. var. or extension defaults
|
|
165
|
+
"""
|
|
166
|
+
return self._configs
|
|
167
|
+
|
|
168
|
+
async def provider(
|
|
169
|
+
self, messages: List[Dict[str, str]], **kwargs
|
|
170
|
+
) -> tuple[
|
|
171
|
+
str | None,
|
|
172
|
+
Any | list[ChatCompletionMessageToolCall] | None,
|
|
173
|
+
ChatCompletion | None,
|
|
174
|
+
]:
|
|
175
|
+
"""
|
|
176
|
+
Default provider method. Uses AsyncOpenAI from the openai package
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
# Build request
|
|
180
|
+
api_key = kwargs.get("api_key", self._api_key)
|
|
181
|
+
organization = kwargs.get("organization", self._organization_id)
|
|
182
|
+
project = kwargs.get("project", self._project_id)
|
|
183
|
+
timeout = kwargs.get("timeout", self._timeout)
|
|
184
|
+
base_url = kwargs.get("api_base_url", self._api_base_url)
|
|
185
|
+
max_retries = kwargs.get("max_retries", self._max_retries)
|
|
186
|
+
|
|
187
|
+
client: AsyncOpenAI = AsyncOpenAI(
|
|
188
|
+
api_key=api_key,
|
|
189
|
+
organization=organization,
|
|
190
|
+
project=project,
|
|
191
|
+
timeout=timeout,
|
|
192
|
+
base_url=base_url,
|
|
193
|
+
max_retries=max_retries,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
model = kwargs.get("model", self._model)
|
|
197
|
+
temperature = kwargs.get("temperature", self._temperature)
|
|
198
|
+
response_format = kwargs.get("response_format", self._response_format)
|
|
199
|
+
|
|
200
|
+
configs: dict = {
|
|
201
|
+
"messages": messages,
|
|
202
|
+
"model": model,
|
|
203
|
+
"temperature": temperature,
|
|
204
|
+
"response_format": response_format,
|
|
205
|
+
}
|
|
206
|
+
if kwargs.get("use_tools", False):
|
|
207
|
+
configs["tools"] = self._tools
|
|
208
|
+
|
|
209
|
+
chat: ChatCompletion = await client.chat.completions.create(**configs)
|
|
210
|
+
tool_calls = chat.choices[0].message.tool_calls or None
|
|
211
|
+
assistant_message_content = chat.choices[0].message.content or None
|
|
212
|
+
return assistant_message_content, tool_calls, chat
|
|
213
|
+
|
|
214
|
+
async def create_chat_completion(
|
|
215
|
+
self, messages: List[Dict[str, str]], **kwargs
|
|
216
|
+
) -> tuple[
|
|
217
|
+
str | None, list[ChatCompletionMessageToolCall] | None, ChatCompletion | None
|
|
218
|
+
]:
|
|
219
|
+
"""
|
|
220
|
+
Makes prompt with chosen provider method.
|
|
221
|
+
Default is the default_provider method which is OpenAI compatible.
|
|
222
|
+
|
|
223
|
+
:param stream: bool = False
|
|
224
|
+
:returns chat_completion:
|
|
225
|
+
"""
|
|
226
|
+
##if default method is selected
|
|
227
|
+
return await self.provider(messages, **kwargs)
|
|
228
|
+
|
|
229
|
+
async def envoke_ai_tool(self, tool_name, *args, **kwargs) -> Any:
|
|
230
|
+
"""
|
|
231
|
+
Runs a registered AI tool method
|
|
232
|
+
"""
|
|
233
|
+
tool_method: Optional[Callable] = self._tools_mapping.get(tool_name, None)
|
|
234
|
+
if tool_method is None:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Tool method named {tool_name} is not registered with the AI interface"
|
|
237
|
+
)
|
|
238
|
+
try:
|
|
239
|
+
return await run_sync_or_async(tool_method, *args, **kwargs)
|
|
240
|
+
except Exception as exc:
|
|
241
|
+
raise FailedToRunAiToolMethod(
|
|
242
|
+
f"Failed to run AI tool {tool_name}", *args, **kwargs
|
|
243
|
+
) from exc
|
|
244
|
+
|
|
245
|
+
def build_function_schema(
|
|
246
|
+
self,
|
|
247
|
+
func: Callable,
|
|
248
|
+
func_name: Optional[str] = None,
|
|
249
|
+
description: Optional[str] = None,
|
|
250
|
+
) -> dict[str, Any]:
|
|
251
|
+
"""
|
|
252
|
+
Automatically builds an OpenAI function schema from a Python function.
|
|
253
|
+
Assumes docstring and type hints follow some basic conventions.
|
|
254
|
+
"""
|
|
255
|
+
# Parse the docstring
|
|
256
|
+
doc = docstring_parser.parse(func.__doc__ or "")
|
|
257
|
+
func_description = description or doc.short_description or ""
|
|
258
|
+
|
|
259
|
+
# Build the skeleton
|
|
260
|
+
schema: dict[str, Any] = {
|
|
261
|
+
"type": "function",
|
|
262
|
+
"function": {
|
|
263
|
+
"name": func_name or func.__name__,
|
|
264
|
+
"description": func_description,
|
|
265
|
+
"parameters": {
|
|
266
|
+
"type": "object",
|
|
267
|
+
"properties": {},
|
|
268
|
+
"required": [],
|
|
269
|
+
"additionalProperties": False,
|
|
270
|
+
},
|
|
271
|
+
},
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
# Collect parameter info
|
|
275
|
+
sig = inspect.signature(func)
|
|
276
|
+
hints = get_type_hints(func)
|
|
277
|
+
|
|
278
|
+
for param_name, param in sig.parameters.items():
|
|
279
|
+
# Derive param type from hints
|
|
280
|
+
param_type = hints.get(param_name, str)
|
|
281
|
+
if param_type is str:
|
|
282
|
+
schema_type = "string"
|
|
283
|
+
elif param_type in [int, float]:
|
|
284
|
+
schema_type = "number"
|
|
285
|
+
elif param_type is bool:
|
|
286
|
+
schema_type = "boolean"
|
|
287
|
+
else:
|
|
288
|
+
schema_type = "string"
|
|
289
|
+
|
|
290
|
+
param_desc: str | None = ""
|
|
291
|
+
for doc_param in doc.params:
|
|
292
|
+
if doc_param.arg_name == param_name:
|
|
293
|
+
param_desc = doc_param.description
|
|
294
|
+
break
|
|
295
|
+
|
|
296
|
+
schema["function"]["parameters"]["properties"][param_name] = {
|
|
297
|
+
"type": schema_type,
|
|
298
|
+
"description": param_desc,
|
|
299
|
+
}
|
|
300
|
+
# If no default value is detected, the parameter is required
|
|
301
|
+
if param.default is inspect.Parameter.empty:
|
|
302
|
+
schema["function"]["parameters"]["required"].append(param_name)
|
|
303
|
+
return schema
|
|
304
|
+
|
|
305
|
+
@abstractmethod
|
|
306
|
+
async def chat_context_loader(
|
|
307
|
+
self, req: Request, *args: Any, **kwargs: Any
|
|
308
|
+
) -> Optional[Any]:
|
|
309
|
+
"""Should load and return a chat session object (ie db model) or none"""
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def chat_context_name(self) -> str:
|
|
313
|
+
return self._chat_context_name
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def with_chat_context(self) -> Callable:
|
|
317
|
+
"""
|
|
318
|
+
Decorator for injecting chat session to route handler.
|
|
319
|
+
Uses the chat session loader method added with the
|
|
320
|
+
@chat_context_loader decorator.
|
|
321
|
+
|
|
322
|
+
Injects the chat context object as a keyword argument
|
|
323
|
+
"""
|
|
324
|
+
interface: AiInterface = self
|
|
325
|
+
|
|
326
|
+
def decorator(func: Callable):
|
|
327
|
+
@wraps(func)
|
|
328
|
+
async def wrapper(self, *args, **kwargs) -> "Response":
|
|
329
|
+
req: Request = args[0]
|
|
330
|
+
if not isinstance(req, Request):
|
|
331
|
+
raise ValueError(
|
|
332
|
+
"Missing Request object at @with_chat_context decorator. The request object"
|
|
333
|
+
" must be the first argument of the route handler. Please check if you have "
|
|
334
|
+
"changed the argument sequence."
|
|
335
|
+
)
|
|
336
|
+
chat_context = await run_sync_or_async(
|
|
337
|
+
interface.chat_context_loader, req
|
|
338
|
+
)
|
|
339
|
+
if chat_context is None:
|
|
340
|
+
raise ChatContextNotFound("Chat session not found")
|
|
341
|
+
kwargs[interface.chat_context_name] = chat_context
|
|
342
|
+
return await run_sync_or_async(func, self, *args, **kwargs)
|
|
343
|
+
|
|
344
|
+
return wrapper
|
|
345
|
+
|
|
346
|
+
return decorator
|
|
347
|
+
|
|
348
|
+
def _get_tools(self):
|
|
349
|
+
for name in dir(self):
|
|
350
|
+
method = getattr(self, name)
|
|
351
|
+
if not callable(method):
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
is_tool = getattr(method, "__ai_tool", None) or None
|
|
355
|
+
if not is_tool:
|
|
356
|
+
continue
|
|
357
|
+
self._tools.append(
|
|
358
|
+
self.build_function_schema(
|
|
359
|
+
method, is_tool["name"], is_tool["description"]
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
self._tools_mapping[is_tool["name"]] = method
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def tool(name: Optional[str] = None, description: Optional[str] = None):
|
|
366
|
+
"""
|
|
367
|
+
Decorator for adding a method as a tool to the Ai interface
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def decorator(func: Callable):
|
|
371
|
+
"""
|
|
372
|
+
Marks method to as ai interface tool
|
|
373
|
+
"""
|
|
374
|
+
setattr(
|
|
375
|
+
func,
|
|
376
|
+
"__ai_tool",
|
|
377
|
+
{"name": name or func.__name__, "description": description or func.__doc__},
|
|
378
|
+
)
|
|
379
|
+
return func
|
|
380
|
+
|
|
381
|
+
return decorator
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector column for PostgreSQL database.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Callable, List, Self, Type
|
|
6
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
7
|
+
from sqlalchemy.types import UserDefinedType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Vector(UserDefinedType):
|
|
11
|
+
"""
|
|
12
|
+
Custom Vector type for pgvector in PostgreSQL.
|
|
13
|
+
|
|
14
|
+
Requires pgvector extension for the PostgreSQL database.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
cache_ok = True
|
|
18
|
+
|
|
19
|
+
def __init__(self, dimensions: int | None = None):
|
|
20
|
+
self.dimensions: int | None = dimensions
|
|
21
|
+
|
|
22
|
+
# pylint: disable-next=C0116
|
|
23
|
+
def get_col_spec(self, **kw) -> str:
|
|
24
|
+
if self.dimensions:
|
|
25
|
+
return f"VECTOR({int(self.dimensions)})"
|
|
26
|
+
return "VECTOR"
|
|
27
|
+
|
|
28
|
+
# pylint: disable-next=C0116
|
|
29
|
+
def bind_expression(self, bindvalue):
|
|
30
|
+
return bindvalue
|
|
31
|
+
|
|
32
|
+
# pylint: disable-next=C0116
|
|
33
|
+
def coerce_compared_value(self, op, value):
|
|
34
|
+
if isinstance(value, (list, tuple)):
|
|
35
|
+
return self
|
|
36
|
+
return super().coerce_compared_value(op, value)
|
|
37
|
+
|
|
38
|
+
# pylint: disable-next=C0116
|
|
39
|
+
def literal_processor(self, dialect) -> Callable:
|
|
40
|
+
def process(value):
|
|
41
|
+
if value is None:
|
|
42
|
+
return "NULL"
|
|
43
|
+
inside = ",".join(str(float(x)) for x in value)
|
|
44
|
+
return f"'[{inside}]'::vector"
|
|
45
|
+
|
|
46
|
+
return process
|
|
47
|
+
|
|
48
|
+
# pylint: disable-next=C0116
|
|
49
|
+
def compare_values(self, x: List, y: List) -> bool:
|
|
50
|
+
if x is y:
|
|
51
|
+
return True
|
|
52
|
+
if x is None or y is None:
|
|
53
|
+
return x is y
|
|
54
|
+
if len(x) != len(y):
|
|
55
|
+
return False
|
|
56
|
+
eps = 1e-9
|
|
57
|
+
return all(abs(float(a) - float(b)) <= eps for a, b in zip(x, y))
|
|
58
|
+
|
|
59
|
+
class comparator_factory(UserDefinedType.Comparator):
|
|
60
|
+
# distance operators exposed as methods for nicer query syntax
|
|
61
|
+
def l2_distance(self, other):
|
|
62
|
+
return self.expr.op("<->")(other)
|
|
63
|
+
|
|
64
|
+
def inner_product(self, other):
|
|
65
|
+
return self.expr.op("<#>")(other)
|
|
66
|
+
|
|
67
|
+
def cosine_distance(self, other):
|
|
68
|
+
return self.expr.op("<=>")(other)
|
|
69
|
+
|
|
70
|
+
# pylint: disable-next=C0116
|
|
71
|
+
def column_expression(self, colexpr) -> ColumnElement:
|
|
72
|
+
return colexpr
|
|
73
|
+
|
|
74
|
+
# pylint: disable-next=C0116
|
|
75
|
+
def _with_collation(self, collation) -> Self:
|
|
76
|
+
return self
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def python_type(self) -> Type[List]:
|
|
80
|
+
return list
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helpers for creating embeddings
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from sentence_transformers import SentenceTransformer
|
|
8
|
+
from pgvector.sqlalchemy import Vector as VectorColumn
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"l2_distance",
|
|
12
|
+
"cosine_similarity",
|
|
13
|
+
"cosine_distance",
|
|
14
|
+
"create_embedding",
|
|
15
|
+
"chunkify_text",
|
|
16
|
+
"VectorColumn",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def l2_distance(
|
|
21
|
+
vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
|
|
22
|
+
) -> float:
|
|
23
|
+
"""
|
|
24
|
+
Calculates l2 distance between two vectors
|
|
25
|
+
|
|
26
|
+
:param vec1: first vector.
|
|
27
|
+
:param vec2: second vector.
|
|
28
|
+
"""
|
|
29
|
+
if isinstance(vec1, list):
|
|
30
|
+
vec1 = np.array(vec1)
|
|
31
|
+
if isinstance(vec2, list):
|
|
32
|
+
vec2 = np.array(vec2)
|
|
33
|
+
return np.sqrt(np.sum((np.array(vec1) - np.array(vec2)) ** 2))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def cosine_similarity(
|
|
37
|
+
vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
|
|
38
|
+
) -> float:
|
|
39
|
+
"""
|
|
40
|
+
Calculates cosine similarity between two vectors
|
|
41
|
+
|
|
42
|
+
:param vec1: first vector.
|
|
43
|
+
:param vec2: second vector.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(vec1, list):
|
|
46
|
+
vec1 = np.array(vec1)
|
|
47
|
+
if isinstance(vec2, list):
|
|
48
|
+
vec2 = np.array(vec2)
|
|
49
|
+
dot_product = np.dot(vec1, vec2)
|
|
50
|
+
norm_vec1 = np.linalg.norm(vec1)
|
|
51
|
+
norm_vec2 = np.linalg.norm(vec2)
|
|
52
|
+
return dot_product / (norm_vec1 * norm_vec2)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def cosine_distance(
|
|
56
|
+
vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
|
|
57
|
+
) -> float:
|
|
58
|
+
"""
|
|
59
|
+
Calculates cosine distance between two vectors
|
|
60
|
+
|
|
61
|
+
:param vec1: first vector.
|
|
62
|
+
:param vec2: second vector.
|
|
63
|
+
:returns: cosine distance as 1 - cosimn_similarity(vec1, vec2)
|
|
64
|
+
"""
|
|
65
|
+
similarity: float = cosine_similarity(vec1, vec2)
|
|
66
|
+
return 1 - similarity # Cosine distance is 1 - cosine similarity
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def create_embedding(
|
|
70
|
+
text: str,
|
|
71
|
+
transformer: str = "infgrad/stella-base-en-v2",
|
|
72
|
+
trust_remote_code: bool = True,
|
|
73
|
+
**kwargs,
|
|
74
|
+
) -> Tensor:
|
|
75
|
+
"""
|
|
76
|
+
Creates embedding for provided text with specified transformer
|
|
77
|
+
|
|
78
|
+
:param text: text for which the embedding is created.
|
|
79
|
+
:param transformer: sentence transfor that is used.
|
|
80
|
+
:param trust_remote_code: if you allow code from the Hugging Face Hun repo to be executed locally
|
|
81
|
+
:param kwargs: any keyword arguments that are accepted by the SentenceTransformer class
|
|
82
|
+
"""
|
|
83
|
+
embedding_model = SentenceTransformer(
|
|
84
|
+
transformer, trust_remote_code=trust_remote_code, **kwargs
|
|
85
|
+
)
|
|
86
|
+
return embedding_model.encode(text)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def chunkify_text(text: str, chunk_size: int) -> list[str]:
|
|
90
|
+
"""
|
|
91
|
+
Takes a text and creates a list of chunks with words <= chunk_size
|
|
92
|
+
|
|
93
|
+
:param text: the text to be chunkified.
|
|
94
|
+
:param chunk_size: max number of words in chunk
|
|
95
|
+
"""
|
|
96
|
+
words: list[str] = text.split()
|
|
97
|
+
chunks: list[str] = []
|
|
98
|
+
current_chunk: list[str] = []
|
|
99
|
+
|
|
100
|
+
for word in words:
|
|
101
|
+
current_chunk.append(word)
|
|
102
|
+
if len(current_chunk) == chunk_size:
|
|
103
|
+
chunks.append(" ".join(current_chunk))
|
|
104
|
+
current_chunk = []
|
|
105
|
+
if len(current_chunk) > 0:
|
|
106
|
+
chunks.append(" ".join(current_chunk))
|
|
107
|
+
return chunks
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: patera-aiinterface
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Requires-Dist: docstring-parser>=0.17.0
|
|
5
|
+
Requires-Dist: numpy>=2.4.3
|
|
6
|
+
Requires-Dist: openai>=2.26.0
|
|
7
|
+
Requires-Dist: pgvector>=0.4.2
|
|
8
|
+
Requires-Dist: patera
|
|
9
|
+
Requires-Dist: sentence-transformers>=5.2.3
|
|
10
|
+
Requires-Dist: torch>=2.10.0
|
|
11
|
+
Requires-Python: >=3.12
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
patera/aiinterface/__init__.py,sha256=c1TjsLl22osoSYW_3jCFkjwtoYb-AdzqkL2b7jyacu0,274
|
|
2
|
+
patera/aiinterface/ai_interface.py,sha256=N0mXwkeLhhe4Hi4XeV6apmBV4KWLbbe2zxUrkSOSGLs,13420
|
|
3
|
+
patera/aiinterface/database_types/__init__.py,sha256=hdMgFLqsj6UPocSi0T83cC7Qy99oXPkdHV94Xs8iSXA,92
|
|
4
|
+
patera/aiinterface/database_types/vector_column.py,sha256=sSRS62RSClIr64ThrJ9xaSIT9zbDDNfWn2xJA9boSf4,2295
|
|
5
|
+
patera/aiinterface/embeddings.py,sha256=tS_tfDzHBflq1QInbW5i5lQfiKjqhP24f494qR3Wh9M,3025
|
|
6
|
+
patera_aiinterface-0.1.0.dist-info/WHEEL,sha256=mydTeHxOpFHo-DnYhAd_3ATePms-g4rrYvM7wJK8P-U,80
|
|
7
|
+
patera_aiinterface-0.1.0.dist-info/METADATA,sha256=lc-M237qf4snlAoI_kXt7vfCBRZ2qTPXkvnXu78IxCw,310
|
|
8
|
+
patera_aiinterface-0.1.0.dist-info/RECORD,,
|