lionagi 0.0.111__py3-none-any.whl → 0.0.113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +7 -2
- lionagi/bridge/__init__.py +7 -0
- lionagi/bridge/langchain.py +131 -0
- lionagi/bridge/llama_index.py +157 -0
- lionagi/configs/__init__.py +7 -0
- lionagi/configs/oai_configs.py +49 -0
- lionagi/configs/openrouter_config.py +49 -0
- lionagi/core/__init__.py +15 -0
- lionagi/{session/conversation.py → core/conversations.py} +10 -17
- lionagi/core/flows.py +1 -0
- lionagi/core/instruction_sets.py +1 -0
- lionagi/{session/message.py → core/messages.py} +5 -5
- lionagi/core/sessions.py +262 -0
- lionagi/datastore/__init__.py +1 -0
- lionagi/datastore/chroma.py +1 -0
- lionagi/datastore/deeplake.py +1 -0
- lionagi/datastore/elasticsearch.py +1 -0
- lionagi/datastore/lantern.py +1 -0
- lionagi/datastore/pinecone.py +1 -0
- lionagi/datastore/postgres.py +1 -0
- lionagi/datastore/qdrant.py +1 -0
- lionagi/loader/__init__.py +12 -0
- lionagi/loader/chunker.py +157 -0
- lionagi/loader/reader.py +124 -0
- lionagi/objs/__init__.py +7 -0
- lionagi/objs/messenger.py +163 -0
- lionagi/objs/tool_registry.py +247 -0
- lionagi/schema/__init__.py +11 -0
- lionagi/schema/base_condition.py +1 -0
- lionagi/schema/base_schema.py +239 -0
- lionagi/schema/base_tool.py +9 -0
- lionagi/schema/data_logger.py +94 -0
- lionagi/services/__init__.py +14 -0
- lionagi/services/anthropic.py +1 -0
- lionagi/services/anyscale.py +0 -0
- lionagi/services/azure.py +1 -0
- lionagi/{api/oai_service.py → services/base_api_service.py} +74 -148
- lionagi/services/bedrock.py +0 -0
- lionagi/services/chatcompletion.py +48 -0
- lionagi/services/everlyai.py +0 -0
- lionagi/services/gemini.py +0 -0
- lionagi/services/gpt4all.py +0 -0
- lionagi/services/huggingface.py +0 -0
- lionagi/services/litellm.py +1 -0
- lionagi/services/localai.py +0 -0
- lionagi/services/mistralai.py +0 -0
- lionagi/services/oai.py +34 -0
- lionagi/services/ollama.py +1 -0
- lionagi/services/openllm.py +0 -0
- lionagi/services/openrouter.py +32 -0
- lionagi/services/perplexity.py +0 -0
- lionagi/services/predibase.py +0 -0
- lionagi/services/rungpt.py +0 -0
- lionagi/services/service_objs.py +282 -0
- lionagi/services/vllm.py +0 -0
- lionagi/services/xinference.py +0 -0
- lionagi/structure/__init__.py +7 -0
- lionagi/structure/relationship.py +128 -0
- lionagi/structure/structure.py +160 -0
- lionagi/tests/__init__.py +0 -0
- lionagi/tests/test_flatten_util.py +426 -0
- lionagi/tools/__init__.py +0 -0
- lionagi/tools/coder.py +1 -0
- lionagi/tools/planner.py +1 -0
- lionagi/tools/prompter.py +1 -0
- lionagi/tools/sandbox.py +1 -0
- lionagi/tools/scorer.py +1 -0
- lionagi/tools/summarizer.py +1 -0
- lionagi/tools/validator.py +1 -0
- lionagi/utils/__init__.py +46 -8
- lionagi/utils/api_util.py +63 -416
- lionagi/utils/call_util.py +347 -0
- lionagi/utils/flat_util.py +540 -0
- lionagi/utils/io_util.py +102 -0
- lionagi/utils/load_utils.py +190 -0
- lionagi/utils/sys_util.py +85 -660
- lionagi/utils/tool_util.py +82 -199
- lionagi/utils/type_util.py +81 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.111.dist-info → lionagi-0.0.113.dist-info}/METADATA +44 -15
- lionagi-0.0.113.dist-info/RECORD +84 -0
- lionagi/api/__init__.py +0 -8
- lionagi/api/oai_config.py +0 -16
- lionagi/session/__init__.py +0 -7
- lionagi/session/session.py +0 -380
- lionagi/utils/doc_util.py +0 -331
- lionagi/utils/log_util.py +0 -86
- lionagi-0.0.111.dist-info/RECORD +0 -20
- {lionagi-0.0.111.dist-info → lionagi-0.0.113.dist-info}/LICENSE +0 -0
- {lionagi-0.0.111.dist-info → lionagi-0.0.113.dist-info}/WHEEL +0 -0
- {lionagi-0.0.111.dist-info → lionagi-0.0.113.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1,94 @@ | |
| 1 | 
            +
            from collections import deque
         | 
| 2 | 
            +
            from typing import List, Optional
         | 
| 3 | 
            +
            from ..utils.sys_util import create_path
         | 
| 4 | 
            +
            from ..utils.io_util import to_csv
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class DataLogger:
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                A class for logging data entries and exporting them as CSV files.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                This class provides functionality to log data entries in a deque and 
         | 
| 12 | 
            +
                supports exporting the logged data to a CSV file. The DataLogger can 
         | 
| 13 | 
            +
                be configured to use a specific directory for saving files.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                Attributes:
         | 
| 16 | 
            +
                    dir (Optional[str]): 
         | 
| 17 | 
            +
                        The default directory where CSV files will be saved.
         | 
| 18 | 
            +
                    log (deque): 
         | 
| 19 | 
            +
                        A deque object that stores the logged data entries.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                Methods:
         | 
| 22 | 
            +
                    __call__:
         | 
| 23 | 
            +
                        Adds an entry to the log.
         | 
| 24 | 
            +
                    to_csv:
         | 
| 25 | 
            +
                        Exports the logged data to a CSV file and clears the log.
         | 
| 26 | 
            +
                    set_dir:
         | 
| 27 | 
            +
                        Sets the default directory for saving CSV files.
         | 
| 28 | 
            +
                """    
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def __init__(self, dir= None, log: list = None) -> None:
         | 
| 31 | 
            +
                    """
         | 
| 32 | 
            +
                    Initializes the DataLogger with an optional directory and initial log.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    Parameters:
         | 
| 35 | 
            +
                        dir (Optional[str]): 
         | 
| 36 | 
            +
                            The directory where CSV files will be saved. Defaults to None.
         | 
| 37 | 
            +
                        log (Optional[List]): 
         | 
| 38 | 
            +
                            An initial list of log entries. Defaults to an empty list.
         | 
| 39 | 
            +
                    """        
         | 
| 40 | 
            +
                    self.dir = dir
         | 
| 41 | 
            +
                    self.log = deque(log) if log else deque()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __call__(self, entry):
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    Adds a new entry to the log.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    Parameters:
         | 
| 48 | 
            +
                        entry: 
         | 
| 49 | 
            +
                            The data entry to be added to the log.
         | 
| 50 | 
            +
                    """        
         | 
| 51 | 
            +
                    self.log.append(entry)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def to_csv(self, filename: str, dir: Optional[str] = None, verbose: bool = True, 
         | 
| 54 | 
            +
                           timestamp: bool = True, dir_exist_ok: bool = True, file_exist_ok: bool = False) -> None:
         | 
| 55 | 
            +
                    """
         | 
| 56 | 
            +
                    Exports the logged data to a CSV file and optionally clears the log.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    Parameters:
         | 
| 59 | 
            +
                        filename (str): 
         | 
| 60 | 
            +
                            The name of the CSV file.
         | 
| 61 | 
            +
                        dir (Optional[str]): 
         | 
| 62 | 
            +
                            The directory to save the file. Defaults to the instance's dir attribute.
         | 
| 63 | 
            +
                        verbose (bool): 
         | 
| 64 | 
            +
                            If True, prints a message upon completion. Defaults to True.
         | 
| 65 | 
            +
                        timestamp (bool): 
         | 
| 66 | 
            +
                            If True, appends a timestamp to the filename. Defaults to True.
         | 
| 67 | 
            +
                        dir_exist_ok (bool): 
         | 
| 68 | 
            +
                            If True, will not raise an error if the directory already exists. Defaults to True.
         | 
| 69 | 
            +
                        file_exist_ok (bool): 
         | 
| 70 | 
            +
                            If True, overwrites the file if it exists. Defaults to False.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    Side Effects:
         | 
| 73 | 
            +
                        Clears the log after saving the CSV file.
         | 
| 74 | 
            +
                        Prints a message indicating the save location and number of logs saved if verbose is True.
         | 
| 75 | 
            +
                    """        
         | 
| 76 | 
            +
                    dir = dir or self.dir
         | 
| 77 | 
            +
                    filepath = create_path(
         | 
| 78 | 
            +
                        dir=dir, filename=filename, timestamp=timestamp, dir_exist_ok=dir_exist_ok)
         | 
| 79 | 
            +
                    to_csv(list(self.log), filepath, file_exist_ok=file_exist_ok)
         | 
| 80 | 
            +
                    n_logs = len(list(self.log))
         | 
| 81 | 
            +
                    self.log = deque()
         | 
| 82 | 
            +
                    if verbose:
         | 
| 83 | 
            +
                        print(f"{n_logs} logs saved to {filepath}")
         | 
| 84 | 
            +
                        
         | 
| 85 | 
            +
                def set_dir(self, dir: str) -> None:
         | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    Sets the default directory for saving CSV files.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    Parameters:
         | 
| 90 | 
            +
                        dir (str): 
         | 
| 91 | 
            +
                            The directory to be set as the default for saving files.
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                    self.dir = dir
         | 
| 94 | 
            +
                    
         | 
| @@ -0,0 +1,14 @@ | |
| 1 | 
            +
            from .chatcompletion import ChatCompletion
         | 
| 2 | 
            +
            from .base_api_service import BaseAPIService, BaseAPIRateLimiter
         | 
| 3 | 
            +
            from .oai import OpenAIService
         | 
| 4 | 
            +
            from .openrouter import OpenRouterService
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            __all__ = [
         | 
| 9 | 
            +
                "BaseAPIService",
         | 
| 10 | 
            +
                "OpenAIService",
         | 
| 11 | 
            +
                "OpenRouterService",
         | 
| 12 | 
            +
                "ChatCompletion", 
         | 
| 13 | 
            +
                "BaseAPIRateLimiter"
         | 
| 14 | 
            +
            ]
         | 
| @@ -0,0 +1 @@ | |
| 1 | 
            +
            # TODO
         | 
| 
            File without changes
         | 
| @@ -0,0 +1 @@ | |
| 1 | 
            +
            # TODO
         | 
| @@ -1,47 +1,17 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
            import dotenv
         | 
| 1 | 
            +
            import re
         | 
| 3 2 | 
             
            import asyncio
         | 
| 4 | 
            -
            import  | 
| 3 | 
            +
            import os
         | 
| 5 4 | 
             
            import tiktoken
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            from ..utils.api_util import AsyncQueue, StatusTracker, RateLimiter, BaseAPIService
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            class OpenAIRateLimiter(RateLimiter):
         | 
| 14 | 
            -
                """
         | 
| 15 | 
            -
                A specialized RateLimiter for managing requests to the OpenAI API.
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                Extends the generic RateLimiter to enforce specific rate-limiting rules and limits
         | 
| 18 | 
            -
                as required by the OpenAI API. This includes maximum requests and tokens per minute
         | 
| 19 | 
            -
                and replenishing these limits at regular intervals.
         | 
| 20 | 
            -
             | 
| 21 | 
            -
                Attributes:
         | 
| 22 | 
            -
                    max_requests_per_minute (int):
         | 
| 23 | 
            -
                        Maximum number of requests allowed per minute.
         | 
| 24 | 
            -
                    max_tokens_per_minute (int):
         | 
| 25 | 
            -
                        Maximum number of tokens allowed per minute.
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import aiohttp
         | 
| 7 | 
            +
            from typing import Generator, NoReturn, Dict, Any, Optional
         | 
| 8 | 
            +
            from .service_objs import BaseService, RateLimiter, StatusTracker, AsyncQueue
         | 
| 26 9 |  | 
| 27 | 
            -
             | 
| 28 | 
            -
                    rate_limit_replenisher:
         | 
| 29 | 
            -
                        Coroutine to replenish rate limits over time.
         | 
| 30 | 
            -
                    calculate_num_token:
         | 
| 31 | 
            -
                        Calculates the required tokens for a request.
         | 
| 32 | 
            -
                """
         | 
| 10 | 
            +
            class BaseAPIRateLimiter(RateLimiter):
         | 
| 33 11 |  | 
| 34 12 | 
             
                def __init__(
         | 
| 35 13 | 
             
                    self, max_requests_per_minute: int, max_tokens_per_minute: int
         | 
| 36 14 | 
             
                ) -> None:
         | 
| 37 | 
            -
                    """
         | 
| 38 | 
            -
                    Initializes the rate limiter with specific limits for OpenAI API.
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                    Parameters:
         | 
| 41 | 
            -
                        max_requests_per_minute (int): The maximum number of requests allowed per minute.
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                        max_tokens_per_minute (int): The maximum number of tokens that can accumulate per minute.
         | 
| 44 | 
            -
                    """
         | 
| 45 15 | 
             
                    super().__init__(max_requests_per_minute, max_tokens_per_minute)
         | 
| 46 16 |  | 
| 47 17 | 
             
                @classmethod
         | 
| @@ -155,114 +125,59 @@ class OpenAIRateLimiter(RateLimiter): | |
| 155 125 | 
             
                        )
         | 
