lionagi 0.0.306__py3-none-any.whl → 0.0.308__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. lionagi/__init__.py +2 -5
  2. lionagi/core/__init__.py +7 -5
  3. lionagi/core/agent/__init__.py +3 -0
  4. lionagi/core/agent/base_agent.py +10 -12
  5. lionagi/core/branch/__init__.py +4 -0
  6. lionagi/core/branch/base_branch.py +81 -81
  7. lionagi/core/branch/branch.py +16 -28
  8. lionagi/core/branch/branch_flow_mixin.py +3 -7
  9. lionagi/core/branch/executable_branch.py +86 -56
  10. lionagi/core/branch/util.py +77 -162
  11. lionagi/core/{flow/direct → direct}/__init__.py +1 -1
  12. lionagi/core/{flow/direct/predict.py → direct/parallel_predict.py} +39 -17
  13. lionagi/core/direct/parallel_react.py +0 -0
  14. lionagi/core/direct/parallel_score.py +0 -0
  15. lionagi/core/direct/parallel_select.py +0 -0
  16. lionagi/core/direct/parallel_sentiment.py +0 -0
  17. lionagi/core/direct/predict.py +174 -0
  18. lionagi/core/{flow/direct → direct}/react.py +2 -2
  19. lionagi/core/{flow/direct → direct}/score.py +28 -23
  20. lionagi/core/{flow/direct → direct}/select.py +48 -45
  21. lionagi/core/direct/utils.py +83 -0
  22. lionagi/core/flow/monoflow/ReAct.py +6 -5
  23. lionagi/core/flow/monoflow/__init__.py +9 -0
  24. lionagi/core/flow/monoflow/chat.py +10 -10
  25. lionagi/core/flow/monoflow/chat_mixin.py +11 -10
  26. lionagi/core/flow/monoflow/followup.py +6 -5
  27. lionagi/core/flow/polyflow/__init__.py +1 -0
  28. lionagi/core/flow/polyflow/chat.py +15 -3
  29. lionagi/core/mail/mail_manager.py +18 -19
  30. lionagi/core/mail/schema.py +5 -4
  31. lionagi/core/messages/schema.py +18 -20
  32. lionagi/core/prompt/__init__.py +0 -0
  33. lionagi/core/prompt/prompt_template.py +0 -0
  34. lionagi/core/schema/__init__.py +2 -2
  35. lionagi/core/schema/action_node.py +11 -3
  36. lionagi/core/schema/base_mixin.py +56 -59
  37. lionagi/core/schema/base_node.py +34 -37
  38. lionagi/core/schema/condition.py +24 -0
  39. lionagi/core/schema/data_logger.py +96 -99
  40. lionagi/core/schema/data_node.py +19 -19
  41. lionagi/core/schema/prompt_template.py +0 -0
  42. lionagi/core/schema/structure.py +171 -169
  43. lionagi/core/session/__init__.py +1 -3
  44. lionagi/core/session/session.py +196 -214
  45. lionagi/core/tool/tool_manager.py +95 -103
  46. lionagi/integrations/__init__.py +1 -3
  47. lionagi/integrations/bridge/langchain_/documents.py +17 -18
  48. lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
  49. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
  50. lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
  51. lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
  52. lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
  53. lionagi/integrations/config/openrouter_configs.py +0 -1
  54. lionagi/integrations/provider/oai.py +26 -26
  55. lionagi/integrations/provider/services.py +38 -38
  56. lionagi/libs/__init__.py +34 -1
  57. lionagi/libs/ln_api.py +211 -221
  58. lionagi/libs/ln_async.py +53 -60
  59. lionagi/libs/ln_convert.py +118 -120
  60. lionagi/libs/ln_dataframe.py +32 -33
  61. lionagi/libs/ln_func_call.py +334 -342
  62. lionagi/libs/ln_nested.py +99 -107
  63. lionagi/libs/ln_parse.py +161 -165
  64. lionagi/libs/sys_util.py +52 -52
  65. lionagi/tests/test_core/test_session.py +254 -266
  66. lionagi/tests/test_core/test_session_base_util.py +299 -300
  67. lionagi/tests/test_core/test_tool_manager.py +70 -74
  68. lionagi/tests/test_libs/test_nested.py +2 -7
  69. lionagi/tests/test_libs/test_parse.py +2 -2
  70. lionagi/version.py +1 -1
  71. {lionagi-0.0.306.dist-info → lionagi-0.0.308.dist-info}/METADATA +4 -2
  72. lionagi-0.0.308.dist-info/RECORD +115 -0
  73. lionagi/core/flow/direct/utils.py +0 -43
  74. lionagi-0.0.306.dist-info/RECORD +0 -106
  75. /lionagi/core/{flow/direct → direct}/sentiment.py +0 -0
  76. {lionagi-0.0.306.dist-info → lionagi-0.0.308.dist-info}/LICENSE +0 -0
  77. {lionagi-0.0.306.dist-info → lionagi-0.0.308.dist-info}/WHEEL +0 -0
  78. {lionagi-0.0.306.dist-info → lionagi-0.0.308.dist-info}/top_level.txt +0 -0
