lionagi 0.0.209__py3-none-any.whl → 0.0.211__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +2 -4
- lionagi/api_service/base_endpoint.py +65 -0
- lionagi/api_service/base_rate_limiter.py +121 -0
- lionagi/api_service/base_service.py +146 -0
- lionagi/api_service/chat_completion.py +6 -0
- lionagi/api_service/embeddings.py +6 -0
- lionagi/api_service/payload_package.py +47 -0
- lionagi/api_service/status_tracker.py +29 -0
- lionagi/core/__init__.py +3 -3
- lionagi/core/branch.py +22 -3
- lionagi/core/session.py +14 -2
- lionagi/schema/__init__.py +5 -8
- lionagi/schema/base_schema.py +821 -0
- lionagi/structures/graph.py +1 -1
- lionagi/structures/relationship.py +1 -1
- lionagi/structures/structure.py +1 -1
- lionagi/tools/tool_manager.py +0 -163
- lionagi/tools/tool_util.py +2 -1
- lionagi/utils/__init__.py +5 -6
- lionagi/utils/api_util.py +6 -1
- lionagi/version.py +1 -1
- {lionagi-0.0.209.dist-info → lionagi-0.0.211.dist-info}/METADATA +3 -18
- lionagi-0.0.211.dist-info/RECORD +56 -0
- lionagi/agents/planner.py +0 -1
- lionagi/agents/prompter.py +0 -1
- lionagi/agents/scorer.py +0 -1
- lionagi/agents/summarizer.py +0 -1
- lionagi/agents/validator.py +0 -1
- lionagi/bridge/__init__.py +0 -22
- lionagi/bridge/langchain.py +0 -195
- lionagi/bridge/llama_index.py +0 -266
- lionagi/datastores/__init__.py +0 -1
- lionagi/datastores/chroma.py +0 -1
- lionagi/datastores/deeplake.py +0 -1
- lionagi/datastores/elasticsearch.py +0 -1
- lionagi/datastores/lantern.py +0 -1
- lionagi/datastores/pinecone.py +0 -1
- lionagi/datastores/postgres.py +0 -1
- lionagi/datastores/qdrant.py +0 -1
- lionagi/iservices/anthropic.py +0 -79
- lionagi/iservices/anyscale.py +0 -0
- lionagi/iservices/azure.py +0 -1
- lionagi/iservices/bedrock.py +0 -0
- lionagi/iservices/everlyai.py +0 -0
- lionagi/iservices/gemini.py +0 -0
- lionagi/iservices/gpt4all.py +0 -0
- lionagi/iservices/huggingface.py +0 -0
- lionagi/iservices/litellm.py +0 -33
- lionagi/iservices/localai.py +0 -0
- lionagi/iservices/openllm.py +0 -0
- lionagi/iservices/openrouter.py +0 -44
- lionagi/iservices/perplexity.py +0 -0
- lionagi/iservices/predibase.py +0 -0
- lionagi/iservices/rungpt.py +0 -0
- lionagi/iservices/vllm.py +0 -0
- lionagi/iservices/xinference.py +0 -0
- lionagi/loaders/__init__.py +0 -18
- lionagi/loaders/chunker.py +0 -166
- lionagi/loaders/load_util.py +0 -240
- lionagi/loaders/reader.py +0 -122
- lionagi/models/__init__.py +0 -0
- lionagi/models/base_model.py +0 -0
- lionagi/models/imodel.py +0 -53
- lionagi/parsers/__init__.py +0 -1
- lionagi/schema/async_queue.py +0 -158
- lionagi/schema/base_condition.py +0 -1
- lionagi/schema/base_node.py +0 -422
- lionagi/schema/base_tool.py +0 -44
- lionagi/schema/data_logger.py +0 -131
- lionagi/schema/data_node.py +0 -88
- lionagi/schema/status_tracker.py +0 -37
- lionagi/tests/test_utils/test_encrypt_util.py +0 -323
- lionagi/utils/encrypt_util.py +0 -283
- lionagi-0.0.209.dist-info/RECORD +0 -98
- /lionagi/{agents → api_service}/__init__.py +0 -0
- /lionagi/{iservices → services}/__init__.py +0 -0
- /lionagi/{iservices → services}/base_service.py +0 -0
- /lionagi/{iservices → services}/mistralai.py +0 -0
- /lionagi/{iservices → services}/mlx_service.py +0 -0
- /lionagi/{iservices → services}/oai.py +0 -0
- /lionagi/{iservices → services}/ollama.py +0 -0
- /lionagi/{iservices → services}/services.py +0 -0
- /lionagi/{iservices → services}/transformers.py +0 -0
- {lionagi-0.0.209.dist-info → lionagi-0.0.211.dist-info}/LICENSE +0 -0
- {lionagi-0.0.209.dist-info → lionagi-0.0.211.dist-info}/WHEEL +0 -0
- {lionagi-0.0.209.dist-info → lionagi-0.0.211.dist-info}/top_level.txt +0 -0
lionagi/__init__.py
CHANGED
@@ -19,12 +19,10 @@ from .version import __version__
|
|
19
19
|
from dotenv import load_dotenv
|
20
20
|
|
21
21
|
from .utils import *
|
22
|
-
from .schema import *
|
22
|
+
from .schema.base_schema import *
|
23
23
|
from .structures import *
|
24
|
-
from .
|
25
|
-
from .iservices import *
|
24
|
+
from .api_service import *
|
26
25
|
from .tools import *
|
27
|
-
from .core import *
|
28
26
|
|
29
27
|
logger = logging.getLogger(__name__)
|
30
28
|
logger.setLevel(logging.INFO)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import Any, Dict, NoReturn, Optional, Type, List, Union
|
2
|
+
from .base_rate_limiter import BaseRateLimiter, SimpleRateLimiter
|
3
|
+
|
4
|
+
|
5
|
+
class BaseEndpoint:
|
6
|
+
"""
|
7
|
+
Represents an API endpoint with rate limiting capabilities.
|
8
|
+
|
9
|
+
This class encapsulates the details of an API endpoint, including its rate limiter.
|
10
|
+
|
11
|
+
Attributes:
|
12
|
+
endpoint (str): The API endpoint path.
|
13
|
+
rate_limiter_class (Type[li.BaseRateLimiter]): The class used for rate limiting requests to the endpoint.
|
14
|
+
max_requests (int): The maximum number of requests allowed per interval.
|
15
|
+
max_tokens (int): The maximum number of tokens allowed per interval.
|
16
|
+
interval (int): The time interval in seconds for replenishing rate limit capacities.
|
17
|
+
config (Dict): Configuration parameters for the endpoint.
|
18
|
+
rate_limiter (Optional[li.BaseRateLimiter]): The rate limiter instance for this endpoint.
|
19
|
+
|
20
|
+
Examples:
|
21
|
+
# Example usage of EndPoint with SimpleRateLimiter
|
22
|
+
endpoint = EndPoint(
|
23
|
+
max_requests=100,
|
24
|
+
max_tokens=1000,
|
25
|
+
interval=60,
|
26
|
+
endpoint_='chat/completions',
|
27
|
+
rate_limiter_class=li.SimpleRateLimiter,
|
28
|
+
config={'param1': 'value1'}
|
29
|
+
)
|
30
|
+
asyncio.run(endpoint.init_rate_limiter())
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
max_requests: int = 1_000,
|
36
|
+
max_tokens: int = 100_000,
|
37
|
+
interval: int = 60,
|
38
|
+
endpoint_: Optional[str] = None,
|
39
|
+
rate_limiter_class: Type[BaseRateLimiter] = SimpleRateLimiter,
|
40
|
+
encode_kwargs=None,
|
41
|
+
token_encoding_name=None,
|
42
|
+
config: Dict = None,
|
43
|
+
) -> None:
|
44
|
+
self.endpoint = endpoint_ or 'chat/completions'
|
45
|
+
self.rate_limiter_class = rate_limiter_class
|
46
|
+
self.max_requests = max_requests
|
47
|
+
self.max_tokens = max_tokens
|
48
|
+
self.interval = interval
|
49
|
+
self.token_encoding_name = token_encoding_name
|
50
|
+
self.config = config or {}
|
51
|
+
self.rate_limiter: Optional[BaseRateLimiter] = None
|
52
|
+
self._has_initialized = False
|
53
|
+
self.encode_kwargs = encode_kwargs or {}
|
54
|
+
|
55
|
+
async def init_rate_limiter(self) -> None:
|
56
|
+
"""Initializes the rate limiter for the endpoint."""
|
57
|
+
self.rate_limiter = await self.rate_limiter_class.create(
|
58
|
+
self.max_requests, self.max_tokens, self.interval, self.token_encoding_name
|
59
|
+
)
|
60
|
+
self._has_initialized = True
|
61
|
+
|
62
|
+
|
63
|
+
class Embedding(BaseEndpoint):
|
64
|
+
...
|
65
|
+
|
@@ -0,0 +1,121 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from abc import ABC
|
4
|
+
from typing import Dict, NoReturn, Optional
|
5
|
+
|
6
|
+
from ..utils import APIUtil
|
7
|
+
|
8
|
+
|
9
|
+
class BaseRateLimiter(ABC):
|
10
|
+
def __init__(self, max_requests: int, max_tokens: int, interval: int = 60, token_encoding_name=None) -> None:
|
11
|
+
self.interval: int = interval
|
12
|
+
self.max_requests: int = max_requests
|
13
|
+
self.max_tokens: int = max_tokens
|
14
|
+
self.available_request_capacity: int = max_requests
|
15
|
+
self.available_token_capacity: int = max_tokens
|
16
|
+
self.rate_limit_replenisher_task: Optional[asyncio.Task[NoReturn]] = None
|
17
|
+
self._stop_replenishing: asyncio.Event = asyncio.Event()
|
18
|
+
self._lock: asyncio.Lock = asyncio.Lock()
|
19
|
+
self.token_encoding_name = token_encoding_name
|
20
|
+
|
21
|
+
async def start_replenishing(self) -> NoReturn:
|
22
|
+
"""Starts the replenishment of rate limit capacities at regular intervals."""
|
23
|
+
try:
|
24
|
+
while not self._stop_replenishing.is_set():
|
25
|
+
await asyncio.sleep(self.interval)
|
26
|
+
async with self._lock:
|
27
|
+
self.available_request_capacity = self.max_requests
|
28
|
+
self.available_token_capacity = self.max_tokens
|
29
|
+
except asyncio.CancelledError:
|
30
|
+
logging.info("Rate limit replenisher task cancelled.")
|
31
|
+
|
32
|
+
except Exception as e:
|
33
|
+
logging.error(f"An error occurred in the rate limit replenisher: {e}")
|
34
|
+
|
35
|
+
async def stop_replenishing(self) -> None:
|
36
|
+
"""Stops the replenishment task."""
|
37
|
+
if self.rate_limit_replenisher_task:
|
38
|
+
self.rate_limit_replenisher_task.cancel()
|
39
|
+
await self.rate_limit_replenisher_task
|
40
|
+
self._stop_replenishing.set()
|
41
|
+
|
42
|
+
async def request_permission(self, required_tokens) -> bool:
|
43
|
+
"""Requests permission to make an API call.
|
44
|
+
|
45
|
+
Returns True if the request can be made immediately, otherwise False.
|
46
|
+
"""
|
47
|
+
async with self._lock:
|
48
|
+
if self.available_request_capacity > 0 and self.available_token_capacity > 0:
|
49
|
+
self.available_request_capacity -= 1
|
50
|
+
self.available_token_capacity -= required_tokens # Assuming 1 token per request for simplicity
|
51
|
+
return True
|
52
|
+
return False
|
53
|
+
|
54
|
+
async def _call_api(
|
55
|
+
self,
|
56
|
+
http_session,
|
57
|
+
endpoint: str,
|
58
|
+
base_url: str,
|
59
|
+
api_key: str,
|
60
|
+
max_attempts: int = 3,
|
61
|
+
method: str = "post",
|
62
|
+
payload: Dict[str, any]=None,
|
63
|
+
**kwargs,
|
64
|
+
) -> Optional[Dict[str, any]]:
|
65
|
+
endpoint = APIUtil.api_endpoint_from_url(base_url + endpoint)
|
66
|
+
while True:
|
67
|
+
if self.available_request_capacity < 1 or self.available_token_capacity < 10: # Minimum token count
|
68
|
+
await asyncio.sleep(1) # Wait for capacity
|
69
|
+
continue
|
70
|
+
required_tokens = APIUtil.calculate_num_token(payload, endpoint, self.token_encoding_name, **kwargs)
|
71
|
+
|
72
|
+
if await self.request_permission(required_tokens):
|
73
|
+
request_headers = {"Authorization": f"Bearer {api_key}"}
|
74
|
+
attempts_left = max_attempts
|
75
|
+
|
76
|
+
while attempts_left > 0:
|
77
|
+
try:
|
78
|
+
method = APIUtil.api_method(http_session, method)
|
79
|
+
async with method(
|
80
|
+
url=(base_url+endpoint), headers=request_headers, json=payload
|
81
|
+
) as response:
|
82
|
+
response_json = await response.json()
|
83
|
+
|
84
|
+
if "error" in response_json:
|
85
|
+
logging.warning(
|
86
|
+
f"API call failed with error: {response_json['error']}"
|
87
|
+
)
|
88
|
+
attempts_left -= 1
|
89
|
+
|
90
|
+
if "Rate limit" in response_json["error"].get("message", ""):
|
91
|
+
await asyncio.sleep(15)
|
92
|
+
else:
|
93
|
+
return response_json
|
94
|
+
except Exception as e:
|
95
|
+
logging.warning(f"API call failed with exception: {e}")
|
96
|
+
attempts_left -= 1
|
97
|
+
|
98
|
+
logging.error("API call failed after all attempts.")
|
99
|
+
break
|
100
|
+
else:
|
101
|
+
await asyncio.sleep(1)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
async def create(cls, max_requests: int, max_tokens: int, interval: int = 60, token_encoding_name = None) -> 'BaseRateLimiter':
|
105
|
+
instance = cls(max_requests, max_tokens, interval, token_encoding_name)
|
106
|
+
instance.rate_limit_replenisher_task = asyncio.create_task(
|
107
|
+
instance.start_replenishing()
|
108
|
+
)
|
109
|
+
return instance
|
110
|
+
|
111
|
+
|
112
|
+
class SimpleRateLimiter(BaseRateLimiter):
|
113
|
+
"""
|
114
|
+
A simple implementation of a rate limiter.
|
115
|
+
|
116
|
+
Inherits from BaseRateLimiter and provides a basic rate limiting mechanism.
|
117
|
+
"""
|
118
|
+
|
119
|
+
def __init__(self, max_requests: int, max_tokens: int, interval: int = 60, token_encoding_name=None) -> None:
|
120
|
+
"""Initializes the SimpleRateLimiter with the specified parameters."""
|
121
|
+
super().__init__(max_requests, max_tokens, interval, token_encoding_name)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import aiohttp
|
4
|
+
from abc import ABC
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any, Dict, NoReturn, Optional, Type, List, Union
|
7
|
+
|
8
|
+
from ..utils import nget, APIUtil, to_list, lcall
|
9
|
+
from .base_rate_limiter import BaseRateLimiter, SimpleRateLimiter
|
10
|
+
from .status_tracker import StatusTracker
|
11
|
+
|
12
|
+
from .base_endpoint import BaseEndpoint
|
13
|
+
|
14
|
+
|
15
|
+
class BaseService:
|
16
|
+
"""
|
17
|
+
Base class for services that interact with API endpoints.
|
18
|
+
|
19
|
+
This class provides a foundation for services that need to make API calls with rate limiting.
|
20
|
+
|
21
|
+
Attributes:
|
22
|
+
api_key (Optional[str]): The API key used for authentication.
|
23
|
+
schema (Dict[str, Any]): The schema defining the service's endpoints.
|
24
|
+
status_tracker (StatusTracker): The object tracking the status of API calls.
|
25
|
+
endpoints (Dict[str, BaseEndpoint]): A dictionary of endpoint objects.
|
26
|
+
"""
|
27
|
+
|
28
|
+
base_url: str = ''
|
29
|
+
available_endpoints: list = []
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
api_key: Optional[str] = None,
|
34
|
+
schema: Dict[str, Any] = None,
|
35
|
+
token_encoding_name: str = None,
|
36
|
+
max_tokens : int = 100_000,
|
37
|
+
max_requests : int = 1_000,
|
38
|
+
interval: int = 60
|
39
|
+
) -> None:
|
40
|
+
self.api_key = api_key
|
41
|
+
self.schema = schema or {}
|
42
|
+
self.status_tracker = StatusTracker()
|
43
|
+
self.endpoints: Dict[str, BaseEndpoint] = {}
|
44
|
+
self.token_encoding_name = token_encoding_name
|
45
|
+
self.chat_config = {
|
46
|
+
'max_requests': max_requests,
|
47
|
+
'max_tokens': max_tokens,
|
48
|
+
'interval': interval,
|
49
|
+
"token_encoding_name": token_encoding_name
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
async def init_endpoint(self, endpoint_: Optional[Union[List[str], List[BaseEndpoint], str, BaseEndpoint]] = None) -> None:
|
54
|
+
"""
|
55
|
+
Initializes the specified endpoint or all endpoints if none is specified.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
endpoint_: The endpoint(s) to initialize. Can be a string, an BaseEndpoint, a list of strings, or a list of BaseEndpoints.
|
59
|
+
"""
|
60
|
+
|
61
|
+
if endpoint_:
|
62
|
+
endpoint_ = to_list(endpoint_, flatten=True, dropna=True)
|
63
|
+
|
64
|
+
for ep in endpoint_:
|
65
|
+
self._check_endpoints(ep)
|
66
|
+
|
67
|
+
if ep not in self.endpoints:
|
68
|
+
endpoint_config = self._get_endpoint(ep)
|
69
|
+
|
70
|
+
if endpoint_config is not None:
|
71
|
+
if ep == "chat/completions":
|
72
|
+
self.endpoints[ep] = BaseEndpoint(
|
73
|
+
max_requests=self.chat_config.get('max_requests', 1000),
|
74
|
+
max_tokens=self.chat_config.get('max_tokens', 100000),
|
75
|
+
interval=self.chat_config.get('interval', 60),
|
76
|
+
endpoint_=ep,
|
77
|
+
token_encoding_name=self.token_encoding_name,
|
78
|
+
config=endpoint_config,
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
self.endpoints[ep] = BaseEndpoint(
|
82
|
+
max_requests=endpoint_config.get('max_requests', 1000) if endpoint_config.get('max_requests', 1000) is not None else 1000,
|
83
|
+
max_tokens=endpoint_config.get('max_tokens', 100000) if endpoint_config.get('max_tokens', 100000) is not None else 100000,
|
84
|
+
interval=endpoint_config.get('interval', 60) if endpoint_config.get('interval', 60) is not None else 60,
|
85
|
+
endpoint_=ep,
|
86
|
+
token_encoding_name=self.token_encoding_name,
|
87
|
+
config=endpoint_config,
|
88
|
+
)
|
89
|
+
|
90
|
+
if not self.endpoints[ep]._has_initialized:
|
91
|
+
await self.endpoints[ep].init_rate_limiter()
|
92
|
+
|
93
|
+
else:
|
94
|
+
for ep in self.available_endpoints:
|
95
|
+
endpoint_config = nget(self.schema, [ep, 'config'])
|
96
|
+
self.schema.get(ep, {})
|
97
|
+
if ep not in self.endpoints:
|
98
|
+
self.endpoints[ep] = BaseEndpoint(
|
99
|
+
max_requests=endpoint_config.get('max_requests', 1000),
|
100
|
+
max_tokens=endpoint_config.get('max_tokens', 100000),
|
101
|
+
interval=endpoint_config.get('interval', 60),
|
102
|
+
endpoint_=ep,
|
103
|
+
token_encoding_name=self.token_encoding_name,
|
104
|
+
config=endpoint_config,
|
105
|
+
)
|
106
|
+
if not self.endpoints[ep]._has_initialized:
|
107
|
+
await self.endpoints[ep].init_rate_limiter()
|
108
|
+
|
109
|
+
async def call_api(self, payload, endpoint, method, **kwargs):
|
110
|
+
"""
|
111
|
+
Calls the specified API endpoint with the given payload and method.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
payload: The payload to send with the API call.
|
115
|
+
endpoint: The endpoint to call.
|
116
|
+
method: The HTTP method to use for the call.
|
117
|
+
kwargs are for tiktoken encoding
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The response from the API call.
|
121
|
+
|
122
|
+
Raises:
|
123
|
+
ValueError: If the endpoint has not been initialized.
|
124
|
+
"""
|
125
|
+
if endpoint not in self.endpoints.keys():
|
126
|
+
raise ValueError(f'The endpoint {endpoint} has not initialized.')
|
127
|
+
async with aiohttp.ClientSession() as http_session:
|
128
|
+
completion = await self.endpoints[endpoint].rate_limiter._call_api(
|
129
|
+
http_session=http_session, endpoint=endpoint, base_url=self.base_url, api_key=self.api_key,
|
130
|
+
method=method, payload=payload, **kwargs)
|
131
|
+
return completion
|
132
|
+
|
133
|
+
def _check_endpoints(self, endpoint_):
|
134
|
+
f = lambda ep: ValueError (f"Endpoint {ep} not available for service {self.__class__.__name__}")
|
135
|
+
if not endpoint_ in self.available_endpoints:
|
136
|
+
raise f(endpoint_)
|
137
|
+
|
138
|
+
def _get_endpoint(self, endpoint_):
|
139
|
+
if endpoint_ not in self.endpoints:
|
140
|
+
endpoint_config = nget(self.schema, [endpoint_, 'config'])
|
141
|
+
self.schema.get(endpoint_, {})
|
142
|
+
|
143
|
+
if isinstance(endpoint_, BaseEndpoint):
|
144
|
+
self.endpoints[endpoint_.endpoint] = endpoint_
|
145
|
+
return None
|
146
|
+
return endpoint_config
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from lionagi.utils.api_util import APIUtil
|
2
|
+
|
3
|
+
class PayloadCreation:
|
4
|
+
|
5
|
+
@classmethod
|
6
|
+
def chat_completion(cls, messages, llmconfig, schema, **kwargs):
|
7
|
+
"""
|
8
|
+
Creates a payload for the chat completion operation.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
messages: The messages to include in the chat completion.
|
12
|
+
llmconfig: Configuration for the language model.
|
13
|
+
schema: The schema describing required and optional fields.
|
14
|
+
**kwargs: Additional keyword arguments.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
The constructed payload.
|
18
|
+
"""
|
19
|
+
return APIUtil._create_payload(
|
20
|
+
input_=messages,
|
21
|
+
config=llmconfig,
|
22
|
+
required_=schema['required'],
|
23
|
+
optional_=schema['optional'],
|
24
|
+
input_key="messages",
|
25
|
+
**kwargs)
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def fine_tuning(cls, training_file, llmconfig, schema, **kwargs):
|
29
|
+
"""
|
30
|
+
Creates a payload for the fine-tuning operation.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
training_file: The file containing training data.
|
34
|
+
llmconfig: Configuration for the language model.
|
35
|
+
schema: The schema describing required and optional fields.
|
36
|
+
**kwargs: Additional keyword arguments.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
The constructed payload.
|
40
|
+
"""
|
41
|
+
return APIUtil._create_payload(
|
42
|
+
input_=training_file,
|
43
|
+
config=llmconfig,
|
44
|
+
required_=schema['required'],
|
45
|
+
optional_=schema['optional'],
|
46
|
+
input_key="training_file",
|
47
|
+
**kwargs)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class StatusTracker:
|
6
|
+
"""
|
7
|
+
Keeps track of various task statuses within a system.
|
8
|
+
|
9
|
+
Attributes:
|
10
|
+
num_tasks_started (int): The number of tasks that have been initiated.
|
11
|
+
num_tasks_in_progress (int): The number of tasks currently being processed.
|
12
|
+
num_tasks_succeeded (int): The number of tasks that have completed successfully.
|
13
|
+
num_tasks_failed (int): The number of tasks that have failed.
|
14
|
+
num_rate_limit_errors (int): The number of tasks that failed due to rate limiting.
|
15
|
+
num_api_errors (int): The number of tasks that failed due to API errors.
|
16
|
+
num_other_errors (int): The number of tasks that failed due to other errors.
|
17
|
+
|
18
|
+
Examples:
|
19
|
+
>>> tracker = StatusTracker()
|
20
|
+
>>> tracker.num_tasks_started += 1
|
21
|
+
>>> tracker.num_tasks_succeeded += 1
|
22
|
+
"""
|
23
|
+
num_tasks_started: int = 0
|
24
|
+
num_tasks_in_progress: int = 0
|
25
|
+
num_tasks_succeeded: int = 0
|
26
|
+
num_tasks_failed: int = 0
|
27
|
+
num_rate_limit_errors: int = 0
|
28
|
+
num_api_errors: int = 0
|
29
|
+
num_other_errors: int = 0
|
lionagi/core/__init__.py
CHANGED
lionagi/core/branch.py
CHANGED
@@ -9,9 +9,9 @@ from lionagi.utils.sys_util import create_path, is_same_dtype
|
|
9
9
|
from lionagi.utils import as_dict, lcall,to_df, to_list, CoreUtil
|
10
10
|
|
11
11
|
|
12
|
-
from lionagi.
|
13
|
-
from lionagi.
|
14
|
-
from lionagi.
|
12
|
+
from lionagi.services.base_service import BaseService, StatusTracker
|
13
|
+
from lionagi.services.oai import OpenAIService
|
14
|
+
from lionagi.services.openrouter import OpenRouterService
|
15
15
|
from lionagi.configs.oai_configs import oai_schema
|
16
16
|
from lionagi.configs.openrouter_configs import openrouter_schema
|
17
17
|
from lionagi.schema import DataLogger, Tool
|
@@ -28,6 +28,25 @@ except:
|
|
28
28
|
pass
|
29
29
|
|
30
30
|
class Branch:
|
31
|
+
"""
|
32
|
+
Represents a branch in a conversation with messages, instruction sets, and tool management.
|
33
|
+
|
34
|
+
A `Branch` is a subset of a conversation that contains messages, instruction sets, and tools for managing interactions
|
35
|
+
within the conversation. It encapsulates the state and behavior of a specific branch of conversation flow.
|
36
|
+
|
37
|
+
Attributes:
|
38
|
+
_cols (List[str]): A list of column names for the DataFrame containing messages.
|
39
|
+
messages (pd.DataFrame): A DataFrame containing messages for the branch.
|
40
|
+
instruction_sets (Dict[str, InstructionSet]): A dictionary of instruction sets associated with the branch.
|
41
|
+
tool_manager (ToolManager): The tool manager for managing tools within the branch.
|
42
|
+
service (Optional[BaseService]): The service associated with the branch.
|
43
|
+
llmconfig (Optional[Dict]): Configuration for the LLM (Large Language Model) service.
|
44
|
+
name (Optional[str]): The name of the branch.
|
45
|
+
pending_ins (Dict): Dictionary to store pending inputs for the branch.
|
46
|
+
pending_outs (Deque): Queue to store pending outputs for the branch.
|
47
|
+
logger (Optional[DataLogger]): Logger for data logging.
|
48
|
+
status_tracker (StatusTracker): Tracks the status of the branch.
|
49
|
+
"""
|
31
50
|
_cols = ["node_id", "role", "sender", "timestamp", "content"]
|
32
51
|
|
33
52
|
def __init__(self, name: Optional[str] = None, messages: Optional[pd.DataFrame] = None,
|
lionagi/core/session.py
CHANGED
@@ -5,14 +5,26 @@ import pandas as pd
|
|
5
5
|
from lionagi.utils.sys_util import create_path
|
6
6
|
from lionagi.utils import to_list, to_df
|
7
7
|
from lionagi.schema import Tool
|
8
|
-
from lionagi.
|
8
|
+
from lionagi.services.base_service import BaseService
|
9
9
|
from lionagi.core.branch import Branch
|
10
10
|
from lionagi.core.branch_manager import BranchManager
|
11
11
|
from lionagi.core.messages import Instruction, System
|
12
12
|
|
13
13
|
|
14
14
|
class Session:
|
15
|
-
|
15
|
+
"""
|
16
|
+
Represents a session for managing conversations and branches.
|
17
|
+
|
18
|
+
A `Session` encapsulates the state and behavior for managing conversations and their branches.
|
19
|
+
It provides functionality for initializing and managing conversation sessions, including setting up default
|
20
|
+
branches, configuring language learning models, managing tools, and handling session data logging.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
branches (Dict[str, Branch]): A dictionary of branch instances associated with the session.
|
24
|
+
service (Optional[BaseService]): The external service instance associated with the session.
|
25
|
+
branch_manager (BranchManager): The manager for handling branches within the session.
|
26
|
+
logger (Optional[Any]): The logger instance for session data logging.
|
27
|
+
"""
|
16
28
|
def __init__(
|
17
29
|
self,
|
18
30
|
system: Optional[Union[str, System]] = None,
|
lionagi/schema/__init__.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1
|
-
from .
|
2
|
-
from .base_tool import Tool
|
3
|
-
from .data_logger import DataLogger
|
4
|
-
from .data_node import DataNode
|
1
|
+
from .base_schema import BaseNode, Tool, DataLogger, DataNode
|
5
2
|
|
6
3
|
__all__ = [
|
7
|
-
"BaseNode",
|
8
|
-
"
|
9
|
-
"
|
10
|
-
"
|
4
|
+
"BaseNode",
|
5
|
+
"Tool",
|
6
|
+
"DataLogger",
|
7
|
+
"DataNode"
|
11
8
|
]
|