patera-aiinterface 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ """
2
+ Ai interface module
3
+ """
4
+
5
+ from .ai_interface import (
6
+ AiInterface,
7
+ ChatContextNotFound,
8
+ tool,
9
+ FailedToRunAiToolMethod,
10
+ AiConfig,
11
+ )
12
+
13
+ __all__ = [
14
+ "AiInterface",
15
+ "ChatContextNotFound",
16
+ "tool",
17
+ "FailedToRunAiToolMethod",
18
+ "AiConfig",
19
+ ]
@@ -0,0 +1,381 @@
1
+ """
2
+ AI interface for Patera app
3
+ Makes connecting to LLM's easy
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ import inspect
8
+ from functools import wraps
9
+ from typing import (
10
+ List,
11
+ Dict,
12
+ Any,
13
+ Optional,
14
+ get_type_hints,
15
+ Callable,
16
+ cast,
17
+ TypedDict,
18
+ NotRequired,
19
+ )
20
+ import docstring_parser
21
+ from openai import AsyncOpenAI
22
+ from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
23
+ from pydantic import BaseModel, Field
24
+
25
+ from patera import Patera, Request, HttpStatus, Response
26
+ from patera.utilities import run_sync_or_async
27
+ from patera.exceptions import BaseHttpException
28
+ from patera.base_extension import BaseExtension
29
+
30
+
31
+ class ChatContextNotFound(BaseHttpException):
32
+ def __init__(self, msg: str, status_code: int | HttpStatus = HttpStatus.NOT_FOUND):
33
+ super().__init__(msg, status_code=status_code)
34
+ if isinstance(status_code, HttpStatus):
35
+ status_code = status_code.value
36
+ self.status_code = status_code
37
+
38
+
39
+ class FailedToRunAiToolMethod(BaseHttpException):
40
+ def __init__(
41
+ self,
42
+ msg: str,
43
+ method_name: str,
44
+ *args,
45
+ status_code: int | HttpStatus = HttpStatus.UNPROCESSABLE_ENTITY,
46
+ **kwargs,
47
+ ):
48
+ super().__init__(msg, status_code=status_code)
49
+ if isinstance(status_code, HttpStatus):
50
+ status_code = status_code.value
51
+ self.status_code = status_code
52
+ self.method_name = method_name
53
+ self.args = args
54
+ self.kwargs = kwargs
55
+
56
+
57
+ class _AiInterfaceConfigs(BaseModel):
58
+ """
59
+ AI interface configuration model
60
+ """
61
+
62
+ API_KEY: str = Field(description="API key for the AI provider")
63
+ API_BASE_URL: Optional[str] = Field(
64
+ "https://api.openai.com/v1", description="Base URL for the AI provider API"
65
+ )
66
+ ORGANIZATION_ID: Optional[str] = Field(
67
+ None, description="Organization ID for the AI provider"
68
+ )
69
+ PROJECT_ID: Optional[str] = Field(
70
+ None, description="Project ID for the AI provider"
71
+ )
72
+ TIMEOUT: Optional[int] = Field(
73
+ 30, description="Timeout for AI provider requests in seconds"
74
+ )
75
+ MODEL: str = Field(description="Model name to use for AI requests")
76
+ TEMPERATURE: Optional[float] = Field(
77
+ 1.0, description="Temperature for AI model responses"
78
+ )
79
+ RESPONSE_FORMAT: Optional[dict[str, str]] = Field(
80
+ {"type": "json_object"}, description="Desired response format from the AI model"
81
+ )
82
+ TOOL_CHOICE: Optional[bool] = Field(
83
+ False, description="Whether to enable tool choice for the AI model"
84
+ )
85
+ MAX_RETRIES: Optional[int] = Field(
86
+ 0, description="Maximum number of retries for AI provider requests"
87
+ )
88
+ CHAT_CONTEXT_NAME: Optional[str] = Field(
89
+ "chat_context", description="Name of the chat context model for injection"
90
+ )
91
+
92
+
93
+ class AiConfig(TypedDict):
94
+ """Admin configurations typed dictionary"""
95
+
96
+ API_KEY: str
97
+ API_BASE_URL: NotRequired[str]
98
+ ORGANIZATION_ID: NotRequired[str]
99
+ PROJECT_ID: NotRequired[str]
100
+ TIMEOUT: NotRequired[int]
101
+ MODEL: str
102
+ TEMPERATURE: NotRequired[float]
103
+ RESPONSE_FORMAT: NotRequired[dict[str, str]]
104
+ TOOL_CHOICE: NotRequired[bool]
105
+ MAX_RETRIES: NotRequired[int]
106
+ CHAT_CONTEXT_NAME: NotRequired[str]
107
+
108
+
109
+ class AiInterface(BaseExtension, ABC):
110
+ """
111
+ Main AI interface
112
+ """
113
+
114
+ def __init__(self, configs_name: Optional[str] = "AI_INTERFACE"):
115
+ """
116
+ Extension init method
117
+ """
118
+ self._app: Patera
119
+ self._configs_name: str = cast(str, configs_name)
120
+ self._configs: dict[str, Any] = {}
121
+ self._api_key: str
122
+ self._api_base_url: str
123
+ self._organization_id: str
124
+ self._project_id: str
125
+ self._timeout: int
126
+ self._model: str
127
+ self._temperature: float
128
+ self._response_format: dict[str, str]
129
+ self._tool_choice: bool
130
+ self._max_retries: int
131
+ self._tools: list = []
132
+ self._tools_mapping: dict[str, Callable] = {}
133
+ self._chat_context_name: str = "chat_context"
134
+
135
+ def init_app(self, app: Patera):
136
+ """
137
+ Initilizer method for extension
138
+ """
139
+ self._app = app # type: ignore
140
+ self._configs = cast(dict[str, Any], app.get_conf(self._configs_name, None))
141
+ if self._configs is None:
142
+ raise ValueError(
143
+ f"Configurations for {self._configs_name} not found in app configurations."
144
+ )
145
+ self._configs = self.validate_configs(self._configs, _AiInterfaceConfigs)
146
+ self._get_tools()
147
+
148
+ self._api_key = self._configs["API_KEY"]
149
+ self._api_base_url = self._configs["API_BASE_URL"]
150
+ self._organization_id = self._configs["ORGANIZATION_ID"]
151
+ self._project_id = self._configs["PROJECT_ID"]
152
+ self._timeout = self._configs["TIMEOUT"]
153
+ self._model = self._configs["MODEL"]
154
+ self._temperature = self._configs["TEMPERATURE"]
155
+ self._response_format = self._configs["RESPONSE_FORMAT"]
156
+ self._tool_choice = self._configs["TOOL_CHOICE"]
157
+ self._max_retries = self._configs["MAX_RETRIES"]
158
+ self._chat_context_name = self._configs["CHAT_CONTEXT_NAME"]
159
+ self._app.add_extension(self)
160
+
161
+ @property
162
+ def configs(self) -> dict[str, Any]:
163
+ """
164
+ Returns default configs object with env. var. or extension defaults
165
+ """
166
+ return self._configs
167
+
168
+ async def provider(
169
+ self, messages: List[Dict[str, str]], **kwargs
170
+ ) -> tuple[
171
+ str | None,
172
+ Any | list[ChatCompletionMessageToolCall] | None,
173
+ ChatCompletion | None,
174
+ ]:
175
+ """
176
+ Default provider method. Uses AsyncOpenAI from the openai package
177
+ """
178
+
179
+ # Build request
180
+ api_key = kwargs.get("api_key", self._api_key)
181
+ organization = kwargs.get("organization", self._organization_id)
182
+ project = kwargs.get("project", self._project_id)
183
+ timeout = kwargs.get("timeout", self._timeout)
184
+ base_url = kwargs.get("api_base_url", self._api_base_url)
185
+ max_retries = kwargs.get("max_retries", self._max_retries)
186
+
187
+ client: AsyncOpenAI = AsyncOpenAI(
188
+ api_key=api_key,
189
+ organization=organization,
190
+ project=project,
191
+ timeout=timeout,
192
+ base_url=base_url,
193
+ max_retries=max_retries,
194
+ )
195
+
196
+ model = kwargs.get("model", self._model)
197
+ temperature = kwargs.get("temperature", self._temperature)
198
+ response_format = kwargs.get("response_format", self._response_format)
199
+
200
+ configs: dict = {
201
+ "messages": messages,
202
+ "model": model,
203
+ "temperature": temperature,
204
+ "response_format": response_format,
205
+ }
206
+ if kwargs.get("use_tools", False):
207
+ configs["tools"] = self._tools
208
+
209
+ chat: ChatCompletion = await client.chat.completions.create(**configs)
210
+ tool_calls = chat.choices[0].message.tool_calls or None
211
+ assistant_message_content = chat.choices[0].message.content or None
212
+ return assistant_message_content, tool_calls, chat
213
+
214
+ async def create_chat_completion(
215
+ self, messages: List[Dict[str, str]], **kwargs
216
+ ) -> tuple[
217
+ str | None, list[ChatCompletionMessageToolCall] | None, ChatCompletion | None
218
+ ]:
219
+ """
220
+ Makes prompt with chosen provider method.
221
+ Default is the default_provider method which is OpenAI compatible.
222
+
223
+ :param stream: bool = False
224
+ :returns chat_completion:
225
+ """
226
+ ##if default method is selected
227
+ return await self.provider(messages, **kwargs)
228
+
229
+ async def envoke_ai_tool(self, tool_name, *args, **kwargs) -> Any:
230
+ """
231
+ Runs a registered AI tool method
232
+ """
233
+ tool_method: Optional[Callable] = self._tools_mapping.get(tool_name, None)
234
+ if tool_method is None:
235
+ raise ValueError(
236
+ f"Tool method named {tool_name} is not registered with the AI interface"
237
+ )
238
+ try:
239
+ return await run_sync_or_async(tool_method, *args, **kwargs)
240
+ except Exception as exc:
241
+ raise FailedToRunAiToolMethod(
242
+ f"Failed to run AI tool {tool_name}", *args, **kwargs
243
+ ) from exc
244
+
245
+ def build_function_schema(
246
+ self,
247
+ func: Callable,
248
+ func_name: Optional[str] = None,
249
+ description: Optional[str] = None,
250
+ ) -> dict[str, Any]:
251
+ """
252
+ Automatically builds an OpenAI function schema from a Python function.
253
+ Assumes docstring and type hints follow some basic conventions.
254
+ """
255
+ # Parse the docstring
256
+ doc = docstring_parser.parse(func.__doc__ or "")
257
+ func_description = description or doc.short_description or ""
258
+
259
+ # Build the skeleton
260
+ schema: dict[str, Any] = {
261
+ "type": "function",
262
+ "function": {
263
+ "name": func_name or func.__name__,
264
+ "description": func_description,
265
+ "parameters": {
266
+ "type": "object",
267
+ "properties": {},
268
+ "required": [],
269
+ "additionalProperties": False,
270
+ },
271
+ },
272
+ }
273
+
274
+ # Collect parameter info
275
+ sig = inspect.signature(func)
276
+ hints = get_type_hints(func)
277
+
278
+ for param_name, param in sig.parameters.items():
279
+ # Derive param type from hints
280
+ param_type = hints.get(param_name, str)
281
+ if param_type is str:
282
+ schema_type = "string"
283
+ elif param_type in [int, float]:
284
+ schema_type = "number"
285
+ elif param_type is bool:
286
+ schema_type = "boolean"
287
+ else:
288
+ schema_type = "string"
289
+
290
+ param_desc: str | None = ""
291
+ for doc_param in doc.params:
292
+ if doc_param.arg_name == param_name:
293
+ param_desc = doc_param.description
294
+ break
295
+
296
+ schema["function"]["parameters"]["properties"][param_name] = {
297
+ "type": schema_type,
298
+ "description": param_desc,
299
+ }
300
+ # If no default value is detected, the parameter is required
301
+ if param.default is inspect.Parameter.empty:
302
+ schema["function"]["parameters"]["required"].append(param_name)
303
+ return schema
304
+
305
+ @abstractmethod
306
+ async def chat_context_loader(
307
+ self, req: Request, *args: Any, **kwargs: Any
308
+ ) -> Optional[Any]:
309
+ """Should load and return a chat session object (ie db model) or none"""
310
+
311
+ @property
312
+ def chat_context_name(self) -> str:
313
+ return self._chat_context_name
314
+
315
+ @property
316
+ def with_chat_context(self) -> Callable:
317
+ """
318
+ Decorator for injecting chat session to route handler.
319
+ Uses the chat session loader method added with the
320
+ @chat_context_loader decorator.
321
+
322
+ Injects the chat context object as a keyword argument
323
+ """
324
+ interface: AiInterface = self
325
+
326
+ def decorator(func: Callable):
327
+ @wraps(func)
328
+ async def wrapper(self, *args, **kwargs) -> "Response":
329
+ req: Request = args[0]
330
+ if not isinstance(req, Request):
331
+ raise ValueError(
332
+ "Missing Request object at @with_chat_context decorator. The request object"
333
+ " must be the first argument of the route handler. Please check if you have "
334
+ "changed the argument sequence."
335
+ )
336
+ chat_context = await run_sync_or_async(
337
+ interface.chat_context_loader, req
338
+ )
339
+ if chat_context is None:
340
+ raise ChatContextNotFound("Chat session not found")
341
+ kwargs[interface.chat_context_name] = chat_context
342
+ return await run_sync_or_async(func, self, *args, **kwargs)
343
+
344
+ return wrapper
345
+
346
+ return decorator
347
+
348
+ def _get_tools(self):
349
+ for name in dir(self):
350
+ method = getattr(self, name)
351
+ if not callable(method):
352
+ continue
353
+
354
+ is_tool = getattr(method, "__ai_tool", None) or None
355
+ if not is_tool:
356
+ continue
357
+ self._tools.append(
358
+ self.build_function_schema(
359
+ method, is_tool["name"], is_tool["description"]
360
+ )
361
+ )
362
+ self._tools_mapping[is_tool["name"]] = method
363
+
364
+
365
+ def tool(name: Optional[str] = None, description: Optional[str] = None):
366
+ """
367
+ Decorator for adding a method as a tool to the Ai interface
368
+ """
369
+
370
+ def decorator(func: Callable):
371
+ """
372
+ Marks method to as ai interface tool
373
+ """
374
+ setattr(
375
+ func,
376
+ "__ai_tool",
377
+ {"name": name or func.__name__, "description": description or func.__doc__},
378
+ )
379
+ return func
380
+
381
+ return decorator
@@ -0,0 +1,5 @@
1
+ """Database types"""
2
+
3
+ from .vector_column import Vector # noqa: F401
4
+
5
+ __all__ = ["Vector"]
@@ -0,0 +1,80 @@
1
+ """
2
+ Vector column for PostgreSQL database.
3
+ """
4
+
5
+ from typing import Callable, List, Self, Type
6
+ from sqlalchemy.sql.elements import ColumnElement
7
+ from sqlalchemy.types import UserDefinedType
8
+
9
+
10
+ class Vector(UserDefinedType):
11
+ """
12
+ Custom Vector type for pgvector in PostgreSQL.
13
+
14
+ Requires pgvector extension for the PostgreSQL database.
15
+ """
16
+
17
+ cache_ok = True
18
+
19
+ def __init__(self, dimensions: int | None = None):
20
+ self.dimensions: int | None = dimensions
21
+
22
+ # pylint: disable-next=C0116
23
+ def get_col_spec(self, **kw) -> str:
24
+ if self.dimensions:
25
+ return f"VECTOR({int(self.dimensions)})"
26
+ return "VECTOR"
27
+
28
+ # pylint: disable-next=C0116
29
+ def bind_expression(self, bindvalue):
30
+ return bindvalue
31
+
32
+ # pylint: disable-next=C0116
33
+ def coerce_compared_value(self, op, value):
34
+ if isinstance(value, (list, tuple)):
35
+ return self
36
+ return super().coerce_compared_value(op, value)
37
+
38
+ # pylint: disable-next=C0116
39
+ def literal_processor(self, dialect) -> Callable:
40
+ def process(value):
41
+ if value is None:
42
+ return "NULL"
43
+ inside = ",".join(str(float(x)) for x in value)
44
+ return f"'[{inside}]'::vector"
45
+
46
+ return process
47
+
48
+ # pylint: disable-next=C0116
49
+ def compare_values(self, x: List, y: List) -> bool:
50
+ if x is y:
51
+ return True
52
+ if x is None or y is None:
53
+ return x is y
54
+ if len(x) != len(y):
55
+ return False
56
+ eps = 1e-9
57
+ return all(abs(float(a) - float(b)) <= eps for a, b in zip(x, y))
58
+
59
+ class comparator_factory(UserDefinedType.Comparator):
60
+ # distance operators exposed as methods for nicer query syntax
61
+ def l2_distance(self, other):
62
+ return self.expr.op("<->")(other)
63
+
64
+ def inner_product(self, other):
65
+ return self.expr.op("<#>")(other)
66
+
67
+ def cosine_distance(self, other):
68
+ return self.expr.op("<=>")(other)
69
+
70
+ # pylint: disable-next=C0116
71
+ def column_expression(self, colexpr) -> ColumnElement:
72
+ return colexpr
73
+
74
+ # pylint: disable-next=C0116
75
+ def _with_collation(self, collation) -> Self:
76
+ return self
77
+
78
+ @property
79
+ def python_type(self) -> Type[List]:
80
+ return list
@@ -0,0 +1,107 @@
1
+ """
2
+ Helpers for creating embeddings
3
+ """
4
+
5
+ import numpy as np
6
+ from torch import Tensor
7
+ from sentence_transformers import SentenceTransformer
8
+ from pgvector.sqlalchemy import Vector as VectorColumn
9
+
10
+ __all__ = [
11
+ "l2_distance",
12
+ "cosine_similarity",
13
+ "cosine_distance",
14
+ "create_embedding",
15
+ "chunkify_text",
16
+ "VectorColumn",
17
+ ]
18
+
19
+
20
+ def l2_distance(
21
+ vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
22
+ ) -> float:
23
+ """
24
+ Calculates l2 distance between two vectors
25
+
26
+ :param vec1: first vector.
27
+ :param vec2: second vector.
28
+ """
29
+ if isinstance(vec1, list):
30
+ vec1 = np.array(vec1)
31
+ if isinstance(vec2, list):
32
+ vec2 = np.array(vec2)
33
+ return np.sqrt(np.sum((np.array(vec1) - np.array(vec2)) ** 2))
34
+
35
+
36
+ def cosine_similarity(
37
+ vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
38
+ ) -> float:
39
+ """
40
+ Calculates cosine similarity between two vectors
41
+
42
+ :param vec1: first vector.
43
+ :param vec2: second vector.
44
+ """
45
+ if isinstance(vec1, list):
46
+ vec1 = np.array(vec1)
47
+ if isinstance(vec2, list):
48
+ vec2 = np.array(vec2)
49
+ dot_product = np.dot(vec1, vec2)
50
+ norm_vec1 = np.linalg.norm(vec1)
51
+ norm_vec2 = np.linalg.norm(vec2)
52
+ return dot_product / (norm_vec1 * norm_vec2)
53
+
54
+
55
+ def cosine_distance(
56
+ vec1: list[float] | np.ndarray, vec2: list[float] | np.ndarray
57
+ ) -> float:
58
+ """
59
+ Calculates cosine distance between two vectors
60
+
61
+ :param vec1: first vector.
62
+ :param vec2: second vector.
63
+ :returns: cosine distance as 1 - cosimn_similarity(vec1, vec2)
64
+ """
65
+ similarity: float = cosine_similarity(vec1, vec2)
66
+ return 1 - similarity # Cosine distance is 1 - cosine similarity
67
+
68
+
69
+ def create_embedding(
70
+ text: str,
71
+ transformer: str = "infgrad/stella-base-en-v2",
72
+ trust_remote_code: bool = True,
73
+ **kwargs,
74
+ ) -> Tensor:
75
+ """
76
+ Creates embedding for provided text with specified transformer
77
+
78
+ :param text: text for which the embedding is created.
79
+ :param transformer: sentence transfor that is used.
80
+ :param trust_remote_code: if you allow code from the Hugging Face Hun repo to be executed locally
81
+ :param kwargs: any keyword arguments that are accepted by the SentenceTransformer class
82
+ """
83
+ embedding_model = SentenceTransformer(
84
+ transformer, trust_remote_code=trust_remote_code, **kwargs
85
+ )
86
+ return embedding_model.encode(text)
87
+
88
+
89
+ def chunkify_text(text: str, chunk_size: int) -> list[str]:
90
+ """
91
+ Takes a text and creates a list of chunks with words <= chunk_size
92
+
93
+ :param text: the text to be chunkified.
94
+ :param chunk_size: max number of words in chunk
95
+ """
96
+ words: list[str] = text.split()
97
+ chunks: list[str] = []
98
+ current_chunk: list[str] = []
99
+
100
+ for word in words:
101
+ current_chunk.append(word)
102
+ if len(current_chunk) == chunk_size:
103
+ chunks.append(" ".join(current_chunk))
104
+ current_chunk = []
105
+ if len(current_chunk) > 0:
106
+ chunks.append(" ".join(current_chunk))
107
+ return chunks
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.3
2
+ Name: patera-aiinterface
3
+ Version: 0.1.0
4
+ Requires-Dist: docstring-parser>=0.17.0
5
+ Requires-Dist: numpy>=2.4.3
6
+ Requires-Dist: openai>=2.26.0
7
+ Requires-Dist: pgvector>=0.4.2
8
+ Requires-Dist: patera
9
+ Requires-Dist: sentence-transformers>=5.2.3
10
+ Requires-Dist: torch>=2.10.0
11
+ Requires-Python: >=3.12
@@ -0,0 +1,8 @@
1
+ patera/aiinterface/__init__.py,sha256=c1TjsLl22osoSYW_3jCFkjwtoYb-AdzqkL2b7jyacu0,274
2
+ patera/aiinterface/ai_interface.py,sha256=N0mXwkeLhhe4Hi4XeV6apmBV4KWLbbe2zxUrkSOSGLs,13420
3
+ patera/aiinterface/database_types/__init__.py,sha256=hdMgFLqsj6UPocSi0T83cC7Qy99oXPkdHV94Xs8iSXA,92
4
+ patera/aiinterface/database_types/vector_column.py,sha256=sSRS62RSClIr64ThrJ9xaSIT9zbDDNfWn2xJA9boSf4,2295
5
+ patera/aiinterface/embeddings.py,sha256=tS_tfDzHBflq1QInbW5i5lQfiKjqhP24f494qR3Wh9M,3025
6
+ patera_aiinterface-0.1.0.dist-info/WHEEL,sha256=mydTeHxOpFHo-DnYhAd_3ATePms-g4rrYvM7wJK8P-U,80
7
+ patera_aiinterface-0.1.0.dist-info/METADATA,sha256=lc-M237qf4snlAoI_kXt7vfCBRZ2qTPXkvnXu78IxCw,310
8
+ patera_aiinterface-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.10.9
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any