select-ai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of select-ai might be problematic. Click here for more details.

@@ -0,0 +1,534 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from contextlib import asynccontextmanager
10
+ from dataclasses import replace as dataclass_replace
11
+ from typing import (
12
+ AsyncGenerator,
13
+ List,
14
+ Mapping,
15
+ Optional,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import oracledb
21
+ import pandas
22
+
23
+ from select_ai.action import Action
24
+ from select_ai.base_profile import BaseProfile, ProfileAttributes
25
+ from select_ai.conversation import AsyncConversation
26
+ from select_ai.db import async_cursor, async_get_connection
27
+ from select_ai.errors import ProfileExistsError, ProfileNotFoundError
28
+ from select_ai.provider import Provider
29
+ from select_ai.sql import (
30
+ GET_USER_AI_PROFILE,
31
+ GET_USER_AI_PROFILE_ATTRIBUTES,
32
+ LIST_USER_AI_PROFILES,
33
+ )
34
+ from select_ai.synthetic_data import SyntheticDataAttributes
35
+
36
+ __all__ = ["AsyncProfile"]
37
+
38
+
39
+ class AsyncProfile(BaseProfile):
40
+ """AsyncProfile defines methods to interact with the underlying AI Provider
41
+ asynchronously.
42
+ """
43
+
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self._init_coroutine = self._init_profile()
47
+
48
+ def __await__(self):
49
+ coroutine = self._init_coroutine
50
+ return coroutine.__await__()
51
+
52
+ async def _init_profile(self):
53
+ """Initializes AI profile based on the passed attributes
54
+
55
+ :return: None
56
+ :raises: oracledb.DatabaseError
57
+ """
58
+ if self.profile_name:
59
+ profile_exists = False
60
+ try:
61
+ saved_attributes = await self._get_attributes(
62
+ profile_name=self.profile_name
63
+ )
64
+ profile_exists = True
65
+ if not self.replace and not self.merge:
66
+ if (
67
+ self.attributes is not None
68
+ or self.description is not None
69
+ ):
70
+ if self.raise_error_if_exists:
71
+ raise ProfileExistsError(self.profile_name)
72
+
73
+ if self.description is None:
74
+ self.description = await self._get_profile_description(
75
+ profile_name=self.profile_name
76
+ )
77
+ except ProfileNotFoundError:
78
+ if self.attributes is None and self.description is None:
79
+ raise
80
+ else:
81
+ if self.attributes is None:
82
+ self.attributes = saved_attributes
83
+ if self.merge:
84
+ self.replace = True
85
+ if self.attributes is not None:
86
+ self.attributes = dataclass_replace(
87
+ saved_attributes,
88
+ **self.attributes.dict(exclude_null=True),
89
+ )
90
+ if self.replace or not profile_exists:
91
+ await self.create(
92
+ replace=self.replace, description=self.description
93
+ )
94
+ else: # profile name is None:
95
+ if self.attributes is not None or self.description is not None:
96
+ raise ValueError("'profile_name' cannot be empty or None")
97
+ return self
98
+
99
+ @staticmethod
100
+ async def _get_profile_description(profile_name) -> Union[str, None]:
101
+ """Get description of profile from USER_CLOUD_AI_PROFILES
102
+
103
+ :param str profile_name: Name of profile
104
+ :return: Description of profile
105
+ :rtype: str
106
+ :raises: ProfileNotFoundError
107
+
108
+ """
109
+ async with async_cursor() as cr:
110
+ await cr.execute(
111
+ GET_USER_AI_PROFILE,
112
+ profile_name=profile_name.upper(),
113
+ )
114
+ profile = await cr.fetchone()
115
+ if profile:
116
+ if profile[1] is not None:
117
+ return await profile[1].read()
118
+ else:
119
+ return None
120
+ else:
121
+ raise ProfileNotFoundError(profile_name)
122
+
123
+ @staticmethod
124
+ async def _get_attributes(profile_name) -> ProfileAttributes:
125
+ """Asynchronously gets AI profile attributes from the Database
126
+
127
+ :param str profile_name: Name of the profile
128
+ :return: select_ai.provider.ProviderAttributes
129
+ :raises: ProfileNotFoundError
130
+
131
+ """
132
+ async with async_cursor() as cr:
133
+ await cr.execute(
134
+ GET_USER_AI_PROFILE_ATTRIBUTES,
135
+ profile_name=profile_name.upper(),
136
+ )
137
+ attributes = await cr.fetchall()
138
+ if attributes:
139
+ return await ProfileAttributes.async_create(**dict(attributes))
140
+ else:
141
+ raise ProfileNotFoundError(profile_name=profile_name)
142
+
143
+ async def get_attributes(self) -> ProfileAttributes:
144
+ """Asynchronously gets AI profile attributes from the Database
145
+
146
+ :return: select_ai.provider.ProviderAttributes
147
+ :raises: ProfileNotFoundError
148
+ """
149
+ return await self._get_attributes(profile_name=self.profile_name)
150
+
151
+ async def _set_attribute(
152
+ self,
153
+ attribute_name: str,
154
+ attribute_value: Union[bool, str, int, float],
155
+ ):
156
+ parameters = {
157
+ "profile_name": self.profile_name,
158
+ "attribute_name": attribute_name,
159
+ "attribute_value": attribute_value,
160
+ }
161
+ async with async_cursor() as cr:
162
+ await cr.callproc(
163
+ "DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
164
+ )
165
+
166
+ async def set_attribute(
167
+ self,
168
+ attribute_name: str,
169
+ attribute_value: Union[bool, str, int, float, Provider],
170
+ ):
171
+ """Updates AI profile attribute on the Python object and also
172
+ saves it in the database
173
+
174
+ :param str attribute_name: Name of the AI profile attribute
175
+ :param Union[bool, str, int, float] attribute_value: Value of the
176
+ profile attribute
177
+ :return: None
178
+
179
+ """
180
+ self.attributes.set_attribute(attribute_name, attribute_value)
181
+ if isinstance(attribute_value, Provider):
182
+ for k, v in attribute_value.dict().items():
183
+ await self._set_attribute(k, v)
184
+ else:
185
+ await self._set_attribute(attribute_name, attribute_value)
186
+
187
+ async def set_attributes(self, attributes: ProfileAttributes):
188
+ """Updates AI profile attributes on the Python object and also
189
+ saves it in the database
190
+
191
+ :param ProfileAttributes attributes: Object specifying AI profile
192
+ attributes
193
+ :return: None
194
+ """
195
+ if not isinstance(attributes, ProfileAttributes):
196
+ raise TypeError(
197
+ "'attributes' must be an object of type "
198
+ "select_ai.ProfileAttributes"
199
+ )
200
+
201
+ self.attributes = attributes
202
+ parameters = {
203
+ "profile_name": self.profile_name,
204
+ "attributes": self.attributes.json(),
205
+ }
206
+ async with async_cursor() as cr:
207
+ await cr.callproc(
208
+ "DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
209
+ )
210
+
211
+ async def create(
212
+ self, replace: Optional[int] = False, description: Optional[str] = None
213
+ ) -> None:
214
+ """Asynchronously create an AI Profile in the Database
215
+
216
+ :param bool replace: Set True to replace else False
217
+ :param description: The profile description
218
+ :return: None
219
+ :raises: oracledb.DatabaseError
220
+ """
221
+ if self.attributes is None:
222
+ raise AttributeError("Profile attributes cannot be None")
223
+ parameters = {
224
+ "profile_name": self.profile_name,
225
+ "attributes": self.attributes.json(),
226
+ }
227
+ if description:
228
+ parameters["description"] = description
229
+ async with async_cursor() as cr:
230
+ try:
231
+ await cr.callproc(
232
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
233
+ keyword_parameters=parameters,
234
+ )
235
+ except oracledb.DatabaseError as e:
236
+ (error,) = e.args
237
+ # If already exists and replace is True then drop and recreate
238
+ if error.code == 20046 and replace:
239
+ await self.delete(force=True)
240
+ await cr.callproc(
241
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
242
+ keyword_parameters=parameters,
243
+ )
244
+ else:
245
+ raise
246
+
247
+ async def delete(self, force=False) -> None:
248
+ """Asynchronously deletes an AI profile from the database
249
+
250
+ :param bool force: Ignores errors if AI profile does not exist.
251
+ :return: None
252
+ :raises: oracledb.DatabaseError
253
+
254
+ """
255
+ async with async_cursor() as cr:
256
+ await cr.callproc(
257
+ "DBMS_CLOUD_AI.DROP_PROFILE",
258
+ keyword_parameters={
259
+ "profile_name": self.profile_name,
260
+ "force": force,
261
+ },
262
+ )
263
+
264
+ @classmethod
265
+ async def _from_db(cls, profile_name: str) -> "AsyncProfile":
266
+ """Asynchronously create an AI Profile object from attributes
267
+ saved in the database against the profile
268
+
269
+ :param str profile_name:
270
+ :return: select_ai.Profile
271
+ :raises: ProfileNotFoundError
272
+ """
273
+ async with async_cursor() as cr:
274
+ await cr.execute(
275
+ GET_USER_AI_PROFILE_ATTRIBUTES, profile_name=profile_name
276
+ )
277
+ attributes = await cr.fetchall()
278
+ if attributes:
279
+ attributes = await ProfileAttributes.async_create(
280
+ **dict(attributes)
281
+ )
282
+ return cls(profile_name=profile_name, attributes=attributes)
283
+ else:
284
+ raise ProfileNotFoundError(profile_name=profile_name)
285
+
286
+ @classmethod
287
+ async def list(
288
+ cls, profile_name_pattern: str = ".*"
289
+ ) -> AsyncGenerator["AsyncProfile", None]:
290
+ """Asynchronously list AI Profiles saved in the database.
291
+
292
+ :param str profile_name_pattern: Regular expressions can be used
293
+ to specify a pattern. Function REGEXP_LIKE is used to perform the
294
+ match. Default value is ".*" i.e. match all AI profiles.
295
+
296
+ :return: Iterator[Profile]
297
+ """
298
+ async with async_cursor() as cr:
299
+ await cr.execute(
300
+ LIST_USER_AI_PROFILES,
301
+ profile_name_pattern=profile_name_pattern,
302
+ )
303
+ rows = await cr.fetchall()
304
+ for row in rows:
305
+ profile_name = row[0]
306
+ description = row[1]
307
+ attributes = await cls._get_attributes(
308
+ profile_name=profile_name
309
+ )
310
+ yield cls(
311
+ profile_name=profile_name,
312
+ description=description,
313
+ attributes=attributes,
314
+ raise_error_if_exists=False,
315
+ )
316
+
317
+ async def generate(
318
+ self, prompt: str, action=Action.SHOWSQL, params: Mapping = None
319
+ ) -> Union[pandas.DataFrame, str, None]:
320
+ """Asynchronously perform AI translation using this profile
321
+
322
+ :param str prompt: Natural language prompt to translate
323
+ :param select_ai.profile.Action action:
324
+ :param params: Parameters to include in the LLM request. For e.g.
325
+ conversation_id for context-aware chats
326
+ :return: Union[pandas.DataFrame, str]
327
+ """
328
+ if not prompt:
329
+ raise ValueError("prompt cannot be empty or None")
330
+
331
+ parameters = {
332
+ "prompt": prompt,
333
+ "action": action,
334
+ "profile_name": self.profile_name,
335
+ # "attributes": self.attributes.json(),
336
+ }
337
+ if params:
338
+ parameters["params"] = json.dumps(params)
339
+
340
+ async with async_cursor() as cr:
341
+ data = await cr.callfunc(
342
+ "DBMS_CLOUD_AI.GENERATE",
343
+ oracledb.DB_TYPE_CLOB,
344
+ keyword_parameters=parameters,
345
+ )
346
+ if data is not None:
347
+ result = await data.read()
348
+ else:
349
+ result = None
350
+ if action == Action.RUNSQL and result:
351
+ return pandas.DataFrame(json.loads(result))
352
+ elif action == Action.RUNSQL:
353
+ return pandas.DataFrame()
354
+ else:
355
+ return result
356
+
357
+ async def chat(self, prompt, params: Mapping = None) -> str:
358
+ """Asynchronously chat with the LLM
359
+
360
+ :param str prompt: Natural language prompt
361
+ :param params: Parameters to include in the LLM request
362
+ :return: str
363
+ """
364
+ return await self.generate(prompt, action=Action.CHAT, params=params)
365
+
366
+ @asynccontextmanager
367
+ async def chat_session(
368
+ self, conversation: AsyncConversation, delete: bool = False
369
+ ):
370
+ """Starts a new chat session for context-aware conversations
371
+
372
+ :param AsyncConversation conversation: Conversation object to use for this
373
+ chat session
374
+ :param bool delete: Delete conversation after session ends
375
+
376
+ """
377
+ try:
378
+ if (
379
+ conversation.conversation_id is None
380
+ and conversation.attributes is not None
381
+ ):
382
+ await conversation.create()
383
+ params = {"conversation_id": conversation.conversation_id}
384
+ async_session = AsyncSession(async_profile=self, params=params)
385
+ yield async_session
386
+ finally:
387
+ if delete:
388
+ await conversation.delete()
389
+
390
+ async def narrate(self, prompt, params: Mapping = None) -> str:
391
+ """Narrate the result of the SQL
392
+
393
+ :param str prompt: Natural language prompt
394
+ :param params: Parameters to include in the LLM request
395
+ :return: str
396
+ """
397
+ return await self.generate(
398
+ prompt, action=Action.NARRATE, params=params
399
+ )
400
+
401
+ async def explain_sql(self, prompt: str, params: Mapping = None):
402
+ """Explain the generated SQL
403
+
404
+ :param str prompt: Natural language prompt
405
+ :param params: Parameters to include in the LLM request
406
+ :return: str
407
+ """
408
+ return await self.generate(
409
+ prompt, action=Action.EXPLAINSQL, params=params
410
+ )
411
+
412
+ async def run_sql(
413
+ self, prompt, params: Mapping = None
414
+ ) -> pandas.DataFrame:
415
+ """Explain the generated SQL
416
+
417
+ :param str prompt: Natural language prompt
418
+ :param params: Parameters to include in the LLM request
419
+ :return: pandas.DataFrame
420
+ """
421
+ return await self.generate(prompt, action=Action.RUNSQL, params=params)
422
+
423
+ async def show_sql(self, prompt, params: Mapping = None):
424
+ """Show the generated SQL
425
+
426
+ :param str prompt: Natural language prompt
427
+ :param params: Parameters to include in the LLM request
428
+ :return: str
429
+ """
430
+ return await self.generate(
431
+ prompt, action=Action.SHOWSQL, params=params
432
+ )
433
+
434
+ async def show_prompt(self, prompt: str, params: Mapping = None):
435
+ """Show the prompt sent to LLM
436
+
437
+ :param str prompt: Natural language prompt
438
+ :param params: Parameters to include in the LLM request
439
+ :return: str
440
+ """
441
+ return await self.generate(
442
+ prompt, action=Action.SHOWPROMPT, params=params
443
+ )
444
+
445
+ async def generate_synthetic_data(
446
+ self, synthetic_data_attributes: SyntheticDataAttributes
447
+ ) -> None:
448
+ """Generate synthetic data for a single table, multiple tables or a
449
+ full schema.
450
+
451
+ :param select_ai.SyntheticDataAttributes synthetic_data_attributes:
452
+ :return: None
453
+ :raises: oracledb.DatabaseError
454
+
455
+ """
456
+ if synthetic_data_attributes is None:
457
+ raise ValueError("'synthetic_data_attributes' cannot be None")
458
+
459
+ if not isinstance(synthetic_data_attributes, SyntheticDataAttributes):
460
+ raise TypeError(
461
+ "'synthetic_data_attributes' must be an object "
462
+ "of type select_ai.SyntheticDataAttributes"
463
+ )
464
+
465
+ keyword_parameters = synthetic_data_attributes.prepare()
466
+ keyword_parameters["profile_name"] = self.profile_name
467
+ async with async_cursor() as cr:
468
+ await cr.callproc(
469
+ "DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
470
+ keyword_parameters=keyword_parameters,
471
+ )
472
+
473
+ async def run_pipeline(
474
+ self,
475
+ prompt_specifications: List[Tuple[str, Action]],
476
+ continue_on_error: bool = False,
477
+ ) -> List[Union[str, pandas.DataFrame]]:
478
+ """Send Multiple prompts in a single roundtrip to the Database
479
+
480
+ :param List[Tuple[str, Action]] prompt_specifications: List of
481
+ 2-element tuples. First element is the prompt and second is the
482
+ corresponding action
483
+
484
+ :param bool continue_on_error: True to continue on error else False
485
+ :return: List[Union[str, pandas.DataFrame]]
486
+ """
487
+ pipeline = oracledb.create_pipeline()
488
+ for prompt, action in prompt_specifications:
489
+ parameters = {
490
+ "prompt": prompt,
491
+ "action": action,
492
+ "profile_name": self.profile_name,
493
+ # "attributes": self.attributes.json(),
494
+ }
495
+ pipeline.add_callfunc(
496
+ "DBMS_CLOUD_AI.GENERATE",
497
+ return_type=oracledb.DB_TYPE_CLOB,
498
+ keyword_parameters=parameters,
499
+ )
500
+ async_connection = await async_get_connection()
501
+ pipeline_results = await async_connection.run_pipeline(
502
+ pipeline, continue_on_error=continue_on_error
503
+ )
504
+ responses = []
505
+ for result in pipeline_results:
506
+ if not result.error:
507
+ responses.append(await result.return_value.read())
508
+ else:
509
+ responses.append(result.error)
510
+ return responses
511
+
512
+
513
+ class AsyncSession:
514
+ """AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
515
+ requests. This is useful in context-aware conversations
516
+ """
517
+
518
+ def __init__(self, async_profile: AsyncProfile, params: Mapping):
519
+ """
520
+
521
+ :param async_profile: An AI Profile to use in this session
522
+ :param params: Parameters to be persisted across requests
523
+ """
524
+ self.params = params
525
+ self.async_profile = async_profile
526
+
527
+ async def chat(self, prompt: str):
528
+ return await self.async_profile.chat(prompt=prompt, params=self.params)
529
+
530
+ async def __aenter__(self):
531
+ return self
532
+
533
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
534
+ pass