ASUllmAPI 2.0.2__tar.gz → 2.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
- from typing import Dict
2
+ from typing import Dict, Union
3
3
 
4
4
  from tqdm.auto import tqdm
5
5
 
@@ -9,9 +9,13 @@ from .utils import time_api
9
9
 
10
10
 
11
11
  @time_api
12
- def batch_query_llm(model: ModelConfig, queries: Dict[int, str], max_threads: int,
13
- num_retry: int = 3, auto_increase_retry: bool = False,
14
- success_sleep: float = 0.0, fail_sleep: float = 1.0) -> Dict[int, dict]:
12
+ def batch_query_llm(model: ModelConfig,
13
+ queries: Dict[Union[str, int], str],
14
+ max_threads: int,
15
+ num_retry: int = 3,
16
+ auto_increase_retry: bool = False,
17
+ success_sleep: float = 0.0,
18
+ fail_sleep: float = 1.0) -> Dict[Union[str, int], dict]:
15
19
  with ThreadPoolExecutor(max_workers=max_threads) as executor:
16
20
  # Submit tasks to the executor - order of return will be asynchronous.
17
21
  # If `auto_increase_retry` enabled, then scaling API backoff is implemented.
@@ -0,0 +1,84 @@
1
+ from typing import Dict, List
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import time
4
+ import json
5
+ import asyncio
6
+
7
+
8
+ def begin_task_execution(async_func):
9
+ """
10
+ See https://stackoverflow.com/a/75341431 for underlying rationale.
11
+ This code should allow any async function to be called synchronously by circumventing
12
+ Jupyter's existing event loop via the creation of a separate thread.
13
+ If an event loop doesn't exist, it reverts back to existing asyncio logic.
14
+ """
15
+ def wrap(*args, **kwargs):
16
+ try:
17
+ asyncio.get_running_loop() # Triggers RuntimeError if no running event loop
18
+ # Create a separate thread so we can block before returning
19
+ with ThreadPoolExecutor(1) as pool:
20
+ result = pool.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
21
+ except RuntimeError:
22
+ result = asyncio.run(async_func(*args, **kwargs))
23
+ return result
24
+ return wrap
25
+
26
+
27
+ def load_json_buffer(string):
28
+ try:
29
+ return json.loads(string)
30
+ except json.JSONDecodeError:
31
+ return None
32
+
33
+
34
+ def split_dict_into_chunks(input_dict: Dict, n: int) -> List[Dict]:
35
+ """
36
+ Split a dictionary into `n` chunks.
37
+
38
+ Parameters:
39
+ input_dict (dict): The dictionary to split.
40
+ n (int): The number of chunks to split the dictionary into.
41
+
42
+ Returns:
43
+ List[Dict]: A list of dictionaries, where each dictionary is a chunk of the input dictionary.
44
+ """
45
+ if n <= 0:
46
+ raise ValueError("Number of chunks must be greater than 0")
47
+
48
+ if not input_dict:
49
+ return []
50
+
51
+ items = list(input_dict.items())
52
+ chunk_size = len(items) // n
53
+ remainder = len(items) % n
54
+
55
+ chunks = []
56
+ start = 0
57
+ for i in range(n):
58
+ end = start + chunk_size + (1 if i < remainder else 0)
59
+ chunk = dict(items[start:end])
60
+ chunks.append(chunk)
61
+ start = end
62
+
63
+ return chunks
64
+
65
+
66
+ def time_api(func):
67
+ """
68
+ Decorator to measure the execution time of a function.
69
+ """
70
+
71
+ def time_wrapper(*args, **kwargs):
72
+ """
73
+ Passed function reference is utilized to run the function with its
74
+ original arguments while maintaining timing and logging functions.
75
+ """
76
+ start_time = time.time()
77
+ result = func(*args, **kwargs)
78
+ end_time = time.time()
79
+ elapsed_time = end_time - start_time
80
+ print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
81
+ return result
82
+
83
+ # We return the augmented function's reference.
84
+ return time_wrapper
@@ -0,0 +1,165 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from typing import Dict, Any, Union
6
+ import ssl
7
+ import time
8
+
9
+ import websockets
10
+ import certifi
11
+
12
+ from .model_config import ModelConfig
13
+ from .utils import load_json_buffer, begin_task_execution
14
+
15
+ SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
16
+ END_OF_STREAM = '<EOS>'
17
+ DEFAULT_RESPONSE = {"response": "", "success": 0}
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
19
+
20
+
21
+ class ResponseDataError(Exception):
22
+ def __init__(self, message="An error message has been sent by the endpoint."):
23
+ super().__init__(message)
24
+ self.message = message
25
+
26
+
27
+ async def interact_with_websocket(uri: str, queue: asyncio.Queue,
28
+ response_payloads: Dict[Union[str, int], Dict[str, Any]],
29
+ ws_timeout_min: int = 8,
30
+ error_threshold: int = 100,
31
+ reconnect_timeout_secs: int = 2):
32
+ """
33
+ :param uri: WebSocket URI
34
+ :param queue: An instantiated asyncio Queue with your dataset fully inserted.
35
+ :param response_payloads: A dictionary with base responses set up for processing.
36
+ :param ws_timeout_min: The number of minutes before a connection to the WebSocket is terminated on the async task.
37
+ :param error_threshold: The number of errors permissible for a query in the WebSocket before the program
38
+ moves to the next question.
39
+ :param reconnect_timeout_secs: The number of seconds before another connection is instantiated.
40
+ :return: the `response_payloads` object initially passed into the function
41
+ """
42
+ error_ct = 0
43
+ ws_timeout = ws_timeout_min * 60
44
+ qid = None
45
+ tmp_input_payload = None
46
+
47
+ # START - WEBSOCKET LOOP
48
+ while not (queue.empty() and error_ct == 0):
49
+ connection_start_time = time.time()
50
+ try:
51
+ async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
52
+ # START - QUERY QUEUE LOOP
53
+ while not (queue.empty() and error_ct == 0):
54
+ if error_ct == 0:
55
+ qid, tmp_input_payload = await queue.get()
56
+
57
+ # To prevent unnecessary timeouts DURING a query, we can
58
+ # cautiously opt out of the connection at a user specified threshold.
59
+ query_time = time.time()
60
+ connection_time = query_time - connection_start_time
61
+ if connection_time >= ws_timeout:
62
+ raise asyncio.TimeoutError()
63
+
64
+ if tmp_input_payload["query"] != "":
65
+ await ws.send(json.dumps(tmp_input_payload))
66
+ # START - QUERY CHUNK LOOP
67
+ while True:
68
+ # First chunk will typically take 29 seconds of time for every model
69
+ # Remaining chunks can take up to 9 minutes to be received.
70
+ # First chunk will take more time with GeminiPro - long time to load in comparison
71
+ # to other models
72
+ response = await asyncio.wait_for(ws.recv(), ws_timeout)
73
+
74
+ parsed_response = load_json_buffer(response)
75
+ if isinstance(parsed_response, dict):
76
+ if 'response' in parsed_response.keys():
77
+ cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
78
+ response_payloads[qid]["response"] += cleaned
79
+ if 'metadata' in parsed_response.keys() \
80
+ or END_OF_STREAM in parsed_response['response']:
81
+ response_payloads[qid]["metadata"] = parsed_response["metadata"]
82
+ response_payloads[qid]["success"] = 1
83
+ error_ct = 0
84
+ break
85
+ else:
86
+ # Connection ID expires - first chunk took more than 29 seconds.
87
+ response_payloads[qid].update(parsed_response)
88
+ raise ResponseDataError()
89
+ else:
90
+ response_payloads[qid]["response"] += response.replace(END_OF_STREAM, "")
91
+ if END_OF_STREAM in response:
92
+ response_payloads[qid]["success"] = 1
93
+ error_ct = 0
94
+ break
95
+ # END - QUERY CHUNK LOOP
96
+ if error_ct == 0:
97
+ logging.info(f"Query ID {qid} completed... Message: {response_payloads[qid]['response']}")
98
+ queue.task_done()
99
+ # END - QUERY QUEUE LOOP
100
+ except (asyncio.TimeoutError, websockets.ConnectionClosed, Exception) as exc:
101
+ time.sleep(reconnect_timeout_secs)
102
+ # If the query is already complete, we don't want to increment the error count
103
+ if response_payloads[qid]["success"] == 0:
104
+ error_ct += 1
105
+ # Reset buffer stream so that you don't get messed by pre-existing data.
106
+ response_payloads[qid]["response"] = ""
107
+ else:
108
+ logging.info(f"Query ID {qid} completed...")
109
+ queue.task_done()
110
+ if isinstance(exc, asyncio.TimeoutError):
111
+ logging.error(f"Error {error_ct} on {qid} stream timeout: resetting connection...")
112
+ elif isinstance(exc, ResponseDataError):
113
+ logging.error(f"Error {error_ct} on {qid}: invalid response from endpoint.\n"
114
+ f"{response_payloads[qid]}")
115
+ elif isinstance(exc, websockets.ConnectionClosed):
116
+ logging.error(f"Error {error_ct} on {qid}: WebSocket connection closed on "
117
+ f"query ID {qid}. Reopening...")
118
+ else:
119
+ logging.error(f"Error {error_ct} on {qid}: {traceback.format_exc()}")
120
+
121
+ # prevent any further retries if at error limit.
122
+ if error_ct == error_threshold:
123
+ error_ct = 0
124
+ # END - WEBSOCKET LOOP
125
+ logging.info("WebSocket connection closed. Queue appears to be empty...")
126
+
127
+
128
+ @begin_task_execution
129
+ async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
130
+ max_concurrent_tasks: int = 3,
131
+ ws_timeout_min: int = 8,
132
+ error_threshold: int = 100,
133
+ reconnect_timeout_secs: int = 2) -> Dict[Union[str, int], Dict[str, Any]]:
134
+ payloads = {}
135
+ for qid, message in queries.items():
136
+ payloads[qid] = model.compute_payload(message)
137
+
138
+ response_payloads = {qid: DEFAULT_RESPONSE.copy() for qid in queries}
139
+
140
+ if len(response_payloads) > 0:
141
+ queue = asyncio.Queue()
142
+ for qid, payload in payloads.items():
143
+ await queue.put((qid, payload))
144
+ tasks = [interact_with_websocket(model.api_url, queue, response_payloads,
145
+ ws_timeout_min, error_threshold, reconnect_timeout_secs)
146
+ for _ in range(max_concurrent_tasks)]
147
+ await asyncio.gather(*tasks)
148
+
149
+ return response_payloads
150
+
151
+
152
+ @begin_task_execution
153
+ async def query_llm_socket(model: ModelConfig, query: str,
154
+ ws_timeout_min: int = 8,
155
+ error_threshold: int = 100,
156
+ reconnect_timeout_secs: int = 2) -> Dict[str, Any]:
157
+ response_payloads = {0: DEFAULT_RESPONSE.copy()}
158
+ tmp_payload = model.compute_payload(query=query)
159
+ if query != "":
160
+ queue = asyncio.Queue()
161
+ await queue.put((0, tmp_payload))
162
+
163
+ await interact_with_websocket(model.api_url, queue, response_payloads, ws_timeout_min,
164
+ error_threshold, reconnect_timeout_secs)
165
+ return response_payloads[0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.2
3
+ Version: 2.0.4
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.2
3
+ Version: 2.0.4
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASUllmAPI"
3
- version = "2.0.2"
3
+ version = "2.0.4"
4
4
  authors = [
5
5
  { name="Stella Wenxing Liu", email="stellawenxingliu@gmail.com" },
6
6
  { name="Varun Shourie", email="svarun195@gmail.com" }
@@ -1,38 +0,0 @@
1
- import time
2
- import json
3
- import asyncio
4
-
5
-
6
- def begin_task_execution(async_func):
7
- def wrap(*args, **kwargs):
8
- return asyncio.run(async_func(*args, **kwargs))
9
-
10
- return wrap
11
-
12
-
13
- def load_json_buffer(string):
14
- try:
15
- return json.loads(string)
16
- except json.JSONDecodeError:
17
- return None
18
-
19
-
20
- def time_api(func):
21
- """
22
- Decorator to measure the execution time of a function.
23
- """
24
-
25
- def time_wrapper(*args, **kwargs):
26
- """
27
- Passed function reference is utilized to run the function with its
28
- original arguments while maintaining timing and logging functions.
29
- """
30
- start_time = time.time()
31
- result = func(*args, **kwargs)
32
- end_time = time.time()
33
- elapsed_time = end_time - start_time
34
- print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
35
- return result
36
-
37
- # We return the augmented function's reference.
38
- return time_wrapper
@@ -1,89 +0,0 @@
1
- import asyncio
2
- import json
3
- from typing import Dict
4
- import ssl
5
- import warnings
6
-
7
- import websockets
8
- import certifi
9
-
10
- from .model_config import ModelConfig
11
- from .utils import load_json_buffer, begin_task_execution
12
-
13
- SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
14
-
15
-
16
- async def interact_with_websocket(uri: str, query_payload: Dict[str, str], qid: int):
17
- end_of_stream = "<EOS>"
18
- payload = {qid: {"response": ""}}
19
-
20
- async with websockets.connect(uri, ssl=SSL_CONTEXT) as websocket:
21
- # Send the user-provided message to the WebSocket server
22
- await websocket.send(json.dumps(query_payload))
23
-
24
- # Loop to receive messages until the server is finished
25
- while True:
26
- try:
27
- response = await websocket.recv()
28
- parsed_response = load_json_buffer(response)
29
-
30
- if isinstance(parsed_response, dict):
31
- # Case: json/text response and user is denied entry
32
- if 'message' in parsed_response.keys() or 'error' in parsed_response.keys():
33
- # Embed entire Forbidden message payload and leave response blank.
34
- payload[qid].update(parsed_response)
35
- break
36
- # Case: json response and user is not denied entry
37
- elif 'response' in parsed_response.keys():
38
- payload[qid]["response"] += (parsed_response["response"].replace(end_of_stream, ""))
39
- if 'metadata' in parsed_response.keys() or end_of_stream in parsed_response['response']:
40
- payload[qid]["metadata"] = parsed_response["metadata"]
41
- break
42
- # Unknown edge case: json getting returned that is not a `message`, `response`, `error`
43
- # (not observed yet)
44
- else:
45
- warnings.warn(f"Unknown ASU LLM endpoint edge case detected: JSON parsed but data does not "
46
- f"contain a message or response field.", RuntimeWarning)
47
- payload[qid].update(parsed_response)
48
- break
49
- # Intended case: user asked for a response type of `text` and was not denied entry.
50
- else:
51
- payload[qid]["response"] += response.replace(end_of_stream, "")
52
- if end_of_stream in response:
53
- break
54
-
55
- except websockets.ConnectionClosed:
56
- print(f"Question {qid} connection closed by the server")
57
- break
58
- print(f"......Question {qid} response sent by the WebSocket server.")
59
- return payload
60
-
61
-
62
- async def limited_task(semaphore: asyncio.Semaphore, task):
63
- async with semaphore:
64
- return await task
65
-
66
-
67
- @begin_task_execution
68
- async def batch_query_llm_socket(model: ModelConfig, queries: Dict[int, str], max_concurrent_tasks: int = 3):
69
- tasks = []
70
- for qid, message in queries.items():
71
- tmp_payload = model.compute_payload(message)
72
- tasks.append(interact_with_websocket(model.api_url, tmp_payload, qid))
73
- semaphore = asyncio.Semaphore(max_concurrent_tasks)
74
- limited_tasks = [limited_task(semaphore, task) for task in tasks]
75
-
76
- # Gather all tasks to run them concurrently and collect results
77
- results = await asyncio.gather(*limited_tasks)
78
- final_results = {}
79
- # Desired format: {1: {"response": ..., "metadata": ...}, 2: {...}}
80
- for result_dict in results:
81
- final_results.update(result_dict)
82
- return final_results
83
-
84
-
85
- @begin_task_execution
86
- async def query_llm_socket(model: ModelConfig, query: str):
87
- tmp_payload = model.compute_payload(query=query)
88
- result = await interact_with_websocket(model.api_url, tmp_payload, qid=0)
89
- return result
File without changes
File without changes
File without changes
File without changes