| 156 126 |  | 
| 157 127 |  | 
| 158 | 
            -
            class  | 
| 159 | 
            -
                 | 
| 160 | 
            -
                 | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 | 
            -
                     | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
                     | 
| 170 | 
            -
                 | 
| 171 | 
            -
             | 
| 172 | 
            -
                 | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
                     | 
| 176 | 
            -
             | 
| 177 | 
            -
                     | 
| 178 | 
            -
             | 
| 179 | 
            -
                     | 
| 180 | 
            -
             | 
| 181 | 
            -
                     | 
| 182 | 
            -
             | 
| 183 | 
            -
                     | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
                     | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
                        max_attempts (int): The maximum number of attempts for calling an API endpoint.
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                        status_tracker (Optional[StatusTracker]): Tracker for API call outcomes.
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                        ratelimiter (Optional[OpenAIRateLimiter]): Rate limiter for OpenAI's limits.
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                        queue (Optional[AsyncQueue]): Queue for managing asynchronous API calls.
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                    Example:
         | 
| 202 | 
            -
                        >>> service = OpenAIService(
         | 
| 203 | 
            -
                        ...     api_key="api-key-123",
         | 
| 204 | 
            -
                        ...     token_encoding_name="utf-8",
         | 
| 205 | 
            -
                        ...     max_attempts=5,
         | 
| 206 | 
            -
                        ...     status_tracker=None,
         | 
| 207 | 
            -
                        ...     rate_limiter=OpenAIRateLimiter(100, 200),
         | 
| 208 | 
            -
                        ...     queue=AsyncQueue()
         | 
| 209 | 
            -
                        ... )
         | 
