select-ai 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of select-ai might be problematic. Click here for more details.

select_ai/__init__.py ADDED
@@ -0,0 +1,59 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ from .action import Action
9
+ from .async_profile import AsyncProfile
10
+ from .base_profile import BaseProfile, ProfileAttributes
11
+ from .conversation import (
12
+ AsyncConversation,
13
+ Conversation,
14
+ ConversationAttributes,
15
+ )
16
+ from .credential import (
17
+ async_create_credential,
18
+ async_delete_credential,
19
+ create_credential,
20
+ delete_credential,
21
+ )
22
+ from .db import (
23
+ async_connect,
24
+ async_cursor,
25
+ async_disconnect,
26
+ async_is_connected,
27
+ connect,
28
+ cursor,
29
+ disconnect,
30
+ is_connected,
31
+ )
32
+ from .profile import Profile
33
+ from .provider import (
34
+ AnthropicProvider,
35
+ AWSProvider,
36
+ AzureProvider,
37
+ CohereProvider,
38
+ GoogleProvider,
39
+ HuggingFaceProvider,
40
+ OCIGenAIProvider,
41
+ OpenAIProvider,
42
+ Provider,
43
+ async_disable_provider,
44
+ async_enable_provider,
45
+ disable_provider,
46
+ enable_provider,
47
+ )
48
+ from .synthetic_data import (
49
+ SyntheticDataAttributes,
50
+ SyntheticDataParams,
51
+ )
52
+ from .vector_index import (
53
+ AsyncVectorIndex,
54
+ OracleVectorIndexAttributes,
55
+ VectorDistanceMetric,
56
+ VectorIndex,
57
+ VectorIndexAttributes,
58
+ )
59
+ from .version import __version__ as __version__
select_ai/_abc.py ADDED
@@ -0,0 +1,77 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ import typing
10
+ from abc import ABC
11
+ from dataclasses import dataclass, fields
12
+ from typing import Any, List, Mapping
13
+
14
+ __all__ = ["SelectAIDataClass"]
15
+
16
+
17
+ def _bool(value: Any) -> bool:
18
+ if isinstance(value, bool):
19
+ return value
20
+ if isinstance(value, int):
21
+ return bool(value)
22
+ if value.lower() in ("yes", "true", "t", "y", "1"):
23
+ return True
24
+ elif value.lower() in ("no", "false", "f", "n", "0"):
25
+ return False
26
+ else:
27
+ raise ValueError(f"Invalid boolean value: {value}")
28
+
29
+
30
+ @dataclass
31
+ class SelectAIDataClass(ABC):
32
+ """SelectAIDataClass is an abstract container for all data
33
+ models defined in the select_ai Python module
34
+ """
35
+
36
+ def __getitem__(self, item):
37
+ return getattr(self, item)
38
+
39
+ def __setitem__(self, key, value):
40
+ setattr(self, key, value)
41
+
42
+ @classmethod
43
+ def keys(cls):
44
+ return set([field.name for field in fields(cls)])
45
+
46
+ def dict(self, exclude_null=True):
47
+ attributes = {}
48
+ for k, v in self.__dict__.items():
49
+ if v is not None or not exclude_null:
50
+ attributes[k] = v
51
+ return attributes
52
+
53
+ def json(self, exclude_null=True):
54
+ return json.dumps(self.dict(exclude_null=exclude_null))
55
+
56
+ def __post_init__(self):
57
+ for field in fields(self):
58
+ value = getattr(self, field.name)
59
+ if value is not None:
60
+ if field.type is typing.Optional[int]:
61
+ setattr(self, field.name, int(value))
62
+ elif field.type is typing.Optional[str]:
63
+ setattr(self, field.name, str(value))
64
+ elif field.type is typing.Optional[bool]:
65
+ setattr(self, field.name, _bool(value))
66
+ elif field.type is typing.Optional[float]:
67
+ setattr(self, field.name, float(value))
68
+ elif field.type is typing.Optional[Mapping] and isinstance(
69
+ value, (str, bytes, bytearray)
70
+ ):
71
+ setattr(self, field.name, json.loads(value))
72
+ elif field.type is typing.Optional[
73
+ List[typing.Mapping]
74
+ ] and isinstance(value, (str, bytes, bytearray)):
75
+ setattr(self, field.name, json.loads(value))
76
+ else:
77
+ setattr(self, field.name, value)
select_ai/_enums.py ADDED
@@ -0,0 +1,14 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import enum
9
+
10
+
11
+ class StrEnum(str, enum.Enum):
12
+
13
+ def __str__(self):
14
+ return self.value
select_ai/action.py ADDED
@@ -0,0 +1,21 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ from select_ai._enums import StrEnum
9
+
10
+ __all__ = ["Action"]
11
+
12
+
13
+ class Action(StrEnum):
14
+ """Supported Select AI actions"""
15
+
16
+ RUNSQL = "runsql"
17
+ SHOWSQL = "showsql"
18
+ EXPLAINSQL = "explainsql"
19
+ NARRATE = "narrate"
20
+ CHAT = "chat"
21
+ SHOWPROMPT = "showprompt"
@@ -0,0 +1,528 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) 2025, Oracle and/or its affiliates.
3
+ #
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at
5
+ # http://oss.oracle.com/licenses/upl.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import json
9
+ from contextlib import asynccontextmanager
10
+ from dataclasses import replace as dataclass_replace
11
+ from typing import (
12
+ AsyncGenerator,
13
+ List,
14
+ Mapping,
15
+ Optional,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import oracledb
21
+ import pandas
22
+
23
+ from select_ai.action import Action
24
+ from select_ai.base_profile import BaseProfile, ProfileAttributes
25
+ from select_ai.conversation import AsyncConversation
26
+ from select_ai.db import async_cursor, async_get_connection
27
+ from select_ai.errors import ProfileExistsError, ProfileNotFoundError
28
+ from select_ai.provider import Provider
29
+ from select_ai.sql import (
30
+ GET_USER_AI_PROFILE,
31
+ GET_USER_AI_PROFILE_ATTRIBUTES,
32
+ LIST_USER_AI_PROFILES,
33
+ )
34
+ from select_ai.synthetic_data import SyntheticDataAttributes
35
+
36
+ __all__ = ["AsyncProfile"]
37
+
38
+
39
+ class AsyncProfile(BaseProfile):
40
+ """AsyncProfile defines methods to interact with the underlying AI Provider
41
+ asynchronously.
42
+ """
43
+
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self._init_coroutine = self._init_profile()
47
+
48
+ def __await__(self):
49
+ coroutine = self._init_coroutine
50
+ return coroutine.__await__()
51
+
52
+ async def _init_profile(self):
53
+ """Initializes AI profile based on the passed attributes
54
+
55
+ :return: None
56
+ :raises: oracledb.DatabaseError
57
+ """
58
+ if self.profile_name:
59
+ profile_exists = False
60
+ try:
61
+ saved_attributes = await self._get_attributes(
62
+ profile_name=self.profile_name
63
+ )
64
+ profile_exists = True
65
+ if not self.replace and not self.merge:
66
+ if (
67
+ self.attributes is not None
68
+ or self.description is not None
69
+ ):
70
+ if self.raise_error_if_exists:
71
+ raise ProfileExistsError(self.profile_name)
72
+
73
+ if self.description is None:
74
+ self.description = await self._get_profile_description(
75
+ profile_name=self.profile_name
76
+ )
77
+ except ProfileNotFoundError:
78
+ if self.attributes is None and self.description is None:
79
+ raise
80
+ else:
81
+ if self.attributes is None:
82
+ self.attributes = saved_attributes
83
+ if self.merge:
84
+ self.replace = True
85
+ if self.attributes is not None:
86
+ self.attributes = dataclass_replace(
87
+ saved_attributes,
88
+ **self.attributes.dict(exclude_null=True),
89
+ )
90
+ if self.replace or not profile_exists:
91
+ await self.create(
92
+ replace=self.replace, description=self.description
93
+ )
94
+ else: # profile name is None:
95
+ if self.attributes is not None or self.description is not None:
96
+ raise ValueError("'profile_name' cannot be empty or None")
97
+ return self
98
+
99
+ @staticmethod
100
+ async def _get_profile_description(profile_name) -> Union[str, None]:
101
+ """Get description of profile from USER_CLOUD_AI_PROFILES
102
+
103
+ :param str profile_name: Name of profile
104
+ :return: Description of profile
105
+ :rtype: str
106
+ :raises: ProfileNotFoundError
107
+
108
+ """
109
+ async with async_cursor() as cr:
110
+ await cr.execute(
111
+ GET_USER_AI_PROFILE,
112
+ profile_name=profile_name.upper(),
113
+ )
114
+ profile = await cr.fetchone()
115
+ if profile:
116
+ if profile[1] is not None:
117
+ return await profile[1].read()
118
+ else:
119
+ return None
120
+ else:
121
+ raise ProfileNotFoundError(profile_name)
122
+
123
+ @staticmethod
124
+ async def _get_attributes(profile_name) -> ProfileAttributes:
125
+ """Asynchronously gets AI profile attributes from the Database
126
+
127
+ :param str profile_name: Name of the profile
128
+ :return: select_ai.provider.ProviderAttributes
129
+ :raises: ProfileNotFoundError
130
+
131
+ """
132
+ async with async_cursor() as cr:
133
+ await cr.execute(
134
+ GET_USER_AI_PROFILE_ATTRIBUTES,
135
+ profile_name=profile_name.upper(),
136
+ )
137
+ attributes = await cr.fetchall()
138
+ if attributes:
139
+ return await ProfileAttributes.async_create(**dict(attributes))
140
+ else:
141
+ raise ProfileNotFoundError(profile_name=profile_name)
142
+
143
+ async def get_attributes(self) -> ProfileAttributes:
144
+ """Asynchronously gets AI profile attributes from the Database
145
+
146
+ :return: select_ai.provider.ProviderAttributes
147
+ :raises: ProfileNotFoundError
148
+ """
149
+ return await self._get_attributes(profile_name=self.profile_name)
150
+
151
+ async def _set_attribute(
152
+ self,
153
+ attribute_name: str,
154
+ attribute_value: Union[bool, str, int, float],
155
+ ):
156
+ parameters = {
157
+ "profile_name": self.profile_name,
158
+ "attribute_name": attribute_name,
159
+ "attribute_value": attribute_value,
160
+ }
161
+ async with async_cursor() as cr:
162
+ await cr.callproc(
163
+ "DBMS_CLOUD_AI.SET_ATTRIBUTE", keyword_parameters=parameters
164
+ )
165
+
166
+ async def set_attribute(
167
+ self,
168
+ attribute_name: str,
169
+ attribute_value: Union[bool, str, int, float, Provider],
170
+ ):
171
+ """Updates AI profile attribute on the Python object and also
172
+ saves it in the database
173
+
174
+ :param str attribute_name: Name of the AI profile attribute
175
+ :param Union[bool, str, int, float] attribute_value: Value of the
176
+ profile attribute
177
+ :return: None
178
+
179
+ """
180
+ self.attributes.set_attribute(attribute_name, attribute_value)
181
+ if isinstance(attribute_value, Provider):
182
+ for k, v in attribute_value.dict().items():
183
+ await self._set_attribute(k, v)
184
+ else:
185
+ await self._set_attribute(attribute_name, attribute_value)
186
+
187
+ async def set_attributes(self, attributes: ProfileAttributes):
188
+ """Updates AI profile attributes on the Python object and also
189
+ saves it in the database
190
+
191
+ :param ProfileAttributes attributes: Object specifying AI profile
192
+ attributes
193
+ :return: None
194
+ """
195
+ if not isinstance(attributes, ProfileAttributes):
196
+ raise TypeError(
197
+ "'attributes' must be an object of type "
198
+ "select_ai.ProfileAttributes"
199
+ )
200
+
201
+ self.attributes = attributes
202
+ parameters = {
203
+ "profile_name": self.profile_name,
204
+ "attributes": self.attributes.json(),
205
+ }
206
+ async with async_cursor() as cr:
207
+ await cr.callproc(
208
+ "DBMS_CLOUD_AI.SET_ATTRIBUTES", keyword_parameters=parameters
209
+ )
210
+
211
+ async def create(
212
+ self, replace: Optional[int] = False, description: Optional[str] = None
213
+ ) -> None:
214
+ """Asynchronously create an AI Profile in the Database
215
+
216
+ :param bool replace: Set True to replace else False
217
+ :param description: The profile description
218
+ :return: None
219
+ :raises: oracledb.DatabaseError
220
+ """
221
+ if self.attributes is None:
222
+ raise AttributeError("Profile attributes cannot be None")
223
+ parameters = {
224
+ "profile_name": self.profile_name,
225
+ "attributes": self.attributes.json(),
226
+ }
227
+ if description:
228
+ parameters["description"] = description
229
+ async with async_cursor() as cr:
230
+ try:
231
+ await cr.callproc(
232
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
233
+ keyword_parameters=parameters,
234
+ )
235
+ except oracledb.DatabaseError as e:
236
+ (error,) = e.args
237
+ # If already exists and replace is True then drop and recreate
238
+ if error.code == 20046 and replace:
239
+ await self.delete(force=True)
240
+ await cr.callproc(
241
+ "DBMS_CLOUD_AI.CREATE_PROFILE",
242
+ keyword_parameters=parameters,
243
+ )
244
+ else:
245
+ raise
246
+
247
+ async def delete(self, force=False) -> None:
248
+ """Asynchronously deletes an AI profile from the database
249
+
250
+ :param bool force: Ignores errors if AI profile does not exist.
251
+ :return: None
252
+ :raises: oracledb.DatabaseError
253
+
254
+ """
255
+ async with async_cursor() as cr:
256
+ await cr.callproc(
257
+ "DBMS_CLOUD_AI.DROP_PROFILE",
258
+ keyword_parameters={
259
+ "profile_name": self.profile_name,
260
+ "force": force,
261
+ },
262
+ )
263
+
264
+ @classmethod
265
+ async def _from_db(cls, profile_name: str) -> "AsyncProfile":
266
+ """Asynchronously create an AI Profile object from attributes
267
+ saved in the database against the profile
268
+
269
+ :param str profile_name:
270
+ :return: select_ai.Profile
271
+ :raises: ProfileNotFoundError
272
+ """
273
+ async with async_cursor() as cr:
274
+ await cr.execute(
275
+ GET_USER_AI_PROFILE_ATTRIBUTES, profile_name=profile_name
276
+ )
277
+ attributes = await cr.fetchall()
278
+ if attributes:
279
+ attributes = await ProfileAttributes.async_create(
280
+ **dict(attributes)
281
+ )
282
+ return cls(profile_name=profile_name, attributes=attributes)
283
+ else:
284
+ raise ProfileNotFoundError(profile_name=profile_name)
285
+
286
+ @classmethod
287
+ async def list(
288
+ cls, profile_name_pattern: str = ".*"
289
+ ) -> AsyncGenerator["AsyncProfile", None]:
290
+ """Asynchronously list AI Profiles saved in the database.
291
+
292
+ :param str profile_name_pattern: Regular expressions can be used
293
+ to specify a pattern. Function REGEXP_LIKE is used to perform the
294
+ match. Default value is ".*" i.e. match all AI profiles.
295
+
296
+ :return: Iterator[Profile]
297
+ """
298
+ async with async_cursor() as cr:
299
+ await cr.execute(
300
+ LIST_USER_AI_PROFILES,
301
+ profile_name_pattern=profile_name_pattern,
302
+ )
303
+ rows = await cr.fetchall()
304
+ for row in rows:
305
+ profile_name = row[0]
306
+ description = row[1]
307
+ attributes = await cls._get_attributes(
308
+ profile_name=profile_name
309
+ )
310
+ yield cls(
311
+ profile_name=profile_name,
312
+ description=description,
313
+ attributes=attributes,
314
+ raise_error_if_exists=False,
315
+ )
316
+
317
+ async def generate(
318
+ self, prompt: str, action=Action.SHOWSQL, params: Mapping = None
319
+ ) -> Union[pandas.DataFrame, str, None]:
320
+ """Asynchronously perform AI translation using this profile
321
+
322
+ :param str prompt: Natural language prompt to translate
323
+ :param select_ai.profile.Action action:
324
+ :param params: Parameters to include in the LLM request. For e.g.
325
+ conversation_id for context-aware chats
326
+ :return: Union[pandas.DataFrame, str]
327
+ """
328
+ if not prompt:
329
+ raise ValueError("prompt cannot be empty or None")
330
+
331
+ parameters = {
332
+ "prompt": prompt,
333
+ "action": action,
334
+ "profile_name": self.profile_name,
335
+ # "attributes": self.attributes.json(),
336
+ }
337
+ if params:
338
+ parameters["params"] = json.dumps(params)
339
+
340
+ async with async_cursor() as cr:
341
+ data = await cr.callfunc(
342
+ "DBMS_CLOUD_AI.GENERATE",
343
+ oracledb.DB_TYPE_CLOB,
344
+ keyword_parameters=parameters,
345
+ )
346
+ if data is not None:
347
+ return await data.read()
348
+ return None
349
+
350
+ async def chat(self, prompt, params: Mapping = None) -> str:
351
+ """Asynchronously chat with the LLM
352
+
353
+ :param str prompt: Natural language prompt
354
+ :param params: Parameters to include in the LLM request
355
+ :return: str
356
+ """
357
+ return await self.generate(prompt, action=Action.CHAT, params=params)
358
+
359
+ @asynccontextmanager
360
+ async def chat_session(
361
+ self, conversation: AsyncConversation, delete: bool = False
362
+ ):
363
+ """Starts a new chat session for context-aware conversations
364
+
365
+ :param AsyncConversation conversation: Conversation object to use for this
366
+ chat session
367
+ :param bool delete: Delete conversation after session ends
368
+
369
+ """
370
+ try:
371
+ if (
372
+ conversation.conversation_id is None
373
+ and conversation.attributes is not None
374
+ ):
375
+ await conversation.create()
376
+ params = {"conversation_id": conversation.conversation_id}
377
+ async_session = AsyncSession(async_profile=self, params=params)
378
+ yield async_session
379
+ finally:
380
+ if delete:
381
+ await conversation.delete()
382
+
383
+ async def narrate(self, prompt, params: Mapping = None) -> str:
384
+ """Narrate the result of the SQL
385
+
386
+ :param str prompt: Natural language prompt
387
+ :param params: Parameters to include in the LLM request
388
+ :return: str
389
+ """
390
+ return await self.generate(
391
+ prompt, action=Action.NARRATE, params=params
392
+ )
393
+
394
+ async def explain_sql(self, prompt: str, params: Mapping = None):
395
+ """Explain the generated SQL
396
+
397
+ :param str prompt: Natural language prompt
398
+ :param params: Parameters to include in the LLM request
399
+ :return: str
400
+ """
401
+ return await self.generate(
402
+ prompt, action=Action.EXPLAINSQL, params=params
403
+ )
404
+
405
+ async def run_sql(
406
+ self, prompt, params: Mapping = None
407
+ ) -> pandas.DataFrame:
408
+ """Explain the generated SQL
409
+
410
+ :param str prompt: Natural language prompt
411
+ :param params: Parameters to include in the LLM request
412
+ :return: pandas.DataFrame
413
+ """
414
+ data = await self.generate(prompt, action=Action.RUNSQL, params=params)
415
+ return pandas.DataFrame(json.loads(data))
416
+
417
+ async def show_sql(self, prompt, params: Mapping = None):
418
+ """Show the generated SQL
419
+
420
+ :param str prompt: Natural language prompt
421
+ :param params: Parameters to include in the LLM request
422
+ :return: str
423
+ """
424
+ return await self.generate(
425
+ prompt, action=Action.SHOWSQL, params=params
426
+ )
427
+
428
+ async def show_prompt(self, prompt: str, params: Mapping = None):
429
+ """Show the prompt sent to LLM
430
+
431
+ :param str prompt: Natural language prompt
432
+ :param params: Parameters to include in the LLM request
433
+ :return: str
434
+ """
435
+ return await self.generate(
436
+ prompt, action=Action.SHOWPROMPT, params=params
437
+ )
438
+
439
+ async def generate_synthetic_data(
440
+ self, synthetic_data_attributes: SyntheticDataAttributes
441
+ ) -> None:
442
+ """Generate synthetic data for a single table, multiple tables or a
443
+ full schema.
444
+
445
+ :param select_ai.SyntheticDataAttributes synthetic_data_attributes:
446
+ :return: None
447
+ :raises: oracledb.DatabaseError
448
+
449
+ """
450
+ if synthetic_data_attributes is None:
451
+ raise ValueError("'synthetic_data_attributes' cannot be None")
452
+
453
+ if not isinstance(synthetic_data_attributes, SyntheticDataAttributes):
454
+ raise TypeError(
455
+ "'synthetic_data_attributes' must be an object "
456
+ "of type select_ai.SyntheticDataAttributes"
457
+ )
458
+
459
+ keyword_parameters = synthetic_data_attributes.prepare()
460
+ keyword_parameters["profile_name"] = self.profile_name
461
+ async with async_cursor() as cr:
462
+ await cr.callproc(
463
+ "DBMS_CLOUD_AI.GENERATE_SYNTHETIC_DATA",
464
+ keyword_parameters=keyword_parameters,
465
+ )
466
+
467
+ async def run_pipeline(
468
+ self,
469
+ prompt_specifications: List[Tuple[str, Action]],
470
+ continue_on_error: bool = False,
471
+ ) -> List[Union[str, pandas.DataFrame]]:
472
+ """Send Multiple prompts in a single roundtrip to the Database
473
+
474
+ :param List[Tuple[str, Action]] prompt_specifications: List of
475
+ 2-element tuples. First element is the prompt and second is the
476
+ corresponding action
477
+
478
+ :param bool continue_on_error: True to continue on error else False
479
+ :return: List[Union[str, pandas.DataFrame]]
480
+ """
481
+ pipeline = oracledb.create_pipeline()
482
+ for prompt, action in prompt_specifications:
483
+ parameters = {
484
+ "prompt": prompt,
485
+ "action": action,
486
+ "profile_name": self.profile_name,
487
+ # "attributes": self.attributes.json(),
488
+ }
489
+ pipeline.add_callfunc(
490
+ "DBMS_CLOUD_AI.GENERATE",
491
+ return_type=oracledb.DB_TYPE_CLOB,
492
+ keyword_parameters=parameters,
493
+ )
494
+ async_connection = await async_get_connection()
495
+ pipeline_results = await async_connection.run_pipeline(
496
+ pipeline, continue_on_error=continue_on_error
497
+ )
498
+ responses = []
499
+ for result in pipeline_results:
500
+ if not result.error:
501
+ responses.append(await result.return_value.read())
502
+ else:
503
+ responses.append(result.error)
504
+ return responses
505
+
506
+
507
+ class AsyncSession:
508
+ """AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
509
+ requests. This is useful in context-aware conversations
510
+ """
511
+
512
+ def __init__(self, async_profile: AsyncProfile, params: Mapping):
513
+ """
514
+
515
+ :param async_profile: An AI Profile to use in this session
516
+ :param params: Parameters to be persisted across requests
517
+ """
518
+ self.params = params
519
+ self.async_profile = async_profile
520
+
521
+ async def chat(self, prompt: str):
522
+ return await self.async_profile.chat(prompt=prompt, params=self.params)
523
+
524
+ async def __aenter__(self):
525
+ return self
526
+
527
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
528
+ pass