select-ai 1.0.0.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of select-ai might be problematic. Click here for more details.

@@ -0,0 +1,468 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from contextlib import asynccontextmanager
10
+ from dataclasses import replace as dataclass_replace
11
+ from typing import (
12
+ AsyncGenerator,
13
+ List,
14
+ Mapping,
15
+ Optional,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import oracledb
21
+ import pandas
22
+
23
+ from select_ai.action import Action
24
+ from select_ai.base_profile import BaseProfile, ProfileAttributes
25
+ from select_ai.conversation import AsyncConversation
26
+ from select_ai.db import async_cursor, async_get_connection
27
+ from select_ai.errors import ProfileNotFoundError
28
+ from select_ai.provider import Provider
29
+ from select_ai.sql import (
30
+ GET_USER_AI_PROFILE_ATTRIBUTES,
31
+ LIST_USER_AI_PROFILES,
32
+ )
33
+ from select_ai.synthetic_data import SyntheticDataAttributes
34
+
35
+ __all__ = ["AsyncProfile"]
36
+
37
+
38
+ class AsyncProfile(BaseProfile):
39
+ """AsyncProfile defines methods to interact with the underlying AI Provider
40
+ asynchronously.
41
+ """
42
+
43
+ def __init__(self, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+ self._init_coroutine = self._init_profile()
46
+
47
+ def __await__(self):
48
+ coroutine = self._init_coroutine
49
+ return coroutine.__await__()
50
+
51
+ async def _init_profile(self):
52
+ """Initializes AI profile based on the passed attributes
53
+
54
+ :return: None
55
+ :raises: oracledb.DatabaseError
56
+ """
57
+ if self.profile_name is not None:
58
+ profile_exists = False
59
+ try:
60
+ saved_attributes = await self._get_attributes(
61
+ profile_name=self.profile_name
62
+ )
63
+ profile_exists = True
64
+ except ProfileNotFoundError:
65
+ if self.attributes is None:
66
+ raise
67
+ else:
68
+ if self.attributes is None:
69
+ self.attributes = saved_attributes
70
+ if self.merge:
71
+ self.replace = True
72
+ if self.attributes is not None:
73
+ self.attributes = dataclass_replace(
74
+ saved_attributes,
75
+ **self.attributes.dict(exclude_null=True),
76
+ )
77
+ if self.replace or not profile_exists:
78
+ await self.create(
79
+ replace=self.replace, description=self.description
80
+ )
81
+ return self
82
+
83
+ @staticmethod
84
+ async def _get_attributes(profile_name) -> ProfileAttributes:
85
+ """Asynchronously gets AI profile attributes from the Database
86
+
87
+ :param str profile_name: Name of the profile
88
+ :return: select_ai.provider.ProviderAttributes
89
+ :raises: ProfileNotFoundError
90
+
91
+ """
92
+ async with async_cursor() as cr:
93
+ await cr.execute(
94
+ GET_USER_AI_PROFILE_ATTRIBUTES,
95
+ profile_name=profile_name.upper(),
96
+ )
97
+ attributes = await cr.fetchall()
98
+ if attributes:
99
+ return await ProfileAttributes.async_create(**dict(attributes))
100
+ else:
101
+ raise ProfileNotFoundError(profile_name=profile_name)
102
+
103
+ async def get_attributes(self) -> ProfileAttributes:
104
+ """Asynchronously gets AI profile attributes from the Database
105
+
106
+ :return: select_ai.provider.ProviderAttributes
107
+ :raises: ProfileNotFoundError
108
+ """
109
+ return await self._get_attributes(profile_name=self.profile_name)
110
+
111
+ async def _set_attribute(
112
+ self,
113
+ attribute_name: str,
114
+ attribute_value: Union[bool, str, int, float],
115
+ ):
116
+ parameters = {
117
+ "profile_name": self.profile_name,
118
+ "attribute_name": attribute_name,
119
+ "attribute_value": attribute_value,
120
+ }
121
+ async with async_cursor() as cr:
122
+ await cr.callproc(
123
+ "DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
124
+ )
125
+
126
+ async def set_attribute(
127
+ self,
128
+ attribute_name: str,
129
+ attribute_value: Union[bool, str, int, float, Provider],
130
+ ):
131
+ """Updates AI profile attribute on the Python object and also
132
+ saves it in the database
133
+
134
+ :param str attribute_name: Name of the AI profile attribute
135
+ :param Union[bool, str, int, float] attribute_value: Value of the
136
+ profile attribute
137
+ :return: None
138
+
139
+ """
140
+ self.attributes.set_attribute(attribute_name, attribute_value)
141
+ if isinstance(attribute_value, Provider):
142
+ for k, v in attribute_value.dict().items():
143
+ await self._set_attribute(k, v)
144
+ else:
145
+ await self._set_attribute(attribute_name, attribute_value)
146
+
147
+ async def set_attributes(self, attributes: ProfileAttributes):
148
+ """Updates AI profile attributes on the Python object and also
149
+ saves it in the database
150
+
151
+ :param ProfileAttributes attributes: Object specifying AI profile
152
+ attributes
153
+ :return: None
154
+ """
155
+ self.attributes = attributes
156
+ parameters = {
157
+ "profile_name": self.profile_name,
158
+ "attributes": self.attributes.json(),
159
+ }
160
+ async with async_cursor() as cr:
161
+ await cr.callproc(
162
+ "DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
163
+ )
164
+
165
+ async def create(
166
+ self, replace: Optional[int] = False, description: Optional[str] = None
167
+ ) -> None:
168
+ """Asynchronously create an AI Profile in the Database
169
+
170
+ :param bool replace: Set True to replace else False
171
+ :param description: The profile description
172
+ :return: None
173
+ :raises: oracledb.DatabaseError
174
+ """
175
+ parameters = {
176
+ "profile_name": self.profile_name,
177
+ "attributes": self.attributes.json(),
178
+ }
179
+ if description:
180
+ parameters["description"] = description
181
+
182
+ async with async_cursor() as cr:
183
+ try:
184
+ await cr.callproc(
185
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
186
+ keyword_parameters=parameters,
187
+ )
188
+ except oracledb.DatabaseError as e:
189
+ (error,) = e.args
190
+ # If already exists and replace is True then drop and recreate
191
+ if "already exists" in error.message.lower() and replace:
192
+ await self.delete(force=True)
193
+ await cr.callproc(
194
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
195
+ keyword_parameters=parameters,
196
+ )
197
+ else:
198
+ raise
199
+
200
+ async def delete(self, force=False) -> None:
201
+ """Asynchronously deletes an AI profile from the database
202
+
203
+ :param bool force: Ignores errors if AI profile does not exist.
204
+ :return: None
205
+ :raises: oracledb.DatabaseError
206
+
207
+ """
208
+ async with async_cursor() as cr:
209
+ await cr.callproc(
210
+ "DBMS_CLOUD_AI.DROP_PROFILE",
211
+ keyword_parameters={
212
+ "profile_name": self.profile_name,
213
+ "force": force,
214
+ },
215
+ )
216
+
217
+ @classmethod
218
+ async def _from_db(cls, profile_name: str) -> "AsyncProfile":
219
+ """Asynchronously create an AI Profile object from attributes
220
+ saved in the database against the profile
221
+
222
+ :param str profile_name:
223
+ :return: select_ai.Profile
224
+ :raises: ProfileNotFoundError
225
+ """
226
+ async with async_cursor() as cr:
227
+ await cr.execute(
228
+ GET_USER_AI_PROFILE_ATTRIBUTES, profile_name=profile_name
229
+ )
230
+ attributes = await cr.fetchall()
231
+ if attributes:
232
+ attributes = await ProfileAttributes.async_create(
233
+ **dict(attributes)
234
+ )
235
+ return cls(profile_name=profile_name, attributes=attributes)
236
+ else:
237
+ raise ProfileNotFoundError(profile_name=profile_name)
238
+
239
+ @classmethod
240
+ async def list(
241
+ cls, profile_name_pattern: str
242
+ ) -> AsyncGenerator["AsyncProfile", None]:
243
+ """Asynchronously list AI Profiles saved in the database.
244
+
245
+ :param str profile_name_pattern: Regular expressions can be used
246
+ to specify a pattern. Function REGEXP_LIKE is used to perform the
247
+ match
248
+
249
+ :return: Iterator[Profile]
250
+ """
251
+ async with async_cursor() as cr:
252
+ await cr.execute(
253
+ LIST_USER_AI_PROFILES,
254
+ profile_name_pattern=profile_name_pattern,
255
+ )
256
+ rows = await cr.fetchall()
257
+ for row in rows:
258
+ profile_name = row[0]
259
+ description = row[1]
260
+ attributes = await cls._get_attributes(
261
+ profile_name=profile_name
262
+ )
263
+ yield cls(
264
+ profile_name=profile_name,
265
+ description=description,
266
+ attributes=attributes,
267
+ )
268
+
269
+ async def generate(
270
+ self, prompt, action=Action.SHOWSQL, params: Mapping = None
271
+ ) -> Union[pandas.DataFrame, str, None]:
272
+ """Asynchronously perform AI translation using this profile
273
+
274
+ :param str prompt: Natural language prompt to translate
275
+ :param select_ai.profile.Action action:
276
+ :param params: Parameters to include in the LLM request. For e.g.
277
+ conversation_id for context-aware chats
278
+ :return: Union[pandas.DataFrame, str]
279
+ """
280
+ parameters = {
281
+ "prompt": prompt,
282
+ "action": action,
283
+ "profile_name": self.profile_name,
284
+ # "attributes": self.attributes.json(),
285
+ }
286
+ if params:
287
+ parameters["params"] = json.dumps(params)
288
+
289
+ async with async_cursor() as cr:
290
+ data = await cr.callfunc(
291
+ "DBMS_CLOUD_AI.GENERATE",
292
+ oracledb.DB_TYPE_CLOB,
293
+ keyword_parameters=parameters,
294
+ )
295
+ if data is not None:
296
+ return await data.read()
297
+ return None
298
+
299
+ async def chat(self, prompt, params: Mapping = None) -> str:
300
+ """Asynchronously chat with the LLM
301
+
302
+ :param str prompt: Natural language prompt
303
+ :param params: Parameters to include in the LLM request
304
+ :return: str
305
+ """
306
+ return await self.generate(prompt, action=Action.CHAT, params=params)
307
+
308
+ @asynccontextmanager
309
+ async def chat_session(
310
+ self, conversation: AsyncConversation, delete: bool = False
311
+ ):
312
+ """Starts a new chat session for context-aware conversations
313
+
314
+ :param AsyncConversation conversation: Conversation object to use for this
315
+ chat session
316
+ :param bool delete: Delete conversation after session ends
317
+
318
+ """
319
+ try:
320
+ if (
321
+ conversation.conversation_id is None
322
+ and conversation.attributes is not None
323
+ ):
324
+ await conversation.create()
325
+ params = {"conversation_id": conversation.conversation_id}
326
+ async_session = AsyncSession(async_profile=self, params=params)
327
+ yield async_session
328
+ finally:
329
+ if delete:
330
+ await conversation.delete()
331
+
332
+ async def narrate(self, prompt, params: Mapping = None) -> str:
333
+ """Narrate the result of the SQL
334
+
335
+ :param str prompt: Natural language prompt
336
+ :param params: Parameters to include in the LLM request
337
+ :return: str
338
+ """
339
+ return await self.generate(
340
+ prompt, action=Action.NARRATE, params=params
341
+ )
342
+
343
+ async def explain_sql(self, prompt: str, params: Mapping = None):
344
+ """Explain the generated SQL
345
+
346
+ :param str prompt: Natural language prompt
347
+ :param params: Parameters to include in the LLM request
348
+ :return: str
349
+ """
350
+ return await self.generate(
351
+ prompt, action=Action.EXPLAINSQL, params=params
352
+ )
353
+
354
+ async def run_sql(
355
+ self, prompt, params: Mapping = None
356
+ ) -> pandas.DataFrame:
357
+ """Explain the generated SQL
358
+
359
+ :param str prompt: Natural language prompt
360
+ :param params: Parameters to include in the LLM request
361
+ :return: pandas.DataFrame
362
+ """
363
+ data = await self.generate(prompt, action=Action.RUNSQL, params=params)
364
+ return pandas.DataFrame(json.loads(data))
365
+
366
+ async def show_sql(self, prompt, params: Mapping = None):
367
+ """Show the generated SQL
368
+
369
+ :param str prompt: Natural language prompt
370
+ :param params: Parameters to include in the LLM request
371
+ :return: str
372
+ """
373
+ return await self.generate(
374
+ prompt, action=Action.SHOWSQL, params=params
375
+ )
376
+
377
+ async def show_prompt(self, prompt: str, params: Mapping = None):
378
+ """Show the prompt sent to LLM
379
+
380
+ :param str prompt: Natural language prompt
381
+ :param params: Parameters to include in the LLM request
382
+ :return: str
383
+ """
384
+ return await self.generate(
385
+ prompt, action=Action.SHOWPROMPT, params=params
386
+ )
387
+
388
+ async def generate_synthetic_data(
389
+ self, synthetic_data_attributes: SyntheticDataAttributes
390
+ ) -> None:
391
+ """Generate synthetic data for a single table, multiple tables or a
392
+ full schema.
393
+
394
+ :param select_ai.SyntheticDataAttributes synthetic_data_attributes:
395
+ :return: None
396
+ :raises: oracledb.DatabaseError
397
+
398
+ """
399
+ keyword_parameters = synthetic_data_attributes.prepare()
400
+ keyword_parameters["profile_name"] = self.profile_name
401
+ async with async_cursor() as cr:
402
+ await cr.callproc(
403
+ "DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
404
+ keyword_parameters=keyword_parameters,
405
+ )
406
+
407
+ async def run_pipeline(
408
+ self,
409
+ prompt_specifications: List[Tuple[str, Action]],
410
+ continue_on_error: bool = False,
411
+ ) -> List[Union[str, pandas.DataFrame]]:
412
+ """Send Multiple prompts in a single roundtrip to the Database
413
+
414
+ :param List[Tuple[str, Action]] prompt_specifications: List of
415
+ 2-element tuples. First element is the prompt and second is the
416
+ corresponding action
417
+
418
+ :param bool continue_on_error: True to continue on error else False
419
+ :return: List[Union[str, pandas.DataFrame]]
420
+ """
421
+ pipeline = oracledb.create_pipeline()
422
+ for prompt, action in prompt_specifications:
423
+ parameters = {
424
+ "prompt": prompt,
425
+ "action": action,
426
+ "profile_name": self.profile_name,
427
+ # "attributes": self.attributes.json(),
428
+ }
429
+ pipeline.add_callfunc(
430
+ "DBMS_CLOUD_AI.GENERATE",
431
+ return_type=oracledb.DB_TYPE_CLOB,
432
+ keyword_parameters=parameters,
433
+ )
434
+ async_connection = await async_get_connection()
435
+ pipeline_results = await async_connection.run_pipeline(
436
+ pipeline, continue_on_error=continue_on_error
437
+ )
438
+ responses = []
439
+ for result in pipeline_results:
440
+ if not result.error:
441
+ responses.append(await result.return_value.read())
442
+ else:
443
+ responses.append(result.error)
444
+ return responses
445
+
446
+
447
+ class AsyncSession:
448
+ """AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
449
+ requests. This is useful in context-aware conversations
450
+ """
451
+
452
+ def __init__(self, async_profile: AsyncProfile, params: Mapping):
453
+ """
454
+
455
+ :param async_profile: An AI Profile to use in this session
456
+ :param params: Parameters to be persisted across requests
457
+ """
458
+ self.params = params
459
+ self.async_profile = async_profile
460
+
461
+ async def chat(self, prompt: str):
462
+ return await self.async_profile.chat(prompt=prompt, params=self.params)
463
+
464
+ async def __aenter__(self):
465
+ return self
466
+
467
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
468
+ pass
@@ -0,0 +1,166 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from abc import ABC
10
+ from dataclasses import dataclass
11
+ from typing import List, Mapping, Optional
12
+
13
+ import oracledb
14
+
15
+ from select_ai._abc import SelectAIDataClass
16
+
17
+ from .provider import Provider
18
+
19
+
20
+ @dataclass
21
+ class ProfileAttributes(SelectAIDataClass):
22
+ """
23
+ Use this class to define attributes to manage and configure the behavior of
24
+ an AI profile
25
+
26
+ :param bool comments: True to include column comments in the metadata used
27
+ for generating SQL queries from natural language prompts.
28
+ :param bool constraints: True to include referential integrity constraints
29
+ such as primary and foreign keys in the metadata sent to the LLM.
30
+ :param bool conversation: Indicates if conversation history is enabled for
31
+ a profile.
32
+ :param str credential_name: The name of the credential to access the AI
33
+ provider APIs.
34
+ :param bool enforce_object_list: Specifies whether to restrict the LLM
35
+ to generate SQL that uses only tables covered by the object list.
36
+ :param int max_tokens: Denotes the number of tokens to return per
37
+ generation. Default is 1024.
38
+ :param List[Mapping] object_list: Array of JSON objects specifying
39
+ the owner and object names that are eligible for natural language
40
+ translation to SQL.
41
+ :param str object_list_mode: Specifies whether to send metadata for the
42
+ most relevant tables or all tables to the LLM. Supported values are -
43
+ 'automated' and 'all'
44
+ :param select_ai.Provider provider: AI Provider
45
+ :param str stop_tokens: The generated text will be terminated at the
46
+ beginning of the earliest stop sequence. Sequence will be incorporated
47
+ into the text. The attribute value must be a valid array of string values
48
+ in JSON format
49
+ :param float temperature: Temperature is a non-negative float number used
50
+ to tune the degree of randomness. Lower temperatures mean less random
51
+ generations.
52
+ :param str vector_index_name: Name of the vector index
53
+
54
+ """
55
+
56
+ annotations: Optional[str] = None
57
+ case_sensitive_values: Optional[bool] = None
58
+ comments: Optional[bool] = None
59
+ constraints: Optional[str] = None
60
+ conversation: Optional[bool] = None
61
+ credential_name: Optional[str] = None
62
+ enable_sources: Optional[bool] = None
63
+ enable_source_offsets: Optional[bool] = None
64
+ enforce_object_list: Optional[bool] = None
65
+ max_tokens: Optional[int] = 1024
66
+ object_list: Optional[List[Mapping]] = None
67
+ object_list_mode: Optional[str] = None
68
+ provider: Optional[Provider] = None
69
+ seed: Optional[str] = None
70
+ stop_tokens: Optional[str] = None
71
+ streaming: Optional[str] = None
72
+ temperature: Optional[float] = None
73
+ vector_index_name: Optional[str] = None
74
+
75
+ def json(self, exclude_null=True):
76
+ attributes = {}
77
+ for k, v in self.dict(exclude_null=exclude_null).items():
78
+ if isinstance(v, Provider):
79
+ for provider_k, provider_v in v.dict(
80
+ exclude_null=exclude_null
81
+ ).items():
82
+ attributes[Provider.key_alias(provider_k)] = provider_v
83
+ else:
84
+ attributes[k] = v
85
+ return json.dumps(attributes)
86
+
87
+ @classmethod
88
+ def create(cls, **kwargs):
89
+ provider_attributes = {}
90
+ profile_attributes = {}
91
+ for k, v in kwargs.items():
92
+ if isinstance(v, oracledb.LOB):
93
+ v = v.read()
94
+ if k in Provider.keys():
95
+ provider_attributes[Provider.key_alias(k)] = v
96
+ else:
97
+ profile_attributes[k] = v
98
+ provider = Provider.create(**provider_attributes)
99
+ profile_attributes["provider"] = provider
100
+ return ProfileAttributes(**profile_attributes)
101
+
102
+ @classmethod
103
+ async def async_create(cls, **kwargs):
104
+ provider_attributes = {}
105
+ profile_attributes = {}
106
+ for k, v in kwargs.items():
107
+ if isinstance(v, oracledb.AsyncLOB):
108
+ v = await v.read()
109
+ if k in Provider.keys():
110
+ provider_attributes[Provider.key_alias(k)] = v
111
+ else:
112
+ profile_attributes[k] = v
113
+ provider = Provider.create(**provider_attributes)
114
+ profile_attributes["provider"] = provider
115
+ return ProfileAttributes(**profile_attributes)
116
+
117
+ def set_attribute(self, key, value):
118
+ if key in Provider.keys() and not isinstance(value, Provider):
119
+ setattr(self.provider, key, value)
120
+ else:
121
+ setattr(self, key, value)
122
+
123
+
124
+ class BaseProfile(ABC):
125
+ """
126
+ BaseProfile is an abstract base class representing a Profile
127
+ for Select AI's interactions with AI service providers (LLMs).
128
+ Use either select_ai.Profile or select_ai.AsyncProfile to
129
+ instantiate an AI profile object.
130
+
131
+ :param str profile_name : Name of the profile
132
+
133
+ :param select_ai.ProfileAttributes attributes:
134
+ Object specifying AI profile attributes
135
+
136
+ :param str description: Description of the profile
137
+
138
+ :param bool merge: Fetches the profile
139
+ from database, merges the attributes and saves it back
140
+ in the database. Default value is False
141
+
142
+ :param bool replace: Replaces the profile and attributes
143
+ in the database. Default value is False
144
+
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ profile_name: Optional[str] = None,
150
+ attributes: Optional[ProfileAttributes] = None,
151
+ description: Optional[str] = None,
152
+ merge: Optional[bool] = False,
153
+ replace: Optional[bool] = False,
154
+ ):
155
+ """Initialize a base profile"""
156
+ self.profile_name = profile_name
157
+ self.attributes = attributes
158
+ self.description = description
159
+ self.merge = merge
160
+ self.replace = replace
161
+
162
+ def __repr__(self):
163
+ return (
164
+ f"{self.__class__.__name__}(profile_name={self.profile_name}, "
165
+ f"attributes={self.attributes}, description={self.description})"
166
+ )