ASUllmAPI 2.0.3__tar.gz → 2.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI/utils.py +19 -4
- asullmapi-2.0.5/ASUllmAPI/web_socket.py +166 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI.egg-info/PKG-INFO +1 -1
- {asullmapi-2.0.3 → asullmapi-2.0.5}/PKG-INFO +1 -1
- {asullmapi-2.0.3 → asullmapi-2.0.5}/pyproject.toml +1 -1
- asullmapi-2.0.3/ASUllmAPI/web_socket.py +0 -105
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI/__init__.py +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI/api.py +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI/model_config.py +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI/multithreading.py +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI.egg-info/SOURCES.txt +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI.egg-info/dependency_links.txt +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI.egg-info/requires.txt +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/ASUllmAPI.egg-info/top_level.txt +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/LICENSE +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/README.md +0 -0
- {asullmapi-2.0.3 → asullmapi-2.0.5}/setup.cfg +0 -0
|
@@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
3
3
|
import time
|
|
4
4
|
import json
|
|
5
5
|
import asyncio
|
|
6
|
+
import sys
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def begin_task_execution(async_func):
|
|
@@ -13,17 +14,31 @@ def begin_task_execution(async_func):
|
|
|
13
14
|
If an event loop doesn't exist, it reverts back to existing asyncio logic.
|
|
14
15
|
"""
|
|
15
16
|
def wrap(*args, **kwargs):
|
|
16
|
-
try
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
# It is safer to use a if/else statement than try/except since
|
|
18
|
+
# the underlying code we wrap our function around can also raise RuntimeErrors.
|
|
19
|
+
# When this happens, multiple asyncio.run() executions occur, which is dangerous.
|
|
20
|
+
if is_jupyter():
|
|
19
21
|
with ThreadPoolExecutor(1) as pool:
|
|
20
22
|
result = pool.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
|
|
21
|
-
|
|
23
|
+
else:
|
|
22
24
|
result = asyncio.run(async_func(*args, **kwargs))
|
|
23
25
|
return result
|
|
24
26
|
return wrap
|
|
25
27
|
|
|
26
28
|
|
|
29
|
+
def is_jupyter():
|
|
30
|
+
try:
|
|
31
|
+
# Check if 'IPython' is in sys.modules
|
|
32
|
+
if 'IPython' in sys.modules:
|
|
33
|
+
from IPython import get_ipython
|
|
34
|
+
# Check if we're in an IPython environment
|
|
35
|
+
if get_ipython() is not None:
|
|
36
|
+
return True
|
|
37
|
+
except ImportError:
|
|
38
|
+
pass
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
|
|
27
42
|
def load_json_buffer(string):
|
|
28
43
|
try:
|
|
29
44
|
return json.loads(string)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import traceback
|
|
5
|
+
from typing import Dict, Any, Union
|
|
6
|
+
import ssl
|
|
7
|
+
import time
|
|
8
|
+
|
|
9
|
+
import websockets
|
|
10
|
+
import certifi
|
|
11
|
+
|
|
12
|
+
from .model_config import ModelConfig
|
|
13
|
+
from .utils import load_json_buffer, begin_task_execution
|
|
14
|
+
|
|
15
|
+
SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
|
|
16
|
+
END_OF_STREAM = '<EOS>'
|
|
17
|
+
DEFAULT_RESPONSE = {"response": "", "success": 0}
|
|
18
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ResponseDataError(Exception):
|
|
22
|
+
def __init__(self, message="An error message has been sent by the endpoint."):
|
|
23
|
+
super().__init__(message)
|
|
24
|
+
self.message = message
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def interact_with_websocket(uri: str, queue: asyncio.Queue,
|
|
28
|
+
response_payloads: Dict[Union[str, int], Dict[str, Any]],
|
|
29
|
+
ws_timeout_min: int = 8,
|
|
30
|
+
error_threshold: int = 100,
|
|
31
|
+
reconnect_timeout_secs: int = 2):
|
|
32
|
+
"""
|
|
33
|
+
:param uri: WebSocket URI
|
|
34
|
+
:param queue: An instantiated asyncio Queue with your dataset fully inserted.
|
|
35
|
+
:param response_payloads: A dictionary with base responses set up for processing.
|
|
36
|
+
:param ws_timeout_min: The number of minutes before a connection to the WebSocket is terminated on the async task.
|
|
37
|
+
:param error_threshold: The number of errors permissible for a query in the WebSocket before the program
|
|
38
|
+
moves to the next question.
|
|
39
|
+
:param reconnect_timeout_secs: The number of seconds before another connection is instantiated.
|
|
40
|
+
:return: the `response_payloads` object initially passed into the function
|
|
41
|
+
"""
|
|
42
|
+
error_ct = 0
|
|
43
|
+
ws_timeout = ws_timeout_min * 60
|
|
44
|
+
qid = None
|
|
45
|
+
tmp_input_payload = None
|
|
46
|
+
|
|
47
|
+
# START - WEBSOCKET LOOP
|
|
48
|
+
while not (queue.empty() and error_ct == 0):
|
|
49
|
+
connection_start_time = time.time()
|
|
50
|
+
try:
|
|
51
|
+
async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
|
|
52
|
+
# START - QUERY QUEUE LOOP
|
|
53
|
+
while not (queue.empty() and error_ct == 0):
|
|
54
|
+
if error_ct == 0:
|
|
55
|
+
qid, tmp_input_payload = await queue.get()
|
|
56
|
+
|
|
57
|
+
# To prevent unnecessary timeouts DURING a query, we can
|
|
58
|
+
# cautiously opt out of the connection at a user specified threshold.
|
|
59
|
+
query_time = time.time()
|
|
60
|
+
connection_time = query_time - connection_start_time
|
|
61
|
+
if connection_time >= ws_timeout:
|
|
62
|
+
raise asyncio.TimeoutError()
|
|
63
|
+
|
|
64
|
+
if tmp_input_payload["query"] != "":
|
|
65
|
+
await ws.send(json.dumps(tmp_input_payload))
|
|
66
|
+
# START - QUERY CHUNK LOOP
|
|
67
|
+
while True:
|
|
68
|
+
# First chunk will typically take 29 seconds of time for every model
|
|
69
|
+
# Remaining chunks can take up to 9 minutes to be received.
|
|
70
|
+
# First chunk will take more time with GeminiPro - long time to load in comparison
|
|
71
|
+
# to other models
|
|
72
|
+
response = await asyncio.wait_for(ws.recv(), ws_timeout)
|
|
73
|
+
|
|
74
|
+
parsed_response = load_json_buffer(response)
|
|
75
|
+
if isinstance(parsed_response, dict):
|
|
76
|
+
if 'response' in parsed_response.keys():
|
|
77
|
+
cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
|
|
78
|
+
response_payloads[qid]["response"] += cleaned
|
|
79
|
+
if 'metadata' in parsed_response.keys() \
|
|
80
|
+
or END_OF_STREAM in parsed_response['response']:
|
|
81
|
+
response_payloads[qid]["metadata"] = parsed_response["metadata"]
|
|
82
|
+
response_payloads[qid]["success"] = 1
|
|
83
|
+
error_ct = 0
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
# Connection ID expires - first chunk took more than 29 seconds.
|
|
87
|
+
response_payloads[qid].update(parsed_response)
|
|
88
|
+
raise ResponseDataError()
|
|
89
|
+
else:
|
|
90
|
+
response_payloads[qid]["response"] += response.replace(END_OF_STREAM, "")
|
|
91
|
+
if END_OF_STREAM in response:
|
|
92
|
+
response_payloads[qid]["success"] = 1
|
|
93
|
+
error_ct = 0
|
|
94
|
+
break
|
|
95
|
+
# END - QUERY CHUNK LOOP
|
|
96
|
+
if error_ct == 0:
|
|
97
|
+
logging.info(f"Query ID {qid} completed... Message: {response_payloads[qid]['response']}")
|
|
98
|
+
queue.task_done()
|
|
99
|
+
# END - QUERY QUEUE LOOP
|
|
100
|
+
except (asyncio.TimeoutError, websockets.ConnectionClosed, Exception) as exc:
|
|
101
|
+
if isinstance(exc, asyncio.TimeoutError):
|
|
102
|
+
logging.error(f"Error {error_ct} on {qid} stream timeout: resetting connection...")
|
|
103
|
+
elif isinstance(exc, ResponseDataError):
|
|
104
|
+
logging.error(f"Error {error_ct} on {qid}: invalid response from endpoint.\n"
|
|
105
|
+
f"{response_payloads[qid]}")
|
|
106
|
+
elif isinstance(exc, websockets.ConnectionClosed):
|
|
107
|
+
logging.error(f"Error {error_ct} on {qid}: WebSocket connection closed on "
|
|
108
|
+
f"query ID {qid}. Reopening...")
|
|
109
|
+
else:
|
|
110
|
+
logging.error(f"Error {error_ct} on {qid}: {traceback.format_exc()}")
|
|
111
|
+
|
|
112
|
+
time.sleep(reconnect_timeout_secs)
|
|
113
|
+
# If the query is already complete, we don't want to increment the error count
|
|
114
|
+
if response_payloads[qid]["success"] == 0:
|
|
115
|
+
error_ct += 1
|
|
116
|
+
# Reset buffer stream so that you don't get messed by pre-existing data.
|
|
117
|
+
response_payloads[qid]["response"] = ""
|
|
118
|
+
else:
|
|
119
|
+
logging.info(f"Query ID {qid} completed...")
|
|
120
|
+
queue.task_done()
|
|
121
|
+
|
|
122
|
+
# prevent any further retries if at error limit.
|
|
123
|
+
if error_ct == error_threshold:
|
|
124
|
+
error_ct = 0
|
|
125
|
+
# END - WEBSOCKET LOOP
|
|
126
|
+
logging.info("WebSocket connection closed. Queue appears to be empty...")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@begin_task_execution
|
|
130
|
+
async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
|
|
131
|
+
max_concurrent_tasks: int = 3,
|
|
132
|
+
ws_timeout_min: int = 8,
|
|
133
|
+
error_threshold: int = 100,
|
|
134
|
+
reconnect_timeout_secs: int = 2) -> Dict[Union[str, int], Dict[str, Any]]:
|
|
135
|
+
payloads = {}
|
|
136
|
+
for qid, message in queries.items():
|
|
137
|
+
payloads[qid] = model.compute_payload(message)
|
|
138
|
+
|
|
139
|
+
response_payloads = {qid: DEFAULT_RESPONSE.copy() for qid in queries}
|
|
140
|
+
|
|
141
|
+
if len(response_payloads) > 0:
|
|
142
|
+
queue = asyncio.Queue()
|
|
143
|
+
for qid, payload in payloads.items():
|
|
144
|
+
await queue.put((qid, payload))
|
|
145
|
+
tasks = [interact_with_websocket(model.api_url, queue, response_payloads,
|
|
146
|
+
ws_timeout_min, error_threshold, reconnect_timeout_secs)
|
|
147
|
+
for _ in range(max_concurrent_tasks)]
|
|
148
|
+
await asyncio.gather(*tasks)
|
|
149
|
+
|
|
150
|
+
return response_payloads
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@begin_task_execution
|
|
154
|
+
async def query_llm_socket(model: ModelConfig, query: str,
|
|
155
|
+
ws_timeout_min: int = 8,
|
|
156
|
+
error_threshold: int = 100,
|
|
157
|
+
reconnect_timeout_secs: int = 2) -> Dict[str, Any]:
|
|
158
|
+
response_payloads = {0: DEFAULT_RESPONSE.copy()}
|
|
159
|
+
tmp_payload = model.compute_payload(query=query)
|
|
160
|
+
if query != "":
|
|
161
|
+
queue = asyncio.Queue()
|
|
162
|
+
await queue.put((0, tmp_payload))
|
|
163
|
+
|
|
164
|
+
await interact_with_websocket(model.api_url, queue, response_payloads, ws_timeout_min,
|
|
165
|
+
error_threshold, reconnect_timeout_secs)
|
|
166
|
+
return response_payloads[0]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ASUllmAPI
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.5
|
|
4
4
|
Summary: A simple python package to facilitate connection to ASU LLM API
|
|
5
5
|
Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ASUllmAPI
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.5
|
|
4
4
|
Summary: A simple python package to facilitate connection to ASU LLM API
|
|
5
5
|
Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import json
|
|
3
|
-
import logging
|
|
4
|
-
from typing import Dict, Any, Union
|
|
5
|
-
import ssl
|
|
6
|
-
|
|
7
|
-
import websockets
|
|
8
|
-
import certifi
|
|
9
|
-
|
|
10
|
-
from .model_config import ModelConfig
|
|
11
|
-
from .utils import load_json_buffer, begin_task_execution, split_dict_into_chunks
|
|
12
|
-
|
|
13
|
-
SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
|
|
14
|
-
END_OF_STREAM = '<EOS>'
|
|
15
|
-
DEFAULT_RESPONSE = {"response": ""}
|
|
16
|
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
async def interact_with_websocket(uri: str, query_payloads: Dict[Union[str, int], Dict[str, Any]]):
|
|
20
|
-
response_payloads = {qid: {"response": ""} for qid in query_payloads}
|
|
21
|
-
payload_iterator = iter(query_payloads.items())
|
|
22
|
-
tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
|
|
23
|
-
retry = False
|
|
24
|
-
ws_ready_for_closure = False
|
|
25
|
-
error_ct = 0
|
|
26
|
-
|
|
27
|
-
# WEBSOCKET LOOP: Indefinitely opens a new connection when a new one is closed in the SAME asyncio task.
|
|
28
|
-
while not ws_ready_for_closure:
|
|
29
|
-
async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
|
|
30
|
-
try:
|
|
31
|
-
if ws_ready_for_closure or tmp_qid is None:
|
|
32
|
-
break
|
|
33
|
-
# QUERY LOOP: While there are queries remaining to send to the server.
|
|
34
|
-
while tmp_qid is not None:
|
|
35
|
-
response_payloads[tmp_qid]["success"] = 0
|
|
36
|
-
if not retry:
|
|
37
|
-
error_ct = 0
|
|
38
|
-
await ws.send(json.dumps(tmp_input_payload))
|
|
39
|
-
# QUERY CHUNK LOOP: While chunks are being received by the WebSocket Client.
|
|
40
|
-
while True:
|
|
41
|
-
# Wait for the duration of the WebSocket connection for ANY response
|
|
42
|
-
response = await ws.recv()
|
|
43
|
-
parsed_response = load_json_buffer(response)
|
|
44
|
-
# Parse response stream from WebSocket server based on every use case
|
|
45
|
-
if isinstance(parsed_response, dict):
|
|
46
|
-
# Intended use case: user receives JSON response.
|
|
47
|
-
if 'response' in parsed_response.keys():
|
|
48
|
-
cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
|
|
49
|
-
response_payloads[tmp_qid]["response"] += cleaned
|
|
50
|
-
if 'metadata' in parsed_response.keys() or END_OF_STREAM in parsed_response['response']:
|
|
51
|
-
response_payloads[tmp_qid]["metadata"] = parsed_response["metadata"]
|
|
52
|
-
response_payloads[tmp_qid]["success"] = 1
|
|
53
|
-
retry = False
|
|
54
|
-
break
|
|
55
|
-
# Error: user denied access to the response for some reason.
|
|
56
|
-
else:
|
|
57
|
-
logging.log(logging.ERROR, f"Data error: query ID {tmp_qid}: {parsed_response}")
|
|
58
|
-
response_payloads[tmp_qid].update(parsed_response)
|
|
59
|
-
retry = False
|
|
60
|
-
break
|
|
61
|
-
# Intended use case: user receives TEXT response from websocket
|
|
62
|
-
# Data format is streamlined regardless of which response type the user selects.
|
|
63
|
-
else:
|
|
64
|
-
response_payloads[tmp_qid]["response"] += response.replace(END_OF_STREAM, "")
|
|
65
|
-
if END_OF_STREAM in response:
|
|
66
|
-
retry = False
|
|
67
|
-
response_payloads[tmp_qid]["success"] = 1
|
|
68
|
-
break
|
|
69
|
-
# END - QUERY CHUNK LOOP
|
|
70
|
-
if not retry:
|
|
71
|
-
logging.log(level=logging.INFO, msg=f"Query ID {tmp_qid} completed...")
|
|
72
|
-
tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
|
|
73
|
-
# END - QUERY LOOP
|
|
74
|
-
ws_ready_for_closure = True
|
|
75
|
-
except websockets.ConnectionClosed:
|
|
76
|
-
error_ct += 1
|
|
77
|
-
logging.log(level=logging.INFO, msg=f"Error {error_ct}: WebSocket connection closed on query ID "
|
|
78
|
-
f"{tmp_qid}. Reopening...")
|
|
79
|
-
retry = True if error_ct <= 2 else False
|
|
80
|
-
# END - WEBSOCKET LOOP
|
|
81
|
-
logging.log(logging.INFO, msg="WebSocket connection closed.")
|
|
82
|
-
return response_payloads
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@begin_task_execution
|
|
86
|
-
async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
|
|
87
|
-
max_concurrent_tasks: int = 3) -> Dict[Union[str, int], Dict[str, Any]]:
|
|
88
|
-
payloads = {}
|
|
89
|
-
for qid, message in queries.items():
|
|
90
|
-
payloads[qid] = model.compute_payload(message)
|
|
91
|
-
batches = split_dict_into_chunks(payloads, max_concurrent_tasks)
|
|
92
|
-
tasks = [interact_with_websocket(model.api_url, batch) for batch in batches]
|
|
93
|
-
|
|
94
|
-
results = await asyncio.gather(*tasks)
|
|
95
|
-
final_results = {}
|
|
96
|
-
for result_dict in results:
|
|
97
|
-
final_results.update(result_dict)
|
|
98
|
-
return final_results
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
@begin_task_execution
|
|
102
|
-
async def query_llm_socket(model: ModelConfig, query: str) -> Dict[str, Any]:
|
|
103
|
-
tmp_payload = model.compute_payload(query=query)
|
|
104
|
-
result = await interact_with_websocket(model.api_url, {0: tmp_payload})
|
|
105
|
-
return result[0]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|