lionagi/libs/ln_api.py CHANGED
@@ -29,26 +29,27 @@ class APIUtil:
29
29
  Returns the corresponding HTTP method function from the http_session object.
30
30
 
31
31
  Args:
32
- http_session: The session object from the aiohttp library.
33
- method: The HTTP method as a string.
32
+ http_session: The session object from the aiohttp library.
33
+ method: The HTTP method as a string.
34
34
 
35
35
  Returns:
36
- The Callable for the specified HTTP method.
36
+ The Callable for the specified HTTP method.
37
37
 
38
38
  Raises:
39
- ValueError: If the method is not one of the allowed ones.
39
+ ValueError: If the method is not one of the allowed ones.
40
40
 
41
41
  Examples:
42
- >>> session = aiohttp.ClientSession()
43
- >>> post_method = APIUtil.api_method(session, "post")
44
- >>> print(post_method)
45
- <bound method ClientSession._request of <aiohttp.client.ClientSession object at 0x...>>
42
+ >>> session = aiohttp.ClientSession()
43
+ >>> post_method = APIUtil.api_method(session, "post")
44
+ >>> print(post_method)
45
+ <bound method ClientSession._request of <aiohttp.client.ClientSession object at 0x...>>
46
46
  """
47
- if method not in ["post", "delete", "head", "options", "patch"]:
47
+ if method in {"post", "delete", "head", "options", "patch"}:
48
+ return getattr(http_session, method)
49
+ else:
48
50
  raise ValueError(
49
51
  "Invalid request, method must be in ['post', 'delete', 'head', 'options', 'patch']"
50
52
  )
51
- return getattr(http_session, method)
52
53
 
53
54
  @staticmethod
54
55
  def api_error(response_json: Mapping[str, Any]) -> bool:
@@ -56,18 +57,18 @@ class APIUtil:
56
57
  Checks if the given response_json dictionary contains an "error" key.
57
58
 
58
59
  Args:
59
- response_json: The JSON assistant_response as a dictionary.
60
+ response_json: The JSON assistant_response as a dictionary.
60
61
 
61
62
  Returns:
62
- True if there is an error, False otherwise.
63
+ True if there is an error, False otherwise.
63
64
 
64
65
  Examples:
65
- >>> response_json_with_error = {"error": "Something went wrong"}
66
- >>> APIUtil.api_error(response_json_with_error)
67
- True
68
- >>> response_json_without_error = {"result": "Success"}
69
- >>> APIUtil.api_error(response_json_without_error)
70
- False
66
+ >>> response_json_with_error = {"error": "Something went wrong"}
67
+ >>> APIUtil.api_error(response_json_with_error)
68
+ True
69
+ >>> response_json_without_error = {"result": "Success"}
70
+ >>> APIUtil.api_error(response_json_without_error)
71
+ False
71
72
  """
72
73
  if "error" in response_json:
73
74
  logging.warning(f"API call failed with error: {response_json['error']}")
@@ -80,18 +81,18 @@ class APIUtil:
80
81
  Checks if the error message in the response_json dictionary contains the phrase "Rate limit".
81
82
 
82
83
  Args:
83
- response_json: The JSON assistant_response as a dictionary.
84
+ response_json: The JSON assistant_response as a dictionary.
84
85
 
85
86
  Returns:
86
- True if the phrase "Rate limit" is found, False otherwise.
87
+ True if the phrase "Rate limit" is found, False otherwise.
87
88
 
88
89
  Examples:
