lionagi 0.0.306__py3-none-any.whl → 0.0.307__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. lionagi/__init__.py +2 -5
  2. lionagi/core/__init__.py +7 -5
  3. lionagi/core/agent/__init__.py +3 -0
  4. lionagi/core/agent/base_agent.py +10 -12
  5. lionagi/core/branch/__init__.py +4 -0
  6. lionagi/core/branch/base_branch.py +81 -81
  7. lionagi/core/branch/branch.py +16 -28
  8. lionagi/core/branch/branch_flow_mixin.py +3 -7
  9. lionagi/core/branch/executable_branch.py +86 -56
  10. lionagi/core/branch/util.py +77 -162
  11. lionagi/core/{flow/direct → direct}/__init__.py +1 -1
  12. lionagi/core/{flow/direct/predict.py → direct/parallel_predict.py} +39 -17
  13. lionagi/core/direct/parallel_react.py +0 -0
  14. lionagi/core/direct/parallel_score.py +0 -0
  15. lionagi/core/direct/parallel_select.py +0 -0
  16. lionagi/core/direct/parallel_sentiment.py +0 -0
  17. lionagi/core/direct/predict.py +174 -0
  18. lionagi/core/{flow/direct → direct}/react.py +2 -2
  19. lionagi/core/{flow/direct → direct}/score.py +28 -23
  20. lionagi/core/{flow/direct → direct}/select.py +48 -45
  21. lionagi/core/direct/utils.py +83 -0
  22. lionagi/core/flow/monoflow/ReAct.py +6 -5
  23. lionagi/core/flow/monoflow/__init__.py +9 -0
  24. lionagi/core/flow/monoflow/chat.py +10 -10
  25. lionagi/core/flow/monoflow/chat_mixin.py +11 -10
  26. lionagi/core/flow/monoflow/followup.py +6 -5
  27. lionagi/core/flow/polyflow/__init__.py +1 -0
  28. lionagi/core/flow/polyflow/chat.py +15 -3
  29. lionagi/core/mail/mail_manager.py +18 -19
  30. lionagi/core/mail/schema.py +5 -4
  31. lionagi/core/messages/schema.py +18 -20
  32. lionagi/core/prompt/__init__.py +0 -0
  33. lionagi/core/prompt/prompt_template.py +0 -0
  34. lionagi/core/schema/__init__.py +2 -2
  35. lionagi/core/schema/action_node.py +11 -3
  36. lionagi/core/schema/base_mixin.py +56 -59
  37. lionagi/core/schema/base_node.py +35 -38
  38. lionagi/core/schema/condition.py +24 -0
  39. lionagi/core/schema/data_logger.py +96 -99
  40. lionagi/core/schema/data_node.py +19 -19
  41. lionagi/core/schema/prompt_template.py +0 -0
  42. lionagi/core/schema/structure.py +171 -169
  43. lionagi/core/session/__init__.py +1 -3
  44. lionagi/core/session/session.py +196 -214
  45. lionagi/core/tool/tool_manager.py +95 -103
  46. lionagi/integrations/__init__.py +1 -3
  47. lionagi/integrations/bridge/langchain_/documents.py +17 -18
  48. lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
  49. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
  50. lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
  51. lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
  52. lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
  53. lionagi/integrations/config/openrouter_configs.py +0 -1
  54. lionagi/integrations/provider/oai.py +26 -26
  55. lionagi/integrations/provider/services.py +38 -38
  56. lionagi/libs/__init__.py +34 -1
  57. lionagi/libs/ln_api.py +211 -221
  58. lionagi/libs/ln_async.py +53 -60
  59. lionagi/libs/ln_convert.py +118 -120
  60. lionagi/libs/ln_dataframe.py +32 -33
  61. lionagi/libs/ln_func_call.py +334 -342
  62. lionagi/libs/ln_nested.py +99 -107
  63. lionagi/libs/ln_parse.py +161 -165
  64. lionagi/libs/sys_util.py +52 -52
  65. lionagi/tests/test_core/test_session.py +254 -266
  66. lionagi/tests/test_core/test_session_base_util.py +299 -300
  67. lionagi/tests/test_core/test_tool_manager.py +70 -74
  68. lionagi/tests/test_libs/test_nested.py +2 -7
  69. lionagi/tests/test_libs/test_parse.py +2 -2
  70. lionagi/version.py +1 -1
  71. {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
  72. lionagi-0.0.307.dist-info/RECORD +115 -0
  73. lionagi/core/flow/direct/utils.py +0 -43
  74. lionagi-0.0.306.dist-info/RECORD +0 -106
  75. /lionagi/core/{flow/direct → direct}/sentiment.py +0 -0
  76. {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
  77. {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
  78. {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
lionagi/libs/ln_api.py CHANGED
@@ -29,26 +29,27 @@ class APIUtil:
29
29
  Returns the corresponding HTTP method function from the http_session object.
30
30
 
31
31
  Args:
32
- http_session: The session object from the aiohttp library.
33
- method: The HTTP method as a string.
32
+ http_session: The session object from the aiohttp library.
33
+ method: The HTTP method as a string.
34
34
 
35
35
  Returns:
36
- The Callable for the specified HTTP method.
36
+ The Callable for the specified HTTP method.
37
37
 
38
38
  Raises:
39
- ValueError: If the method is not one of the allowed ones.
39
+ ValueError: If the method is not one of the allowed ones.
40
40
 
41
41
  Examples:
42
- >>> session = aiohttp.ClientSession()
43
- >>> post_method = APIUtil.api_method(session, "post")
44
- >>> print(post_method)
45
- <bound method ClientSession._request of <aiohttp.client.ClientSession object at 0x...>>
42
+ >>> session = aiohttp.ClientSession()
43
+ >>> post_method = APIUtil.api_method(session, "post")
44
+ >>> print(post_method)
45
+ <bound method ClientSession._request of <aiohttp.client.ClientSession object at 0x...>>
46
46
  """
47
- if method not in ["post", "delete", "head", "options", "patch"]:
47
+ if method in {"post", "delete", "head", "options", "patch"}:
48
+ return getattr(http_session, method)
49
+ else:
48
50
  raise ValueError(
49
51
  "Invalid request, method must be in ['post', 'delete', 'head', 'options', 'patch']"
50
52
  )
51
- return getattr(http_session, method)
52
53
 
53
54
  @staticmethod
54
55
  def api_error(response_json: Mapping[str, Any]) -> bool:
@@ -56,18 +57,18 @@ class APIUtil:
56
57
  Checks if the given response_json dictionary contains an "error" key.
57
58
 
58
59
  Args:
59
- response_json: The JSON assistant_response as a dictionary.
60
+ response_json: The JSON assistant_response as a dictionary.
60
61
 
61
62
  Returns:
62
- True if there is an error, False otherwise.
63
+ True if there is an error, False otherwise.
63
64
 
64
65
  Examples:
65
- >>> response_json_with_error = {"error": "Something went wrong"}
66
- >>> APIUtil.api_error(response_json_with_error)
67
- True
68
- >>> response_json_without_error = {"result": "Success"}
69
- >>> APIUtil.api_error(response_json_without_error)
70
- False
66
+ >>> response_json_with_error = {"error": "Something went wrong"}
67
+ >>> APIUtil.api_error(response_json_with_error)
68
+ True
69
+ >>> response_json_without_error = {"result": "Success"}
70
+ >>> APIUtil.api_error(response_json_without_error)
71
+ False
71
72
  """
72
73
  if "error" in response_json:
73
74
  logging.warning(f"API call failed with error: {response_json['error']}")
@@ -80,18 +81,18 @@ class APIUtil:
80
81
  Checks if the error message in the response_json dictionary contains the phrase "Rate limit".
81
82
 
82
83
  Args:
83
- response_json: The JSON assistant_response as a dictionary.
84
+ response_json: The JSON assistant_response as a dictionary.
84
85
 
85
86
  Returns:
86
- True if the phrase "Rate limit" is found, False otherwise.
87
+ True if the phrase "Rate limit" is found, False otherwise.
87
88
 
88
89
  Examples:
89
- >>> response_json_with_rate_limit = {"error": {"message": "Rate limit exceeded"}}
90
- >>> api_rate_limit_error(response_json_with_rate_limit)
91
- True
92
- >>> response_json_without_rate_limit = {"error": {"message": "Another error"}}
93
- >>> api_rate_limit_error(response_json_without_rate_limit)
94
- False
90
+ >>> response_json_with_rate_limit = {"error": {"message": "Rate limit exceeded"}}
91
+ >>> api_rate_limit_error(response_json_with_rate_limit)
92
+ True
93
+ >>> response_json_without_rate_limit = {"error": {"message": "Another error"}}
94
+ >>> api_rate_limit_error(response_json_without_rate_limit)
95
+ False
95
96
  """
96
97
  return "Rate limit" in response_json.get("error", {}).get("message", "")
97
98
 
@@ -102,21 +103,21 @@ class APIUtil:
102
103
  Extracts the API endpoint from a given URL using a regular expression.
103
104
 
104
105
  Args:
105
- request_url: The full URL to the API endpoint.
106
+ request_url: The full URL to the API endpoint.
106
107
 
107
108
  Returns:
108
- The extracted endpoint or an empty string if the pattern does not match.
109
+ The extracted endpoint or an empty string if the pattern does not match.
109
110
 
110
111
  Examples:
111
- >>> valid_url = "https://api.example.com/v1/users"
112
- >>> api_endpoint_from_url(valid_url)
113
- 'users'
114
- >>> invalid_url = "https://api.example.com/users"
115
- >>> api_endpoint_from_url(invalid_url)
116
- ''
112
+ >>> valid_url = "https://api.example.com/v1/users"
113
+ >>> api_endpoint_from_url(valid_url)
114
+ 'users'
115
+ >>> invalid_url = "https://api.example.com/users"
116
+ >>> api_endpoint_from_url(invalid_url)
117
+ ''
117
118
  """
118
119
  match = re.search(r"^https://[^/]+(/.+)?/v\d+/(.+)$", request_url)
119
- return match.group(2) if match else ""
120
+ return match[2] if match else ""
120
121
 
121
122
  @staticmethod
122
123
  async def unified_api_call(
@@ -126,22 +127,22 @@ class APIUtil:
126
127
  Makes an API call and automatically retries on rate limit error.
127
128
 
128
129
  Args:
129
- http_session: The session object from the aiohttp library.
130
- method: The HTTP method as a string.
131
- url: The URL to which the request is made.
132
- **kwargs: Additional keyword arguments to pass to the API call.
130
+ http_session: The session object from the aiohttp library.
131
+ method: The HTTP method as a string.
132
+ url: The URL to which the request is made.
133
+ **kwargs: Additional keyword arguments to pass to the API call.
133
134
 
134
135
  Returns:
135
- The JSON assistant_response as a dictionary.
136
+ The JSON assistant_response as a dictionary.
136
137
 
137
138
  Examples:
138
- >>> session = aiohttp.ClientSession()
139
- >>> success_url = "https://api.example.com/v1/success"
140
- >>> print(await unified_api_call(session, 'get', success_url))
141
- {'result': 'Success'}
142
- >>> rate_limit_url = "https://api.example.com/v1/rate_limit"
143
- >>> print(await unified_api_call(session, 'get', rate_limit_url))
144
- {'error': {'message': 'Rate limit exceeded'}}
139
+ >>> session = aiohttp.ClientSession()
140
+ >>> success_url = "https://api.example.com/v1/success"
141
+ >>> print(await unified_api_call(session, 'get', success_url))
142
+ {'result': 'Success'}
143
+ >>> rate_limit_url = "https://api.example.com/v1/rate_limit"
144
+ >>> print(await unified_api_call(session, 'get', rate_limit_url))
145
+ {'error': {'message': 'Rate limit exceeded'}}
145
146
  """
146
147
  api_call = APIUtil.api_method(http_session, method)
147
148
  retry_count = 3
@@ -189,14 +190,14 @@ class APIUtil:
189
190
  Retries an API call on failure, with exponential backoff.
190
191
 
191
192
  Args:
192
- http_session: The aiohttp client session.
193
- url: The URL to make the API call.
194
- retries: The number of times to retry.
195
- backoff_factor: The backoff factor for retries.
196
- **kwargs: Additional arguments for the API call.
193
+ http_session: The aiohttp client session.
194
+ url: The URL to make the API call.
195
+ retries: The number of times to retry.
196
+ backoff_factor: The backoff factor for retries.
197
+ **kwargs: Additional arguments for the API call.
197
198
 
198
199
  Returns:
199
- The assistant_response from the API call, if successful; otherwise, None.
200
+ The assistant_response from the API call, if successful; otherwise, None.
200
201
  """
201
202
  for attempt in range(retries):
202
203
  try:
@@ -227,27 +228,27 @@ class APIUtil:
227
228
  Uploads a file to a specified URL with a retry mechanism for handling failures.
228
229
 
229
230
  Args:
230
- http_session: The HTTP session object to use for making the request.
231
- url: The URL to which the file will be uploaded.
232
- file_path: The path to the file that will be uploaded.
233
- param_name: The name of the parameter expected by the server for the file upload.
234
- additional_data: Additional data to be sent with the upload.
235
- retries: The number of times to retry the upload in case of failure.
231
+ http_session: The HTTP session object to use for making the request.
232
+ url: The URL to which the file will be uploaded.
233
+ file_path: The path to the file that will be uploaded.
234
+ param_name: The name of the parameter expected by the server for the file upload.
235
+ additional_data: Additional data to be sent with the upload.
236
+ retries: The number of times to retry the upload in case of failure.
236
237
 
237
238
  Returns:
238
- The HTTP assistant_response object.
239
+ The HTTP assistant_response object.
239
240
 
240
241
  Examples:
241
- >>> session = aiohttp.ClientSession()
242
- >>> assistant_response = await APIUtil.upload_file_with_retry(session, 'http://example.com/upload', 'path/to/file.txt')
243
- >>> assistant_response.status
244
- 200
242
+ >>> session = aiohttp.ClientSession()
243
+ >>> assistant_response = await APIUtil.upload_file_with_retry(session, 'http://example.com/upload', 'path/to/file.txt')
244
+ >>> assistant_response.status
245
+ 200
245
246
  """
246
247
  for attempt in range(retries):
247
248
  try:
248
249
  with open(file_path, "rb") as file:
249
250
  files = {param_name: file}
250
- additional_data = additional_data if additional_data else {}
251
+ additional_data = additional_data or {}
251
252
  async with http_session.post(
252
253
  url, data={**files, **additional_data}
253
254
  ) as response:
@@ -273,20 +274,20 @@ class APIUtil:
273
274
  Retrieves an OAuth token from the authentication server and caches it to avoid unnecessary requests.
274
275
 
275
276
  Args:
276
- http_session: The HTTP session object to use for making the request.
277
- auth_url: The URL of the authentication server.
278
- client_id: The client ID for OAuth authentication.
279
- client_secret: The client secret for OAuth authentication.
280
- scope: The scope for which the OAuth token is requested.
277
+ http_session: The HTTP session object to use for making the request.
278
+ auth_url: The URL of the authentication server.
279
+ client_id: The client ID for OAuth authentication.
280
+ client_secret: The client secret for OAuth authentication.
281
+ scope: The scope for which the OAuth token is requested.
281
282
 
282
283
  Returns:
283
- The OAuth token as a string.
284
+ The OAuth token as a string.
284
285
 
285
286
  Examples:
286
- >>> session = aiohttp.ClientSession()
287
- >>> token = await APIUtil.get_oauth_token_with_cache(session, 'http://auth.example.com', 'client_id', 'client_secret', 'read')
288
- >>> token
289
- 'mock_access_token'
287
+ >>> session = aiohttp.ClientSession()
288
+ >>> token = await APIUtil.get_oauth_token_with_cache(session, 'http://auth.example.com', 'client_id', 'client_secret', 'read')
289
+ >>> token
290
+ 'mock_access_token'
290
291
  """
291
292
  async with http_session.post(
292
293
  auth_url,
@@ -309,12 +310,12 @@ class APIUtil:
309
310
  Makes an API call.
310
311
 
311
312
  Args:
312
- http_session: The aiohttp client session.
313
- url: The URL for the API call.
314
- **kwargs: Additional arguments for the API call.
313
+ http_session: The aiohttp client session.
314
+ url: The URL for the API call.
315
+ **kwargs: Additional arguments for the API call.
315
316
 
316
317
  Returns:
317
- The assistant_response from the API call, if successful; otherwise, None.
318
+ The assistant_response from the API call, if successful; otherwise, None.
318
319
  """
319
320
  try:
320
321
  async with http_session.get(url, **kwargs) as response:
@@ -325,12 +326,11 @@ class APIUtil:
325
326
  return None
326
327
 
327
328
  @staticmethod
328
- # @lru_cache(maxsize=1024)
329
329
  def calculate_num_token(
330
330
  payload: Mapping[str, Any] = None,
331
331
  api_endpoint: str = None,
332
332
  token_encoding_name: str = None,
333
- ) -> int:
333
+ ) -> int: # sourcery skip: avoid-builtin-shadow
334
334
  """
335
335
  Calculates the number of tokens required for a request based on the payload and API endpoint.
336
336
 
@@ -339,20 +339,20 @@ class APIUtil:
339
339
  for the OpenAI API.
340
340
 
341
341
  Parameters:
342
- payload (Mapping[str, Any]): The payload of the request.
342
+ payload (Mapping[str, Any]): The payload of the request.
343
343
 
344
- api_endpoint (str): The specific API endpoint for the request.
344
+ api_endpoint (str): The specific API endpoint for the request.
345
345
 
346
- token_encoding_name (str): The name of the token encoding method.
346
+ token_encoding_name (str): The name of the token encoding method.
347
347
 
348
348
  Returns:
349
- int: The estimated number of tokens required for the request.
349
+ int: The estimated number of tokens required for the request.
350
350
 
351
351
  Example:
352
- >>> rate_limiter = OpenAIRateLimiter(100, 200)
353
- >>> payload = {'prompt': 'Translate the following text:', 'max_tokens': 50}
354
- >>> rate_limiter.calculate_num_token(payload, 'completions')
355
- # Expected token calculation for the given payload and endpoint.
352
+ >>> rate_limiter = OpenAIRateLimiter(100, 200)
353
+ >>> payload = {'prompt': 'Translate the following text:', 'max_tokens': 50}
354
+ >>> rate_limiter.calculate_num_token(payload, 'completions')
355
+ # Expected token calculation for the given payload and endpoint.
356
356
  """
357
357
  import tiktoken
358
358
 
@@ -371,21 +371,19 @@ class APIUtil:
371
371
  num_tokens += len(encoding.encode(value))
372
372
  if key == "name": # if there's a name, the role is omitted
373
373
  num_tokens -= (
374
- 1 # role is always required and always 1 token
374
+ 1
375
+ # role is always required and always 1 token
375
376
  )
376
377
  num_tokens += 2 # every reply is primed with <im_start>assistant
377
378
  return num_tokens + completion_tokens
378
- # normal completions
379
379
  else:
380
380
  prompt = payload["prompt"]
381
381
  if isinstance(prompt, str): # single prompt
382
382
  prompt_tokens = len(encoding.encode(prompt))
383
- num_tokens = prompt_tokens + completion_tokens
384
- return num_tokens
383
+ return prompt_tokens + completion_tokens
385
384
  elif isinstance(prompt, list): # multiple prompts
386
- prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
387
- num_tokens = prompt_tokens + completion_tokens * len(prompt)
388
- return num_tokens
385
+ prompt_tokens = sum(len(encoding.encode(p)) for p in prompt)
386
+ return prompt_tokens + completion_tokens * len(prompt)
389
387
  else:
390
388
  raise TypeError(
391
389
  'Expecting either string or list of strings for "prompt" field in completion request'
@@ -393,11 +391,9 @@ class APIUtil:
393
391
  elif api_endpoint == "embeddings":
394
392
  input = payload["input"]
395
393
  if isinstance(input, str): # single input
396
- num_tokens = len(encoding.encode(input))
397
- return num_tokens
394
+ return len(encoding.encode(input))
398
395
  elif isinstance(input, list): # multiple inputs
399
- num_tokens = sum([len(encoding.encode(i)) for i in input])
400
- return num_tokens
396
+ return sum(len(encoding.encode(i)) for i in input)
401
397
  else:
402
398
  raise TypeError(
403
399
  'Expecting either string or list of strings for "inputs" field in embedding request'
@@ -413,11 +409,11 @@ class APIUtil:
413
409
  payload = {input_key: input_}
414
410
 
415
411
  for key in required_:
416
- payload.update({key: config[key]})
412
+ payload[key] = config[key]
417
413
 
418
414
  for key in optional_:
419
- if bool(config[key]) is True and convert.strip_lower(config[key]) != "none":
420
- payload.update({key: config[key]})
415
+ if bool(config[key]) and convert.strip_lower(config[key]) != "none":
416
+ payload[key] = config[key]
421
417
 
422
418
  return payload
423
419
 
@@ -428,18 +424,18 @@ class StatusTracker:
428
424
  Keeps track of various task statuses within a system.
429
425
 
430
426
  Attributes:
431
- num_tasks_started (int): The number of tasks that have been initiated.
432
- num_tasks_in_progress (int): The number of tasks currently being processed.
433
- num_tasks_succeeded (int): The number of tasks that have completed successfully.
434
- num_tasks_failed (int): The number of tasks that have failed.
435
- num_rate_limit_errors (int): The number of tasks that failed due to rate limiting.
436
- num_api_errors (int): The number of tasks that failed due to API errors.
437
- num_other_errors (int): The number of tasks that failed due to other errors.
427
+ num_tasks_started (int): The number of tasks that have been initiated.
428
+ num_tasks_in_progress (int): The number of tasks currently being processed.
429
+ num_tasks_succeeded (int): The number of tasks that have completed successfully.
430
+ num_tasks_failed (int): The number of tasks that have failed.
431
+ num_rate_limit_errors (int): The number of tasks that failed due to rate limiting.
432
+ num_api_errors (int): The number of tasks that failed due to API errors.
433
+ num_other_errors (int): The number of tasks that failed due to other errors.
438
434
 
439
435
  Examples:
440
- >>> tracker = StatusTracker()
441
- >>> tracker.num_tasks_started += 1
442
- >>> tracker.num_tasks_succeeded += 1
436
+ >>> tracker = StatusTracker()
437
+ >>> tracker.num_tasks_started += 1
438
+ >>> tracker.num_tasks_succeeded += 1
443
439
  """
444
440
 
445
441
  num_tasks_started: int = 0
@@ -459,12 +455,12 @@ class BaseRateLimiter(ABC):
459
455
  the replenishment of request and token capacities at regular intervals.
460
456
 
461
457
  Attributes:
462
- interval: The time interval in seconds for replenishing capacities.
463
- max_requests: The maximum number of requests allowed per interval.
464
- max_tokens: The maximum number of tokens allowed per interval.
465
- available_request_capacity: The current available request capacity.
466
- available_token_capacity: The current available token capacity.
467
- rate_limit_replenisher_task: The asyncio task for replenishing capacities.
458
+ interval: The time interval in seconds for replenishing capacities.
459
+ max_requests: The maximum number of requests allowed per interval.
460
+ max_tokens: The maximum number of tokens allowed per interval.
461
+ available_request_capacity: The current available request capacity.
462
+ available_token_capacity: The current available token capacity.
463
+ rate_limit_replenisher_task: The asyncio task for replenishing capacities.
468
464
  """
469
465
 
470
466
  def __init__(
@@ -516,7 +512,8 @@ class BaseRateLimiter(ABC):
516
512
  ):
517
513
  self.available_request_capacity -= 1
518
514
  self.available_token_capacity -= (
519
- required_tokens # Assuming 1 token per request for simplicity
515
+ required_tokens
516
+ # Assuming 1 token per request for simplicity
520
517
  )
521
518
  return True
522
519
  return False
@@ -536,16 +533,16 @@ class BaseRateLimiter(ABC):
536
533
  Makes an API call to the specified endpoint using the provided HTTP session.
537
534
 
538
535
  Args:
539
- http_session: The aiohttp client session to use for the API call.
540
- endpoint: The API endpoint to call.
541
- base_url: The base URL of the API.
542
- api_key: The API key for authentication.
543
- max_attempts: The maximum number of attempts for the API call.
544
- method: The HTTP method to use for the API call.
545
- payload: The payload to send with the API call.
536
+ http_session: The aiohttp client session to use for the API call.
537
+ endpoint: The API endpoint to call.
538
+ base_url: The base URL of the API.
539
+ api_key: The API key for authentication.
540
+ max_attempts: The maximum number of attempts for the API call.
541
+ method: The HTTP method to use for the API call.
542
+ payload: The payload to send with the API call.
546
543
 
547
544
  Returns:
548
- The JSON assistant_response from the API call if successful, otherwise None.
545
+ The JSON assistant_response from the API call if successful, otherwise None.
549
546
  """
550
547
  endpoint = APIUtil.api_endpoint_from_url(base_url + endpoint)
551
548
  while True:
@@ -573,18 +570,17 @@ class BaseRateLimiter(ABC):
573
570
  ) as response:
574
571
  response_json = await response.json()
575
572
 
576
- if "error" in response_json:
577
- logging.warning(
578
- f"API call failed with error: {response_json['error']}"
579
- )
580
- attempts_left -= 1
581
-
582
- if "Rate limit" in response_json["error"].get(
583
- "message", ""
584
- ):
585
- await AsyncUtil.sleep(15)
586
- else:
573
+ if "error" not in response_json:
587
574
  return response_json
575
+ logging.warning(
576
+ f"API call failed with error: {response_json['error']}"
577
+ )
578
+ attempts_left -= 1
579
+
580
+ if "Rate limit" in response_json["error"].get(
581
+ "message", ""
582
+ ):
583
+ await AsyncUtil.sleep(15)
588
584
  except Exception as e:
589
585
  logging.warning(f"API call failed with exception: {e}")
590
586
  attempts_left -= 1
@@ -606,13 +602,13 @@ class BaseRateLimiter(ABC):
606
602
  Creates an instance of BaseRateLimiter and starts the replenisher task.
607
603
 
608
604
  Args:
609
- max_requests: The maximum number of requests allowed per interval.
610
- max_tokens: The maximum number of tokens allowed per interval.
611
- interval: The time interval in seconds for replenishing capacities.
612
- token_encoding_name: The name of the token encoding to use.
605
+ max_requests: The maximum number of requests allowed per interval.
606
+ max_tokens: The maximum number of tokens allowed per interval.
607
+ interval: The time interval in seconds for replenishing capacities.
608
+ token_encoding_name: The name of the token encoding to use.
613
609
 
614
610
  Returns:
615
- An instance of BaseRateLimiter with the replenisher task started.
611
+ An instance of BaseRateLimiter with the replenisher task started.
616
612
  """
617
613
  instance = cls(max_requests, max_tokens, interval, token_encoding_name)
618
614
  instance.rate_limit_replenisher_task = AsyncUtil.create_task(
@@ -646,25 +642,25 @@ class EndPoint:
646
642
  This class encapsulates the details of an API endpoint, including its rate limiter.
647
643
 
648
644
  Attributes:
649
- endpoint (str): The API endpoint path.
650
- rate_limiter_class (Type[li.BaseRateLimiter]): The class used for rate limiting requests to the endpoint.
651
- max_requests (int): The maximum number of requests allowed per interval.
652
- max_tokens (int): The maximum number of tokens allowed per interval.
653
- interval (int): The time interval in seconds for replenishing rate limit capacities.
654
- config (Mapping): Configuration parameters for the endpoint.
655
- rate_limiter (Optional[li.BaseRateLimiter]): The rate limiter instance for this endpoint.
645
+ endpoint (str): The API endpoint path.
646
+ rate_limiter_class (Type[li.BaseRateLimiter]): The class used for rate limiting requests to the endpoint.
647
+ max_requests (int): The maximum number of requests allowed per interval.
648
+ max_tokens (int): The maximum number of tokens allowed per interval.
649
+ interval (int): The time interval in seconds for replenishing rate limit capacities.
650
+ config (Mapping): Configuration parameters for the endpoint.
651
+ rate_limiter (Optional[li.BaseRateLimiter]): The rate limiter instance for this endpoint.
656
652
 
657
653
  Examples:
658
- # Example usage of EndPoint with SimpleRateLimiter
659
- endpoint = EndPoint(
660
- max_requests=100,
661
- max_tokens=1000,
662
- interval=60,
663
- endpoint_='chat/completions',
664
- rate_limiter_class=li.SimpleRateLimiter,
665
- config={'param1': 'value1'}
666
- )
667
- asyncio.run(endpoint.init_rate_limiter())
654
+ # Example usage of EndPoint with SimpleRateLimiter
655
+ endpoint = EndPoint(
656
+ max_requests=100,
657
+ max_tokens=1000,
658
+ interval=60,
659
+ endpoint_='chat/completions',
660
+ rate_limiter_class=li.SimpleRateLimiter,
661
+ config={'param1': 'value1'}
662
+ )
663
+ asyncio.run(endpoint.init_rate_limiter())
668
664
  """
669
665
 
670
666
  def __init__(
@@ -702,10 +698,10 @@ class BaseService:
702
698
  This class provides a foundation for services that need to make API calls with rate limiting.
703
699
 
704
700
  Attributes:
705
- api_key (Optional[str]): The API key used for authentication.
706
- schema (Mapping[str, Any]): The schema defining the service's endpoints.
707
- status_tracker (StatusTracker): The object tracking the status of API calls.
708
- endpoints (Mapping[str, EndPoint]): A dictionary of endpoint objects.
701
+ api_key (Optional[str]): The API key used for authentication.
702
+ schema (Mapping[str, Any]): The schema defining the service's endpoints.
703
+ status_tracker (StatusTracker): The object tracking the status of API calls.
704
+ endpoints (Mapping[str, EndPoint]): A dictionary of endpoint objects.
709
705
  """
710
706
 
711
707
  base_url: str = ""
@@ -739,7 +735,7 @@ class BaseService:
739
735
  Initializes the specified endpoint or all endpoints if none is specified.
740
736
 
741
737
  Args:
742
- endpoint_: The endpoint(s) to initialize. Can be a string, an EndPoint, a list of strings, or a list of EndPoints.
738
+ endpoint_: The endpoint(s) to initialize. Can be a string, an EndPoint, a list of strings, or a list of EndPoints.
743
739
  """
744
740
 
745
741
  if endpoint_:
@@ -756,45 +752,40 @@ class BaseService:
756
752
  self.schema.get(ep, {})
757
753
  if isinstance(ep, EndPoint):
758
754
  self.endpoints[ep.endpoint] = ep
755
+ elif ep == "chat/completions":
756
+ self.endpoints[ep] = EndPoint(
757
+ max_requests=self.chat_config_rate_limit.get(
758
+ "max_requests", 1000
759
+ ),
760
+ max_tokens=self.chat_config_rate_limit.get(
761
+ "max_tokens", 100000
762
+ ),
763
+ interval=self.chat_config_rate_limit.get("interval", 60),
764
+ endpoint_=ep,
765
+ token_encoding_name=self.token_encoding_name,
766
+ config=endpoint_config,
767
+ )
759
768
  else:
760
- if ep == "chat/completions":
761
- self.endpoints[ep] = EndPoint(
762
- max_requests=self.chat_config_rate_limit.get(
763
- "max_requests", 1000
764
- ),
765
- max_tokens=self.chat_config_rate_limit.get(
766
- "max_tokens", 100000
767
- ),
768
- interval=self.chat_config_rate_limit.get(
769
- "interval", 60
770
- ),
771
- endpoint_=ep,
772
- token_encoding_name=self.token_encoding_name,
773
- config=endpoint_config,
774
- )
775
- else:
776
- self.endpoints[ep] = EndPoint(
777
- max_requests=(
778
- endpoint_config.get("max_requests", 1000)
779
- if endpoint_config.get("max_requests", 1000)
780
- is not None
781
- else 1000
782
- ),
783
- max_tokens=(
784
- endpoint_config.get("max_tokens", 100000)
785
- if endpoint_config.get("max_tokens", 100000)
786
- is not None
787
- else 100000
788
- ),
789
- interval=(
790
- endpoint_config.get("interval", 60)
791
- if endpoint_config.get("interval", 60) is not None
792
- else 60
793
- ),
794
- endpoint_=ep,
795
- token_encoding_name=self.token_encoding_name,
796
- config=endpoint_config,
797
- )
769
+ self.endpoints[ep] = EndPoint(
770
+ max_requests=(
771
+ endpoint_config.get("max_requests", 1000)
772
+ if endpoint_config.get("max_requests", 1000) is not None
773
+ else 1000
774
+ ),
775
+ max_tokens=(
776
+ endpoint_config.get("max_tokens", 100000)
777
+ if endpoint_config.get("max_tokens", 100000) is not None
778
+ else 100000
779
+ ),
780
+ interval=(
781
+ endpoint_config.get("interval", 60)
782
+ if endpoint_config.get("interval", 60) is not None
783
+ else 60
784
+ ),
785
+ endpoint_=ep,
786
+ token_encoding_name=self.token_encoding_name,
787
+ config=endpoint_config,
788
+ )
798
789
 
799
790
  if not self.endpoints[ep]._has_initialized:
800
791
  await self.endpoints[ep].init_rate_limiter()
@@ -820,20 +811,20 @@ class BaseService:
820
811
  Calls the specified API endpoint with the given payload and method.
821
812
 
822
813
  Args:
823
- payload: The payload to send with the API call.
824
- endpoint: The endpoint to call.
825
- method: The HTTP method to use for the call.
814
+ payload: The payload to send with the API call.
815
+ endpoint: The endpoint to call.
816
+ method: The HTTP method to use for the call.
826
817
 
827
818
  Returns:
828
- The assistant_response from the API call.
819
+ The assistant_response from the API call.
829
820
 
830
821
  Raises:
831
- ValueError: If the endpoint has not been initialized.
822
+ ValueError: If the endpoint has not been initialized.
832
823
  """
833
824
  if endpoint not in self.endpoints.keys():
834
825
  raise ValueError(f"The endpoint {endpoint} has not initialized.")
835
826
  async with aiohttp.ClientSession() as http_session:
836
- completion = await self.endpoints[endpoint].rate_limiter._call_api(
827
+ return await self.endpoints[endpoint].rate_limiter._call_api(
837
828
  http_session=http_session,
838
829
  endpoint=endpoint,
839
830
  base_url=self.base_url,
@@ -842,7 +833,6 @@ class BaseService:
842
833
  payload=payload,
843
834
  **kwargs,
844
835
  )
845
- return completion
846
836
 
847
837
 
848
838
  class PayloadPackage:
@@ -853,13 +843,13 @@ class PayloadPackage:
853
843
  Creates a payload for the chat completion operation.
854
844
 
855
845
  Args:
856
- messages: The messages to include in the chat completion.
857
- llmconfig: Configuration for the language model.
858
- schema: The schema describing required and optional fields.
859
- **kwargs: Additional keyword arguments.
846
+ messages: The messages to include in the chat completion.
847
+ llmconfig: Configuration for the language model.
848
+ schema: The schema describing required and optional fields.
849
+ **kwargs: Additional keyword arguments.
860
850
 
861
851
  Returns:
862
- The constructed payload.
852
+ The constructed payload.
863
853
  """
864
854
  return APIUtil.create_payload(
865
855
  input_=messages,
@@ -876,13 +866,13 @@ class PayloadPackage:
876
866
  Creates a payload for the fine-tuning operation.
877
867
 
878
868
  Args:
879
- training_file: The file containing training data.
880
- llmconfig: Configuration for the language model.
881
- schema: The schema describing required and optional fields.
882
- **kwargs: Additional keyword arguments.
869
+ training_file: The file containing training data.
870
+ llmconfig: Configuration for the language model.
871
+ schema: The schema describing required and optional fields.
872
+ **kwargs: Additional keyword arguments.
883
873
 
884
874
  Returns:
885
- The constructed payload.
875
+ The constructed payload.
886
876
  """
887
877
  return APIUtil._create_payload(
888
878
  input_=training_file,