select-ai 1.2.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- select_ai/__init__.py +66 -0
- select_ai/_abc.py +82 -0
- select_ai/_enums.py +14 -0
- select_ai/_validations.py +123 -0
- select_ai/action.py +23 -0
- select_ai/agent/__init__.py +25 -0
- select_ai/agent/core.py +511 -0
- select_ai/agent/sql.py +82 -0
- select_ai/agent/task.py +521 -0
- select_ai/agent/team.py +590 -0
- select_ai/agent/tool.py +1129 -0
- select_ai/async_profile.py +648 -0
- select_ai/base_profile.py +265 -0
- select_ai/conversation.py +295 -0
- select_ai/credential.py +135 -0
- select_ai/db.py +191 -0
- select_ai/errors.py +113 -0
- select_ai/feedback.py +19 -0
- select_ai/privilege.py +135 -0
- select_ai/profile.py +579 -0
- select_ai/provider.py +195 -0
- select_ai/sql.py +111 -0
- select_ai/summary.py +61 -0
- select_ai/synthetic_data.py +90 -0
- select_ai/vector_index.py +642 -0
- select_ai/version.py +8 -0
- select_ai-1.2.0rc3.dist-info/METADATA +129 -0
- select_ai-1.2.0rc3.dist-info/RECORD +31 -0
- select_ai-1.2.0rc3.dist-info/WHEEL +5 -0
- select_ai-1.2.0rc3.dist-info/licenses/LICENSE.txt +35 -0
- select_ai-1.2.0rc3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from abc import ABC
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import List, Mapping, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import oracledb
|
|
14
|
+
|
|
15
|
+
from select_ai._abc import SelectAIDataClass
|
|
16
|
+
from select_ai.action import Action
|
|
17
|
+
from select_ai.feedback import (
|
|
18
|
+
FeedbackOperation,
|
|
19
|
+
FeedbackType,
|
|
20
|
+
)
|
|
21
|
+
from select_ai.provider import Provider
|
|
22
|
+
from select_ai.summary import SummaryParams
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ProfileAttributes(SelectAIDataClass):
|
|
27
|
+
"""
|
|
28
|
+
Use this class to define attributes to manage and configure the behavior of
|
|
29
|
+
an AI profile
|
|
30
|
+
|
|
31
|
+
:param bool comments: True to include column comments in the metadata used
|
|
32
|
+
for generating SQL queries from natural language prompts.
|
|
33
|
+
:param bool constraints: True to include referential integrity constraints
|
|
34
|
+
such as primary and foreign keys in the metadata sent to the LLM.
|
|
35
|
+
:param bool conversation: Indicates if conversation history is enabled for
|
|
36
|
+
a profile.
|
|
37
|
+
:param str credential_name: The name of the credential to access the AI
|
|
38
|
+
provider APIs.
|
|
39
|
+
:param bool enforce_object_list: Specifies whether to restrict the LLM
|
|
40
|
+
to generate SQL that uses only tables covered by the object list.
|
|
41
|
+
:param int max_tokens: Denotes the number of tokens to return per
|
|
42
|
+
generation. Default is 1024.
|
|
43
|
+
:param List[Mapping] object_list: Array of JSON objects specifying
|
|
44
|
+
the owner and object names that are eligible for natural language
|
|
45
|
+
translation to SQL.
|
|
46
|
+
:param str object_list_mode: Specifies whether to send metadata for the
|
|
47
|
+
most relevant tables or all tables to the LLM. Supported values are -
|
|
48
|
+
'automated' and 'all'
|
|
49
|
+
:param select_ai.Provider provider: AI Provider
|
|
50
|
+
:param str stop_tokens: The generated text will be terminated at the
|
|
51
|
+
beginning of the earliest stop sequence. Sequence will be incorporated
|
|
52
|
+
into the text. The attribute value must be a valid array of string values
|
|
53
|
+
in JSON format
|
|
54
|
+
:param float temperature: Temperature is a non-negative float number used
|
|
55
|
+
to tune the degree of randomness. Lower temperatures mean less random
|
|
56
|
+
generations.
|
|
57
|
+
:param str vector_index_name: Name of the vector index
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
annotations: Optional[str] = None
|
|
62
|
+
case_sensitive_values: Optional[bool] = None
|
|
63
|
+
comments: Optional[bool] = None
|
|
64
|
+
constraints: Optional[str] = None
|
|
65
|
+
conversation: Optional[bool] = None
|
|
66
|
+
credential_name: Optional[str] = None
|
|
67
|
+
enable_custom_source_uri: Optional[bool] = None
|
|
68
|
+
enable_sources: Optional[bool] = None
|
|
69
|
+
enable_source_offsets: Optional[bool] = None
|
|
70
|
+
enforce_object_list: Optional[bool] = None
|
|
71
|
+
max_tokens: Optional[int] = 1024
|
|
72
|
+
object_list: Optional[List[Mapping]] = None
|
|
73
|
+
object_list_mode: Optional[str] = None
|
|
74
|
+
provider: Optional[Provider] = None
|
|
75
|
+
seed: Optional[str] = None
|
|
76
|
+
stop_tokens: Optional[str] = None
|
|
77
|
+
streaming: Optional[str] = None
|
|
78
|
+
temperature: Optional[float] = None
|
|
79
|
+
vector_index_name: Optional[str] = None
|
|
80
|
+
|
|
81
|
+
def __post_init__(self):
|
|
82
|
+
super().__post_init__()
|
|
83
|
+
if self.provider and not isinstance(self.provider, Provider):
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"'provider' must be an object of " f"type select_ai.Provider"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def json(self, exclude_null=True):
|
|
89
|
+
attributes = {}
|
|
90
|
+
for k, v in self.dict(exclude_null=exclude_null).items():
|
|
91
|
+
if isinstance(v, Provider):
|
|
92
|
+
for provider_k, provider_v in v.dict(
|
|
93
|
+
exclude_null=exclude_null
|
|
94
|
+
).items():
|
|
95
|
+
attributes[Provider.key_alias(provider_k)] = provider_v
|
|
96
|
+
else:
|
|
97
|
+
attributes[k] = v
|
|
98
|
+
return json.dumps(attributes)
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def create(cls, **kwargs):
|
|
102
|
+
provider_attributes = {}
|
|
103
|
+
profile_attributes = {}
|
|
104
|
+
for k, v in kwargs.items():
|
|
105
|
+
if isinstance(v, oracledb.LOB):
|
|
106
|
+
v = v.read()
|
|
107
|
+
if k in Provider.keys():
|
|
108
|
+
provider_attributes[Provider.key_alias(k)] = v
|
|
109
|
+
else:
|
|
110
|
+
profile_attributes[k] = v
|
|
111
|
+
provider = Provider.create(**provider_attributes)
|
|
112
|
+
profile_attributes["provider"] = provider
|
|
113
|
+
return ProfileAttributes(**profile_attributes)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
async def async_create(cls, **kwargs):
|
|
117
|
+
provider_attributes = {}
|
|
118
|
+
profile_attributes = {}
|
|
119
|
+
for k, v in kwargs.items():
|
|
120
|
+
if isinstance(v, oracledb.AsyncLOB):
|
|
121
|
+
v = await v.read()
|
|
122
|
+
if k in Provider.keys():
|
|
123
|
+
provider_attributes[Provider.key_alias(k)] = v
|
|
124
|
+
else:
|
|
125
|
+
profile_attributes[k] = v
|
|
126
|
+
provider = Provider.create(**provider_attributes)
|
|
127
|
+
profile_attributes["provider"] = provider
|
|
128
|
+
return ProfileAttributes(**profile_attributes)
|
|
129
|
+
|
|
130
|
+
def set_attribute(self, key, value):
|
|
131
|
+
if key in Provider.keys() and not isinstance(value, Provider):
|
|
132
|
+
setattr(self.provider, key, value)
|
|
133
|
+
else:
|
|
134
|
+
setattr(self, key, value)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class BaseProfile(ABC):
|
|
138
|
+
"""
|
|
139
|
+
BaseProfile is an abstract base class representing a Profile
|
|
140
|
+
for Select AI's interactions with AI service providers (LLMs).
|
|
141
|
+
Use either select_ai.Profile or select_ai.AsyncProfile to
|
|
142
|
+
instantiate an AI profile object.
|
|
143
|
+
|
|
144
|
+
:param str profile_name : Name of the profile
|
|
145
|
+
|
|
146
|
+
:param select_ai.ProfileAttributes attributes:
|
|
147
|
+
Object specifying AI profile attributes
|
|
148
|
+
|
|
149
|
+
:param str description: Description of the profile
|
|
150
|
+
|
|
151
|
+
:param bool merge: Fetches the profile
|
|
152
|
+
from database, merges the non-null attributes and saves it back
|
|
153
|
+
in the database. Default value is False
|
|
154
|
+
|
|
155
|
+
:param bool replace: Replaces the profile and attributes
|
|
156
|
+
in the database. Default value is False
|
|
157
|
+
|
|
158
|
+
:param bool raise_error_if_exists: Raise ProfileExistsError
|
|
159
|
+
if profile exists in the database and replace = False and
|
|
160
|
+
merge = False. Default value is True
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
profile_name: Optional[str] = None,
|
|
167
|
+
attributes: Optional[ProfileAttributes] = None,
|
|
168
|
+
description: Optional[str] = None,
|
|
169
|
+
merge: Optional[bool] = False,
|
|
170
|
+
replace: Optional[bool] = False,
|
|
171
|
+
raise_error_if_exists: Optional[bool] = True,
|
|
172
|
+
):
|
|
173
|
+
"""Initialize a base profile"""
|
|
174
|
+
self.profile_name = profile_name
|
|
175
|
+
if attributes and not isinstance(attributes, ProfileAttributes):
|
|
176
|
+
raise TypeError(
|
|
177
|
+
"'attributes' must be an object of type "
|
|
178
|
+
"select_ai.ProfileAttributes"
|
|
179
|
+
)
|
|
180
|
+
self.attributes = attributes
|
|
181
|
+
self.description = description
|
|
182
|
+
self.merge = merge
|
|
183
|
+
self.replace = replace
|
|
184
|
+
self.raise_error_if_exists = raise_error_if_exists
|
|
185
|
+
|
|
186
|
+
def __repr__(self):
|
|
187
|
+
return (
|
|
188
|
+
f"{self.__class__.__name__}(profile_name={self.profile_name}, "
|
|
189
|
+
f"attributes={self.attributes}, description={self.description})"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def no_data_for_prompt(result) -> bool:
|
|
194
|
+
if result is None:
|
|
195
|
+
return True
|
|
196
|
+
if result == "No data found for the prompt.":
|
|
197
|
+
return True
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def validate_params_for_feedback(
|
|
202
|
+
feedback_type: FeedbackType,
|
|
203
|
+
feedback_content: str,
|
|
204
|
+
prompt_spec: Tuple[str, Action] = None,
|
|
205
|
+
sql_id: Optional[str] = None,
|
|
206
|
+
response: Optional[str] = None,
|
|
207
|
+
operation: Optional[FeedbackOperation] = FeedbackOperation.ADD,
|
|
208
|
+
):
|
|
209
|
+
if sql_id and prompt_spec:
|
|
210
|
+
raise AttributeError("Either sql_id or prompt_spec must be specified")
|
|
211
|
+
if not sql_id and not prompt_spec:
|
|
212
|
+
raise AttributeError("Either sql_id or prompt_spec must be specified")
|
|
213
|
+
parameters = {
|
|
214
|
+
"feedback_type": feedback_type.value,
|
|
215
|
+
"feedback_content": feedback_content,
|
|
216
|
+
"operation": operation.value,
|
|
217
|
+
}
|
|
218
|
+
if prompt_spec:
|
|
219
|
+
prompt, action = prompt_spec
|
|
220
|
+
if action not in (Action.RUNSQL, Action.SHOWSQL, Action.EXPLAINSQL):
|
|
221
|
+
raise AttributeError(
|
|
222
|
+
"'action' must be one of 'RUNSQL', 'SHOWSQL' or 'EXPLAINSQL'"
|
|
223
|
+
)
|
|
224
|
+
if (
|
|
225
|
+
operation == FeedbackOperation.ADD
|
|
226
|
+
and feedback_type == FeedbackType.NEGATIVE
|
|
227
|
+
and response is None
|
|
228
|
+
):
|
|
229
|
+
raise AttributeError(
|
|
230
|
+
"'response' must be specified if feedback_type is NEGATIVE"
|
|
231
|
+
)
|
|
232
|
+
sql_text = "select ai {} {}".format(action, prompt)
|
|
233
|
+
parameters["sql_text"] = sql_text
|
|
234
|
+
elif sql_id:
|
|
235
|
+
parameters["sql_id"] = sql_id
|
|
236
|
+
return parameters
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def validate_params_for_summary(
|
|
240
|
+
prompt: str = None,
|
|
241
|
+
content: str = None,
|
|
242
|
+
location_uri: str = None,
|
|
243
|
+
credential_name: str = None,
|
|
244
|
+
params: SummaryParams = None,
|
|
245
|
+
):
|
|
246
|
+
if content and location_uri:
|
|
247
|
+
raise AttributeError(
|
|
248
|
+
"Either content or location_uri must be specified"
|
|
249
|
+
)
|
|
250
|
+
if not content and not location_uri:
|
|
251
|
+
raise AttributeError(
|
|
252
|
+
"Either content or location_uri must be specified"
|
|
253
|
+
)
|
|
254
|
+
parameters = {}
|
|
255
|
+
if content:
|
|
256
|
+
parameters["content"] = content
|
|
257
|
+
if location_uri:
|
|
258
|
+
parameters["location_uri"] = location_uri
|
|
259
|
+
if credential_name:
|
|
260
|
+
parameters["credential_name"] = credential_name
|
|
261
|
+
if prompt:
|
|
262
|
+
parameters["prompt"] = prompt
|
|
263
|
+
if params:
|
|
264
|
+
parameters["parameters"] = params.json()
|
|
265
|
+
return parameters
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import datetime
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import AsyncGenerator, Iterator, Optional
|
|
12
|
+
|
|
13
|
+
import oracledb
|
|
14
|
+
|
|
15
|
+
from select_ai._abc import SelectAIDataClass
|
|
16
|
+
from select_ai.db import async_cursor, cursor
|
|
17
|
+
from select_ai.errors import ConversationNotFoundError
|
|
18
|
+
from select_ai.sql import (
|
|
19
|
+
GET_USER_CONVERSATION_ATTRIBUTES,
|
|
20
|
+
LIST_USER_CONVERSATIONS,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = ["AsyncConversation", "Conversation", "ConversationAttributes"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ConversationAttributes(SelectAIDataClass):
|
|
28
|
+
"""Conversation Attributes
|
|
29
|
+
|
|
30
|
+
:param str title: Conversation Title
|
|
31
|
+
:param str description: Description of the conversation topic
|
|
32
|
+
:param datetime.timedelta retention_days: The number of days the conversation
|
|
33
|
+
will be stored in the database from its creation date. If value is 0, the
|
|
34
|
+
conversation will not be removed unless it is manually deleted by
|
|
35
|
+
delete
|
|
36
|
+
:param int conversation_length: Number of prompts to store for this
|
|
37
|
+
conversation
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
title: Optional[str] = "New Conversation"
|
|
42
|
+
description: Optional[str] = None
|
|
43
|
+
retention_days: Optional[datetime.timedelta] = datetime.timedelta(days=7)
|
|
44
|
+
conversation_length: Optional[int] = 10
|
|
45
|
+
|
|
46
|
+
def json(self, exclude_null=True):
|
|
47
|
+
attributes = {}
|
|
48
|
+
for k, v in self.dict(exclude_null=exclude_null).items():
|
|
49
|
+
if isinstance(v, datetime.timedelta):
|
|
50
|
+
attributes[k] = v.days
|
|
51
|
+
else:
|
|
52
|
+
attributes[k] = v
|
|
53
|
+
return json.dumps(attributes)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _BaseConversation:
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
conversation_id: Optional[str] = None,
|
|
61
|
+
attributes: Optional[ConversationAttributes] = None,
|
|
62
|
+
):
|
|
63
|
+
self.conversation_id = conversation_id
|
|
64
|
+
self.attributes = attributes
|
|
65
|
+
|
|
66
|
+
def __repr__(self):
|
|
67
|
+
return (
|
|
68
|
+
f"{self.__class__.__name__}(conversation_id={self.conversation_id}, "
|
|
69
|
+
f"attributes={self.attributes})"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Conversation(_BaseConversation):
|
|
74
|
+
"""Conversation class can be used to create, update and delete
|
|
75
|
+
conversations in the database
|
|
76
|
+
|
|
77
|
+
Typical usage is to combine this conversation object with an AI
|
|
78
|
+
Profile.chat_session() to have context-aware conversations with
|
|
79
|
+
the LLM provider
|
|
80
|
+
|
|
81
|
+
:param str conversation_id: Conversation ID
|
|
82
|
+
:param ConversationAttributes attributes: Conversation attributes
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def create(self) -> str:
|
|
87
|
+
"""Creates a new conversation and returns the conversation_id
|
|
88
|
+
to be used in context-aware conversations with LLMs
|
|
89
|
+
|
|
90
|
+
:return: conversation_id
|
|
91
|
+
"""
|
|
92
|
+
with cursor() as cr:
|
|
93
|
+
self.conversation_id = cr.callfunc(
|
|
94
|
+
"DBMS_CLOUD_AI.CREATE_CONVERSATION",
|
|
95
|
+
oracledb.DB_TYPE_VARCHAR,
|
|
96
|
+
keyword_parameters={"attributes": self.attributes.json()},
|
|
97
|
+
)
|
|
98
|
+
return self.conversation_id
|
|
99
|
+
|
|
100
|
+
def delete(self, force: bool = False):
|
|
101
|
+
"""Drops the conversation"""
|
|
102
|
+
with cursor() as cr:
|
|
103
|
+
cr.callproc(
|
|
104
|
+
"DBMS_CLOUD_AI.DROP_CONVERSATION",
|
|
105
|
+
keyword_parameters={
|
|
106
|
+
"conversation_id": self.conversation_id,
|
|
107
|
+
"force": force,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def fetch(cls, conversation_id: str) -> "Conversation":
|
|
113
|
+
"""Fetch conversation attributes from the database
|
|
114
|
+
and build a proxy object
|
|
115
|
+
|
|
116
|
+
:param str conversation_id: Conversation ID
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
conversation = cls(conversation_id=conversation_id)
|
|
120
|
+
conversation.attributes = conversation.get_attributes()
|
|
121
|
+
return conversation
|
|
122
|
+
|
|
123
|
+
def set_attributes(self, attributes: ConversationAttributes):
|
|
124
|
+
"""Updates the attributes of the conversation in the database"""
|
|
125
|
+
with cursor() as cr:
|
|
126
|
+
cr.callproc(
|
|
127
|
+
"DBMS_CLOUD_AI.UPDATE_CONVERSATION",
|
|
128
|
+
keyword_parameters={
|
|
129
|
+
"conversation_id": self.conversation_id,
|
|
130
|
+
"attributes": attributes.json(),
|
|
131
|
+
},
|
|
132
|
+
)
|
|
133
|
+
self.attributes = self.get_attributes()
|
|
134
|
+
|
|
135
|
+
def get_attributes(self) -> ConversationAttributes:
|
|
136
|
+
"""Get attributes of the conversation from the database"""
|
|
137
|
+
with cursor() as cr:
|
|
138
|
+
cr.execute(
|
|
139
|
+
GET_USER_CONVERSATION_ATTRIBUTES,
|
|
140
|
+
conversation_id=self.conversation_id,
|
|
141
|
+
)
|
|
142
|
+
attributes = cr.fetchone()
|
|
143
|
+
if attributes:
|
|
144
|
+
conversation_title = attributes[0]
|
|
145
|
+
if attributes[1]:
|
|
146
|
+
description = attributes[1].read() # Oracle.LOB
|
|
147
|
+
else:
|
|
148
|
+
description = None
|
|
149
|
+
retention_days = attributes[2]
|
|
150
|
+
return ConversationAttributes(
|
|
151
|
+
title=conversation_title,
|
|
152
|
+
description=description,
|
|
153
|
+
retention_days=retention_days,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
raise ConversationNotFoundError(
|
|
157
|
+
conversation_id=self.conversation_id
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def list(cls) -> Iterator["Conversation"]:
|
|
162
|
+
"""List all conversations
|
|
163
|
+
|
|
164
|
+
:return: Iterator[VectorIndex]
|
|
165
|
+
"""
|
|
166
|
+
with cursor() as cr:
|
|
167
|
+
cr.execute(
|
|
168
|
+
LIST_USER_CONVERSATIONS,
|
|
169
|
+
)
|
|
170
|
+
for row in cr.fetchall():
|
|
171
|
+
conversation_id = row[0]
|
|
172
|
+
conversation_title = row[1]
|
|
173
|
+
if row[2]:
|
|
174
|
+
description = row[2].read() # Oracle.LOB
|
|
175
|
+
else:
|
|
176
|
+
description = None
|
|
177
|
+
retention_days = row[3]
|
|
178
|
+
attributes = ConversationAttributes(
|
|
179
|
+
title=conversation_title,
|
|
180
|
+
description=description,
|
|
181
|
+
retention_days=retention_days,
|
|
182
|
+
)
|
|
183
|
+
yield cls(
|
|
184
|
+
attributes=attributes, conversation_id=conversation_id
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class AsyncConversation(_BaseConversation):
|
|
189
|
+
"""AsyncConversation class can be used to create, update and delete
|
|
190
|
+
conversations in the database in an async manner
|
|
191
|
+
|
|
192
|
+
Typical usage is to combine this conversation object with an
|
|
193
|
+
AsyncProfile.chat_session() to have context-aware conversations
|
|
194
|
+
|
|
195
|
+
:param str conversation_id: Conversation ID
|
|
196
|
+
:param ConversationAttributes attributes: Conversation attributes
|
|
197
|
+
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
async def create(self) -> str:
|
|
201
|
+
"""Creates a new conversation and returns the conversation_id
|
|
202
|
+
to be used in context-aware conversations with LLMs
|
|
203
|
+
|
|
204
|
+
:return: conversation_id
|
|
205
|
+
"""
|
|
206
|
+
async with async_cursor() as cr:
|
|
207
|
+
self.conversation_id = await cr.callfunc(
|
|
208
|
+
"DBMS_CLOUD_AI.CREATE_CONVERSATION",
|
|
209
|
+
oracledb.DB_TYPE_VARCHAR,
|
|
210
|
+
keyword_parameters={"attributes": self.attributes.json()},
|
|
211
|
+
)
|
|
212
|
+
return self.conversation_id
|
|
213
|
+
|
|
214
|
+
async def delete(self, force: bool = False):
|
|
215
|
+
"""Delete the conversation"""
|
|
216
|
+
async with async_cursor() as cr:
|
|
217
|
+
await cr.callproc(
|
|
218
|
+
"DBMS_CLOUD_AI.DROP_CONVERSATION",
|
|
219
|
+
keyword_parameters={
|
|
220
|
+
"conversation_id": self.conversation_id,
|
|
221
|
+
"force": force,
|
|
222
|
+
},
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
async def fetch(cls, conversation_id: str) -> "AsyncConversation":
|
|
227
|
+
"""Fetch conversation attributes from the database"""
|
|
228
|
+
conversation = cls(conversation_id=conversation_id)
|
|
229
|
+
conversation.attributes = await conversation.get_attributes()
|
|
230
|
+
return conversation
|
|
231
|
+
|
|
232
|
+
async def set_attributes(self, attributes: ConversationAttributes):
|
|
233
|
+
"""Updates the attributes of the conversation"""
|
|
234
|
+
with cursor() as cr:
|
|
235
|
+
cr.callproc(
|
|
236
|
+
"DBMS_CLOUD_AI.UPDATE_CONVERSATION",
|
|
237
|
+
keyword_parameters={
|
|
238
|
+
"conversation_id": self.conversation_id,
|
|
239
|
+
"attributes": attributes.json(),
|
|
240
|
+
},
|
|
241
|
+
)
|
|
242
|
+
self.attributes = await self.get_attributes()
|
|
243
|
+
|
|
244
|
+
async def get_attributes(self) -> ConversationAttributes:
|
|
245
|
+
"""Get attributes of the conversation from the database"""
|
|
246
|
+
async with async_cursor() as cr:
|
|
247
|
+
await cr.execute(
|
|
248
|
+
GET_USER_CONVERSATION_ATTRIBUTES,
|
|
249
|
+
conversation_id=self.conversation_id,
|
|
250
|
+
)
|
|
251
|
+
attributes = await cr.fetchone()
|
|
252
|
+
if attributes:
|
|
253
|
+
conversation_title = attributes[0]
|
|
254
|
+
if attributes[1]:
|
|
255
|
+
description = await attributes[1].read() # Oracle.AsyncLOB
|
|
256
|
+
else:
|
|
257
|
+
description = None
|
|
258
|
+
retention_days = attributes[2]
|
|
259
|
+
return ConversationAttributes(
|
|
260
|
+
title=conversation_title,
|
|
261
|
+
description=description,
|
|
262
|
+
retention_days=retention_days,
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
raise ConversationNotFoundError(
|
|
266
|
+
conversation_id=self.conversation_id
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
async def list(cls) -> AsyncGenerator["AsyncConversation", None]:
|
|
271
|
+
"""List all conversations
|
|
272
|
+
|
|
273
|
+
:return: Iterator[VectorIndex]
|
|
274
|
+
"""
|
|
275
|
+
async with async_cursor() as cr:
|
|
276
|
+
await cr.execute(
|
|
277
|
+
LIST_USER_CONVERSATIONS,
|
|
278
|
+
)
|
|
279
|
+
rows = await cr.fetchall()
|
|
280
|
+
for row in rows:
|
|
281
|
+
conversation_id = row[0]
|
|
282
|
+
conversation_title = row[1]
|
|
283
|
+
if row[2]:
|
|
284
|
+
description = await row[2].read() # Oracle.AsyncLOB
|
|
285
|
+
else:
|
|
286
|
+
description = None
|
|
287
|
+
retention_days = row[3]
|
|
288
|
+
attributes = ConversationAttributes(
|
|
289
|
+
title=conversation_title,
|
|
290
|
+
description=description,
|
|
291
|
+
retention_days=retention_days,
|
|
292
|
+
)
|
|
293
|
+
yield cls(
|
|
294
|
+
attributes=attributes, conversation_id=conversation_id
|
|
295
|
+
)
|
select_ai/credential.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
from typing import Mapping
|
|
9
|
+
|
|
10
|
+
import oracledb
|
|
11
|
+
|
|
12
|
+
from .db import async_cursor, cursor
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"async_create_credential",
|
|
16
|
+
"async_delete_credential",
|
|
17
|
+
"create_credential",
|
|
18
|
+
"delete_credential",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _validate_credential(credential: Mapping[str, str]):
|
|
23
|
+
valid_keys = {
|
|
24
|
+
"credential_name",
|
|
25
|
+
"username",
|
|
26
|
+
"password",
|
|
27
|
+
"user_ocid",
|
|
28
|
+
"tenancy_ocid",
|
|
29
|
+
"private_key",
|
|
30
|
+
"fingerprint",
|
|
31
|
+
"comments",
|
|
32
|
+
}
|
|
33
|
+
for k in credential.keys():
|
|
34
|
+
if k.lower() not in valid_keys:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Invalid value {k}: {credential[k]} for credential object"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def async_create_credential(credential: Mapping, replace: bool = False):
|
|
41
|
+
"""
|
|
42
|
+
Async API to create credential.
|
|
43
|
+
|
|
44
|
+
Creates a credential object using DBMS_CLOUD.CREATE_CREDENTIAL. if replace
|
|
45
|
+
is True, credential will be replaced if it already exists
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
_validate_credential(credential)
|
|
49
|
+
async with async_cursor() as cr:
|
|
50
|
+
try:
|
|
51
|
+
await cr.callproc(
|
|
52
|
+
"DBMS_CLOUD.CREATE_CREDENTIAL", keyword_parameters=credential
|
|
53
|
+
)
|
|
54
|
+
except oracledb.DatabaseError as e:
|
|
55
|
+
(error,) = e.args
|
|
56
|
+
# If already exists and replace is True then drop and recreate
|
|
57
|
+
if error.code == 20022 and replace:
|
|
58
|
+
await cr.callproc(
|
|
59
|
+
"DBMS_CLOUD.DROP_CREDENTIAL",
|
|
60
|
+
keyword_parameters={
|
|
61
|
+
"credential_name": credential["credential_name"]
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
await cr.callproc(
|
|
65
|
+
"DBMS_CLOUD.CREATE_CREDENTIAL",
|
|
66
|
+
keyword_parameters=credential,
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
raise
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def async_delete_credential(credential_name: str, force: bool = False):
|
|
73
|
+
"""
|
|
74
|
+
Async API to create credential.
|
|
75
|
+
|
|
76
|
+
Deletes a credential object using DBMS_CLOUD.DROP_CREDENTIAL
|
|
77
|
+
"""
|
|
78
|
+
async with async_cursor() as cr:
|
|
79
|
+
try:
|
|
80
|
+
await cr.callproc(
|
|
81
|
+
"DBMS_CLOUD.DROP_CREDENTIAL",
|
|
82
|
+
keyword_parameters={"credential_name": credential_name},
|
|
83
|
+
)
|
|
84
|
+
except oracledb.DatabaseError as e:
|
|
85
|
+
(error,) = e.args
|
|
86
|
+
if error.code == 20004 and force: # does not exist
|
|
87
|
+
pass
|
|
88
|
+
else:
|
|
89
|
+
raise
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def create_credential(credential: Mapping, replace: bool = False):
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
Creates a credential object using DBMS_CLOUD.CREATE_CREDENTIAL. if replace
|
|
96
|
+
is True, credential will be replaced if it "already exists"
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
_validate_credential(credential)
|
|
100
|
+
with cursor() as cr:
|
|
101
|
+
try:
|
|
102
|
+
cr.callproc(
|
|
103
|
+
"DBMS_CLOUD.CREATE_CREDENTIAL", keyword_parameters=credential
|
|
104
|
+
)
|
|
105
|
+
except oracledb.DatabaseError as e:
|
|
106
|
+
(error,) = e.args
|
|
107
|
+
# If already exists and replace is True then drop and recreate
|
|
108
|
+
if error.code == 20022 and replace:
|
|
109
|
+
cr.callproc(
|
|
110
|
+
"DBMS_CLOUD.DROP_CREDENTIAL",
|
|
111
|
+
keyword_parameters={
|
|
112
|
+
"credential_name": credential["credential_name"]
|
|
113
|
+
},
|
|
114
|
+
)
|
|
115
|
+
cr.callproc(
|
|
116
|
+
"DBMS_CLOUD.CREATE_CREDENTIAL",
|
|
117
|
+
keyword_parameters=credential,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def delete_credential(credential_name: str, force: bool = False):
|
|
124
|
+
with cursor() as cr:
|
|
125
|
+
try:
|
|
126
|
+
cr.callproc(
|
|
127
|
+
"DBMS_CLOUD.DROP_CREDENTIAL",
|
|
128
|
+
keyword_parameters={"credential_name": credential_name},
|
|
129
|
+
)
|
|
130
|
+
except oracledb.DatabaseError as e:
|
|
131
|
+
(error,) = e.args
|
|
132
|
+
if error.code == 20004 and force: # does not exist
|
|
133
|
+
pass
|
|
134
|
+
else:
|
|
135
|
+
raise
|