89
- >>> response_json_with_rate_limit = {"error": {"message": "Rate limit exceeded"}}
90
- >>> api_rate_limit_error(response_json_with_rate_limit)
91
- True
92
- >>> response_json_without_rate_limit = {"error": {"message": "Another error"}}
93
- >>> api_rate_limit_error(response_json_without_rate_limit)
94
- False
90
+ >>> response_json_with_rate_limit = {"error": {"message": "Rate limit exceeded"}}
91
+ >>> api_rate_limit_error(response_json_with_rate_limit)
92
+ True
93
+ >>> response_json_without_rate_limit = {"error": {"message": "Another error"}}
94
+ >>> api_rate_limit_error(response_json_without_rate_limit)
95
+ False
95
96
  """
96
97
  return "Rate limit" in response_json.get("error", {}).get("message", "")
97
98
 
@@ -102,21 +103,21 @@ class APIUtil:
102
103
  Extracts the API endpoint from a given URL using a regular expression.
103
104
 
104
105
  Args:
105
- request_url: The full URL to the API endpoint.
106
+ request_url: The full URL to the API endpoint.
106
107
 
107
108
  Returns:
108
- The extracted endpoint or an empty string if the pattern does not match.
109
+ The extracted endpoint or an empty string if the pattern does not match.
109
110
 
110
111
  Examples:
111
- >>> valid_url = "https://api.example.com/v1/users"
112
- >>> api_endpoint_from_url(valid_url)
113
- 'users'
114
- >>> invalid_url = "https://api.example.com/users"
115
- >>> api_endpoint_from_url(invalid_url)
116
- ''
112
+ >>> valid_url = "https://api.example.com/v1/users"
113
+ >>> api_endpoint_from_url(valid_url)
114
+ 'users'
115
+ >>> invalid_url = "https://api.example.com/users"
116
+ >>> api_endpoint_from_url(invalid_url)
117
+ ''
117
118
  """
118
119
  match = re.search(r"^https://[^/]+(/.+)?/v\d+/(.+)$", request_url)
119
- return match.group(2) if match else ""
120
+ return match[2] if match else ""
120
121
 
121
122
  @staticmethod