| 210 | 
            -
                        # Service is configured for interacting with OpenAI API.
         | 
| 211 | 
            -
                    """
         | 
| 212 | 
            -
                    api_key = api_key or os.getenv("OPENAI_API_KEY")
         | 
| 213 | 
            -
                    super().__init__(
         | 
| 214 | 
            -
                        api_key,
         | 
| 215 | 
            -
                        token_encoding_name,
         | 
| 216 | 
            -
                        max_attempts,
         | 
| 217 | 
            -
                        max_requests_per_minute,
         | 
| 218 | 
            -
                        max_tokens_per_minute,
         | 
| 219 | 
            -
                        ratelimiter,
         | 
| 220 | 
            -
                        status_tracker,
         | 
| 221 | 
            -
                        queue,
         | 
| 222 | 
            -
                    )
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                async def call_api(
         | 
| 225 | 
            -
                    self, http_session, endpoint, payload: Dict[str, any] = None
         | 
| 226 | 
            -
                ) -> Optional[Dict[str, any]]:
         | 
| 227 | 
            -
                    """
         | 
| 228 | 
            -
                    Call an OpenAI API endpoint with a specific payload and handle the response.
         | 
| 229 | 
            -
             | 
| 230 | 
            -
                    Parameters:
         | 
| 231 | 
            -
                        http_session: The session object for making HTTP requests.
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                        endpoint (str): The full URL of the OpenAI API endpoint to be called.
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                        payload (Dict[str, any]): The payload to send with the API request.
         | 
