select-ai 1.0.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of select-ai might be problematic. Click here for more details.
- select_ai/__init__.py +52 -0
- select_ai/_abc.py +74 -0
- select_ai/_enums.py +14 -0
- select_ai/action.py +21 -0
- select_ai/admin.py +108 -0
- select_ai/async_profile.py +468 -0
- select_ai/base_profile.py +166 -0
- select_ai/conversation.py +249 -0
- select_ai/db.py +171 -0
- select_ai/errors.py +49 -0
- select_ai/profile.py +397 -0
- select_ai/provider.py +187 -0
- select_ai/sql.py +105 -0
- select_ai/synthetic_data.py +84 -0
- select_ai/vector_index.py +542 -0
- select_ai/version.py +8 -0
- select_ai-1.0.0.dev4.dist-info/METADATA +25 -0
- select_ai-1.0.0.dev4.dist-info/RECORD +21 -0
- select_ai-1.0.0.dev4.dist-info/WHEEL +5 -0
- select_ai-1.0.0.dev4.dist-info/licenses/LICENSE.txt +35 -0
- select_ai-1.0.0.dev4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,468 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from dataclasses import replace as dataclass_replace
|
|
11
|
+
from typing import (
|
|
12
|
+
AsyncGenerator,
|
|
13
|
+
List,
|
|
14
|
+
Mapping,
|
|
15
|
+
Optional,
|
|
16
|
+
Tuple,
|
|
17
|
+
Union,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
import oracledb
|
|
21
|
+
import pandas
|
|
22
|
+
|
|
23
|
+
from select_ai.action import Action
|
|
24
|
+
from select_ai.base_profile import BaseProfile, ProfileAttributes
|
|
25
|
+
from select_ai.conversation import AsyncConversation
|
|
26
|
+
from select_ai.db import async_cursor, async_get_connection
|
|
27
|
+
from select_ai.errors import ProfileNotFoundError
|
|
28
|
+
from select_ai.provider import Provider
|
|
29
|
+
from select_ai.sql import (
|
|
30
|
+
GET_USER_AI_PROFILE_ATTRIBUTES,
|
|
31
|
+
LIST_USER_AI_PROFILES,
|
|
32
|
+
)
|
|
33
|
+
from select_ai.synthetic_data import SyntheticDataAttributes
|
|
34
|
+
|
|
35
|
+
__all__ = ["AsyncProfile"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AsyncProfile(BaseProfile):
|
|
39
|
+
"""AsyncProfile defines methods to interact with the underlying AI Provider
|
|
40
|
+
asynchronously.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, *args, **kwargs):
|
|
44
|
+
super().__init__(*args, **kwargs)
|
|
45
|
+
self._init_coroutine = self._init_profile()
|
|
46
|
+
|
|
47
|
+
def __await__(self):
|
|
48
|
+
coroutine = self._init_coroutine
|
|
49
|
+
return coroutine.__await__()
|
|
50
|
+
|
|
51
|
+
async def _init_profile(self):
|
|
52
|
+
"""Initializes AI profile based on the passed attributes
|
|
53
|
+
|
|
54
|
+
:return: None
|
|
55
|
+
:raises: oracledb.DatabaseError
|
|
56
|
+
"""
|
|
57
|
+
if self.profile_name is not None:
|
|
58
|
+
profile_exists = False
|
|
59
|
+
try:
|
|
60
|
+
saved_attributes = await self._get_attributes(
|
|
61
|
+
profile_name=self.profile_name
|
|
62
|
+
)
|
|
63
|
+
profile_exists = True
|
|
64
|
+
except ProfileNotFoundError:
|
|
65
|
+
if self.attributes is None:
|
|
66
|
+
raise
|
|
67
|
+
else:
|
|
68
|
+
if self.attributes is None:
|
|
69
|
+
self.attributes = saved_attributes
|
|
70
|
+
if self.merge:
|
|
71
|
+
self.replace = True
|
|
72
|
+
if self.attributes is not None:
|
|
73
|
+
self.attributes = dataclass_replace(
|
|
74
|
+
saved_attributes,
|
|
75
|
+
**self.attributes.dict(exclude_null=True),
|
|
76
|
+
)
|
|
77
|
+
if self.replace or not profile_exists:
|
|
78
|
+
await self.create(
|
|
79
|
+
replace=self.replace, description=self.description
|
|
80
|
+
)
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
async def _get_attributes(profile_name) -> ProfileAttributes:
|
|
85
|
+
"""Asynchronously gets AI profile attributes from the Database
|
|
86
|
+
|
|
87
|
+
:param str profile_name: Name of the profile
|
|
88
|
+
:return: select_ai.provider.ProviderAttributes
|
|
89
|
+
:raises: ProfileNotFoundError
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
async with async_cursor() as cr:
|
|
93
|
+
await cr.execute(
|
|
94
|
+
GET_USER_AI_PROFILE_ATTRIBUTES,
|
|
95
|
+
profile_name=profile_name.upper(),
|
|
96
|
+
)
|
|
97
|
+
attributes = await cr.fetchall()
|
|
98
|
+
if attributes:
|
|
99
|
+
return await ProfileAttributes.async_create(**dict(attributes))
|
|
100
|
+
else:
|
|
101
|
+
raise ProfileNotFoundError(profile_name=profile_name)
|
|
102
|
+
|
|
103
|
+
async def get_attributes(self) -> ProfileAttributes:
|
|
104
|
+
"""Asynchronously gets AI profile attributes from the Database
|
|
105
|
+
|
|
106
|
+
:return: select_ai.provider.ProviderAttributes
|
|
107
|
+
:raises: ProfileNotFoundError
|
|
108
|
+
"""
|
|
109
|
+
return await self._get_attributes(profile_name=self.profile_name)
|
|
110
|
+
|
|
111
|
+
async def _set_attribute(
|
|
112
|
+
self,
|
|
113
|
+
attribute_name: str,
|
|
114
|
+
attribute_value: Union[bool, str, int, float],
|
|
115
|
+
):
|
|
116
|
+
parameters = {
|
|
117
|
+
"profile_name": self.profile_name,
|
|
118
|
+
"attribute_name": attribute_name,
|
|
119
|
+
"attribute_value": attribute_value,
|
|
120
|
+
}
|
|
121
|
+
async with async_cursor() as cr:
|
|
122
|
+
await cr.callproc(
|
|
123
|
+
"DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
async def set_attribute(
|
|
127
|
+
self,
|
|
128
|
+
attribute_name: str,
|
|
129
|
+
attribute_value: Union[bool, str, int, float, Provider],
|
|
130
|
+
):
|
|
131
|
+
"""Updates AI profile attribute on the Python object and also
|
|
132
|
+
saves it in the database
|
|
133
|
+
|
|
134
|
+
:param str attribute_name: Name of the AI profile attribute
|
|
135
|
+
:param Union[bool, str, int, float] attribute_value: Value of the
|
|
136
|
+
profile attribute
|
|
137
|
+
:return: None
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
self.attributes.set_attribute(attribute_name, attribute_value)
|
|
141
|
+
if isinstance(attribute_value, Provider):
|
|
142
|
+
for k, v in attribute_value.dict().items():
|
|
143
|
+
await self._set_attribute(k, v)
|
|
144
|
+
else:
|
|
145
|
+
await self._set_attribute(attribute_name, attribute_value)
|
|
146
|
+
|
|
147
|
+
async def set_attributes(self, attributes: ProfileAttributes):
|
|
148
|
+
"""Updates AI profile attributes on the Python object and also
|
|
149
|
+
saves it in the database
|
|
150
|
+
|
|
151
|
+
:param ProfileAttributes attributes: Object specifying AI profile
|
|
152
|
+
attributes
|
|
153
|
+
:return: None
|
|
154
|
+
"""
|
|
155
|
+
self.attributes = attributes
|
|
156
|
+
parameters = {
|
|
157
|
+
"profile_name": self.profile_name,
|
|
158
|
+
"attributes": self.attributes.json(),
|
|
159
|
+
}
|
|
160
|
+
async with async_cursor() as cr:
|
|
161
|
+
await cr.callproc(
|
|
162
|
+
"DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
async def create(
|
|
166
|
+
self, replace: Optional[int] = False, description: Optional[str] = None
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Asynchronously create an AI Profile in the Database
|
|
169
|
+
|
|
170
|
+
:param bool replace: Set True to replace else False
|
|
171
|
+
:param description: The profile description
|
|
172
|
+
:return: None
|
|
173
|
+
:raises: oracledb.DatabaseError
|
|
174
|
+
"""
|
|
175
|
+
parameters = {
|
|
176
|
+
"profile_name": self.profile_name,
|
|
177
|
+
"attributes": self.attributes.json(),
|
|
178
|
+
}
|
|
179
|
+
if description:
|
|
180
|
+
parameters["description"] = description
|
|
181
|
+
|
|
182
|
+
async with async_cursor() as cr:
|
|
183
|
+
try:
|
|
184
|
+
await cr.callproc(
|
|
185
|
+
"DBMS_CLOUD_AI.CREATE_PROFILE",
|
|
186
|
+
keyword_parameters=parameters,
|
|
187
|
+
)
|
|
188
|
+
except oracledb.DatabaseError as e:
|
|
189
|
+
(error,) = e.args
|
|
190
|
+
# If already exists and replace is True then drop and recreate
|
|
191
|
+
if "already exists" in error.message.lower() and replace:
|
|
192
|
+
await self.delete(force=True)
|
|
193
|
+
await cr.callproc(
|
|
194
|
+
"DBMS_CLOUD_AI.CREATE_PROFILE",
|
|
195
|
+
keyword_parameters=parameters,
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
raise
|
|
199
|
+
|
|
200
|
+
async def delete(self, force=False) -> None:
|
|
201
|
+
"""Asynchronously deletes an AI profile from the database
|
|
202
|
+
|
|
203
|
+
:param bool force: Ignores errors if AI profile does not exist.
|
|
204
|
+
:return: None
|
|
205
|
+
:raises: oracledb.DatabaseError
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
async with async_cursor() as cr:
|
|
209
|
+
await cr.callproc(
|
|
210
|
+
"DBMS_CLOUD_AI.DROP_PROFILE",
|
|
211
|
+
keyword_parameters={
|
|
212
|
+
"profile_name": self.profile_name,
|
|
213
|
+
"force": force,
|
|
214
|
+
},
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
async def _from_db(cls, profile_name: str) -> "AsyncProfile":
|
|
219
|
+
"""Asynchronously create an AI Profile object from attributes
|
|
220
|
+
saved in the database against the profile
|
|
221
|
+
|
|
222
|
+
:param str profile_name:
|
|
223
|
+
:return: select_ai.Profile
|
|
224
|
+
:raises: ProfileNotFoundError
|
|
225
|
+
"""
|
|
226
|
+
async with async_cursor() as cr:
|
|
227
|
+
await cr.execute(
|
|
228
|
+
GET_USER_AI_PROFILE_ATTRIBUTES, profile_name=profile_name
|
|
229
|
+
)
|
|
230
|
+
attributes = await cr.fetchall()
|
|
231
|
+
if attributes:
|
|
232
|
+
attributes = await ProfileAttributes.async_create(
|
|
233
|
+
**dict(attributes)
|
|
234
|
+
)
|
|
235
|
+
return cls(profile_name=profile_name, attributes=attributes)
|
|
236
|
+
else:
|
|
237
|
+
raise ProfileNotFoundError(profile_name=profile_name)
|
|
238
|
+
|
|
239
|
+
@classmethod
|
|
240
|
+
async def list(
|
|
241
|
+
cls, profile_name_pattern: str
|
|
242
|
+
) -> AsyncGenerator["AsyncProfile", None]:
|
|
243
|
+
"""Asynchronously list AI Profiles saved in the database.
|
|
244
|
+
|
|
245
|
+
:param str profile_name_pattern: Regular expressions can be used
|
|
246
|
+
to specify a pattern. Function REGEXP_LIKE is used to perform the
|
|
247
|
+
match
|
|
248
|
+
|
|
249
|
+
:return: Iterator[Profile]
|
|
250
|
+
"""
|
|
251
|
+
async with async_cursor() as cr:
|
|
252
|
+
await cr.execute(
|
|
253
|
+
LIST_USER_AI_PROFILES,
|
|
254
|
+
profile_name_pattern=profile_name_pattern,
|
|
255
|
+
)
|
|
256
|
+
rows = await cr.fetchall()
|
|
257
|
+
for row in rows:
|
|
258
|
+
profile_name = row[0]
|
|
259
|
+
description = row[1]
|
|
260
|
+
attributes = await cls._get_attributes(
|
|
261
|
+
profile_name=profile_name
|
|
262
|
+
)
|
|
263
|
+
yield cls(
|
|
264
|
+
profile_name=profile_name,
|
|
265
|
+
description=description,
|
|
266
|
+
attributes=attributes,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
async def generate(
|
|
270
|
+
self, prompt, action=Action.SHOWSQL, params: Mapping = None
|
|
271
|
+
) -> Union[pandas.DataFrame, str, None]:
|
|
272
|
+
"""Asynchronously perform AI translation using this profile
|
|
273
|
+
|
|
274
|
+
:param str prompt: Natural language prompt to translate
|
|
275
|
+
:param select_ai.profile.Action action:
|
|
276
|
+
:param params: Parameters to include in the LLM request. For e.g.
|
|
277
|
+
conversation_id for context-aware chats
|
|
278
|
+
:return: Union[pandas.DataFrame, str]
|
|
279
|
+
"""
|
|
280
|
+
parameters = {
|
|
281
|
+
"prompt": prompt,
|
|
282
|
+
"action": action,
|
|
283
|
+
"profile_name": self.profile_name,
|
|
284
|
+
# "attributes": self.attributes.json(),
|
|
285
|
+
}
|
|
286
|
+
if params:
|
|
287
|
+
parameters["params"] = json.dumps(params)
|
|
288
|
+
|
|
289
|
+
async with async_cursor() as cr:
|
|
290
|
+
data = await cr.callfunc(
|
|
291
|
+
"DBMS_CLOUD_AI.GENERATE",
|
|
292
|
+
oracledb.DB_TYPE_CLOB,
|
|
293
|
+
keyword_parameters=parameters,
|
|
294
|
+
)
|
|
295
|
+
if data is not None:
|
|
296
|
+
return await data.read()
|
|
297
|
+
return None
|
|
298
|
+
|
|
299
|
+
async def chat(self, prompt, params: Mapping = None) -> str:
|
|
300
|
+
"""Asynchronously chat with the LLM
|
|
301
|
+
|
|
302
|
+
:param str prompt: Natural language prompt
|
|
303
|
+
:param params: Parameters to include in the LLM request
|
|
304
|
+
:return: str
|
|
305
|
+
"""
|
|
306
|
+
return await self.generate(prompt, action=Action.CHAT, params=params)
|
|
307
|
+
|
|
308
|
+
@asynccontextmanager
|
|
309
|
+
async def chat_session(
|
|
310
|
+
self, conversation: AsyncConversation, delete: bool = False
|
|
311
|
+
):
|
|
312
|
+
"""Starts a new chat session for context-aware conversations
|
|
313
|
+
|
|
314
|
+
:param AsyncConversation conversation: Conversation object to use for this
|
|
315
|
+
chat session
|
|
316
|
+
:param bool delete: Delete conversation after session ends
|
|
317
|
+
|
|
318
|
+
"""
|
|
319
|
+
try:
|
|
320
|
+
if (
|
|
321
|
+
conversation.conversation_id is None
|
|
322
|
+
and conversation.attributes is not None
|
|
323
|
+
):
|
|
324
|
+
await conversation.create()
|
|
325
|
+
params = {"conversation_id": conversation.conversation_id}
|
|
326
|
+
async_session = AsyncSession(async_profile=self, params=params)
|
|
327
|
+
yield async_session
|
|
328
|
+
finally:
|
|
329
|
+
if delete:
|
|
330
|
+
await conversation.delete()
|
|
331
|
+
|
|
332
|
+
async def narrate(self, prompt, params: Mapping = None) -> str:
|
|
333
|
+
"""Narrate the result of the SQL
|
|
334
|
+
|
|
335
|
+
:param str prompt: Natural language prompt
|
|
336
|
+
:param params: Parameters to include in the LLM request
|
|
337
|
+
:return: str
|
|
338
|
+
"""
|
|
339
|
+
return await self.generate(
|
|
340
|
+
prompt, action=Action.NARRATE, params=params
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
async def explain_sql(self, prompt: str, params: Mapping = None):
|
|
344
|
+
"""Explain the generated SQL
|
|
345
|
+
|
|
346
|
+
:param str prompt: Natural language prompt
|
|
347
|
+
:param params: Parameters to include in the LLM request
|
|
348
|
+
:return: str
|
|
349
|
+
"""
|
|
350
|
+
return await self.generate(
|
|
351
|
+
prompt, action=Action.EXPLAINSQL, params=params
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
async def run_sql(
|
|
355
|
+
self, prompt, params: Mapping = None
|
|
356
|
+
) -> pandas.DataFrame:
|
|
357
|
+
"""Explain the generated SQL
|
|
358
|
+
|
|
359
|
+
:param str prompt: Natural language prompt
|
|
360
|
+
:param params: Parameters to include in the LLM request
|
|
361
|
+
:return: pandas.DataFrame
|
|
362
|
+
"""
|
|
363
|
+
data = await self.generate(prompt, action=Action.RUNSQL, params=params)
|
|
364
|
+
return pandas.DataFrame(json.loads(data))
|
|
365
|
+
|
|
366
|
+
async def show_sql(self, prompt, params: Mapping = None):
|
|
367
|
+
"""Show the generated SQL
|
|
368
|
+
|
|
369
|
+
:param str prompt: Natural language prompt
|
|
370
|
+
:param params: Parameters to include in the LLM request
|
|
371
|
+
:return: str
|
|
372
|
+
"""
|
|
373
|
+
return await self.generate(
|
|
374
|
+
prompt, action=Action.SHOWSQL, params=params
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
async def show_prompt(self, prompt: str, params: Mapping = None):
|
|
378
|
+
"""Show the prompt sent to LLM
|
|
379
|
+
|
|
380
|
+
:param str prompt: Natural language prompt
|
|
381
|
+
:param params: Parameters to include in the LLM request
|
|
382
|
+
:return: str
|
|
383
|
+
"""
|
|
384
|
+
return await self.generate(
|
|
385
|
+
prompt, action=Action.SHOWPROMPT, params=params
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
async def generate_synthetic_data(
|
|
389
|
+
self, synthetic_data_attributes: SyntheticDataAttributes
|
|
390
|
+
) -> None:
|
|
391
|
+
"""Generate synthetic data for a single table, multiple tables or a
|
|
392
|
+
full schema.
|
|
393
|
+
|
|
394
|
+
:param select_ai.SyntheticDataAttributes synthetic_data_attributes:
|
|
395
|
+
:return: None
|
|
396
|
+
:raises: oracledb.DatabaseError
|
|
397
|
+
|
|
398
|
+
"""
|
|
399
|
+
keyword_parameters = synthetic_data_attributes.prepare()
|
|
400
|
+
keyword_parameters["profile_name"] = self.profile_name
|
|
401
|
+
async with async_cursor() as cr:
|
|
402
|
+
await cr.callproc(
|
|
403
|
+
"DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
|
|
404
|
+
keyword_parameters=keyword_parameters,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
async def run_pipeline(
|
|
408
|
+
self,
|
|
409
|
+
prompt_specifications: List[Tuple[str, Action]],
|
|
410
|
+
continue_on_error: bool = False,
|
|
411
|
+
) -> List[Union[str, pandas.DataFrame]]:
|
|
412
|
+
"""Send Multiple prompts in a single roundtrip to the Database
|
|
413
|
+
|
|
414
|
+
:param List[Tuple[str, Action]] prompt_specifications: List of
|
|
415
|
+
2-element tuples. First element is the prompt and second is the
|
|
416
|
+
corresponding action
|
|
417
|
+
|
|
418
|
+
:param bool continue_on_error: True to continue on error else False
|
|
419
|
+
:return: List[Union[str, pandas.DataFrame]]
|
|
420
|
+
"""
|
|
421
|
+
pipeline = oracledb.create_pipeline()
|
|
422
|
+
for prompt, action in prompt_specifications:
|
|
423
|
+
parameters = {
|
|
424
|
+
"prompt": prompt,
|
|
425
|
+
"action": action,
|
|
426
|
+
"profile_name": self.profile_name,
|
|
427
|
+
# "attributes": self.attributes.json(),
|
|
428
|
+
}
|
|
429
|
+
pipeline.add_callfunc(
|
|
430
|
+
"DBMS_CLOUD_AI.GENERATE",
|
|
431
|
+
return_type=oracledb.DB_TYPE_CLOB,
|
|
432
|
+
keyword_parameters=parameters,
|
|
433
|
+
)
|
|
434
|
+
async_connection = await async_get_connection()
|
|
435
|
+
pipeline_results = await async_connection.run_pipeline(
|
|
436
|
+
pipeline, continue_on_error=continue_on_error
|
|
437
|
+
)
|
|
438
|
+
responses = []
|
|
439
|
+
for result in pipeline_results:
|
|
440
|
+
if not result.error:
|
|
441
|
+
responses.append(await result.return_value.read())
|
|
442
|
+
else:
|
|
443
|
+
responses.append(result.error)
|
|
444
|
+
return responses
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class AsyncSession:
|
|
448
|
+
"""AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
|
|
449
|
+
requests. This is useful in context-aware conversations
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def __init__(self, async_profile: AsyncProfile, params: Mapping):
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
:param async_profile: An AI Profile to use in this session
|
|
456
|
+
:param params: Parameters to be persisted across requests
|
|
457
|
+
"""
|
|
458
|
+
self.params = params
|
|
459
|
+
self.async_profile = async_profile
|
|
460
|
+
|
|
461
|
+
async def chat(self, prompt: str):
|
|
462
|
+
return await self.async_profile.chat(prompt=prompt, params=self.params)
|
|
463
|
+
|
|
464
|
+
async def __aenter__(self):
|
|
465
|
+
return self
|
|
466
|
+
|
|
467
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
468
|
+
pass
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025, Oracle and/or its affiliates.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at
|
|
5
|
+
# http://oss.oracle.com/licenses/upl.
|
|
6
|
+
# -----------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from abc import ABC
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import List, Mapping, Optional
|
|
12
|
+
|
|
13
|
+
import oracledb
|
|
14
|
+
|
|
15
|
+
from select_ai._abc import SelectAIDataClass
|
|
16
|
+
|
|
17
|
+
from .provider import Provider
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ProfileAttributes(SelectAIDataClass):
|
|
22
|
+
"""
|
|
23
|
+
Use this class to define attributes to manage and configure the behavior of
|
|
24
|
+
an AI profile
|
|
25
|
+
|
|
26
|
+
:param bool comments: True to include column comments in the metadata used
|
|
27
|
+
for generating SQL queries from natural language prompts.
|
|
28
|
+
:param bool constraints: True to include referential integrity constraints
|
|
29
|
+
such as primary and foreign keys in the metadata sent to the LLM.
|
|
30
|
+
:param bool conversation: Indicates if conversation history is enabled for
|
|
31
|
+
a profile.
|
|
32
|
+
:param str credential_name: The name of the credential to access the AI
|
|
33
|
+
provider APIs.
|
|
34
|
+
:param bool enforce_object_list: Specifies whether to restrict the LLM
|
|
35
|
+
to generate SQL that uses only tables covered by the object list.
|
|
36
|
+
:param int max_tokens: Denotes the number of tokens to return per
|
|
37
|
+
generation. Default is 1024.
|
|
38
|
+
:param List[Mapping] object_list: Array of JSON objects specifying
|
|
39
|
+
the owner and object names that are eligible for natural language
|
|
40
|
+
translation to SQL.
|
|
41
|
+
:param str object_list_mode: Specifies whether to send metadata for the
|
|
42
|
+
most relevant tables or all tables to the LLM. Supported values are -
|
|
43
|
+
'automated' and 'all'
|
|
44
|
+
:param select_ai.Provider provider: AI Provider
|
|
45
|
+
:param str stop_tokens: The generated text will be terminated at the
|
|
46
|
+
beginning of the earliest stop sequence. Sequence will be incorporated
|
|
47
|
+
into the text. The attribute value must be a valid array of string values
|
|
48
|
+
in JSON format
|
|
49
|
+
:param float temperature: Temperature is a non-negative float number used
|
|
50
|
+
to tune the degree of randomness. Lower temperatures mean less random
|
|
51
|
+
generations.
|
|
52
|
+
:param str vector_index_name: Name of the vector index
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
annotations: Optional[str] = None
|
|
57
|
+
case_sensitive_values: Optional[bool] = None
|
|
58
|
+
comments: Optional[bool] = None
|
|
59
|
+
constraints: Optional[str] = None
|
|
60
|
+
conversation: Optional[bool] = None
|
|
61
|
+
credential_name: Optional[str] = None
|
|
62
|
+
enable_sources: Optional[bool] = None
|
|
63
|
+
enable_source_offsets: Optional[bool] = None
|
|
64
|
+
enforce_object_list: Optional[bool] = None
|
|
65
|
+
max_tokens: Optional[int] = 1024
|
|
66
|
+
object_list: Optional[List[Mapping]] = None
|
|
67
|
+
object_list_mode: Optional[str] = None
|
|
68
|
+
provider: Optional[Provider] = None
|
|
69
|
+
seed: Optional[str] = None
|
|
70
|
+
stop_tokens: Optional[str] = None
|
|
71
|
+
streaming: Optional[str] = None
|
|
72
|
+
temperature: Optional[float] = None
|
|
73
|
+
vector_index_name: Optional[str] = None
|
|
74
|
+
|
|
75
|
+
def json(self, exclude_null=True):
|
|
76
|
+
attributes = {}
|
|
77
|
+
for k, v in self.dict(exclude_null=exclude_null).items():
|
|
78
|
+
if isinstance(v, Provider):
|
|
79
|
+
for provider_k, provider_v in v.dict(
|
|
80
|
+
exclude_null=exclude_null
|
|
81
|
+
).items():
|
|
82
|
+
attributes[Provider.key_alias(provider_k)] = provider_v
|
|
83
|
+
else:
|
|
84
|
+
attributes[k] = v
|
|
85
|
+
return json.dumps(attributes)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def create(cls, **kwargs):
|
|
89
|
+
provider_attributes = {}
|
|
90
|
+
profile_attributes = {}
|
|
91
|
+
for k, v in kwargs.items():
|
|
92
|
+
if isinstance(v, oracledb.LOB):
|
|
93
|
+
v = v.read()
|
|
94
|
+
if k in Provider.keys():
|
|
95
|
+
provider_attributes[Provider.key_alias(k)] = v
|
|
96
|
+
else:
|
|
97
|
+
profile_attributes[k] = v
|
|
98
|
+
provider = Provider.create(**provider_attributes)
|
|
99
|
+
profile_attributes["provider"] = provider
|
|
100
|
+
return ProfileAttributes(**profile_attributes)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
async def async_create(cls, **kwargs):
|
|
104
|
+
provider_attributes = {}
|
|
105
|
+
profile_attributes = {}
|
|
106
|
+
for k, v in kwargs.items():
|
|
107
|
+
if isinstance(v, oracledb.AsyncLOB):
|
|
108
|
+
v = await v.read()
|
|
109
|
+
if k in Provider.keys():
|
|
110
|
+
provider_attributes[Provider.key_alias(k)] = v
|
|
111
|
+
else:
|
|
112
|
+
profile_attributes[k] = v
|
|
113
|
+
provider = Provider.create(**provider_attributes)
|
|
114
|
+
profile_attributes["provider"] = provider
|
|
115
|
+
return ProfileAttributes(**profile_attributes)
|
|
116
|
+
|
|
117
|
+
def set_attribute(self, key, value):
|
|
118
|
+
if key in Provider.keys() and not isinstance(value, Provider):
|
|
119
|
+
setattr(self.provider, key, value)
|
|
120
|
+
else:
|
|
121
|
+
setattr(self, key, value)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class BaseProfile(ABC):
|
|
125
|
+
"""
|
|
126
|
+
BaseProfile is an abstract base class representing a Profile
|
|
127
|
+
for Select AI's interactions with AI service providers (LLMs).
|
|
128
|
+
Use either select_ai.Profile or select_ai.AsyncProfile to
|
|
129
|
+
instantiate an AI profile object.
|
|
130
|
+
|
|
131
|
+
:param str profile_name : Name of the profile
|
|
132
|
+
|
|
133
|
+
:param select_ai.ProfileAttributes attributes:
|
|
134
|
+
Object specifying AI profile attributes
|
|
135
|
+
|
|
136
|
+
:param str description: Description of the profile
|
|
137
|
+
|
|
138
|
+
:param bool merge: Fetches the profile
|
|
139
|
+
from database, merges the attributes and saves it back
|
|
140
|
+
in the database. Default value is False
|
|
141
|
+
|
|
142
|
+
:param bool replace: Replaces the profile and attributes
|
|
143
|
+
in the database. Default value is False
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
profile_name: Optional[str] = None,
|
|
150
|
+
attributes: Optional[ProfileAttributes] = None,
|
|
151
|
+
description: Optional[str] = None,
|
|
152
|
+
merge: Optional[bool] = False,
|
|
153
|
+
replace: Optional[bool] = False,
|
|
154
|
+
):
|
|
155
|
+
"""Initialize a base profile"""
|
|
156
|
+
self.profile_name = profile_name
|
|
157
|
+
self.attributes = attributes
|
|
158
|
+
self.description = description
|
|
159
|
+
self.merge = merge
|
|
160
|
+
self.replace = replace
|
|
161
|
+
|
|
162
|
+
def __repr__(self):
|
|
163
|
+
return (
|
|
164
|
+
f"{self.__class__.__name__}(profile_name={self.profile_name}, "
|
|
165
|
+
f"attributes={self.attributes}, description={self.description})"
|
|
166
|
+
)
|