select-ai 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of select-ai might be problematic. Click here for more details.
- select_ai/__init__.py +59 -0
- select_ai/_abc.py +77 -0
- select_ai/_enums.py +14 -0
- select_ai/action.py +21 -0
- select_ai/async_profile.py +528 -0
- select_ai/base_profile.py +183 -0
- select_ai/conversation.py +274 -0
- select_ai/credential.py +135 -0
- select_ai/db.py +186 -0
- select_ai/errors.py +73 -0
- select_ai/profile.py +454 -0
- select_ai/provider.py +288 -0
- select_ai/sql.py +105 -0
- select_ai/synthetic_data.py +90 -0
- select_ai/vector_index.py +552 -0
- select_ai/version.py +8 -0
- select_ai-1.0.0b1.dist-info/METADATA +117 -0
- select_ai-1.0.0b1.dist-info/RECORD +21 -0
- select_ai-1.0.0b1.dist-info/WHEEL +5 -0
- select_ai-1.0.0b1.dist-info/licenses/LICENSE.txt +35 -0
- select_ai-1.0.0b1.dist-info/top_level.txt +1 -0
select_ai/db.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import os
|
|
10
|
+
from threading import get_ident
|
|
11
|
+
from typing import Dict, Hashable
|
|
12
|
+
|
|
13
|
+
import oracledb
|
|
14
|
+
|
|
15
|
+
from select_ai.errors import DatabaseNotConnectedError
|
|
16
|
+
|
|
17
|
+
__conn__: Dict[Hashable, oracledb.Connection] = {}
|
|
18
|
+
__async_conn__: Dict[Hashable, oracledb.AsyncConnection] = {}
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"connect",
|
|
22
|
+
"async_connect",
|
|
23
|
+
"is_connected",
|
|
24
|
+
"async_is_connected",
|
|
25
|
+
"get_connection",
|
|
26
|
+
"async_get_connection",
|
|
27
|
+
"cursor",
|
|
28
|
+
"async_cursor",
|
|
29
|
+
"disconnect",
|
|
30
|
+
"async_disconnect",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def connect(user: str, password: str, dsn: str, *args, **kwargs):
|
|
35
|
+
"""Creates an oracledb.Connection object
|
|
36
|
+
and saves it global dictionary __conn__
|
|
37
|
+
The connection object is thread local meaning
|
|
38
|
+
in a multithreaded application, individual
|
|
39
|
+
threads cannot see each other's connection
|
|
40
|
+
object
|
|
41
|
+
"""
|
|
42
|
+
conn = oracledb.connect(
|
|
43
|
+
user=user,
|
|
44
|
+
password=password,
|
|
45
|
+
dsn=dsn,
|
|
46
|
+
connection_id_prefix="python-select-ai",
|
|
47
|
+
*args,
|
|
48
|
+
**kwargs,
|
|
49
|
+
)
|
|
50
|
+
_set_connection(conn=conn)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def async_connect(user: str, password: str, dsn: str, *args, **kwargs):
|
|
54
|
+
"""Creates an oracledb.AsyncConnection object
|
|
55
|
+
and saves it global dictionary __async_conn__
|
|
56
|
+
The connection object is thread local meaning
|
|
57
|
+
in a multithreaded application, individual
|
|
58
|
+
threads cannot see each other's connection
|
|
59
|
+
object
|
|
60
|
+
"""
|
|
61
|
+
async_conn = await oracledb.connect_async(
|
|
62
|
+
user=user, password=password, dsn=dsn, *args, **kwargs
|
|
63
|
+
)
|
|
64
|
+
_set_connection(async_conn=async_conn)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def is_connected() -> bool:
|
|
68
|
+
"""Checks if database connection is open and healthy"""
|
|
69
|
+
global __conn__
|
|
70
|
+
key = (os.getpid(), get_ident())
|
|
71
|
+
conn = __conn__.get(key)
|
|
72
|
+
if conn is None:
|
|
73
|
+
return False
|
|
74
|
+
try:
|
|
75
|
+
return conn.ping() is None
|
|
76
|
+
except (oracledb.DatabaseError, oracledb.InterfaceError):
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
async def async_is_connected() -> bool:
|
|
81
|
+
"""Asynchronously checks if database connection is open and healthy"""
|
|
82
|
+
|
|
83
|
+
global __async_conn__
|
|
84
|
+
key = (os.getpid(), get_ident())
|
|
85
|
+
conn = __async_conn__.get(key)
|
|
86
|
+
if conn is None:
|
|
87
|
+
return False
|
|
88
|
+
try:
|
|
89
|
+
return await conn.ping() is None
|
|
90
|
+
except (oracledb.DatabaseError, oracledb.InterfaceError):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _set_connection(
|
|
95
|
+
conn: oracledb.Connection = None,
|
|
96
|
+
async_conn: oracledb.AsyncConnection = None,
|
|
97
|
+
):
|
|
98
|
+
"""Set existing connection for select_ai Python API to reuse
|
|
99
|
+
|
|
100
|
+
:param conn: python-oracledb Connection object
|
|
101
|
+
:param async_conn: python-oracledb
|
|
102
|
+
:return:
|
|
103
|
+
"""
|
|
104
|
+
key = (os.getpid(), get_ident())
|
|
105
|
+
if conn:
|
|
106
|
+
global __conn__
|
|
107
|
+
__conn__[key] = conn
|
|
108
|
+
if async_conn:
|
|
109
|
+
global __async_conn__
|
|
110
|
+
__async_conn__[key] = async_conn
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_connection() -> oracledb.Connection:
|
|
114
|
+
"""Returns the connection object if connection is healthy"""
|
|
115
|
+
if not is_connected():
|
|
116
|
+
raise DatabaseNotConnectedError()
|
|
117
|
+
global __conn__
|
|
118
|
+
key = (os.getpid(), get_ident())
|
|
119
|
+
return __conn__[key]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async def async_get_connection() -> oracledb.AsyncConnection:
|
|
123
|
+
"""Returns the AsyncConnection object if connection is healthy"""
|
|
124
|
+
if not await async_is_connected():
|
|
125
|
+
raise DatabaseNotConnectedError()
|
|
126
|
+
global __async_conn__
|
|
127
|
+
key = (os.getpid(), get_ident())
|
|
128
|
+
return __async_conn__[key]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@contextlib.contextmanager
|
|
132
|
+
def cursor():
|
|
133
|
+
"""
|
|
134
|
+
Creates a context manager for database cursor
|
|
135
|
+
|
|
136
|
+
Typical usage:
|
|
137
|
+
|
|
138
|
+
with select_ai.cursor() as cr:
|
|
139
|
+
cr.execute(<QUERY>)
|
|
140
|
+
|
|
141
|
+
This ensures that the cursor is closed regardless
|
|
142
|
+
of whether an exception occurred
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
cr = get_connection().cursor()
|
|
146
|
+
try:
|
|
147
|
+
yield cr
|
|
148
|
+
finally:
|
|
149
|
+
cr.close()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@contextlib.asynccontextmanager
|
|
153
|
+
async def async_cursor():
|
|
154
|
+
"""
|
|
155
|
+
Creates an async context manager for database cursor
|
|
156
|
+
|
|
157
|
+
Typical usage:
|
|
158
|
+
|
|
159
|
+
async with select_ai.cursor() as cr:
|
|
160
|
+
await cr.execute(<QUERY>)
|
|
161
|
+
:return:
|
|
162
|
+
"""
|
|
163
|
+
conn = await async_get_connection()
|
|
164
|
+
cr = conn.cursor()
|
|
165
|
+
try:
|
|
166
|
+
yield cr
|
|
167
|
+
finally:
|
|
168
|
+
cr.close()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def disconnect():
|
|
172
|
+
try:
|
|
173
|
+
conn = get_connection()
|
|
174
|
+
except DatabaseNotConnectedError:
|
|
175
|
+
pass
|
|
176
|
+
else:
|
|
177
|
+
conn.close()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def async_disconnect():
|
|
181
|
+
try:
|
|
182
|
+
conn = await async_get_connection()
|
|
183
|
+
except DatabaseNotConnectedError:
|
|
184
|
+
pass
|
|
185
|
+
else:
|
|
186
|
+
await conn.close()
|
select_ai/errors.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SelectAIError(Exception):
|
|
10
|
+
"""Base class for any SelectAIErrors"""
|
|
11
|
+
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatabaseNotConnectedError(SelectAIError):
|
|
16
|
+
"""Raised when a database is not connected"""
|
|
17
|
+
|
|
18
|
+
def __str__(self):
|
|
19
|
+
return (
|
|
20
|
+
"Not connected to the Database. "
|
|
21
|
+
"Use select_ai.connect() or select_ai.async_connect() "
|
|
22
|
+
"to establish connection"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ConversationNotFoundError(SelectAIError):
|
|
27
|
+
"""Conversation not found in the database"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, conversation_id: str):
|
|
30
|
+
self.conversation_id = conversation_id
|
|
31
|
+
|
|
32
|
+
def __str__(self):
|
|
33
|
+
return f"Conversation with id {self.conversation_id} not found"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ProfileNotFoundError(SelectAIError):
|
|
37
|
+
"""Profile not found in the database"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, profile_name: str):
|
|
40
|
+
self.profile_name = profile_name
|
|
41
|
+
|
|
42
|
+
def __str__(self):
|
|
43
|
+
return f"Profile {self.profile_name} not found"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ProfileExistsError(SelectAIError):
|
|
47
|
+
"""Profile already exists in the database"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, profile_name: str):
|
|
50
|
+
self.profile_name = profile_name
|
|
51
|
+
|
|
52
|
+
def __str__(self):
|
|
53
|
+
return (
|
|
54
|
+
f"Profile {self.profile_name} already exists. "
|
|
55
|
+
f"Use either replace=True or merge=True"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class VectorIndexNotFoundError(SelectAIError):
|
|
60
|
+
"""VectorIndex not found in the database"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, index_name: str, profile_name: str = None):
|
|
63
|
+
self.index_name = index_name
|
|
64
|
+
self.profile_name = profile_name
|
|
65
|
+
|
|
66
|
+
def __str__(self):
|
|
67
|
+
if self.profile_name:
|
|
68
|
+
return (
|
|
69
|
+
f"VectorIndex {self.index_name} "
|
|
70
|
+
f"not found for profile {self.profile_name}"
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
return f"VectorIndex {self.index_name} not found"
|
select_ai/profile.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from dataclasses import replace as dataclass_replace
|
|
11
|
+
from typing import Iterator, Mapping, Optional, Union
|
|
12
|
+
|
|
13
|
+
import oracledb
|
|
14
|
+
import pandas
|
|
15
|
+
|
|
16
|
+
from select_ai import Conversation
|
|
17
|
+
from select_ai.action import Action
|
|
18
|
+
from select_ai.base_profile import BaseProfile, ProfileAttributes
|
|
19
|
+
from select_ai.db import cursor
|
|
20
|
+
from select_ai.errors import ProfileExistsError, ProfileNotFoundError
|
|
21
|
+
from select_ai.provider import Provider
|
|
22
|
+
from select_ai.sql import (
|
|
23
|
+
GET_USER_AI_PROFILE,
|
|
24
|
+
GET_USER_AI_PROFILE_ATTRIBUTES,
|
|
25
|
+
LIST_USER_AI_PROFILES,
|
|
26
|
+
)
|
|
27
|
+
from select_ai.synthetic_data import SyntheticDataAttributes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Profile(BaseProfile):
|
|
31
|
+
"""Profile class represents an AI Profile. It defines
|
|
32
|
+
attributes and methods to interact with the underlying
|
|
33
|
+
AI Provider. All methods in this class are synchronous
|
|
34
|
+
or blocking
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *args, **kwargs):
|
|
38
|
+
super().__init__(*args, **kwargs)
|
|
39
|
+
self._init_profile()
|
|
40
|
+
|
|
41
|
+
def _init_profile(self) -> None:
|
|
42
|
+
"""Initializes AI profile based on the passed attributes
|
|
43
|
+
|
|
44
|
+
:return: None
|
|
45
|
+
:raises: oracledb.DatabaseError
|
|
46
|
+
"""
|
|
47
|
+
if self.profile_name:
|
|
48
|
+
profile_exists = False
|
|
49
|
+
try:
|
|
50
|
+
saved_attributes = self._get_attributes(
|
|
51
|
+
profile_name=self.profile_name
|
|
52
|
+
)
|
|
53
|
+
profile_exists = True
|
|
54
|
+
if not self.replace and not self.merge:
|
|
55
|
+
if (
|
|
56
|
+
self.attributes is not None
|
|
57
|
+
or self.description is not None
|
|
58
|
+
):
|
|
59
|
+
if self.raise_error_if_exists:
|
|
60
|
+
raise ProfileExistsError(self.profile_name)
|
|
61
|
+
|
|
62
|
+
if self.description is None:
|
|
63
|
+
self.description = self._get_profile_description(
|
|
64
|
+
profile_name=self.profile_name
|
|
65
|
+
)
|
|
66
|
+
except ProfileNotFoundError:
|
|
67
|
+
if self.attributes is None and self.description is None:
|
|
68
|
+
raise
|
|
69
|
+
else:
|
|
70
|
+
if self.attributes is None:
|
|
71
|
+
self.attributes = saved_attributes
|
|
72
|
+
if self.merge:
|
|
73
|
+
self.replace = True
|
|
74
|
+
if self.attributes is not None:
|
|
75
|
+
self.attributes = dataclass_replace(
|
|
76
|
+
saved_attributes,
|
|
77
|
+
**self.attributes.dict(exclude_null=True),
|
|
78
|
+
)
|
|
79
|
+
if self.replace or not profile_exists:
|
|
80
|
+
self.create(replace=self.replace)
|
|
81
|
+
else: # profile name is None
|
|
82
|
+
if self.attributes is not None or self.description is not None:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Attribute 'profile_name' cannot be empty or None"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _get_profile_description(profile_name) -> Union[str, None]:
|
|
89
|
+
"""Get description of profile from USER_CLOUD_AI_PROFILES
|
|
90
|
+
|
|
91
|
+
:param str profile_name:
|
|
92
|
+
:return: Union[str, None] profile description
|
|
93
|
+
:raises: ProfileNotFoundError
|
|
94
|
+
"""
|
|
95
|
+
with cursor() as cr:
|
|
96
|
+
cr.execute(GET_USER_AI_PROFILE, profile_name=profile_name.upper())
|
|
97
|
+
profile = cr.fetchone()
|
|
98
|
+
if profile:
|
|
99
|
+
if profile[1] is not None:
|
|
100
|
+
return profile[1].read()
|
|
101
|
+
else:
|
|
102
|
+
return None
|
|
103
|
+
else:
|
|
104
|
+
raise ProfileNotFoundError(profile_name)
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _get_attributes(profile_name) -> ProfileAttributes:
|
|
108
|
+
"""Get AI profile attributes from the Database
|
|
109
|
+
|
|
110
|
+
:param str profile_name: Name of the profile
|
|
111
|
+
:return: select_ai.ProfileAttributes
|
|
112
|
+
:raises: ProfileNotFoundError
|
|
113
|
+
"""
|
|
114
|
+
with cursor() as cr:
|
|
115
|
+
cr.execute(
|
|
116
|
+
GET_USER_AI_PROFILE_ATTRIBUTES,
|
|
117
|
+
profile_name=profile_name.upper(),
|
|
118
|
+
)
|
|
119
|
+
attributes = cr.fetchall()
|
|
120
|
+
if attributes:
|
|
121
|
+
return ProfileAttributes.create(**dict(attributes))
|
|
122
|
+
else:
|
|
123
|
+
raise ProfileNotFoundError(profile_name=profile_name)
|
|
124
|
+
|
|
125
|
+
def get_attributes(self) -> ProfileAttributes:
|
|
126
|
+
"""Get AI profile attributes from the Database
|
|
127
|
+
|
|
128
|
+
:return: select_ai.ProfileAttributes
|
|
129
|
+
"""
|
|
130
|
+
return self._get_attributes(profile_name=self.profile_name)
|
|
131
|
+
|
|
132
|
+
def _set_attribute(
|
|
133
|
+
self,
|
|
134
|
+
attribute_name: str,
|
|
135
|
+
attribute_value: Union[bool, str, int, float],
|
|
136
|
+
):
|
|
137
|
+
parameters = {
|
|
138
|
+
"profile_name": self.profile_name,
|
|
139
|
+
"attribute_name": attribute_name,
|
|
140
|
+
"attribute_value": attribute_value,
|
|
141
|
+
}
|
|
142
|
+
with cursor() as cr:
|
|
143
|
+
cr.callproc(
|
|
144
|
+
"DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def set_attribute(
|
|
148
|
+
self,
|
|
149
|
+
attribute_name: str,
|
|
150
|
+
attribute_value: Union[bool, str, int, float, Provider],
|
|
151
|
+
):
|
|
152
|
+
"""Updates AI profile attribute on the Python object and also
|
|
153
|
+
saves it in the database
|
|
154
|
+
|
|
155
|
+
:param str attribute_name: Name of the AI profile attribute
|
|
156
|
+
:param Union[bool, str, int, float, Provider] attribute_value: Value of
|
|
157
|
+
the profile attribute
|
|
158
|
+
:return: None
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
self.attributes.set_attribute(attribute_name, attribute_value)
|
|
162
|
+
if isinstance(attribute_value, Provider):
|
|
163
|
+
for k, v in attribute_value.dict().items():
|
|
164
|
+
self._set_attribute(k, v)
|
|
165
|
+
else:
|
|
166
|
+
self._set_attribute(attribute_name, attribute_value)
|
|
167
|
+
|
|
168
|
+
def set_attributes(self, attributes: ProfileAttributes):
|
|
169
|
+
"""Updates AI profile attributes on the Python object and also
|
|
170
|
+
saves it in the database
|
|
171
|
+
|
|
172
|
+
:param ProviderAttributes attributes: Object specifying AI profile
|
|
173
|
+
attributes
|
|
174
|
+
:return: None
|
|
175
|
+
"""
|
|
176
|
+
if not isinstance(attributes, ProfileAttributes):
|
|
177
|
+
raise TypeError(
|
|
178
|
+
"'attributes' must be an object of type"
|
|
179
|
+
" select_ai.ProfileAttributes"
|
|
180
|
+
)
|
|
181
|
+
self.attributes = attributes
|
|
182
|
+
parameters = {
|
|
183
|
+
"profile_name": self.profile_name,
|
|
184
|
+
"attributes": self.attributes.json(),
|
|
185
|
+
}
|
|
186
|
+
with cursor() as cr:
|
|
187
|
+
cr.callproc(
|
|
188
|
+
"DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def create(self, replace: Optional[int] = False) -> None:
|
|
192
|
+
"""Create an AI Profile in the Database
|
|
193
|
+
|
|
194
|
+
:param bool replace: Set True to replace else False
|
|
195
|
+
:return: None
|
|
196
|
+
:raises: oracledb.DatabaseError
|
|
197
|
+
"""
|
|
198
|
+
if self.attributes is None:
|
|
199
|
+
raise AttributeError("Profile attributes cannot be None")
|
|
200
|
+
parameters = {
|
|
201
|
+
"profile_name": self.profile_name,
|
|
202
|
+
"attributes": self.attributes.json(),
|
|
203
|
+
}
|
|
204
|
+
if self.description:
|
|
205
|
+
parameters["description"] = self.description
|
|
206
|
+
|
|
207
|
+
with cursor() as cr:
|
|
208
|
+
try:
|
|
209
|
+
cr.callproc(
|
|
210
|
+
"DBMS_CLOUD_AI.CREATE_PROFILE",
|
|
211
|
+
keyword_parameters=parameters,
|
|
212
|
+
)
|
|
213
|
+
except oracledb.DatabaseError as e:
|
|
214
|
+
(error,) = e.args
|
|
215
|
+
# If already exists and replace is True then drop and recreate
|
|
216
|
+
if error.code == 20046 and replace:
|
|
217
|
+
self.delete(force=True)
|
|
218
|
+
cr.callproc(
|
|
219
|
+
"DBMS_CLOUD_AI.CREATE_PROFILE",
|
|
220
|
+
keyword_parameters=parameters,
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
raise
|
|
224
|
+
|
|
225
|
+
def delete(self, force=False) -> None:
|
|
226
|
+
"""Deletes an AI profile from the database
|
|
227
|
+
|
|
228
|
+
:param bool force: Ignores errors if AI profile does not exist.
|
|
229
|
+
:return: None
|
|
230
|
+
:raises: oracledb.DatabaseError
|
|
231
|
+
"""
|
|
232
|
+
with cursor() as cr:
|
|
233
|
+
cr.callproc(
|
|
234
|
+
"DBMS_CLOUD_AI.DROP_PROFILE",
|
|
235
|
+
keyword_parameters={
|
|
236
|
+
"profile_name": self.profile_name,
|
|
237
|
+
"force": force,
|
|
238
|
+
},
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
@classmethod
|
|
242
|
+
def _from_db(cls, profile_name: str) -> "Profile":
|
|
243
|
+
"""Create a Profile object from attributes saved in the database
|
|
244
|
+
|
|
245
|
+
:param str profile_name:
|
|
246
|
+
:return: select_ai.Profile
|
|
247
|
+
:raises: ProfileNotFoundError
|
|
248
|
+
"""
|
|
249
|
+
with cursor() as cr:
|
|
250
|
+
cr.execute(
|
|
251
|
+
GET_USER_AI_PROFILE_ATTRIBUTES, profile_name=profile_name
|
|
252
|
+
)
|
|
253
|
+
attributes = cr.fetchall()
|
|
254
|
+
if attributes:
|
|
255
|
+
attributes = ProfileAttributes.create(**dict(attributes))
|
|
256
|
+
return cls(profile_name=profile_name, attributes=attributes)
|
|
257
|
+
else:
|
|
258
|
+
raise ProfileNotFoundError(profile_name=profile_name)
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def list(cls, profile_name_pattern: str = ".*") -> Iterator["Profile"]:
|
|
262
|
+
"""List AI Profiles saved in the database.
|
|
263
|
+
|
|
264
|
+
:param str profile_name_pattern: Regular expressions can be used
|
|
265
|
+
to specify a pattern. Function REGEXP_LIKE is used to perform the
|
|
266
|
+
match. Default value is ".*" i.e. match all AI profiles.
|
|
267
|
+
|
|
268
|
+
:return: Iterator[Profile]
|
|
269
|
+
"""
|
|
270
|
+
with cursor() as cr:
|
|
271
|
+
cr.execute(
|
|
272
|
+
LIST_USER_AI_PROFILES,
|
|
273
|
+
profile_name_pattern=profile_name_pattern,
|
|
274
|
+
)
|
|
275
|
+
for row in cr.fetchall():
|
|
276
|
+
profile_name = row[0]
|
|
277
|
+
description = row[1]
|
|
278
|
+
attributes = cls._get_attributes(profile_name=profile_name)
|
|
279
|
+
yield cls(
|
|
280
|
+
profile_name=profile_name,
|
|
281
|
+
description=description,
|
|
282
|
+
attributes=attributes,
|
|
283
|
+
raise_error_if_exists=False,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def generate(
|
|
287
|
+
self,
|
|
288
|
+
prompt: str,
|
|
289
|
+
action: Optional[Action] = Action.RUNSQL,
|
|
290
|
+
params: Mapping = None,
|
|
291
|
+
) -> Union[pandas.DataFrame, str, None]:
|
|
292
|
+
"""Perform AI translation using this profile
|
|
293
|
+
|
|
294
|
+
:param str prompt: Natural language prompt to translate
|
|
295
|
+
:param select_ai.profile.Action action:
|
|
296
|
+
:param params: Parameters to include in the LLM request. For e.g.
|
|
297
|
+
conversation_id for context-aware chats
|
|
298
|
+
:return: Union[pandas.DataFrame, str]
|
|
299
|
+
"""
|
|
300
|
+
if not prompt:
|
|
301
|
+
raise ValueError("prompt cannot be empty or None")
|
|
302
|
+
parameters = {
|
|
303
|
+
"prompt": prompt,
|
|
304
|
+
"action": action,
|
|
305
|
+
"profile_name": self.profile_name,
|
|
306
|
+
# "attributes": self.attributes.json(),
|
|
307
|
+
}
|
|
308
|
+
if params:
|
|
309
|
+
parameters["params"] = json.dumps(params)
|
|
310
|
+
with cursor() as cr:
|
|
311
|
+
data = cr.callfunc(
|
|
312
|
+
"DBMS_CLOUD_AI.GENERATE",
|
|
313
|
+
oracledb.DB_TYPE_CLOB,
|
|
314
|
+
keyword_parameters=parameters,
|
|
315
|
+
)
|
|
316
|
+
if data is not None:
|
|
317
|
+
return data.read()
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
def chat(self, prompt: str, params: Mapping = None) -> str:
|
|
321
|
+
"""Chat with the LLM
|
|
322
|
+
|
|
323
|
+
:param str prompt: Natural language prompt
|
|
324
|
+
:param params: Parameters to include in the LLM request
|
|
325
|
+
:return: str
|
|
326
|
+
"""
|
|
327
|
+
return self.generate(prompt, action=Action.CHAT, params=params)
|
|
328
|
+
|
|
329
|
+
@contextmanager
|
|
330
|
+
def chat_session(self, conversation: Conversation, delete: bool = False):
|
|
331
|
+
"""Starts a new chat session for context-aware conversations
|
|
332
|
+
|
|
333
|
+
:param Conversation conversation: Conversation object to use for this
|
|
334
|
+
chat session
|
|
335
|
+
:param bool delete: Delete conversation after session ends
|
|
336
|
+
|
|
337
|
+
:return:
|
|
338
|
+
"""
|
|
339
|
+
try:
|
|
340
|
+
if (
|
|
341
|
+
conversation.conversation_id is None
|
|
342
|
+
and conversation.attributes is not None
|
|
343
|
+
):
|
|
344
|
+
conversation.create()
|
|
345
|
+
params = {"conversation_id": conversation.conversation_id}
|
|
346
|
+
session = Session(profile=self, params=params)
|
|
347
|
+
yield session
|
|
348
|
+
finally:
|
|
349
|
+
if delete:
|
|
350
|
+
conversation.delete()
|
|
351
|
+
|
|
352
|
+
def narrate(self, prompt: str, params: Mapping = None) -> str:
|
|
353
|
+
"""Narrate the result of the SQL
|
|
354
|
+
|
|
355
|
+
:param str prompt: Natural language prompt
|
|
356
|
+
:param params: Parameters to include in the LLM request
|
|
357
|
+
:return: str
|
|
358
|
+
"""
|
|
359
|
+
return self.generate(prompt, action=Action.NARRATE, params=params)
|
|
360
|
+
|
|
361
|
+
def explain_sql(self, prompt: str, params: Mapping = None) -> str:
|
|
362
|
+
"""Explain the generated SQL
|
|
363
|
+
|
|
364
|
+
:param str prompt: Natural language prompt
|
|
365
|
+
:param params: Parameters to include in the LLM request
|
|
366
|
+
:return: str
|
|
367
|
+
"""
|
|
368
|
+
return self.generate(prompt, action=Action.EXPLAINSQL, params=params)
|
|
369
|
+
|
|
370
|
+
def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame:
|
|
371
|
+
"""Run the generate SQL statement and return a pandas Dataframe built
|
|
372
|
+
using the result set
|
|
373
|
+
|
|
374
|
+
:param str prompt: Natural language prompt
|
|
375
|
+
:param params: Parameters to include in the LLM request
|
|
376
|
+
:return: pandas.DataFrame
|
|
377
|
+
"""
|
|
378
|
+
data = json.loads(
|
|
379
|
+
self.generate(prompt, action=Action.RUNSQL, params=params)
|
|
380
|
+
)
|
|
381
|
+
return pandas.DataFrame(data)
|
|
382
|
+
|
|
383
|
+
def show_sql(self, prompt: str, params: Mapping = None) -> str:
|
|
384
|
+
"""Show the generated SQL
|
|
385
|
+
|
|
386
|
+
:param str prompt: Natural language prompt
|
|
387
|
+
:param params: Parameters to include in the LLM request
|
|
388
|
+
:return: str
|
|
389
|
+
"""
|
|
390
|
+
return self.generate(prompt, action=Action.SHOWSQL, params=params)
|
|
391
|
+
|
|
392
|
+
def show_prompt(self, prompt: str, params: Mapping = None) -> str:
|
|
393
|
+
"""Show the prompt sent to LLM
|
|
394
|
+
|
|
395
|
+
:param str prompt: Natural language prompt
|
|
396
|
+
:param params: Parameters to include in the LLM request
|
|
397
|
+
:return: str
|
|
398
|
+
"""
|
|
399
|
+
return self.generate(prompt, action=Action.SHOWPROMPT, params=params)
|
|
400
|
+
|
|
401
|
+
def generate_synthetic_data(
|
|
402
|
+
self, synthetic_data_attributes: SyntheticDataAttributes
|
|
403
|
+
) -> None:
|
|
404
|
+
"""Generate synthetic data for a single table, multiple tables or a
|
|
405
|
+
full schema.
|
|
406
|
+
|
|
407
|
+
:param select_ai.SyntheticDataAttributes synthetic_data_attributes:
|
|
408
|
+
:return: None
|
|
409
|
+
:raises: oracledb.DatabaseError
|
|
410
|
+
|
|
411
|
+
"""
|
|
412
|
+
if synthetic_data_attributes is None:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
"Param 'synthetic_data_attributes' cannot be None"
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
if not isinstance(synthetic_data_attributes, SyntheticDataAttributes):
|
|
418
|
+
raise TypeError(
|
|
419
|
+
"'synthetic_data_attributes' must be an object "
|
|
420
|
+
"of type select_ai.SyntheticDataAttributes"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
keyword_parameters = synthetic_data_attributes.prepare()
|
|
424
|
+
keyword_parameters["profile_name"] = self.profile_name
|
|
425
|
+
with cursor() as cr:
|
|
426
|
+
cr.callproc(
|
|
427
|
+
"DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
|
|
428
|
+
keyword_parameters=keyword_parameters,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class Session:
|
|
433
|
+
"""Session lets you persist request parameters across DBMS_CLOUD_AI
|
|
434
|
+
requests. This is useful in context-aware conversations
|
|
435
|
+
"""
|
|
436
|
+
|
|
437
|
+
def __init__(self, profile: Profile, params: Mapping):
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
:param profile: An AI Profile to use in this session
|
|
441
|
+
:param params: Parameters to be persisted across requests
|
|
442
|
+
"""
|
|
443
|
+
self.params = params
|
|
444
|
+
self.profile = profile
|
|
445
|
+
|
|
446
|
+
def chat(self, prompt: str):
|
|
447
|
+
# params = {"conversation_id": self.conversation_id}
|
|
448
|
+
return self.profile.chat(prompt=prompt, params=self.params)
|
|
449
|
+
|
|
450
|
+
def __enter__(self):
|
|
451
|
+
return self
|
|
452
|
+
|
|
453
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
454
|
+
pass
|