ASUllmAPI 2.0.2__tar.gz → 2.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI/multithreading.py +8 -4
- asullmapi-2.0.4/ASUllmAPI/utils.py +84 -0
- asullmapi-2.0.4/ASUllmAPI/web_socket.py +165 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI.egg-info/PKG-INFO +1 -1
- {asullmapi-2.0.2 → asullmapi-2.0.4}/PKG-INFO +1 -1
- {asullmapi-2.0.2 → asullmapi-2.0.4}/pyproject.toml +1 -1
- asullmapi-2.0.2/ASUllmAPI/utils.py +0 -38
- asullmapi-2.0.2/ASUllmAPI/web_socket.py +0 -89
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI/__init__.py +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI/api.py +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI/model_config.py +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI.egg-info/SOURCES.txt +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI.egg-info/dependency_links.txt +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI.egg-info/requires.txt +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/ASUllmAPI.egg-info/top_level.txt +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/LICENSE +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/README.md +0 -0
- {asullmapi-2.0.2 → asullmapi-2.0.4}/setup.cfg +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
-
from typing import Dict
|
|
2
|
+
from typing import Dict, Union
|
|
3
3
|
|
|
4
4
|
from tqdm.auto import tqdm
|
|
5
5
|
|
|
@@ -9,9 +9,13 @@ from .utils import time_api
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@time_api
|
|
12
|
-
def batch_query_llm(model: ModelConfig,
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
def batch_query_llm(model: ModelConfig,
|
|
13
|
+
queries: Dict[Union[str, int], str],
|
|
14
|
+
max_threads: int,
|
|
15
|
+
num_retry: int = 3,
|
|
16
|
+
auto_increase_retry: bool = False,
|
|
17
|
+
success_sleep: float = 0.0,
|
|
18
|
+
fail_sleep: float = 1.0) -> Dict[Union[str, int], dict]:
|
|
15
19
|
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
|
16
20
|
# Submit tasks to the executor - order of return will be asynchronous.
|
|
17
21
|
# If `auto_increase_retry` enabled, then scaling API backoff is implemented.
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
import time
|
|
4
|
+
import json
|
|
5
|
+
import asyncio
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def begin_task_execution(async_func):
|
|
9
|
+
"""
|
|
10
|
+
See https://stackoverflow.com/a/75341431 for underlying rationale.
|
|
11
|
+
This code should allow any async function to be called synchronously by circumventing
|
|
12
|
+
Jupyter's existing event loop via the creation of a separate thread.
|
|
13
|
+
If an event loop doesn't exist, it reverts back to existing asyncio logic.
|
|
14
|
+
"""
|
|
15
|
+
def wrap(*args, **kwargs):
|
|
16
|
+
try:
|
|
17
|
+
asyncio.get_running_loop() # Triggers RuntimeError if no running event loop
|
|
18
|
+
# Create a separate thread so we can block before returning
|
|
19
|
+
with ThreadPoolExecutor(1) as pool:
|
|
20
|
+
result = pool.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
|
|
21
|
+
except RuntimeError:
|
|
22
|
+
result = asyncio.run(async_func(*args, **kwargs))
|
|
23
|
+
return result
|
|
24
|
+
return wrap
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_json_buffer(string):
|
|
28
|
+
try:
|
|
29
|
+
return json.loads(string)
|
|
30
|
+
except json.JSONDecodeError:
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def split_dict_into_chunks(input_dict: Dict, n: int) -> List[Dict]:
|
|
35
|
+
"""
|
|
36
|
+
Split a dictionary into `n` chunks.
|
|
37
|
+
|
|
38
|
+
Parameters:
|
|
39
|
+
input_dict (dict): The dictionary to split.
|
|
40
|
+
n (int): The number of chunks to split the dictionary into.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
List[Dict]: A list of dictionaries, where each dictionary is a chunk of the input dictionary.
|
|
44
|
+
"""
|
|
45
|
+
if n <= 0:
|
|
46
|
+
raise ValueError("Number of chunks must be greater than 0")
|
|
47
|
+
|
|
48
|
+
if not input_dict:
|
|
49
|
+
return []
|
|
50
|
+
|
|
51
|
+
items = list(input_dict.items())
|
|
52
|
+
chunk_size = len(items) // n
|
|
53
|
+
remainder = len(items) % n
|
|
54
|
+
|
|
55
|
+
chunks = []
|
|
56
|
+
start = 0
|
|
57
|
+
for i in range(n):
|
|
58
|
+
end = start + chunk_size + (1 if i < remainder else 0)
|
|
59
|
+
chunk = dict(items[start:end])
|
|
60
|
+
chunks.append(chunk)
|
|
61
|
+
start = end
|
|
62
|
+
|
|
63
|
+
return chunks
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def time_api(func):
|
|
67
|
+
"""
|
|
68
|
+
Decorator to measure the execution time of a function.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def time_wrapper(*args, **kwargs):
|
|
72
|
+
"""
|
|
73
|
+
Passed function reference is utilized to run the function with its
|
|
74
|
+
original arguments while maintaining timing and logging functions.
|
|
75
|
+
"""
|
|
76
|
+
start_time = time.time()
|
|
77
|
+
result = func(*args, **kwargs)
|
|
78
|
+
end_time = time.time()
|
|
79
|
+
elapsed_time = end_time - start_time
|
|
80
|
+
print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
# We return the augmented function's reference.
|
|
84
|
+
return time_wrapper
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import traceback
|
|
5
|
+
from typing import Dict, Any, Union
|
|
6
|
+
import ssl
|
|
7
|
+
import time
|
|
8
|
+
|
|
9
|
+
import websockets
|
|
10
|
+
import certifi
|
|
11
|
+
|
|
12
|
+
from .model_config import ModelConfig
|
|
13
|
+
from .utils import load_json_buffer, begin_task_execution
|
|
14
|
+
|
|
15
|
+
SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
|
|
16
|
+
END_OF_STREAM = '<EOS>'
|
|
17
|
+
DEFAULT_RESPONSE = {"response": "", "success": 0}
|
|
18
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ResponseDataError(Exception):
|
|
22
|
+
def __init__(self, message="An error message has been sent by the endpoint."):
|
|
23
|
+
super().__init__(message)
|
|
24
|
+
self.message = message
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def interact_with_websocket(uri: str, queue: asyncio.Queue,
|
|
28
|
+
response_payloads: Dict[Union[str, int], Dict[str, Any]],
|
|
29
|
+
ws_timeout_min: int = 8,
|
|
30
|
+
error_threshold: int = 100,
|
|
31
|
+
reconnect_timeout_secs: int = 2):
|
|
32
|
+
"""
|
|
33
|
+
:param uri: WebSocket URI
|
|
34
|
+
:param queue: An instantiated asyncio Queue with your dataset fully inserted.
|
|
35
|
+
:param response_payloads: A dictionary with base responses set up for processing.
|
|
36
|
+
:param ws_timeout_min: The number of minutes before a connection to the WebSocket is terminated on the async task.
|
|
37
|
+
:param error_threshold: The number of errors permissible for a query in the WebSocket before the program
|
|
38
|
+
moves to the next question.
|
|
39
|
+
:param reconnect_timeout_secs: The number of seconds before another connection is instantiated.
|
|
40
|
+
:return: the `response_payloads` object initially passed into the function
|
|
41
|
+
"""
|
|
42
|
+
error_ct = 0
|
|
43
|
+
ws_timeout = ws_timeout_min * 60
|
|
44
|
+
qid = None
|
|
45
|
+
tmp_input_payload = None
|
|
46
|
+
|
|
47
|
+
# START - WEBSOCKET LOOP
|
|
48
|
+
while not (queue.empty() and error_ct == 0):
|
|
49
|
+
connection_start_time = time.time()
|
|
50
|
+
try:
|
|
51
|
+
async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
|
|
52
|
+
# START - QUERY QUEUE LOOP
|
|
53
|
+
while not (queue.empty() and error_ct == 0):
|
|
54
|
+
if error_ct == 0:
|
|
55
|
+
qid, tmp_input_payload = await queue.get()
|
|
56
|
+
|
|
57
|
+
# To prevent unnecessary timeouts DURING a query, we can
|
|
58
|
+
# cautiously opt out of the connection at a user specified threshold.
|
|
59
|
+
query_time = time.time()
|
|
60
|
+
connection_time = query_time - connection_start_time
|
|
61
|
+
if connection_time >= ws_timeout:
|
|
62
|
+
raise asyncio.TimeoutError()
|
|
63
|
+
|
|
64
|
+
if tmp_input_payload["query"] != "":
|
|
65
|
+
await ws.send(json.dumps(tmp_input_payload))
|
|
66
|
+
# START - QUERY CHUNK LOOP
|
|
67
|
+
while True:
|
|
68
|
+
# First chunk will typically take 29 seconds of time for every model
|
|
69
|
+
# Remaining chunks can take up to 9 minutes to be received.
|
|
70
|
+
# First chunk will take more time with GeminiPro - long time to load in comparison
|
|
71
|
+
# to other models
|
|
72
|
+
response = await asyncio.wait_for(ws.recv(), ws_timeout)
|
|
73
|
+
|
|
74
|
+
parsed_response = load_json_buffer(response)
|
|
75
|
+
if isinstance(parsed_response, dict):
|
|
76
|
+
if 'response' in parsed_response.keys():
|
|
77
|
+
cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
|
|
78
|
+
response_payloads[qid]["response"] += cleaned
|
|
79
|
+
if 'metadata' in parsed_response.keys() \
|
|
80
|
+
or END_OF_STREAM in parsed_response['response']:
|
|
81
|
+
response_payloads[qid]["metadata"] = parsed_response["metadata"]
|
|
82
|
+
response_payloads[qid]["success"] = 1
|
|
83
|
+
error_ct = 0
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
# Connection ID expires - first chunk took more than 29 seconds.
|
|
87
|
+
response_payloads[qid].update(parsed_response)
|
|
88
|
+
raise ResponseDataError()
|
|
89
|
+
else:
|
|
90
|
+
response_payloads[qid]["response"] += response.replace(END_OF_STREAM, "")
|
|
91
|
+
if END_OF_STREAM in response:
|
|
92
|
+
response_payloads[qid]["success"] = 1
|
|
93
|
+
error_ct = 0
|
|
94
|
+
break
|
|
95
|
+
# END - QUERY CHUNK LOOP
|
|
96
|
+
if error_ct == 0:
|
|
97
|
+
logging.info(f"Query ID {qid} completed... Message: {response_payloads[qid]['response']}")
|
|
98
|
+
queue.task_done()
|
|
99
|
+
# END - QUERY QUEUE LOOP
|
|
100
|
+
except (asyncio.TimeoutError, websockets.ConnectionClosed, Exception) as exc:
|
|
101
|
+
time.sleep(reconnect_timeout_secs)
|
|
102
|
+
# If the query is already complete, we don't want to increment the error count
|
|
103
|
+
if response_payloads[qid]["success"] == 0:
|
|
104
|
+
error_ct += 1
|
|
105
|
+
# Reset buffer stream so that you don't get messed by pre-existing data.
|
|
106
|
+
response_payloads[qid]["response"] = ""
|
|
107
|
+
else:
|
|
108
|
+
logging.info(f"Query ID {qid} completed...")
|
|
109
|
+
queue.task_done()
|
|
110
|
+
if isinstance(exc, asyncio.TimeoutError):
|
|
111
|
+
logging.error(f"Error {error_ct} on {qid} stream timeout: resetting connection...")
|
|
112
|
+
elif isinstance(exc, ResponseDataError):
|
|
113
|
+
logging.error(f"Error {error_ct} on {qid}: invalid response from endpoint.\n"
|
|
114
|
+
f"{response_payloads[qid]}")
|
|
115
|
+
elif isinstance(exc, websockets.ConnectionClosed):
|
|
116
|
+
logging.error(f"Error {error_ct} on {qid}: WebSocket connection closed on "
|
|
117
|
+
f"query ID {qid}. Reopening...")
|
|
118
|
+
else:
|
|
119
|
+
logging.error(f"Error {error_ct} on {qid}: {traceback.format_exc()}")
|
|
120
|
+
|
|
121
|
+
# prevent any further retries if at error limit.
|
|
122
|
+
if error_ct == error_threshold:
|
|
123
|
+
error_ct = 0
|
|
124
|
+
# END - WEBSOCKET LOOP
|
|
125
|
+
logging.info("WebSocket connection closed. Queue appears to be empty...")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@begin_task_execution
|
|
129
|
+
async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
|
|
130
|
+
max_concurrent_tasks: int = 3,
|
|
131
|
+
ws_timeout_min: int = 8,
|
|
132
|
+
error_threshold: int = 100,
|
|
133
|
+
reconnect_timeout_secs: int = 2) -> Dict[Union[str, int], Dict[str, Any]]:
|
|
134
|
+
payloads = {}
|
|
135
|
+
for qid, message in queries.items():
|
|
136
|
+
payloads[qid] = model.compute_payload(message)
|
|
137
|
+
|
|
138
|
+
response_payloads = {qid: DEFAULT_RESPONSE.copy() for qid in queries}
|
|
139
|
+
|
|
140
|
+
if len(response_payloads) > 0:
|
|
141
|
+
queue = asyncio.Queue()
|
|
142
|
+
for qid, payload in payloads.items():
|
|
143
|
+
await queue.put((qid, payload))
|
|
144
|
+
tasks = [interact_with_websocket(model.api_url, queue, response_payloads,
|
|
145
|
+
ws_timeout_min, error_threshold, reconnect_timeout_secs)
|
|
146
|
+
for _ in range(max_concurrent_tasks)]
|
|
147
|
+
await asyncio.gather(*tasks)
|
|
148
|
+
|
|
149
|
+
return response_payloads
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@begin_task_execution
|
|
153
|
+
async def query_llm_socket(model: ModelConfig, query: str,
|
|
154
|
+
ws_timeout_min: int = 8,
|
|
155
|
+
error_threshold: int = 100,
|
|
156
|
+
reconnect_timeout_secs: int = 2) -> Dict[str, Any]:
|
|
157
|
+
response_payloads = {0: DEFAULT_RESPONSE.copy()}
|
|
158
|
+
tmp_payload = model.compute_payload(query=query)
|
|
159
|
+
if query != "":
|
|
160
|
+
queue = asyncio.Queue()
|
|
161
|
+
await queue.put((0, tmp_payload))
|
|
162
|
+
|
|
163
|
+
await interact_with_websocket(model.api_url, queue, response_payloads, ws_timeout_min,
|
|
164
|
+
error_threshold, reconnect_timeout_secs)
|
|
165
|
+
return response_payloads[0]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ASUllmAPI
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.4
|
|
4
4
|
Summary: A simple python package to facilitate connection to ASU LLM API
|
|
5
5
|
Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ASUllmAPI
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.4
|
|
4
4
|
Summary: A simple python package to facilitate connection to ASU LLM API
|
|
5
5
|
Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import time
|
|
2
|
-
import json
|
|
3
|
-
import asyncio
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def begin_task_execution(async_func):
|
|
7
|
-
def wrap(*args, **kwargs):
|
|
8
|
-
return asyncio.run(async_func(*args, **kwargs))
|
|
9
|
-
|
|
10
|
-
return wrap
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def load_json_buffer(string):
|
|
14
|
-
try:
|
|
15
|
-
return json.loads(string)
|
|
16
|
-
except json.JSONDecodeError:
|
|
17
|
-
return None
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def time_api(func):
|
|
21
|
-
"""
|
|
22
|
-
Decorator to measure the execution time of a function.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
def time_wrapper(*args, **kwargs):
|
|
26
|
-
"""
|
|
27
|
-
Passed function reference is utilized to run the function with its
|
|
28
|
-
original arguments while maintaining timing and logging functions.
|
|
29
|
-
"""
|
|
30
|
-
start_time = time.time()
|
|
31
|
-
result = func(*args, **kwargs)
|
|
32
|
-
end_time = time.time()
|
|
33
|
-
elapsed_time = end_time - start_time
|
|
34
|
-
print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
|
|
35
|
-
return result
|
|
36
|
-
|
|
37
|
-
# We return the augmented function's reference.
|
|
38
|
-
return time_wrapper
|
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import json
|
|
3
|
-
from typing import Dict
|
|
4
|
-
import ssl
|
|
5
|
-
import warnings
|
|
6
|
-
|
|
7
|
-
import websockets
|
|
8
|
-
import certifi
|
|
9
|
-
|
|
10
|
-
from .model_config import ModelConfig
|
|
11
|
-
from .utils import load_json_buffer, begin_task_execution
|
|
12
|
-
|
|
13
|
-
SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
async def interact_with_websocket(uri: str, query_payload: Dict[str, str], qid: int):
|
|
17
|
-
end_of_stream = "<EOS>"
|
|
18
|
-
payload = {qid: {"response": ""}}
|
|
19
|
-
|
|
20
|
-
async with websockets.connect(uri, ssl=SSL_CONTEXT) as websocket:
|
|
21
|
-
# Send the user-provided message to the WebSocket server
|
|
22
|
-
await websocket.send(json.dumps(query_payload))
|
|
23
|
-
|
|
24
|
-
# Loop to receive messages until the server is finished
|
|
25
|
-
while True:
|
|
26
|
-
try:
|
|
27
|
-
response = await websocket.recv()
|
|
28
|
-
parsed_response = load_json_buffer(response)
|
|
29
|
-
|
|
30
|
-
if isinstance(parsed_response, dict):
|
|
31
|
-
# Case: json/text response and user is denied entry
|
|
32
|
-
if 'message' in parsed_response.keys() or 'error' in parsed_response.keys():
|
|
33
|
-
# Embed entire Forbidden message payload and leave response blank.
|
|
34
|
-
payload[qid].update(parsed_response)
|
|
35
|
-
break
|
|
36
|
-
# Case: json response and user is not denied entry
|
|
37
|
-
elif 'response' in parsed_response.keys():
|
|
38
|
-
payload[qid]["response"] += (parsed_response["response"].replace(end_of_stream, ""))
|
|
39
|
-
if 'metadata' in parsed_response.keys() or end_of_stream in parsed_response['response']:
|
|
40
|
-
payload[qid]["metadata"] = parsed_response["metadata"]
|
|
41
|
-
break
|
|
42
|
-
# Unknown edge case: json getting returned that is not a `message`, `response`, `error`
|
|
43
|
-
# (not observed yet)
|
|
44
|
-
else:
|
|
45
|
-
warnings.warn(f"Unknown ASU LLM endpoint edge case detected: JSON parsed but data does not "
|
|
46
|
-
f"contain a message or response field.", RuntimeWarning)
|
|
47
|
-
payload[qid].update(parsed_response)
|
|
48
|
-
break
|
|
49
|
-
# Intended case: user asked for a response type of `text` and was not denied entry.
|
|
50
|
-
else:
|
|
51
|
-
payload[qid]["response"] += response.replace(end_of_stream, "")
|
|
52
|
-
if end_of_stream in response:
|
|
53
|
-
break
|
|
54
|
-
|
|
55
|
-
except websockets.ConnectionClosed:
|
|
56
|
-
print(f"Question {qid} connection closed by the server")
|
|
57
|
-
break
|
|
58
|
-
print(f"......Question {qid} response sent by the WebSocket server.")
|
|
59
|
-
return payload
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
async def limited_task(semaphore: asyncio.Semaphore, task):
|
|
63
|
-
async with semaphore:
|
|
64
|
-
return await task
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@begin_task_execution
|
|
68
|
-
async def batch_query_llm_socket(model: ModelConfig, queries: Dict[int, str], max_concurrent_tasks: int = 3):
|
|
69
|
-
tasks = []
|
|
70
|
-
for qid, message in queries.items():
|
|
71
|
-
tmp_payload = model.compute_payload(message)
|
|
72
|
-
tasks.append(interact_with_websocket(model.api_url, tmp_payload, qid))
|
|
73
|
-
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
|
74
|
-
limited_tasks = [limited_task(semaphore, task) for task in tasks]
|
|
75
|
-
|
|
76
|
-
# Gather all tasks to run them concurrently and collect results
|
|
77
|
-
results = await asyncio.gather(*limited_tasks)
|
|
78
|
-
final_results = {}
|
|
79
|
-
# Desired format: {1: {"response": ..., "metadata": ...}, 2: {...}}
|
|
80
|
-
for result_dict in results:
|
|
81
|
-
final_results.update(result_dict)
|
|
82
|
-
return final_results
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@begin_task_execution
|
|
86
|
-
async def query_llm_socket(model: ModelConfig, query: str):
|
|
87
|
-
tmp_payload = model.compute_payload(query=query)
|
|
88
|
-
result = await interact_with_websocket(model.api_url, tmp_payload, qid=0)
|
|
89
|
-
return result
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|