| 236 | 
            -
             | 
| 237 | 
            -
                    Returns:
         | 
| 238 | 
            -
                        Optional[Dict[str, any]]: The response data from the API call or None if the call fails.
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                    Raises:
         | 
| 241 | 
            -
                        asyncio.TimeoutError: If the request attempts exceed the configured maximum limit.
         | 
| 128 | 
            +
            class BaseAPIService(BaseService):
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                def __init__(self, api_key: str = None, 
         | 
| 131 | 
            +
                             status_tracker = None,
         | 
| 132 | 
            +
                             queue = None, endpoint=None, schema=None, 
         | 
| 133 | 
            +
                             ratelimiter=None, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
         | 
| 134 | 
            +
                    self.api_key = api_key
         | 
| 135 | 
            +
                    self.status_tracker = status_tracker or StatusTracker()
         | 
| 136 | 
            +
                    self.queue = queue or AsyncQueue()
         | 
| 137 | 
            +
                    self.endpoint=endpoint
         | 
| 138 | 
            +
                    self.schema = schema
         | 
| 139 | 
            +
                    self.rate_limiter = ratelimiter(max_requests_per_minute, max_tokens_per_minute)
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                @staticmethod                    
         | 
| 142 | 
            +
                def api_methods(http_session, method="post"):
         | 
| 143 | 
            +
                    if method not in ["post", "delete", "head", "options", "patch"]:
         | 
| 144 | 
            +
                        raise ValueError("Invalid request, method must be in ['post', 'delete', 'head', 'options', 'patch']")
         | 
| 145 | 
            +
                    elif method == "post":
         | 
| 146 | 
            +
                        return http_session.post
         | 
| 147 | 
            +
                    elif method == "delete":
         | 
| 148 | 
            +
                        return http_session.delete
         | 
| 149 | 
            +
                    elif method == "head":
         | 
| 150 | 
            +
                        return http_session.head
         | 
| 151 | 
            +
                    elif method == "options":
         | 
| 152 | 
            +
                        return http_session.options
         | 
| 153 | 
            +
                    elif method == "patch":
         | 
| 154 | 
            +
                        return http_session.patch
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                @staticmethod
         | 
| 157 | 
            +
                def api_endpoint_from_url(request_url: str) -> str:
         | 
| 158 | 
            +
                    match = re.search(r"^https://[^/]+/v\d+/(.+)$", request_url)
         | 
| 159 | 
            +
                    if match:
         | 
| 160 | 
            +
                        return match.group(1)
         | 
| 161 | 
            +
                    else:
         | 
| 162 | 
            +
                        return ""
         | 
| 242 163 |  | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
                         | 
| 248 | 
            -
                         | 
| 249 | 
            -
                        ... )
         | 