122
123
  async def unified_api_call(
@@ -126,22 +127,22 @@ class APIUtil:
126
127
  Makes an API call and automatically retries on rate limit error.
127
128
 
128
129
  Args:
129
- http_session: The session object from the aiohttp library.
130
- method: The HTTP method as a string.
131
- url: The URL to which the request is made.
132
- **kwargs: Additional keyword arguments to pass to the API call.
130
+ http_session: The session object from the aiohttp library.
131
+ method: The HTTP method as a string.
132
+ url: The URL to which the request is made.
133
+ **kwargs: Additional keyword arguments to pass to the API call.
133
134
 
134
135
  Returns:
135
- The JSON assistant_response as a dictionary.
136
+ The JSON assistant_response as a dictionary.
136
137
 
137
138
  Examples:
138
- >>> session = aiohttp.ClientSession()
139
- >>> success_url = "https://api.example.com/v1/success"
140
- >>> print(await unified_api_call(session, 'get', success_url))
141
- {'result': 'Success'}
142
- >>> rate_limit_url = "https://api.example.com/v1/rate_limit"
143
- >>> print(await unified_api_call(session, 'get', rate_limit_url))
144
- {'error': {'message': 'Rate limit exceeded'}}
139
+ >>> session = aiohttp.ClientSession()
140
+ >>> success_url = "https://api.example.com/v1/success"
141
+ >>> print(await unified_api_call(session, 'get', success_url))
142
+ {'result': 'Success'}
143
+ >>> rate_limit_url = "https://api.example.com/v1/rate_limit"
144
+ >>> print(await unified_api_call(session, 'get', rate_limit_url))
145
+ {'error': {'message': 'Rate limit exceeded'}}
145
146
  """
146
147
  api_call = APIUtil.api_method(http_session, method)
147
148
  retry_count = 3
@@ -189,14 +190,14 @@ class APIUtil:
189
190
  Retries an API call on failure, with exponential backoff.
190
191
 
191
192
  Args:
192
- http_session: The aiohttp client session.
193
- url: The URL to make the API call.
194
- retries: The number of times to retry.
195
- backoff_factor: The backoff factor for retries.
196
- **kwargs: Additional arguments for the API call.
193
+ http_session: The aiohttp client session.
194
+ url: The URL to make the API call.
195
+ retries: The number of times to retry.
196
+ backoff_factor: The backoff factor for retries.
197
+ **kwargs: Additional arguments for the API call.
197
198
 
198
199
  Returns:
199
- The assistant_response from the API call, if successful; otherwise, None.
200
+ The assistant_response from the API call, if successful; otherwise, None.
200
201
  """
201
202
  for attempt in range(retries):
202
203
  try:
@@ -227,27 +228,27 @@ class APIUtil:
227
228
  Uploads a file to a specified URL with a retry mechanism for handling failures.
228
229
 
229
230
  Args:
230
- http_session: The HTTP session object to use for making the request.
231
- url: The URL to which the file will be uploaded.
232
- file_path: The path to the file that will be uploaded.
233
- param_name: The name of the parameter expected by the server for the file upload.
234
- additional_data: Additional data to be sent with the upload.
235
- retries: The number of times to retry the upload in case of failure.
231
+ http_session: The HTTP session object to use for making the request.
232
+ url: The URL to which the file will be uploaded.
233
+ file_path: The path to the file that will be uploaded.
234
+ param_name: The name of the parameter expected by the server for the file upload.
235
+ additional_data: Additional data to be sent with the upload.
236
+ retries: The number of times to retry the upload in case of failure.
236
237
 
237
238
  Returns:
238
- The HTTP assistant_response object.
239
+ The HTTP assistant_response object.
239
240
 
240
241
  Examples:
241
- >>> session = aiohttp.ClientSession()
242
- >>> assistant_response = await APIUtil.upload_file_with_retry(session, 'http://example.com/upload', 'path/to/file.txt')
243
- >>> assistant_response.status
244
- 200
242
+ >>> session = aiohttp.ClientSession()
243
+ >>> assistant_response = await APIUtil.upload_file_with_retry(session, 'http://example.com/upload', 'path/to/file.txt')
244
+ >>> assistant_response.status
245
+ 200
245
246
  """
246
247
  for attempt in range(retries):
247
248
  try:
248
249
  with open(file_path, "rb") as file:
249
250
  files = {param_name: file}
250
- additional_data = additional_data if additional_data else {}
251
+ additional_data = additional_data or {}
251
252
  async with http_session.post(
252
253
  url, data={**files, **additional_data}
253
254
  ) as response:
@@ -273,20 +274,20 @@ class APIUtil:
273
274
  Retrieves an OAuth token from the authentication server and caches it to avoid unnecessary requests.
274
275
 
275
276
  Args:
276
- http_session: The HTTP session object to use for making the request.
277
- auth_url: The URL of the authentication server.
278
- client_id: The client ID for OAuth authentication.
279
- client_secret: The client secret for OAuth authentication.
280
- scope: The scope for which the OAuth token is requested.
277
+ http_session: The HTTP session object to use for making the request.
278
+ auth_url: The URL of the authentication server.
279
+ client_id: The client ID for OAuth authentication.
280
+ client_secret: The client secret for OAuth authentication.
281
+ scope: The scope for which the OAuth token is requested.
281
282
 
282
283
  Returns:
283
- The OAuth token as a string.
284
+ The OAuth token as a string.
284
285
 
285
286
  Examples:
286
- >>> session = aiohttp.ClientSession()
287
- >>> token = await APIUtil.get_oauth_token_with_cache(session, 'http://auth.example.com', 'client_id', 'client_secret', 'read')
288
- >>> token
289
- 'mock_access_token'
287
+ >>> session = aiohttp.ClientSession()
288
+ >>> token = await APIUtil.get_oauth_token_with_cache(session, 'http://auth.example.com', 'client_id', 'client_secret', 'read')
289
+ >>> token
290
+ 'mock_access_token'
290
291
  """
291
292
  async with http_session.post(
292
293
  auth_url,
@@ -309,12 +310,12 @@ class APIUtil:
309
310
  Makes an API call.
310
311
 
311
312
  Args:
312
- http_session: The aiohttp client session.
313
- url: The URL for the API call.
314
- **kwargs: Additional arguments for the API call.
313
+ http_session: The aiohttp client session.
314
+ url: The URL for the API call.
315
+ **kwargs: Additional arguments for the API call.
315
316
 
316
317
  Returns:
317
- The assistant_response from the API call, if successful; otherwise, None.
318
+ The assistant_response from the API call, if successful; otherwise, None.
318
319
  """
319
320
  try:
320
321
  async with http_session.get(url, **kwargs) as response:
@@ -325,12 +326,11 @@ class APIUtil:
325
326
  return None
326
327
 
327
328
  @staticmethod
328
- # @lru_cache(maxsize=1024)
329
329
  def calculate_num_token(
330
330
  payload: Mapping[str, Any] = None,
331
331
  api_endpoint: str = None,
332
332
  token_encoding_name: str = None,
333
- ) -> int:
333
+ ) -> int: # sourcery skip: avoid-builtin-shadow
334
334
  """
335
335
  Calculates the number of tokens required for a request based on the payload and API endpoint.
336
336
 
@@ -339,20 +339,20 @@ class APIUtil:
339
339
  for the OpenAI API.
340
340
 
341
341
  Parameters:
342
- payload (Mapping[str, Any]): The payload of the request.
342
+ payload (Mapping[str, Any]): The payload of the request.
343
343
 
344
- api_endpoint (str): The specific API endpoint for the request.
344
+ api_endpoint (str): The specific API endpoint for the request.
345
345
 
346
- token_encoding_name (str): The name of the token encoding method.
346
+ token_encoding_name (str): The name of the token encoding method.
347
347
 
348
348
  Returns:
349
- int: The estimated number of tokens required for the request.
349
+ int: The estimated number of tokens required for the request.
350
350
 
351
351
  Example:
352
- >>> rate_limiter = OpenAIRateLimiter(100, 200)
353
- >>> payload = {'prompt': 'Translate the following text:', 'max_tokens': 50}
354
- >>> rate_limiter.calculate_num_token(payload, 'completions')
355
- # Expected token calculation for the given payload and endpoint.
352
+ >>> rate_limiter = OpenAIRateLimiter(100, 200)
353
+ >>> payload = {'prompt': 'Translate the following text:', 'max_tokens': 50}
354
+ >>> rate_limiter.calculate_num_token(payload, 'completions')
355
+ # Expected token calculation for the given payload and endpoint.
356
356
  """
357
357
  import tiktoken
358
358
 
@@ -371,21 +371,19 @@ class APIUtil:
371
371
  num_tokens += len(encoding.encode(value))
372
372
  if key == "name": # if there's a name, the role is omitted
373
373
  num_tokens -= (
374
- 1 # role is always required and always 1 token
374
+ 1
375
+ # role is always required and always 1 token
375
376
  )
376
377
  num_tokens += 2 # every reply is primed with <im_start>assistant
377
378
  return num_tokens + completion_tokens
378
- # normal completions
379
379
  else:
380
380
  prompt = payload["prompt"]
381
381
  if isinstance(prompt, str): # single prompt
382
382
  prompt_tokens = len(encoding.encode(prompt))
383
- num_tokens = prompt_tokens + completion_tokens
384
- return num_tokens
383
+ return prompt_tokens + completion_tokens
385
384
  elif isinstance(prompt, list): # multiple prompts
386
- prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
387
- num_tokens = prompt_tokens + completion_tokens * len(prompt)
388
- return num_tokens
385
+ prompt_tokens = sum(len(encoding.encode(p)) for p in prompt)
386
+ return prompt_tokens + completion_tokens * len(prompt)
389
387
  else:
390
388
  raise TypeError(
391
389
  'Expecting either string or list of strings for "prompt" field in completion request'
@@ -393,11 +391,9 @@ class APIUtil:
393
391
  elif api_endpoint == "embeddings":
394
392
  input = payload["input"]
395
393
  if isinstance(input, str): # single input
396
- num_tokens = len(encoding.encode(input))
397
- return num_tokens
394
+ return len(encoding.encode(input))
398
395
  elif isinstance(input, list): # multiple inputs
399
- num_tokens = sum([len(encoding.encode(i)) for i in input])
400
- return num_tokens
396
+ return sum(len(encoding.encode(i)) for i in input)
401
397
  else:
402
398
  raise TypeError(
403
399
  'Expecting either string or list of strings for "inputs" field in embedding request'
@@ -413,11 +409,11 @@ class APIUtil:
413
409
  payload = {input_key: input_}
414
410
 
415
411
  for key in required_:
416
- payload.update({key: config[key]})
412
+ payload[key] = config[key]
417
413
 
418
414
  for key in optional_:
419
- if bool(config[key]) is True and convert.strip_lower(config[key]) != "none":
420
- payload.update({key: config[key]})
415
+ if bool(config[key]) and convert.strip_lower(config[key]) != "none":
416
+ payload[key] = config[key]
421
417
 
422
418
  return payload
423
419
 
@@ -428,18 +424,18 @@ class StatusTracker:
428
424
  Keeps track of various task statuses within a system.
429
425
 
430
426
  Attributes:
431
- num_tasks_started (int): The number of tasks that have been initiated.
432
- num_tasks_in_progress (int): The number of tasks currently being processed.
433
- num_tasks_succeeded (int): The number of tasks that have completed successfully.
434
- num_tasks_failed (int): The number of tasks that have failed.
435
- num_rate_limit_errors (int): The number of tasks that failed due to rate limiting.
436
- num_api_errors (int): The number of tasks that failed due to API errors.
437
- num_other_errors (int): The number of tasks that failed due to other errors.
427
+ num_tasks_started (int): The number of tasks that have been initiated.
428
+ num_tasks_in_progress (int): The number of tasks currently being processed.
429
+ num_tasks_succeeded (int): The number of tasks that have completed successfully.
430
+ num_tasks_failed (int): The number of tasks that have failed.
431
+ num_rate_limit_errors (int): The number of tasks that failed due to rate limiting.
432
+ num_api_errors (int): The number of tasks that failed due to API errors.
433
+ num_other_errors (int): The number of tasks that failed due to other errors.
438
434
 
439
435
  Examples:
440
- >>> tracker = StatusTracker()
441
- >>> tracker.num_tasks_started += 1
442
- >>> tracker.num_tasks_succeeded += 1
436
+ >>> tracker = StatusTracker()
437
+ >>> tracker.num_tasks_started += 1
438
+ >>> tracker.num_tasks_succeeded += 1
443
439
  """
444
440
 
445
441
  num_tasks_started: int = 0
@@ -459,12 +455,12 @@ class BaseRateLimiter(ABC):
459
455
  the replenishment of request and token capacities at regular intervals.
460
456
 
461
457
  Attributes:
462
- interval: The time interval in seconds for replenishing capacities.
463
- max_requests: The maximum number of requests allowed per interval.
464
- max_tokens: The maximum number of tokens allowed per interval.
465
- available_request_capacity: The current available request capacity.
466
- available_token_capacity: The current available token capacity.
467
- rate_limit_replenisher_task: The asyncio task for replenishing capacities.
458
+ interval: The time interval in seconds for replenishing capacities.
459
+ max_requests: The maximum number of requests allowed per interval.
460
+ max_tokens: The maximum number of tokens allowed per interval.
461
+ available_request_capacity: The current available request capacity.
462
+ available_token_capacity: The current available token capacity.
463
+ rate_limit_replenisher_task: The asyncio task for replenishing capacities.
468
464
  """
469
465
 
470
466
  def __init__(
@@ -516,7 +512,8 @@ class BaseRateLimiter(ABC):
516
512
  ):
517
513
  self.available_request_capacity -= 1
518
514
  self.available_token_capacity -= (
519
- required_tokens # Assuming 1 token per request for simplicity
515
+ required_tokens
516
+ # Assuming 1 token per request for simplicity
520
517
  )
521
518
  return True
522
519
  return False
@@ -536,16 +533,16 @@ class BaseRateLimiter(ABC):
536
533
  Makes an API call to the specified endpoint using the provided HTTP session.
537
534
 
538
535
  Args:
539
- http_session: The aiohttp client session to use for the API call.
540
- endpoint: The API endpoint to call.
541
- base_url: The base URL of the API.
542
- api_key: The API key for authentication.
543
- max_attempts: The maximum number of attempts for the API call.
544
- method: The HTTP method to use for the API call.
545
- payload: The payload to send with the API call.
536
+ http_session: The aiohttp client session to use for the API call.
537
+ endpoint: The API endpoint to call.
538
+ base_url: The base URL of the API.
539
+ api_key: The API key for authentication.
540
+ max_attempts: The maximum number of attempts for the API call.
541
+ method: The HTTP method to use for the API call.
542
+ payload: The payload to send with the API call.
546
543
 
547
544
  Returns:
548
- The JSON assistant_response from the API call if successful, otherwise None.
545
+ The JSON assistant_response from the API call if successful, otherwise None.
549
546
  """
550
547
  endpoint = APIUtil.api_endpoint_from_url(base_url + endpoint)
551
548
  while True:
@@ -573,18 +570,17 @@ class BaseRateLimiter(ABC):
573
570
  ) as response:
574
571
  response_json = await response.json()
575
572
 
576
- if "error" in response_json:
577
- logging.warning(
578
- f"API call failed with error: {response_json['error']}"
579
- )
580
- attempts_left -= 1
581
-
582
- if "Rate limit" in response_json["error"].get(
583
- "message", ""
584
- ):
585
- await AsyncUtil.sleep(15)
586
- else:
573
+ if "error" not in response_json:
587
574
  return response_json
575
+ logging.warning(
576
+ f"API call failed with error: {response_json['error']}"
577
+ )
578
+ attempts_left -= 1
579
+
580
+ if "Rate limit" in response_json["error"].get(
581
+ "message", ""
582
+ ):
583
+ await AsyncUtil.sleep(15)
588
584
  except Exception as e:
589
585
  logging.warning(f"API call failed with exception: {e}")
590
586
  attempts_left -= 1
@@ -606,13 +602,13 @@ class BaseRateLimiter(ABC):
606
602
  Creates an instance of BaseRateLimiter and starts the replenisher task.
607
603
 
608
604
  Args:
609
- max_requests: The maximum number of requests allowed per interval.
610
- max_tokens: The maximum number of tokens allowed per interval.
611
- interval: The time interval in seconds for replenishing capacities.
612
- token_encoding_name: The name of the token encoding to use.
605
+ max_requests: The maximum number of requests allowed per interval.
606
+ max_tokens: The maximum number of tokens allowed per interval.
607
+ interval: The time interval in seconds for replenishing capacities.
608
+ token_encoding_name: The name of the token encoding to use.
613
609
 
614
610
  Returns:
615
- An instance of BaseRateLimiter with the replenisher task started.
611
+ An instance of BaseRateLimiter with the replenisher task started.
616
612
  """
617
613
  instance = cls(max_requests, max_tokens, interval, token_encoding_name)
618
614
  instance.rate_limit_replenisher_task = AsyncUtil.create_task(
@@ -646,25 +642,25 @@ class EndPoint:
646
642
  This class encapsulates the details of an API endpoint, including its rate limiter.
647
643
 
648
644
  Attributes:
649
- endpoint (str): The API endpoint path.
650
- rate_limiter_class (Type[li.BaseRateLimiter]): The class used for rate limiting requests to the endpoint.
651
- max_requests (int): The maximum number of requests allowed per interval.
652
- max_tokens (int): The maximum number of tokens allowed per interval.
653
- interval (int): The time interval in seconds for replenishing rate limit capacities.
654
- config (Mapping): Configuration parameters for the endpoint.
655
- rate_limiter (Optional[li.BaseRateLimiter]): The rate limiter instance for this endpoint.
645
+ endpoint (str): The API endpoint path.
646
+ rate_limiter_class (Type[li.BaseRateLimiter]): The class used for rate limiting requests to the endpoint.
647
+ max_requests (int): The maximum number of requests allowed per interval.
648
+ max_tokens (int): The maximum number of tokens allowed per interval.
649
+ interval (int): The time interval in seconds for replenishing rate limit capacities.
650
+ config (Mapping): Configuration parameters for the endpoint.
651
+ rate_limiter (Optional[li.BaseRateLimiter]): The rate limiter instance for this endpoint.
656
652
 
657
653
  Examples:
658
- # Example usage of EndPoint with SimpleRateLimiter
659
- endpoint = EndPoint(
660
- max_requests=100,
661
- max_tokens=1000,
662
- interval=60,
663
- endpoint_='chat/completions',
664
- rate_limiter_class=li.SimpleRateLimiter,
665
- config={'param1': 'value1'}
666
- )
667
- asyncio.run(endpoint.init_rate_limiter())
654
+ # Example usage of EndPoint with SimpleRateLimiter
655
+ endpoint = EndPoint(
656
+ max_requests=100,
657
+ max_tokens=1000,
658
+ interval=60,
659
+ endpoint_='chat/completions',
660
+ rate_limiter_class=li.SimpleRateLimiter,
661
+ config={'param1': 'value1'}
662
+ )
663
+ asyncio.run(endpoint.init_rate_limiter())
668
664
  """
669
665
 
670
666
  def __init__(
@@ -702,10 +698,10 @@ class BaseService:
702
698
  This class provides a foundation for services that need to make API calls with rate limiting.
703
699
 
704
700
  Attributes:
705
- api_key (Optional[str]): The API key used for authentication.
706
- schema (Mapping[str, Any]): The schema defining the service's endpoints.
707
- status_tracker (StatusTracker): The object tracking the status of API calls.
708
- endpoints (Mapping[str, EndPoint]): A dictionary of endpoint objects.
701
+ api_key (Optional[str]): The API key used for authentication.
702
+ schema (Mapping[str, Any]): The schema defining the service's endpoints.
703
+ status_tracker (StatusTracker): The object tracking the status of API calls.
704
+ endpoints (Mapping[str, EndPoint]): A dictionary of endpoint objects.
709
705
  """
710
706
 
711
707
  base_url: str = ""
@@ -739,7 +735,7 @@ class BaseService:
739
735
  Initializes the specified endpoint or all endpoints if none is specified.
740
736
 
741
737
  Args:
742
- endpoint_: The endpoint(s) to initialize. Can be a string, an EndPoint, a list of strings, or a list of EndPoints.
738
+ endpoint_: The endpoint(s) to initialize. Can be a string, an EndPoint, a list of strings, or a list of EndPoints.
743
739
  """
744
740
 
745
741
  if endpoint_:
@@ -756,45 +752,40 @@ class BaseService:
756
752
  self.schema.get(ep, {})
757
753
  if isinstance(ep, EndPoint):
758
754
  self.endpoints[ep.endpoint] = ep
755
+ elif ep == "chat/completions":
756
+ self.endpoints[ep] = EndPoint(
757
+ max_requests=self.chat_config_rate_limit.get(
758
+ "max_requests", 1000
759
+ ),
760
+ max_tokens=self.chat_config_rate_limit.get(
761
+ "max_tokens", 100000
762
+ ),
763
+ interval=self.chat_config_rate_limit.get("interval", 60),
764
+ endpoint_=ep,
765
+ token_encoding_name=self.token_encoding_name,
766
+ config=endpoint_config,
767
+ )
759
768
  else:
760
- if ep == "chat/completions":
761
- self.endpoints[ep] = EndPoint(
762
- max_requests=self.chat_config_rate_limit.get(
763
- "max_requests", 1000
764
- ),
765
- max_tokens=self.chat_config_rate_limit.get(
766
- "max_tokens", 100000
767
- ),
768
- interval=self.chat_config_rate_limit.get(
769
- "interval", 60
770
- ),
771
- endpoint_=ep,
772
- token_encoding_name=self.token_encoding_name,
773
- config=endpoint_config,
774
- )
775
- else:
776
- self.endpoints[ep] = EndPoint(
777
- max_requests=(
778
- endpoint_config.get("max_requests", 1000)
779
- if endpoint_config.get("max_requests", 1000)
780
- is not None
781
- else 1000
782
- ),
783
- max_tokens=(
784
- endpoint_config.get("max_tokens", 100000)
785
- if endpoint_config.get("max_tokens", 100000)
786
- is not None
787
- else 100000
788
- ),
789
- interval=(
790
- endpoint_config.get("interval", 60)
791
- if endpoint_config.get("interval", 60) is not None
792
- else 60
793
- ),
794
- endpoint_=ep,
795
- token_encoding_name=self.token_encoding_name,
796
- config=endpoint_config,
797
- )
769
+ self.endpoints[ep] = EndPoint(
770
+ max_requests=(
771
+ endpoint_config.get("max_requests", 1000)
772
+ if endpoint_config.get("max_requests", 1000) is not None
773
+ else 1000
774
+ ),
775
+ max_tokens=(
776
+ endpoint_config.get("max_tokens", 100000)
777
+ if endpoint_config.get("max_tokens", 100000) is not None
778
+ else 100000
779
+ ),
780
+ interval=(
781
+ endpoint_config.get("interval", 60)
782
+ if endpoint_config.get("interval", 60) is not None
783
+ else 60
784
+ ),
785
+ endpoint_=ep,
786
+ token_encoding_name=self.token_encoding_name,
787
+ config=endpoint_config,
788
+ )
798
789
 
799
790
  if not self.endpoints[ep]._has_initialized:
800
791
  await self.endpoints[ep].init_rate_limiter()
@@ -820,20 +811,20 @@ class BaseService:
820
811
  Calls the specified API endpoint with the given payload and method.
821
812
 
822
813
  Args:
823
- payload: The payload to send with the API call.
824
- endpoint: The endpoint to call.
825
- method: The HTTP method to use for the call.
814
+ payload: The payload to send with the API call.
815
+ endpoint: The endpoint to call.
816
+ method: The HTTP method to use for the call.
826
817
 
827
818
  Returns:
828
- The assistant_response from the API call.
819
+ The assistant_response from the API call.
829
820
 
830
821
  Raises:
831
- ValueError: If the endpoint has not been initialized.
822
+ ValueError: If the endpoint has not been initialized.
832
823
  """
833
824
  if endpoint not in self.endpoints.keys():
834
825
  raise ValueError(f"The endpoint {endpoint} has not initialized.")
835
826
  async with aiohttp.ClientSession() as http_session:
836
- completion = await self.endpoints[endpoint].rate_limiter._call_api(
827
+ return await self.endpoints[endpoint].rate_limiter._call_api(
837
828
  http_session=http_session,
838
829
  endpoint=endpoint,
839
830
  base_url=self.base_url,
@@ -842,7 +833,6 @@ class BaseService:
842
833
  payload=payload,
843
834
  **kwargs,
844
835
  )
845
- return completion
846
836
 
847
837
 
848
838
  class PayloadPackage:
@@ -853,13 +843,13 @@ class PayloadPackage:
853
843
  Creates a payload for the chat completion operation.
854
844
 
855
845
  Args:
856
- messages: The messages to include in the chat completion.
857
- llmconfig: Configuration for the language model.
858
- schema: The schema describing required and optional fields.
859
- **kwargs: Additional keyword arguments.
846
+ messages: The messages to include in the chat completion.
847
+ llmconfig: Configuration for the language model.
848
+ schema: The schema describing required and optional fields.
849
+ **kwargs: Additional keyword arguments.
860
850
 
861
851
  Returns:
862
- The constructed payload.
852
+ The constructed payload.
863
853
  """
864
854
  return APIUtil.create_payload(
865
855
  input_=messages,
@@ -876,13 +866,13 @@ class PayloadPackage:
876
866
  Creates a payload for the fine-tuning operation.
877
867
 
878
868
  Args:
879
- training_file: The file containing training data.
880
- llmconfig: Configuration for the language model.
881
- schema: The schema describing required and optional fields.
882
- **kwargs: Additional keyword arguments.
869
+ training_file: The file containing training data.
870
+ llmconfig: Configuration for the language model.
871
+ schema: The schema describing required and optional fields.
872
+ **kwargs: Additional keyword arguments.
883
873
 
884
874
  Returns:
885
- The constructed payload.
875
+ The constructed payload.
886
876
  """
887
877
  return APIUtil._create_payload(
888
878
  input_=training_file,