ASUllmAPI 2.0.3__tar.gz → 2.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
3
3
  import time
4
4
  import json
5
5
  import asyncio
6
+ import sys
6
7
 
7
8
 
8
9
  def begin_task_execution(async_func):
@@ -13,17 +14,31 @@ def begin_task_execution(async_func):
13
14
  If an event loop doesn't exist, it reverts back to existing asyncio logic.
14
15
  """
15
16
  def wrap(*args, **kwargs):
16
- try:
17
- asyncio.get_running_loop() # Triggers RuntimeError if no running event loop
18
- # Create a separate thread so we can block before returning
17
+ # It is safer to use a if/else statement than try/except since
18
+ # the underlying code we wrap our function around can also raise RuntimeErrors.
19
+ # When this happens, multiple asyncio.run() executions occur, which is dangerous.
20
+ if is_jupyter():
19
21
  with ThreadPoolExecutor(1) as pool:
20
22
  result = pool.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
21
- except RuntimeError:
23
+ else:
22
24
  result = asyncio.run(async_func(*args, **kwargs))
23
25
  return result
24
26
  return wrap
25
27
 
26
28
 
29
+ def is_jupyter():
30
+ try:
31
+ # Check if 'IPython' is in sys.modules
32
+ if 'IPython' in sys.modules:
33
+ from IPython import get_ipython
34
+ # Check if we're in an IPython environment
35
+ if get_ipython() is not None:
36
+ return True
37
+ except ImportError:
38
+ pass
39
+ return False
40
+
41
+
27
42
  def load_json_buffer(string):
28
43
  try:
29
44
  return json.loads(string)
@@ -0,0 +1,166 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from typing import Dict, Any, Union
6
+ import ssl
7
+ import time
8
+
9
+ import websockets
10
+ import certifi
11
+
12
+ from .model_config import ModelConfig
13
+ from .utils import load_json_buffer, begin_task_execution
14
+
15
+ SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
16
+ END_OF_STREAM = '<EOS>'
17
+ DEFAULT_RESPONSE = {"response": "", "success": 0}
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
19
+
20
+
21
+ class ResponseDataError(Exception):
22
+ def __init__(self, message="An error message has been sent by the endpoint."):
23
+ super().__init__(message)
24
+ self.message = message
25
+
26
+
27
+ async def interact_with_websocket(uri: str, queue: asyncio.Queue,
28
+ response_payloads: Dict[Union[str, int], Dict[str, Any]],
29
+ ws_timeout_min: int = 8,
30
+ error_threshold: int = 100,
31
+ reconnect_timeout_secs: int = 2):
32
+ """
33
+ :param uri: WebSocket URI
34
+ :param queue: An instantiated asyncio Queue with your dataset fully inserted.
35
+ :param response_payloads: A dictionary with base responses set up for processing.
36
+ :param ws_timeout_min: The number of minutes before a connection to the WebSocket is terminated on the async task.
37
+ :param error_threshold: The number of errors permissible for a query in the WebSocket before the program
38
+ moves to the next question.
39
+ :param reconnect_timeout_secs: The number of seconds before another connection is instantiated.
40
+ :return: the `response_payloads` object initially passed into the function
41
+ """
42
+ error_ct = 0
43
+ ws_timeout = ws_timeout_min * 60
44
+ qid = None
45
+ tmp_input_payload = None
46
+
47
+ # START - WEBSOCKET LOOP
48
+ while not (queue.empty() and error_ct == 0):
49
+ connection_start_time = time.time()
50
+ try:
51
+ async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
52
+ # START - QUERY QUEUE LOOP
53
+ while not (queue.empty() and error_ct == 0):
54
+ if error_ct == 0:
55
+ qid, tmp_input_payload = await queue.get()
56
+
57
+ # To prevent unnecessary timeouts DURING a query, we can
58
+ # cautiously opt out of the connection at a user specified threshold.
59
+ query_time = time.time()
60
+ connection_time = query_time - connection_start_time
61
+ if connection_time >= ws_timeout:
62
+ raise asyncio.TimeoutError()
63
+
64
+ if tmp_input_payload["query"] != "":
65
+ await ws.send(json.dumps(tmp_input_payload))
66
+ # START - QUERY CHUNK LOOP
67
+ while True:
68
+ # First chunk will typically take 29 seconds of time for every model
69
+ # Remaining chunks can take up to 9 minutes to be received.
70
+ # First chunk will take more time with GeminiPro - long time to load in comparison
71
+ # to other models
72
+ response = await asyncio.wait_for(ws.recv(), ws_timeout)
73
+
74
+ parsed_response = load_json_buffer(response)
75
+ if isinstance(parsed_response, dict):
76
+ if 'response' in parsed_response.keys():
77
+ cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
78
+ response_payloads[qid]["response"] += cleaned
79
+ if 'metadata' in parsed_response.keys() \
80
+ or END_OF_STREAM in parsed_response['response']:
81
+ response_payloads[qid]["metadata"] = parsed_response["metadata"]
82
+ response_payloads[qid]["success"] = 1
83
+ error_ct = 0
84
+ break
85
+ else:
86
+ # Connection ID expires - first chunk took more than 29 seconds.
87
+ response_payloads[qid].update(parsed_response)
88
+ raise ResponseDataError()
89
+ else:
90
+ response_payloads[qid]["response"] += response.replace(END_OF_STREAM, "")
91
+ if END_OF_STREAM in response:
92
+ response_payloads[qid]["success"] = 1
93
+ error_ct = 0
94
+ break
95
+ # END - QUERY CHUNK LOOP
96
+ if error_ct == 0:
97
+ logging.info(f"Query ID {qid} completed... Message: {response_payloads[qid]['response']}")
98
+ queue.task_done()
99
+ # END - QUERY QUEUE LOOP
100
+ except (asyncio.TimeoutError, websockets.ConnectionClosed, Exception) as exc:
101
+ if isinstance(exc, asyncio.TimeoutError):
102
+ logging.error(f"Error {error_ct} on {qid} stream timeout: resetting connection...")
103
+ elif isinstance(exc, ResponseDataError):
104
+ logging.error(f"Error {error_ct} on {qid}: invalid response from endpoint.\n"
105
+ f"{response_payloads[qid]}")
106
+ elif isinstance(exc, websockets.ConnectionClosed):
107
+ logging.error(f"Error {error_ct} on {qid}: WebSocket connection closed on "
108
+ f"query ID {qid}. Reopening...")
109
+ else:
110
+ logging.error(f"Error {error_ct} on {qid}: {traceback.format_exc()}")
111
+
112
+ time.sleep(reconnect_timeout_secs)
113
+ # If the query is already complete, we don't want to increment the error count
114
+ if response_payloads[qid]["success"] == 0:
115
+ error_ct += 1
116
+ # Reset buffer stream so that you don't get messed by pre-existing data.
117
+ response_payloads[qid]["response"] = ""
118
+ else:
119
+ logging.info(f"Query ID {qid} completed...")
120
+ queue.task_done()
121
+
122
+ # prevent any further retries if at error limit.
123
+ if error_ct == error_threshold:
124
+ error_ct = 0
125
+ # END - WEBSOCKET LOOP
126
+ logging.info("WebSocket connection closed. Queue appears to be empty...")
127
+
128
+
129
+ @begin_task_execution
130
+ async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
131
+ max_concurrent_tasks: int = 3,
132
+ ws_timeout_min: int = 8,
133
+ error_threshold: int = 100,
134
+ reconnect_timeout_secs: int = 2) -> Dict[Union[str, int], Dict[str, Any]]:
135
+ payloads = {}
136
+ for qid, message in queries.items():
137
+ payloads[qid] = model.compute_payload(message)
138
+
139
+ response_payloads = {qid: DEFAULT_RESPONSE.copy() for qid in queries}
140
+
141
+ if len(response_payloads) > 0:
142
+ queue = asyncio.Queue()
143
+ for qid, payload in payloads.items():
144
+ await queue.put((qid, payload))
145
+ tasks = [interact_with_websocket(model.api_url, queue, response_payloads,
146
+ ws_timeout_min, error_threshold, reconnect_timeout_secs)
147
+ for _ in range(max_concurrent_tasks)]
148
+ await asyncio.gather(*tasks)
149
+
150
+ return response_payloads
151
+
152
+
153
+ @begin_task_execution
154
+ async def query_llm_socket(model: ModelConfig, query: str,
155
+ ws_timeout_min: int = 8,
156
+ error_threshold: int = 100,
157
+ reconnect_timeout_secs: int = 2) -> Dict[str, Any]:
158
+ response_payloads = {0: DEFAULT_RESPONSE.copy()}
159
+ tmp_payload = model.compute_payload(query=query)
160
+ if query != "":
161
+ queue = asyncio.Queue()
162
+ await queue.put((0, tmp_payload))
163
+
164
+ await interact_with_websocket(model.api_url, queue, response_payloads, ws_timeout_min,
165
+ error_threshold, reconnect_timeout_secs)
166
+ return response_payloads[0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.3
3
+ Version: 2.0.5
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.3
3
+ Version: 2.0.5
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASUllmAPI"
3
- version = "2.0.3"
3
+ version = "2.0.5"
4
4
  authors = [
5
5
  { name="Stella Wenxing Liu", email="stellawenxingliu@gmail.com" },
6
6
  { name="Varun Shourie", email="svarun195@gmail.com" }
@@ -1,105 +0,0 @@
1
- import asyncio
2
- import json
3
- import logging
4
- from typing import Dict, Any, Union
5
- import ssl
6
-
7
- import websockets
8
- import certifi
9
-
10
- from .model_config import ModelConfig
11
- from .utils import load_json_buffer, begin_task_execution, split_dict_into_chunks
12
-
13
- SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
14
- END_OF_STREAM = '<EOS>'
15
- DEFAULT_RESPONSE = {"response": ""}
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
17
-
18
-
19
- async def interact_with_websocket(uri: str, query_payloads: Dict[Union[str, int], Dict[str, Any]]):
20
- response_payloads = {qid: {"response": ""} for qid in query_payloads}
21
- payload_iterator = iter(query_payloads.items())
22
- tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
23
- retry = False
24
- ws_ready_for_closure = False
25
- error_ct = 0
26
-
27
- # WEBSOCKET LOOP: Indefinitely opens a new connection when a new one is closed in the SAME asyncio task.
28
- while not ws_ready_for_closure:
29
- async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
30
- try:
31
- if ws_ready_for_closure or tmp_qid is None:
32
- break
33
- # QUERY LOOP: While there are queries remaining to send to the server.
34
- while tmp_qid is not None:
35
- response_payloads[tmp_qid]["success"] = 0
36
- if not retry:
37
- error_ct = 0
38
- await ws.send(json.dumps(tmp_input_payload))
39
- # QUERY CHUNK LOOP: While chunks are being received by the WebSocket Client.
40
- while True:
41
- # Wait for the duration of the WebSocket connection for ANY response
42
- response = await ws.recv()
43
- parsed_response = load_json_buffer(response)
44
- # Parse response stream from WebSocket server based on every use case
45
- if isinstance(parsed_response, dict):
46
- # Intended use case: user receives JSON response.
47
- if 'response' in parsed_response.keys():
48
- cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
49
- response_payloads[tmp_qid]["response"] += cleaned
50
- if 'metadata' in parsed_response.keys() or END_OF_STREAM in parsed_response['response']:
51
- response_payloads[tmp_qid]["metadata"] = parsed_response["metadata"]
52
- response_payloads[tmp_qid]["success"] = 1
53
- retry = False
54
- break
55
- # Error: user denied access to the response for some reason.
56
- else:
57
- logging.log(logging.ERROR, f"Data error: query ID {tmp_qid}: {parsed_response}")
58
- response_payloads[tmp_qid].update(parsed_response)
59
- retry = False
60
- break
61
- # Intended use case: user receives TEXT response from websocket
62
- # Data format is streamlined regardless of which response type the user selects.
63
- else:
64
- response_payloads[tmp_qid]["response"] += response.replace(END_OF_STREAM, "")
65
- if END_OF_STREAM in response:
66
- retry = False
67
- response_payloads[tmp_qid]["success"] = 1
68
- break
69
- # END - QUERY CHUNK LOOP
70
- if not retry:
71
- logging.log(level=logging.INFO, msg=f"Query ID {tmp_qid} completed...")
72
- tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
73
- # END - QUERY LOOP
74
- ws_ready_for_closure = True
75
- except websockets.ConnectionClosed:
76
- error_ct += 1
77
- logging.log(level=logging.INFO, msg=f"Error {error_ct}: WebSocket connection closed on query ID "
78
- f"{tmp_qid}. Reopening...")
79
- retry = True if error_ct <= 2 else False
80
- # END - WEBSOCKET LOOP
81
- logging.log(logging.INFO, msg="WebSocket connection closed.")
82
- return response_payloads
83
-
84
-
85
- @begin_task_execution
86
- async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
87
- max_concurrent_tasks: int = 3) -> Dict[Union[str, int], Dict[str, Any]]:
88
- payloads = {}
89
- for qid, message in queries.items():
90
- payloads[qid] = model.compute_payload(message)
91
- batches = split_dict_into_chunks(payloads, max_concurrent_tasks)
92
- tasks = [interact_with_websocket(model.api_url, batch) for batch in batches]
93
-
94
- results = await asyncio.gather(*tasks)
95
- final_results = {}
96
- for result_dict in results:
97
- final_results.update(result_dict)
98
- return final_results
99
-
100
-
101
- @begin_task_execution
102
- async def query_llm_socket(model: ModelConfig, query: str) -> Dict[str, Any]:
103
- tmp_payload = model.compute_payload(query=query)
104
- result = await interact_with_websocket(model.api_url, {0: tmp_payload})
105
- return result[0]
File without changes
File without changes
File without changes
File without changes