| 250 | 
            -
                        # Calls the specified API endpoint with the given payload.
         | 
| 251 | 
            -
                    """
         | 
| 252 | 
            -
                    endpoint = self.api_endpoint_from_url(self.base_url + endpoint)
         | 
| 164 | 
            +
                @staticmethod
         | 
| 165 | 
            +
                def task_id_generator_function() -> Generator[int, None, None]:
         | 
| 166 | 
            +
                    task_id = 0
         | 
| 167 | 
            +
                    while True:
         | 
| 168 | 
            +
                        yield task_id
         | 
| 169 | 
            +
                        task_id += 1
         | 
| 253 170 |  | 
| 171 | 
            +
                async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
         | 
| 172 | 
            +
                    endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
         | 
| 173 | 
            +
                    
         | 
| 254 174 | 
             
                    while True:
         | 
| 255 | 
            -
                        if  | 
| 256 | 
            -
                            self.rate_limiter.available_request_capacity < 1
         | 
| 257 | 
            -
                            or self.rate_limiter.available_token_capacity < 10
         | 
| 258 | 
            -
                        ):  # Minimum token count
         | 
| 175 | 
            +
                        if self.rate_limiter.available_request_capacity < 1 or self.rate_limiter.available_token_capacity < 10:  # Minimum token count
         | 
| 259 176 | 
             
                            await asyncio.sleep(1)  # Wait for capacity
         | 
| 260 177 | 
             
                            continue
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                        required_tokens = self.rate_limiter.calculate_num_token(
         | 
| 263 | 
            -
             | 
| 264 | 
            -
                        )
         | 
| 265 | 
            -
             | 
| 178 | 
            +
                        
         | 
| 179 | 
            +
                        required_tokens = self.rate_limiter.calculate_num_token(payload, endpoint_, self.token_encoding_name)
         | 
| 180 | 
            +
                        
         | 
| 266 181 | 
             
                        if self.rate_limiter.available_token_capacity >= required_tokens:
         | 
| 267 182 | 
             
                            self.rate_limiter.available_request_capacity -= 1
         | 
| 268 183 | 
             
                            self.rate_limiter.available_token_capacity -= required_tokens
         | 
| @@ -272,10 +187,9 @@ class OpenAIService(BaseAPIService): | |
| 272 187 |  | 
| 273 188 | 
             
                            while attempts_left > 0:
         | 
| 274 189 | 
             
                                try:
         | 
| 275 | 
            -
                                     | 
| 276 | 
            -
             | 
| 277 | 
            -
                                        headers=request_headers,
         | 
| 278 | 
            -
                                        json=payload,
         | 
| 190 | 
            +
                                    method = self.api_methods(http_session, method)                         
         | 
| 191 | 
            +
                                    async with method(
         | 
| 192 | 
            +
                                        url=(self.base_url+endpoint_), headers=request_headers, json=payload
         | 
| 279 193 | 
             
                                    ) as response:
         | 
| 280 194 | 
             
                                        response_json = await response.json()
         | 
| 281 195 |  | 
| @@ -285,9 +199,7 @@ class OpenAIService(BaseAPIService): | |
| 285 199 | 
             
                                            )
         | 
| 286 200 | 
             
                                            attempts_left -= 1
         | 
| 287 201 |  | 
| 288 | 
            -
                                            if "Rate limit" in response_json["error"].get(
         | 
| 289 | 
            -
                                                "message", ""
         | 
| 290 | 
            -
                                            ):
         | 
| 202 | 
            +
                                            if "Rate limit" in response_json["error"].get("message", ""):
         | 
| 291 203 | 
             
                                                await asyncio.sleep(15)
         | 
| 292 204 | 
             
                                        else:
         | 
| 293 205 | 
             
                                            return response_json
         | 
| @@ -298,4 +210,18 @@ class OpenAIService(BaseAPIService): | |
| 298 210 | 
             
                            logging.error("API call failed after all attempts.")
         | 
| 299 211 | 
             
                            break
         | 
| 300 212 | 
             
                        else:
         | 
| 301 | 
            -
                            await asyncio.sleep(1) | 
| 213 | 
            +
                            await asyncio.sleep(1)
         | 
| 214 | 
            +
                
         | 
| 215 | 
            +
                async def _serve(self, payload, endpoint_="chat/completions", method="post"):
         | 
| 216 | 
            +
                     
         | 
| 217 | 
            +
                    async def call_api():
         | 
| 218 | 
            +
                        async with aiohttp.ClientSession() as http_session:
         | 
| 219 | 
            +
                            completion = await self._call_api(http_session=http_session, endpoint_=endpoint_, payload=payload, method=method)
         | 
| 220 | 
            +
                            return completion
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    try:
         | 
| 223 | 
            +
                        return await call_api()
         | 
| 224 | 
            +
                    except Exception as e:
         | 
| 225 | 
            +
                            self.status_tracker.num_tasks_failed += 1
         | 
| 226 | 
            +
                            raise e
         | 
| 227 | 
            +
                        
         | 
| 
            File without changes
         | 
| @@ -0,0 +1,48 @@ | |
| 1 | 
            +
            import abc
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class BaseEndpoint(abc.ABC):
         | 
| 5 | 
            +
                endpoint: str = abc.abstractproperty()
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                @abc.abstractmethod
         | 
| 8 | 
            +
                def create_payload(self, **kwargs):
         | 
| 9 | 
            +
                    """
         | 
