select-ai 1.2.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
select_ai/profile.py ADDED
@@ -0,0 +1,579 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from contextlib import contextmanager
10
+ from dataclasses import replace as dataclass_replace
11
+ from pprint import pformat
12
+ from typing import Generator, Iterator, Mapping, Optional, Tuple, Union
13
+
14
+ import oracledb
15
+ import pandas
16
+
17
+ from select_ai import Conversation
18
+ from select_ai.action import Action
19
+ from select_ai.base_profile import (
20
+ BaseProfile,
21
+ ProfileAttributes,
22
+ no_data_for_prompt,
23
+ validate_params_for_feedback,
24
+ validate_params_for_summary,
25
+ )
26
+ from select_ai.db import cursor
27
+ from select_ai.errors import ProfileExistsError, ProfileNotFoundError
28
+ from select_ai.feedback import FeedbackOperation, FeedbackType
29
+ from select_ai.provider import Provider
30
+ from select_ai.sql import (
31
+ GET_USER_AI_PROFILE,
32
+ GET_USER_AI_PROFILE_ATTRIBUTES,
33
+ LIST_USER_AI_PROFILES,
34
+ )
35
+ from select_ai.summary import SummaryParams
36
+ from select_ai.synthetic_data import SyntheticDataAttributes
37
+
38
+
39
+ class Profile(BaseProfile):
40
+ """Profile class represents an AI Profile. It defines
41
+ attributes and methods to interact with the underlying
42
+ AI Provider. All methods in this class are synchronous
43
+ or blocking
44
+ """
45
+
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+ self._init_profile()
49
+
50
+ def _init_profile(self) -> None:
51
+ """Initializes AI profile based on the passed attributes
52
+
53
+ :return: None
54
+ :raises: oracledb.DatabaseError
55
+ """
56
+ if self.profile_name:
57
+ profile_exists = False
58
+ try:
59
+ saved_attributes = self._get_attributes(
60
+ profile_name=self.profile_name
61
+ )
62
+ profile_exists = True
63
+ if not self.replace and not self.merge:
64
+ if (
65
+ self.attributes is not None
66
+ or self.description is not None
67
+ ):
68
+ if self.raise_error_if_exists:
69
+ raise ProfileExistsError(self.profile_name)
70
+
71
+ if self.description is None and not self.replace:
72
+ self.description = self._get_profile_description(
73
+ profile_name=self.profile_name
74
+ )
75
+ except ProfileNotFoundError:
76
+ if self.attributes is None and self.description is None:
77
+ raise
78
+ else:
79
+ if self.attributes is None:
80
+ self.attributes = saved_attributes
81
+ if self.merge:
82
+ self.replace = True
83
+ if self.attributes is not None:
84
+ self.attributes = dataclass_replace(
85
+ saved_attributes,
86
+ **self.attributes.dict(exclude_null=True),
87
+ )
88
+ if self.replace or not profile_exists:
89
+ self.create(replace=self.replace)
90
+ else: # profile name is None
91
+ if self.attributes is not None or self.description is not None:
92
+ raise ValueError(
93
+ "Attribute 'profile_name' cannot be empty or None"
94
+ )
95
+
96
+ @staticmethod
97
+ def _get_profile_description(profile_name) -> Union[str, None]:
98
+ """Get description of profile from USER_CLOUD_AI_PROFILES
99
+
100
+ :param str profile_name:
101
+ :return: Union[str, None] profile description
102
+ :raises: ProfileNotFoundError
103
+ """
104
+ with cursor() as cr:
105
+ cr.execute(GET_USER_AI_PROFILE, profile_name=profile_name.upper())
106
+ profile = cr.fetchone()
107
+ if profile:
108
+ if profile[1] is not None:
109
+ return profile[1].read()
110
+ else:
111
+ return None
112
+ else:
113
+ raise ProfileNotFoundError(profile_name)
114
+
115
+ @staticmethod
116
+ def _get_attributes(profile_name) -> ProfileAttributes:
117
+ """Get AI profile attributes from the Database
118
+
119
+ :param str profile_name: Name of the profile
120
+ :return: select_ai.ProfileAttributes
121
+ :raises: ProfileNotFoundError
122
+ """
123
+ with cursor() as cr:
124
+ cr.execute(
125
+ GET_USER_AI_PROFILE_ATTRIBUTES,
126
+ profile_name=profile_name.upper(),
127
+ )
128
+ attributes = cr.fetchall()
129
+ if attributes:
130
+ return ProfileAttributes.create(**dict(attributes))
131
+ else:
132
+ raise ProfileNotFoundError(profile_name=profile_name)
133
+
134
+ def get_attributes(self) -> ProfileAttributes:
135
+ """Get AI profile attributes from the Database
136
+
137
+ :return: select_ai.ProfileAttributes
138
+ """
139
+ return self._get_attributes(profile_name=self.profile_name)
140
+
141
+ def _set_attribute(
142
+ self,
143
+ attribute_name: str,
144
+ attribute_value: Union[bool, str, int, float],
145
+ ):
146
+ parameters = {
147
+ "profile_name": self.profile_name,
148
+ "attribute_name": attribute_name,
149
+ "attribute_value": attribute_value,
150
+ }
151
+ with cursor() as cr:
152
+ cr.callproc(
153
+ "DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
154
+ )
155
+
156
+ def set_attribute(
157
+ self,
158
+ attribute_name: str,
159
+ attribute_value: Union[bool, str, int, float, Provider],
160
+ ):
161
+ """Updates AI profile attribute on the Python object and also
162
+ saves it in the database
163
+
164
+ :param str attribute_name: Name of the AI profile attribute
165
+ :param Union[bool, str, int, float, Provider] attribute_value: Value of
166
+ the profile attribute
167
+ :return: None
168
+
169
+ """
170
+ self.attributes.set_attribute(attribute_name, attribute_value)
171
+ if isinstance(attribute_value, Provider):
172
+ for k, v in attribute_value.dict().items():
173
+ self._set_attribute(k, v)
174
+ else:
175
+ self._set_attribute(attribute_name, attribute_value)
176
+
177
+ def set_attributes(self, attributes: ProfileAttributes):
178
+ """Updates AI profile attributes on the Python object and also
179
+ saves it in the database
180
+
181
+ :param ProviderAttributes attributes: Object specifying AI profile
182
+ attributes
183
+ :return: None
184
+ """
185
+ if not isinstance(attributes, ProfileAttributes):
186
+ raise TypeError(
187
+ "'attributes' must be an object of type"
188
+ " select_ai.ProfileAttributes"
189
+ )
190
+ parameters = {
191
+ "profile_name": self.profile_name,
192
+ "attributes": attributes.json(),
193
+ }
194
+ with cursor() as cr:
195
+ cr.callproc(
196
+ "DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
197
+ )
198
+ self.attributes = self.get_attributes()
199
+
200
+ def create(self, replace: Optional[int] = False) -> None:
201
+ """Create an AI Profile in the Database
202
+
203
+ :param bool replace: Set True to replace else False
204
+ :return: None
205
+ :raises: oracledb.DatabaseError
206
+ """
207
+ if self.attributes is None:
208
+ raise AttributeError("Profile attributes cannot be None")
209
+ parameters = {
210
+ "profile_name": self.profile_name,
211
+ "attributes": self.attributes.json(),
212
+ }
213
+ if self.description:
214
+ parameters["description"] = self.description
215
+
216
+ with cursor() as cr:
217
+ try:
218
+ cr.callproc(
219
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
220
+ keyword_parameters=parameters,
221
+ )
222
+ except oracledb.DatabaseError as e:
223
+ (error,) = e.args
224
+ # If already exists and replace is True then drop and recreate
225
+ if error.code == 20046 and replace:
226
+ self.delete(force=True)
227
+ cr.callproc(
228
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
229
+ keyword_parameters=parameters,
230
+ )
231
+ else:
232
+ raise
233
+
234
+ def delete(self, force=False) -> None:
235
+ """Deletes an AI profile from the database
236
+
237
+ :param bool force: Ignores errors if AI profile does not exist.
238
+ :return: None
239
+ :raises: oracledb.DatabaseError
240
+ """
241
+ with cursor() as cr:
242
+ cr.callproc(
243
+ "DBMS_CLOUD_AI.DROP_PROFILE",
244
+ keyword_parameters={
245
+ "profile_name": self.profile_name,
246
+ "force": force,
247
+ },
248
+ )
249
+
250
+ @classmethod
251
+ def fetch(cls, profile_name: str) -> "Profile":
252
+ """Create a proxy Profile object from fetched attributes saved in the
253
+ database
254
+
255
+ :param str profile_name: The name of the AI profile
256
+ :return: select_ai.Profile
257
+ :raises: ProfileNotFoundError
258
+ """
259
+ return cls(profile_name, raise_error_if_exists=False)
260
+
261
+ def _save_feedback(
262
+ self,
263
+ feedback_type: FeedbackType = None,
264
+ prompt_spec: Tuple[str, Action] = None,
265
+ sql_id: Optional[str] = None,
266
+ response: Optional[str] = None,
267
+ feedback_content: Optional[str] = None,
268
+ operation: Optional[FeedbackOperation] = FeedbackOperation.ADD,
269
+ ):
270
+ """
271
+ Internal method to provide feedback
272
+ """
273
+ params = validate_params_for_feedback(
274
+ feedback_type=feedback_type,
275
+ feedback_content=feedback_content,
276
+ prompt_spec=prompt_spec,
277
+ sql_id=sql_id,
278
+ response=response,
279
+ operation=operation,
280
+ )
281
+ params["profile_name"] = self.profile_name
282
+ print(params)
283
+ with cursor() as cr:
284
+ cr.callproc("DBMS_CLOUD_AI.FEEDBACK", keyword_parameters=params)
285
+
286
+ def add_positive_feedback(
287
+ self,
288
+ prompt_spec: Optional[Tuple[str, Action]] = None,
289
+ sql_id: Optional[str] = None,
290
+ ):
291
+ """
292
+ Give positive feedback to the LLM
293
+
294
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
295
+ second is the corresponding action
296
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
297
+ """
298
+ self._save_feedback(
299
+ feedback_type=FeedbackType.POSITIVE,
300
+ prompt_spec=prompt_spec,
301
+ sql_id=sql_id,
302
+ )
303
+
304
+ def add_negative_feedback(
305
+ self,
306
+ prompt_spec: Optional[Tuple[str, Action]] = None,
307
+ sql_id: Optional[str] = None,
308
+ response: Optional[str] = None,
309
+ feedback_content: Optional[str] = None,
310
+ ):
311
+ """
312
+ Give negative feedback to the LLM
313
+
314
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
315
+ second is the corresponding action
316
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
317
+ :param str response: Expected SQL from LLM
318
+ :param str feedback_content: Actual feedback in natural language
319
+ """
320
+ self._save_feedback(
321
+ feedback_type=FeedbackType.NEGATIVE,
322
+ prompt_spec=prompt_spec,
323
+ sql_id=sql_id,
324
+ response=response,
325
+ feedback_content=feedback_content,
326
+ )
327
+
328
+ def delete_feedback(
329
+ self,
330
+ prompt_spec: Tuple[str, Action] = None,
331
+ sql_id: Optional[str] = None,
332
+ ):
333
+ """
334
+ Delete feedback from the database
335
+
336
+ :param Tuple[str, Action] prompt_spec: First element is the prompt and
337
+ second is the corresponding action
338
+ :param str sql_id: SQL identifier from V$MAPPED_SQL view
339
+
340
+ """
341
+ self._save_feedback(
342
+ operation=FeedbackOperation.DELETE,
343
+ prompt_spec=prompt_spec,
344
+ sql_id=sql_id,
345
+ )
346
+
347
+ @classmethod
348
+ def list(
349
+ cls, profile_name_pattern: str = ".*"
350
+ ) -> Generator["Profile", None, None]:
351
+ """List AI Profiles saved in the database.
352
+
353
+ :param str profile_name_pattern: Regular expressions can be used
354
+ to specify a pattern. Function REGEXP_LIKE is used to perform the
355
+ match. Default value is ".*" i.e. match all AI profiles.
356
+
357
+ :return: Iterator[Profile]
358
+ """
359
+ with cursor() as cr:
360
+ cr.execute(
361
+ LIST_USER_AI_PROFILES,
362
+ profile_name_pattern=profile_name_pattern,
363
+ )
364
+ for row in cr.fetchall():
365
+ profile_name = row[0]
366
+ yield cls(
367
+ profile_name=profile_name, raise_error_if_exists=False
368
+ )
369
+
370
+ def generate(
371
+ self,
372
+ prompt: str,
373
+ action: Optional[Action] = Action.RUNSQL,
374
+ params: Mapping = None,
375
+ ) -> Union[pandas.DataFrame, str, None]:
376
+ """Perform AI translation using this profile
377
+
378
+ :param str prompt: Natural language prompt to translate
379
+ :param select_ai.profile.Action action:
380
+ :param params: Parameters to include in the LLM request. For e.g.
381
+ conversation_id for context-aware chats
382
+ :return: Union[pandas.DataFrame, str]
383
+ """
384
+ if not prompt:
385
+ raise ValueError("prompt cannot be empty or None")
386
+ parameters = {
387
+ "prompt": prompt,
388
+ "action": action,
389
+ "profile_name": self.profile_name,
390
+ # "attributes": self.attributes.json(),
391
+ }
392
+ if params:
393
+ parameters["params"] = json.dumps(params)
394
+ with cursor() as cr:
395
+ data = cr.callfunc(
396
+ "DBMS_CLOUD_AI.GENERATE",
397
+ oracledb.DB_TYPE_CLOB,
398
+ keyword_parameters=parameters,
399
+ )
400
+ if data is not None:
401
+ result = data.read()
402
+ else:
403
+ result = None
404
+ if action == Action.RUNSQL:
405
+ if no_data_for_prompt(result): # empty dataframe
406
+ return pandas.DataFrame()
407
+ return pandas.DataFrame(json.loads(result))
408
+ else:
409
+ return result
410
+
411
+ def chat(self, prompt: str, params: Mapping = None) -> str:
412
+ """Chat with the LLM
413
+
414
+ :param str prompt: Natural language prompt
415
+ :param params: Parameters to include in the LLM request
416
+ :return: str
417
+ """
418
+ return self.generate(prompt, action=Action.CHAT, params=params)
419
+
420
+ @contextmanager
421
+ def chat_session(self, conversation: Conversation, delete: bool = False):
422
+ """Starts a new chat session for context-aware conversations
423
+
424
+ :param Conversation conversation: Conversation object to use for this
425
+ chat session
426
+ :param bool delete: Delete conversation after session ends
427
+
428
+ :return:
429
+ """
430
+ try:
431
+ if (
432
+ conversation.conversation_id is None
433
+ and conversation.attributes is not None
434
+ ):
435
+ conversation.create()
436
+ params = {"conversation_id": conversation.conversation_id}
437
+ session = Session(profile=self, params=params)
438
+ yield session
439
+ finally:
440
+ if delete:
441
+ conversation.delete()
442
+
443
+ def narrate(self, prompt: str, params: Mapping = None) -> str:
444
+ """Narrate the result of the SQL
445
+
446
+ :param str prompt: Natural language prompt
447
+ :param params: Parameters to include in the LLM request
448
+ :return: str
449
+ """
450
+ return self.generate(prompt, action=Action.NARRATE, params=params)
451
+
452
+ def explain_sql(self, prompt: str, params: Mapping = None) -> str:
453
+ """Explain the generated SQL
454
+
455
+ :param str prompt: Natural language prompt
456
+ :param params: Parameters to include in the LLM request
457
+ :return: str
458
+ """
459
+ return self.generate(prompt, action=Action.EXPLAINSQL, params=params)
460
+
461
+ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame:
462
+ """Run the generate SQL statement and return a pandas Dataframe built
463
+ using the result set
464
+
465
+ :param str prompt: Natural language prompt
466
+ :param params: Parameters to include in the LLM request
467
+ :return: pandas.DataFrame
468
+ """
469
+ return self.generate(prompt, action=Action.RUNSQL, params=params)
470
+
471
+ def show_sql(self, prompt: str, params: Mapping = None) -> str:
472
+ """Show the generated SQL
473
+
474
+ :param str prompt: Natural language prompt
475
+ :param params: Parameters to include in the LLM request
476
+ :return: str
477
+ """
478
+ return self.generate(prompt, action=Action.SHOWSQL, params=params)
479
+
480
+ def show_prompt(self, prompt: str, params: Mapping = None) -> str:
481
+ """Show the prompt sent to LLM
482
+
483
+ :param str prompt: Natural language prompt
484
+ :param params: Parameters to include in the LLM request
485
+ :return: str
486
+ """
487
+ return self.generate(prompt, action=Action.SHOWPROMPT, params=params)
488
+
489
+ def summarize(
490
+ self,
491
+ content: str = None,
492
+ prompt: str = None,
493
+ location_uri: str = None,
494
+ credential_name: str = None,
495
+ params: SummaryParams = None,
496
+ ) -> str:
497
+ """Generate summary
498
+
499
+ :param str prompt: Natural language prompt to guide the summary
500
+ generation
501
+ :param str content: Specifies the text you want to summarize
502
+ :param str location_uri: Provides the URI where the text is stored or
503
+ the path to a local file stored
504
+ :param str credential_name: Identifies the credential object used to
505
+ authenticate with the object store
506
+ :param select_ai.summary.SummaryParams params: Parameters to include
507
+ in the LLM request
508
+ """
509
+ parameters = validate_params_for_summary(
510
+ prompt=prompt,
511
+ location_uri=location_uri,
512
+ content=content,
513
+ credential_name=credential_name,
514
+ params=params,
515
+ )
516
+ parameters["profile_name"] = self.profile_name
517
+ print(parameters)
518
+ with cursor() as cr:
519
+ data = cr.callfunc(
520
+ "DBMS_CLOUD_AI.SUMMARIZE",
521
+ oracledb.DB_TYPE_CLOB,
522
+ keyword_parameters=parameters,
523
+ )
524
+ return data.read() if data else None
525
+
526
+ def generate_synthetic_data(
527
+ self, synthetic_data_attributes: SyntheticDataAttributes
528
+ ) -> None:
529
+ """Generate synthetic data for a single table, multiple tables or a
530
+ full schema.
531
+
532
+ :param select_ai.SyntheticDataAttributes synthetic_data_attributes:
533
+ :return: None
534
+ :raises: oracledb.DatabaseError
535
+
536
+ """
537
+ if synthetic_data_attributes is None:
538
+ raise ValueError(
539
+ "Param 'synthetic_data_attributes' cannot be None"
540
+ )
541
+
542
+ if not isinstance(synthetic_data_attributes, SyntheticDataAttributes):
543
+ raise TypeError(
544
+ "'synthetic_data_attributes' must be an object "
545
+ "of type select_ai.SyntheticDataAttributes"
546
+ )
547
+
548
+ keyword_parameters = synthetic_data_attributes.prepare()
549
+ keyword_parameters["profile_name"] = self.profile_name
550
+ with cursor() as cr:
551
+ cr.callproc(
552
+ "DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
553
+ keyword_parameters=keyword_parameters,
554
+ )
555
+
556
+
557
+ class Session:
558
+ """Session lets you persist request parameters across DBMS_CLOUD_AI
559
+ requests. This is useful in context-aware conversations
560
+ """
561
+
562
+ def __init__(self, profile: Profile, params: Mapping):
563
+ """
564
+
565
+ :param profile: An AI Profile to use in this session
566
+ :param params: Parameters to be persisted across requests
567
+ """
568
+ self.params = params
569
+ self.profile = profile
570
+
571
+ def chat(self, prompt: str):
572
+ # params = {"conversation_id": self.conversation_id}
573
+ return self.profile.chat(prompt=prompt, params=self.params)
574
+
575
+ def __enter__(self):
576
+ return self
577
+
578
+ def __exit__(self, exc_type, exc_val, exc_tb):
579
+ pass