select-ai 1.2.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,648 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from contextlib import asynccontextmanager
10
+ from dataclasses import replace as dataclass_replace
11
+ from typing import (
12
+ AsyncGenerator,
13
+ List,
14
+ Mapping,
15
+ Optional,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import oracledb
21
+ import pandas
22
+
23
+ from select_ai.action import Action
24
+ from select_ai.base_profile import (
25
+ BaseProfile,
26
+ ProfileAttributes,
27
+ no_data_for_prompt,
28
+ validate_params_for_feedback,
29
+ validate_params_for_summary,
30
+ )
31
+ from select_ai.conversation import AsyncConversation
32
+ from select_ai.db import async_cursor, async_get_connection
33
+ from select_ai.errors import ProfileExistsError, ProfileNotFoundError
34
+ from select_ai.feedback import (
35
+ FeedbackOperation,
36
+ FeedbackType,
37
+ )
38
+ from select_ai.provider import Provider
39
+ from select_ai.sql import (
40
+ GET_USER_AI_PROFILE,
41
+ GET_USER_AI_PROFILE_ATTRIBUTES,
42
+ LIST_USER_AI_PROFILES,
43
+ )
44
+ from select_ai.summary import SummaryParams
45
+ from select_ai.synthetic_data import SyntheticDataAttributes
46
+
47
+ __all__ = ["AsyncProfile"]
48
+
49
+
50
+ class AsyncProfile(BaseProfile):
51
+ """AsyncProfile defines methods to interact with the underlying AI Provider
52
+ asynchronously.
53
+ """
54
+
55
+ def __init__(self, *args, **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+ self._init_coroutine = self._init_profile()
58
+
59
+ def __await__(self):
60
+ coroutine = self._init_coroutine
61
+ return coroutine.__await__()
62
+
63
+ async def _init_profile(self):
64
+ """Initializes AI profile based on the passed attributes
65
+
66
+ :return: None
67
+ :raises: oracledb.DatabaseError
68
+ """
69
+ if self.profile_name:
70
+ profile_exists = False
71
+ try:
72
+ saved_attributes = await self._get_attributes(
73
+ profile_name=self.profile_name
74
+ )
75
+ profile_exists = True
76
+ if not self.replace and not self.merge:
77
+ if (
78
+ self.attributes is not None
79
+ or self.description is not None
80
+ ):
81
+ if self.raise_error_if_exists:
82
+ raise ProfileExistsError(self.profile_name)
83
+
84
+ if self.description is None and not self.replace:
85
+ self.description = await self._get_profile_description(
86
+ profile_name=self.profile_name
87
+ )
88
+ except ProfileNotFoundError:
89
+ if self.attributes is None and self.description is None:
90
+ raise
91
+ else:
92
+ if self.attributes is None:
93
+ self.attributes = saved_attributes
94
+ if self.merge:
95
+ self.replace = True
96
+ if self.attributes is not None:
97
+ self.attributes = dataclass_replace(
98
+ saved_attributes,
99
+ **self.attributes.dict(exclude_null=True),
100
+ )
101
+ if self.replace or not profile_exists:
102
+ await self.create(replace=self.replace)
103
+ else: # profile name is None:
104
+ if self.attributes is not None or self.description is not None:
105
+ raise ValueError("'profile_name' cannot be empty or None")
106
+ return self
107
+
108
+ @staticmethod
109
+ async def _get_profile_description(profile_name) -> Union[str, None]:
110
+ """Get description of profile from USER_CLOUD_AI_PROFILES
111
+
112
+ :param str profile_name: Name of profile
113
+ :return: Description of profile
114
+ :rtype: str
115
+ :raises: ProfileNotFoundError
116
+
117
+ """
118
+ async with async_cursor() as cr:
119
+ await cr.execute(
120
+ GET_USER_AI_PROFILE,
121
+ profile_name=profile_name.upper(),
122
+ )
123
+ profile = await cr.fetchone()
124
+ if profile is None:
125
+ raise ProfileNotFoundError(profile_name)
126
+ if profile:
127
+ if profile[1] is not None:
128
+ return await profile[1].read()
129
+ else:
130
+ return None
131
+ else:
132
+ raise ProfileNotFoundError(profile_name)
133
+
134
+ @staticmethod
135
+ async def _get_attributes(profile_name) -> ProfileAttributes:
136
+ """Asynchronously gets AI profile attributes from the Database
137
+
138
+ :param str profile_name: Name of the profile
139
+ :return: select_ai.provider.ProviderAttributes
140
+ :raises: ProfileNotFoundError
141
+
142
+ """
143
+ async with async_cursor() as cr:
144
+ await cr.execute(
145
+ GET_USER_AI_PROFILE_ATTRIBUTES,
146
+ profile_name=profile_name.upper(),
147
+ )
148
+ attributes = await cr.fetchall()
149
+ if attributes:
150
+ return await ProfileAttributes.async_create(**dict(attributes))
151
+ else:
152
+ raise ProfileNotFoundError(profile_name=profile_name)
153
+
154
+ async def get_attributes(self) -> ProfileAttributes:
155
+ """Asynchronously gets AI profile attributes from the Database
156
+
157
+ :return: select_ai.provider.ProviderAttributes
158
+ :raises: ProfileNotFoundError
159
+ """
160
+ return await self._get_attributes(profile_name=self.profile_name)
161
+
162
+ async def _set_attribute(
163
+ self,
164
+ attribute_name: str,
165
+ attribute_value: Union[bool, str, int, float],
166
+ ):
167
+ parameters = {
168
+ "profile_name": self.profile_name,
169
+ "attribute_name": attribute_name,
170
+ "attribute_value": attribute_value,
171
+ }
172
+ async with async_cursor() as cr:
173
+ await cr.callproc(
174
+ "DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
175
+ )
176
+
177
+ async def set_attribute(
178
+ self,
179
+ attribute_name: str,
180
+ attribute_value: Union[bool, str, int, float, Provider],
181
+ ):
182
+ """Updates AI profile attribute on the Python object and also
183
+ saves it in the database
184
+
185
+ :param str attribute_name: Name of the AI profile attribute
186
+ :param Union[bool, str, int, float] attribute_value: Value of the
187
+ profile attribute
188
+ :return: None
189
+
190
+ """
191
+ self.attributes.set_attribute(attribute_name, attribute_value)
192
+ if isinstance(attribute_value, Provider):
193
+ for k, v in attribute_value.dict().items():
194
+ await self._set_attribute(k, v)
195
+ else:
196
+ await self._set_attribute(attribute_name, attribute_value)
197
+
198
+ async def set_attributes(self, attributes: ProfileAttributes):
199
+ """Updates AI profile attributes on the Python object and also
200
+ saves it in the database
201
+
202
+ :param ProfileAttributes attributes: Object specifying AI profile
203
+ attributes
204
+ :return: None
205
+ """
206
+ if not isinstance(attributes, ProfileAttributes):
207
+ raise TypeError(
208
+ "'attributes' must be an object of type "
209
+ "select_ai.ProfileAttributes"
210
+ )
211
+ parameters = {
212
+ "profile_name": self.profile_name,
213
+ "attributes": attributes.json(),
214
+ }
215
+ async with async_cursor() as cr:
216
+ await cr.callproc(
217
+ "DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
218
+ )
219
+ self.attributes = await self.get_attributes()
220
+
221
+ async def create(self, replace: Optional[int] = False) -> None:
222
+ """Asynchronously create an AI Profile in the Database
223
+
224
+ :param bool replace: Set True to replace else False
225
+ :return: None
226
+ :raises: oracledb.DatabaseError
227
+ """
228
+ if self.attributes is None:
229
+ raise AttributeError("Profile attributes cannot be None")
230
+ parameters = {
231
+ "profile_name": self.profile_name,
232
+ "attributes": self.attributes.json(),
233
+ }
234
+ if self.description:
235
+ parameters["description"] = self.description
236
+ async with async_cursor() as cr:
237
+ try:
238
+ await cr.callproc(
239
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
240
+ keyword_parameters=parameters,
241
+ )
242
+ except oracledb.DatabaseError as e:
243
+ (error,) = e.args
244
+ # If already exists and replace is True then drop and recreate
245
+ if error.code == 20046 and replace:
246
+ await self.delete(force=True)
247
+ await cr.callproc(
248
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
249
+ keyword_parameters=parameters,
250
+ )
251
+ else:
252
+ raise
253
+
254
+ async def delete(self, force=False) -> None:
255
+ """Asynchronously deletes an AI profile from the database
256
+
257
+ :param bool force: Ignores errors if AI profile does not exist.
258
+ :return: None
259
+ :raises: oracledb.DatabaseError
260
+
261
+ """
262
+ async with async_cursor() as cr:
263
+ await cr.callproc(
264
+ "DBMS_CLOUD_AI.DROP_PROFILE",
265
+ keyword_parameters={
266
+ "profile_name": self.profile_name,
267
+ "force": force,
268
+ },
269
+ )
270
+
271
+ @classmethod
272
+ async def fetch(cls, profile_name: str) -> "AsyncProfile":
273
+ """Asynchronously create an AI Profile object from attributes
274
+ saved in the database
275
+
276
+ :param str profile_name:
277
+ :return: select_ai.Profile
278
+ :raises: ProfileNotFoundError
279
+ """
280
+ return await cls(profile_name, raise_error_if_exists=False)
281
+
282
+ async def _save_feedback(
283
+ self,
284
+ feedback_type: FeedbackType = None,
285
+ prompt_spec: Tuple[str, Action] = None,
286
+ sql_id: Optional[str] = None,
287
+ response: Optional[str] = None,
288
+ feedback_content: Optional[str] = None,
289
+ operation: Optional[FeedbackOperation] = FeedbackOperation.ADD,
290
+ ):
291
+ """
292
+ Internal method to provide feedback
293
+ """
294
+ params = validate_params_for_feedback(
295
+ feedback_type=feedback_type,
296
+ feedback_content=feedback_content,
297
+ prompt_spec=prompt_spec,
298
+ sql_id=sql_id,
299
+ response=response,
300
+ operation=operation,
301
+ )
302
+ params["profile_name"] = self.profile_name
303
+ async with async_cursor() as cr:
304
+ await cr.callproc(
305
+ "DBMS_CLOUD_AI.FEEDBACK", keyword_parameters=params
306
+ )
307
+
308
+ async def add_positive_feedback(
309
+ self,
310
+ prompt_spec: Optional[Tuple[str, Action]] = None,
311
+ sql_id: Optional[str] = None,
312
+ ):
313
+ """
314
+ Give positive feedback to the LLM
315
+
316
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
317
+ second is the corresponding action
318
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
319
+ """
320
+ await self._save_feedback(
321
+ feedback_type=FeedbackType.POSITIVE,
322
+ prompt_spec=prompt_spec,
323
+ sql_id=sql_id,
324
+ )
325
+
326
+ async def add_negative_feedback(
327
+ self,
328
+ prompt_spec: Optional[Tuple[str, Action]] = None,
329
+ sql_id: Optional[str] = None,
330
+ response: Optional[str] = None,
331
+ feedback_content: Optional[str] = None,
332
+ ):
333
+ """
334
+ Give negative feedback to the LLM
335
+
336
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
337
+ second is the corresponding action
338
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
339
+ :param str response: Expected SQL from LLM
340
+ :param str feedback_content: Actual feedback in natural language
341
+ """
342
+ await self._save_feedback(
343
+ feedback_type=FeedbackType.NEGATIVE,
344
+ prompt_spec=prompt_spec,
345
+ sql_id=sql_id,
346
+ response=response,
347
+ feedback_content=feedback_content,
348
+ )
349
+
350
+ async def delete_feedback(
351
+ self,
352
+ prompt_spec: Tuple[str, Action] = None,
353
+ sql_id: Optional[str] = None,
354
+ ):
355
+ """
356
+ Delete feedback from the database
357
+
358
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
359
+ second is the corresponding action
360
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
361
+
362
+ """
363
+ await self._save_feedback(
364
+ operation=FeedbackOperation.DELETE,
365
+ prompt_spec=prompt_spec,
366
+ sql_id=sql_id,
367
+ )
368
+
369
+ @classmethod
370
+ async def list(
371
+ cls, profile_name_pattern: str = ".*"
372
+ ) -> AsyncGenerator["AsyncProfile", None]:
373
+ """Asynchronously list AI Profiles saved in the database.
374
+
375
+ :param str profile_name_pattern: Regular expressions can be used
376
+ to specify a pattern. Function REGEXP_LIKE is used to perform the
377
+ match. Default value is ".*" i.e. match all AI profiles.
378
+
379
+ :return: Iterator[Profile]
380
+ """
381
+ async with async_cursor() as cr:
382
+ await cr.execute(
383
+ LIST_USER_AI_PROFILES,
384
+ profile_name_pattern=profile_name_pattern,
385
+ )
386
+ rows = await cr.fetchall()
387
+ for row in rows:
388
+ profile_name = row[0]
389
+ yield await cls(
390
+ profile_name=profile_name, raise_error_if_exists=False
391
+ )
392
+
393
+ async def generate(
394
+ self, prompt: str, action=Action.SHOWSQL, params: Mapping = None
395
+ ) -> Union[pandas.DataFrame, str, None]:
396
+ """Asynchronously perform AI translation using this profile
397
+
398
+ :param str prompt: Natural language prompt to translate
399
+ :param select_ai.profile.Action action:
400
+ :param params: Parameters to include in the LLM request. For e.g.
401
+ conversation_id for context-aware chats
402
+ :return: Union[pandas.DataFrame, str]
403
+ """
404
+ if not prompt:
405
+ raise ValueError("prompt cannot be empty or None")
406
+
407
+ parameters = {
408
+ "prompt": prompt,
409
+ "action": action,
410
+ "profile_name": self.profile_name,
411
+ # "attributes": self.attributes.json(),
412
+ }
413
+ if params:
414
+ parameters["params"] = json.dumps(params)
415
+
416
+ async with async_cursor() as cr:
417
+ data = await cr.callfunc(
418
+ "DBMS_CLOUD_AI.GENERATE",
419
+ oracledb.DB_TYPE_CLOB,
420
+ keyword_parameters=parameters,
421
+ )
422
+ if data is not None:
423
+ result = await data.read()
424
+ else:
425
+ result = None
426
+ if action == Action.RUNSQL:
427
+ if no_data_for_prompt(result): # empty dataframe
428
+ return pandas.DataFrame()
429
+ return pandas.DataFrame(json.loads(result))
430
+ else:
431
+ return result
432
+
433
+ async def chat(self, prompt, params: Mapping = None) -> str:
434
+ """Asynchronously chat with the LLM
435
+
436
+ :param str prompt: Natural language prompt
437
+ :param params: Parameters to include in the LLM request
438
+ :return: str
439
+ """
440
+ return await self.generate(prompt, action=Action.CHAT, params=params)
441
+
442
+ @asynccontextmanager
443
+ async def chat_session(
444
+ self, conversation: AsyncConversation, delete: bool = False
445
+ ):
446
+ """Starts a new chat session for context-aware conversations
447
+
448
+ :param AsyncConversation conversation: Conversation object to use for this
449
+ chat session
450
+ :param bool delete: Delete conversation after session ends
451
+
452
+ """
453
+ try:
454
+ if (
455
+ conversation.conversation_id is None
456
+ and conversation.attributes is not None
457
+ ):
458
+ await conversation.create()
459
+ params = {"conversation_id": conversation.conversation_id}
460
+ async_session = AsyncSession(async_profile=self, params=params)
461
+ yield async_session
462
+ finally:
463
+ if delete:
464
+ await conversation.delete()
465
+
466
+ async def narrate(self, prompt, params: Mapping = None) -> str:
467
+ """Narrate the result of the SQL
468
+
469
+ :param str prompt: Natural language prompt
470
+ :param params: Parameters to include in the LLM request
471
+ :return: str
472
+ """
473
+ return await self.generate(
474
+ prompt, action=Action.NARRATE, params=params
475
+ )
476
+
477
+ async def explain_sql(self, prompt: str, params: Mapping = None):
478
+ """Explain the generated SQL
479
+
480
+ :param str prompt: Natural language prompt
481
+ :param params: Parameters to include in the LLM request
482
+ :return: str
483
+ """
484
+ return await self.generate(
485
+ prompt, action=Action.EXPLAINSQL, params=params
486
+ )
487
+
488
+ async def run_sql(
489
+ self, prompt, params: Mapping = None
490
+ ) -> pandas.DataFrame:
491
+ """Explain the generated SQL
492
+
493
+ :param str prompt: Natural language prompt
494
+ :param params: Parameters to include in the LLM request
495
+ :return: pandas.DataFrame
496
+ """
497
+ return await self.generate(prompt, action=Action.RUNSQL, params=params)
498
+
499
+ async def show_sql(self, prompt, params: Mapping = None):
500
+ """Show the generated SQL
501
+
502
+ :param str prompt: Natural language prompt
503
+ :param params: Parameters to include in the LLM request
504
+ :return: str
505
+ """
506
+ return await self.generate(
507
+ prompt, action=Action.SHOWSQL, params=params
508
+ )
509
+
510
+ async def show_prompt(self, prompt: str, params: Mapping = None):
511
+ """Show the prompt sent to LLM
512
+
513
+ :param str prompt: Natural language prompt
514
+ :param params: Parameters to include in the LLM request
515
+ :return: str
516
+ """
517
+ return await self.generate(
518
+ prompt, action=Action.SHOWPROMPT, params=params
519
+ )
520
+
521
+ async def summarize(
522
+ self,
523
+ content: str = None,
524
+ prompt: str = None,
525
+ location_uri: str = None,
526
+ credential_name: str = None,
527
+ params: SummaryParams = None,
528
+ ) -> str:
529
+ """Generate summary
530
+
531
+ :param str prompt: Natural language prompt to guide the summary
532
+ generation
533
+ :param str content: Specifies the text you want to summarize
534
+ :param str location_uri: Provides the URI where the text is stored or
535
+ the path to a local file stored
536
+ :param str credential_name: Identifies the credential object used to
537
+ authenticate with the object store
538
+ :param select_ai.summary.SummaryParams params: Parameters to include
539
+ in the LLM request
540
+ """
541
+ parameters = validate_params_for_summary(
542
+ prompt=prompt,
543
+ location_uri=location_uri,
544
+ content=content,
545
+ credential_name=credential_name,
546
+ params=params,
547
+ )
548
+ parameters["profile_name"] = self.profile_name
549
+ async with async_cursor() as cr:
550
+ data = await cr.callfunc(
551
+ "DBMS_CLOUD_AI.SUMMARIZE",
552
+ oracledb.DB_TYPE_CLOB,
553
+ keyword_parameters=parameters,
554
+ )
555
+ return await data.read() if data else None
556
+
557
+ async def generate_synthetic_data(
558
+ self, synthetic_data_attributes: SyntheticDataAttributes
559
+ ) -> None:
560
+ """Generate synthetic data for a single table, multiple tables or a
561
+ full schema.
562
+
563
+ :param select_ai.SyntheticDataAttributes synthetic_data_attributes:
564
+ :return: None
565
+ :raises: oracledb.DatabaseError
566
+
567
+ """
568
+ if synthetic_data_attributes is None:
569
+ raise ValueError("'synthetic_data_attributes' cannot be None")
570
+
571
+ if not isinstance(synthetic_data_attributes, SyntheticDataAttributes):
572
+ raise TypeError(
573
+ "'synthetic_data_attributes' must be an object "
574
+ "of type select_ai.SyntheticDataAttributes"
575
+ )
576
+
577
+ keyword_parameters = synthetic_data_attributes.prepare()
578
+ keyword_parameters["profile_name"] = self.profile_name
579
+ async with async_cursor() as cr:
580
+ await cr.callproc(
581
+ "DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
582
+ keyword_parameters=keyword_parameters,
583
+ )
584
+
585
+ async def run_pipeline(
586
+ self,
587
+ prompt_specifications: List[Tuple[str, Action]],
588
+ continue_on_error: bool = False,
589
+ ) -> List[Union[str, pandas.DataFrame]]:
590
+ """Send Multiple prompts in a single roundtrip to the Database
591
+
592
+ :param List[Tuple[str, Action]] prompt_specifications: List of
593
+ 2-element tuples. First element is the prompt and second is the
594
+ corresponding action
595
+
596
+ :param bool continue_on_error: True to continue on error else False
597
+ :return: List[Union[str, pandas.DataFrame]]
598
+ """
599
+ pipeline = oracledb.create_pipeline()
600
+ for prompt, action in prompt_specifications:
601
+ parameters = {
602
+ "prompt": prompt,
603
+ "action": action,
604
+ "profile_name": self.profile_name,
605
+ # "attributes": self.attributes.json(),
606
+ }
607
+ pipeline.add_callfunc(
608
+ "DBMS_CLOUD_AI.GENERATE",
609
+ return_type=oracledb.DB_TYPE_CLOB,
610
+ keyword_parameters=parameters,
611
+ )
612
+ async_connection = await async_get_connection()
613
+ pipeline_results = await async_connection.run_pipeline(
614
+ pipeline, continue_on_error=continue_on_error
615
+ )
616
+ responses = []
617
+ for result in pipeline_results:
618
+ if not result.error:
619
+ lob_data = result.return_value
620
+ data = await lob_data.read()
621
+ responses.append(data)
622
+ else:
623
+ responses.append(result.error)
624
+ return responses
625
+
626
+
627
+ class AsyncSession:
628
+ """AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
629
+ requests. This is useful in context-aware conversations
630
+ """
631
+
632
+ def __init__(self, async_profile: AsyncProfile, params: Mapping):
633
+ """
634
+
635
+ :param async_profile: An AI Profile to use in this session
636
+ :param params: Parameters to be persisted across requests
637
+ """
638
+ self.params = params
639
+ self.async_profile = async_profile
640
+
641
+ async def chat(self, prompt: str):
642
+ return await self.async_profile.chat(prompt=prompt, params=self.params)
643
+
644
+ async def __aenter__(self):
645
+ return self
646
+
647
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
648
+ pass