| 10 | 
            +
                    Create a payload for the request based on configuration.
         | 
| 11 | 
            +
                    
         | 
| 12 | 
            +
                    Parameters:
         | 
| 13 | 
            +
                        **kwargs: Additional keyword arguments for configuration.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    Returns:
         | 
| 16 | 
            +
                        dict: The payload for the request.
         | 
| 17 | 
            +
                    """
         | 
| 18 | 
            +
                    pass
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                @abc.abstractmethod
         | 
| 21 | 
            +
                def process_response(self, response):
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
                    Process the response from the API call.
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                    Parameters:
         | 
| 26 | 
            +
                        response: The response to process.
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    pass
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ChatCompletion(BaseEndpoint):
         | 
| 32 | 
            +
                endpoint: str = "chat/completion"
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                @classmethod
         | 
| 35 | 
            +
                def create_payload(scls, messages, llmconfig, schema, **kwargs):
         | 
| 36 | 
            +
                    config = {**llmconfig, **kwargs}
         | 
| 37 | 
            +
                    payload = {"messages": messages}
         | 
| 38 | 
            +
                    for key in schema['required']:
         | 
| 39 | 
            +
                        payload.update({key: config[key]})
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    for key in schema['optional']:
         | 
| 42 | 
            +
                        if bool(config[key]) is True and str(config[key]).lower() != "none":
         | 
| 43 | 
            +
                            payload.update({key: config[key]})
         | 
| 44 | 
            +
                    return payload
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                def process_response(self, session, payload, completion):
         | 
| 47 | 
            +
                    ...
         | 
| 48 | 
            +
                    
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| @@ -0,0 +1 @@ | |
| 1 | 
            +
            # TODO
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
    
        lionagi/services/oai.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
| 1 | 
            +
            from os import getenv
         | 
| 2 | 
            +
            import dotenv
         | 
| 3 | 
            +
            from .base_api_service import BaseAPIService, BaseAPIRateLimiter
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            dotenv.load_dotenv()
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class OpenAIService(BaseAPIService):
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                base_url = "https://api.openai.com/v1/"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __init__(
         | 
| 12 | 
            +
                    self,
         | 
| 13 | 
            +
                    api_key: str = None,
         | 
| 14 | 
            +
                    token_encoding_name: str = "cl100k_base",
         | 
| 15 | 
            +
                    max_attempts: int = 3,
         | 
| 16 | 
            +
                    max_requests_per_minute: int = 500,
         | 
| 17 | 
            +
                    max_tokens_per_minute: int = 150_000,
         | 
| 18 | 
            +
                    ratelimiter = BaseAPIRateLimiter ,
         | 
| 19 | 
            +
                    status_tracker = None,
         | 
| 20 | 
            +
                    queue = None,
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    super().__init__(
         | 
| 23 | 
            +
                        api_key = api_key or getenv("OPENAI_API_KEY"),
         | 
| 24 | 
            +
                        status_tracker = status_tracker,
         | 
| 25 | 
            +
                        queue = queue,
         | 
| 26 | 
            +
                        ratelimiter=ratelimiter,
         | 
| 27 | 
            +
                        max_requests_per_minute=max_requests_per_minute, 
         | 
| 28 | 
            +
                        max_tokens_per_minute=max_tokens_per_minute),
         | 
| 29 | 
            +
                    self.token_encoding_name=token_encoding_name
         | 
| 30 | 
            +
                    self.max_attempts = max_attempts
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                async def serve(self, payload, endpoint_="chat/completions", method="post"):
         | 
| 33 | 
            +
                    return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
         | 
| 34 | 
            +
                
         | 
| @@ -0,0 +1 @@ | |
| 1 | 
            +
            # TODO
         | 
| 
            File without changes
         | 
| @@ -0,0 +1,32 @@ | |
| 1 | 
            +
            from os import getenv
         | 
| 2 | 
            +
            from .base_api_service import BaseAPIService, BaseAPIRateLimiter
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            class OpenRouterService(BaseAPIService):
         | 
| 5 | 
            +
                _key_scheme = "OPENROUTER_API_KEY"
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                base_url = "https://openrouter.ai/api/v1/"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def __init__(
         | 
| 10 | 
            +
                    self,
         | 
| 11 | 
            +
                    api_key: str = None,
         | 
| 12 | 
            +
                    token_encoding_name: str = "cl100k_base",
         | 
| 13 | 
            +
                    max_attempts: int = 3,
         | 
| 14 | 
            +
                    max_requests_per_minute: int = 500,
         | 
| 15 | 
            +
                    max_tokens_per_minute: int = 150_000,
         | 
| 16 | 
            +
                    ratelimiter = BaseAPIRateLimiter ,
         | 
| 17 | 
            +
                    status_tracker = None,
         | 
| 18 | 
            +
                    queue = None,
         | 
| 19 | 
            +
                ):
         | 
| 20 | 
            +
                    super().__init__(
         | 
| 21 | 
            +
                        api_key = api_key or getenv(self._key_scheme),
         | 
| 22 | 
            +
                        status_tracker = status_tracker,
         | 
| 23 | 
            +
                        queue = queue,
         | 
| 24 | 
            +
                        ratelimiter=ratelimiter,
         | 
| 25 | 
            +
                        max_requests_per_minute=max_requests_per_minute, 
         | 
| 26 | 
            +
                        max_tokens_per_minute=max_tokens_per_minute),
         | 
| 27 | 
            +
                    self.token_encoding_name=token_encoding_name
         | 
| 28 | 
            +
                    self.max_attempts = max_attempts
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                async def serve(self, payload, endpoint_="chat/completions"):
         | 
| 31 | 
            +
                    return await self._serve(payload=payload, endpoint_=endpoint_)
         | 
| 32 | 
            +
                
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |