ASUllmAPI 2.0.1__tar.gz → 2.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
- from typing import Dict
2
+ from typing import Dict, Union
3
3
 
4
4
  from tqdm.auto import tqdm
5
5
 
@@ -9,9 +9,13 @@ from .utils import time_api
9
9
 
10
10
 
11
11
  @time_api
12
- def batch_query_llm(model: ModelConfig, queries: Dict[int, str], max_threads: int,
13
- num_retry: int = 3, auto_increase_retry: bool = False,
14
- success_sleep: float = 0.0, fail_sleep: float = 1.0) -> Dict[int, dict]:
12
+ def batch_query_llm(model: ModelConfig,
13
+ queries: Dict[Union[str, int], str],
14
+ max_threads: int,
15
+ num_retry: int = 3,
16
+ auto_increase_retry: bool = False,
17
+ success_sleep: float = 0.0,
18
+ fail_sleep: float = 1.0) -> Dict[Union[str, int], dict]:
15
19
  with ThreadPoolExecutor(max_workers=max_threads) as executor:
16
20
  # Submit tasks to the executor - order of return will be asynchronous.
17
21
  # If `auto_increase_retry` enabled, then scaling API backoff is implemented.
@@ -0,0 +1,84 @@
1
+ from typing import Dict, List
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import time
4
+ import json
5
+ import asyncio
6
+
7
+
8
+ def begin_task_execution(async_func):
9
+ """
10
+ See https://stackoverflow.com/a/75341431 for underlying rationale.
11
+ This code should allow any async function to be called synchronously by circumventing
12
+ Jupyter's existing event loop via the creation of a separate thread.
13
+ If an event loop doesn't exist, it reverts back to existing asyncio logic.
14
+ """
15
+ def wrap(*args, **kwargs):
16
+ try:
17
+ asyncio.get_running_loop() # Triggers RuntimeError if no running event loop
18
+ # Create a separate thread so we can block before returning
19
+ with ThreadPoolExecutor(1) as pool:
20
+ result = pool.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
21
+ except RuntimeError:
22
+ result = asyncio.run(async_func(*args, **kwargs))
23
+ return result
24
+ return wrap
25
+
26
+
27
+ def load_json_buffer(string):
28
+ try:
29
+ return json.loads(string)
30
+ except json.JSONDecodeError:
31
+ return None
32
+
33
+
34
+ def split_dict_into_chunks(input_dict: Dict, n: int) -> List[Dict]:
35
+ """
36
+ Split a dictionary into `n` chunks.
37
+
38
+ Parameters:
39
+ input_dict (dict): The dictionary to split.
40
+ n (int): The number of chunks to split the dictionary into.
41
+
42
+ Returns:
43
+ List[Dict]: A list of dictionaries, where each dictionary is a chunk of the input dictionary.
44
+ """
45
+ if n <= 0:
46
+ raise ValueError("Number of chunks must be greater than 0")
47
+
48
+ if not input_dict:
49
+ return []
50
+
51
+ items = list(input_dict.items())
52
+ chunk_size = len(items) // n
53
+ remainder = len(items) % n
54
+
55
+ chunks = []
56
+ start = 0
57
+ for i in range(n):
58
+ end = start + chunk_size + (1 if i < remainder else 0)
59
+ chunk = dict(items[start:end])
60
+ chunks.append(chunk)
61
+ start = end
62
+
63
+ return chunks
64
+
65
+
66
+ def time_api(func):
67
+ """
68
+ Decorator to measure the execution time of a function.
69
+ """
70
+
71
+ def time_wrapper(*args, **kwargs):
72
+ """
73
+ Passed function reference is utilized to run the function with its
74
+ original arguments while maintaining timing and logging functions.
75
+ """
76
+ start_time = time.time()
77
+ result = func(*args, **kwargs)
78
+ end_time = time.time()
79
+ elapsed_time = end_time - start_time
80
+ print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
81
+ return result
82
+
83
+ # We return the augmented function's reference.
84
+ return time_wrapper
@@ -0,0 +1,105 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from typing import Dict, Any, Union
5
+ import ssl
6
+
7
+ import websockets
8
+ import certifi
9
+
10
+ from .model_config import ModelConfig
11
+ from .utils import load_json_buffer, begin_task_execution, split_dict_into_chunks
12
+
13
+ SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
14
+ END_OF_STREAM = '<EOS>'
15
+ DEFAULT_RESPONSE = {"response": ""}
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
17
+
18
+
19
+ async def interact_with_websocket(uri: str, query_payloads: Dict[Union[str, int], Dict[str, Any]]):
20
+ response_payloads = {qid: {"response": ""} for qid in query_payloads}
21
+ payload_iterator = iter(query_payloads.items())
22
+ tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
23
+ retry = False
24
+ ws_ready_for_closure = False
25
+ error_ct = 0
26
+
27
+ # WEBSOCKET LOOP: Indefinitely opens a new connection when a new one is closed in the SAME asyncio task.
28
+ while not ws_ready_for_closure:
29
+ async with websockets.connect(uri, ssl=SSL_CONTEXT) as ws:
30
+ try:
31
+ if ws_ready_for_closure or tmp_qid is None:
32
+ break
33
+ # QUERY LOOP: While there are queries remaining to send to the server.
34
+ while tmp_qid is not None:
35
+ response_payloads[tmp_qid]["success"] = 0
36
+ if not retry:
37
+ error_ct = 0
38
+ await ws.send(json.dumps(tmp_input_payload))
39
+ # QUERY CHUNK LOOP: While chunks are being received by the WebSocket Client.
40
+ while True:
41
+ # Wait for the duration of the WebSocket connection for ANY response
42
+ response = await ws.recv()
43
+ parsed_response = load_json_buffer(response)
44
+ # Parse response stream from WebSocket server based on every use case
45
+ if isinstance(parsed_response, dict):
46
+ # Intended use case: user receives JSON response.
47
+ if 'response' in parsed_response.keys():
48
+ cleaned = parsed_response["response"].replace(END_OF_STREAM, "")
49
+ response_payloads[tmp_qid]["response"] += cleaned
50
+ if 'metadata' in parsed_response.keys() or END_OF_STREAM in parsed_response['response']:
51
+ response_payloads[tmp_qid]["metadata"] = parsed_response["metadata"]
52
+ response_payloads[tmp_qid]["success"] = 1
53
+ retry = False
54
+ break
55
+ # Error: user denied access to the response for some reason.
56
+ else:
57
+ logging.log(logging.ERROR, f"Data error: query ID {tmp_qid}: {parsed_response}")
58
+ response_payloads[tmp_qid].update(parsed_response)
59
+ retry = False
60
+ break
61
+ # Intended use case: user receives TEXT response from websocket
62
+ # Data format is streamlined regardless of which response type the user selects.
63
+ else:
64
+ response_payloads[tmp_qid]["response"] += response.replace(END_OF_STREAM, "")
65
+ if END_OF_STREAM in response:
66
+ retry = False
67
+ response_payloads[tmp_qid]["success"] = 1
68
+ break
69
+ # END - QUERY CHUNK LOOP
70
+ if not retry:
71
+ logging.log(level=logging.INFO, msg=f"Query ID {tmp_qid} completed...")
72
+ tmp_qid, tmp_input_payload = next(payload_iterator, (None, None))
73
+ # END - QUERY LOOP
74
+ ws_ready_for_closure = True
75
+ except websockets.ConnectionClosed:
76
+ error_ct += 1
77
+ logging.log(level=logging.INFO, msg=f"Error {error_ct}: WebSocket connection closed on query ID "
78
+ f"{tmp_qid}. Reopening...")
79
+ retry = True if error_ct <= 2 else False
80
+ # END - WEBSOCKET LOOP
81
+ logging.log(logging.INFO, msg="WebSocket connection closed.")
82
+ return response_payloads
83
+
84
+
85
+ @begin_task_execution
86
+ async def batch_query_llm_socket(model: ModelConfig, queries: Dict[Union[str, int], str],
87
+ max_concurrent_tasks: int = 3) -> Dict[Union[str, int], Dict[str, Any]]:
88
+ payloads = {}
89
+ for qid, message in queries.items():
90
+ payloads[qid] = model.compute_payload(message)
91
+ batches = split_dict_into_chunks(payloads, max_concurrent_tasks)
92
+ tasks = [interact_with_websocket(model.api_url, batch) for batch in batches]
93
+
94
+ results = await asyncio.gather(*tasks)
95
+ final_results = {}
96
+ for result_dict in results:
97
+ final_results.update(result_dict)
98
+ return final_results
99
+
100
+
101
+ @begin_task_execution
102
+ async def query_llm_socket(model: ModelConfig, query: str) -> Dict[str, Any]:
103
+ tmp_payload = model.compute_payload(query=query)
104
+ result = await interact_with_websocket(model.api_url, {0: tmp_payload})
105
+ return result[0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -13,6 +13,8 @@ Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  Requires-Dist: requests
15
15
  Requires-Dist: tqdm
16
+ Requires-Dist: websockets
17
+ Requires-Dist: certifi
16
18
 
17
19
  # ASU LLM API
18
20
  This package allows individuals at Arizona State University to access ASU GPT through API. You will need API access token, API endpoints in order to use this package.
@@ -0,0 +1,4 @@
1
+ requests
2
+ tqdm
3
+ websockets
4
+ certifi
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ASUllmAPI
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Summary: A simple python package to facilitate connection to ASU LLM API
5
5
  Author-email: Stella Wenxing Liu <stellawenxingliu@gmail.com>, Varun Shourie <svarun195@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/ASU/aiml-ssmdv-student-support-ml-data-visualization
@@ -13,6 +13,8 @@ Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  Requires-Dist: requests
15
15
  Requires-Dist: tqdm
16
+ Requires-Dist: websockets
17
+ Requires-Dist: certifi
16
18
 
17
19
  # ASU LLM API
18
20
  This package allows individuals at Arizona State University to access ASU GPT through API. You will need API access token, API endpoints in order to use this package.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASUllmAPI"
3
- version = "2.0.1"
3
+ version = "2.0.3"
4
4
  authors = [
5
5
  { name="Stella Wenxing Liu", email="stellawenxingliu@gmail.com" },
6
6
  { name="Varun Shourie", email="svarun195@gmail.com" }
@@ -15,7 +15,9 @@ classifiers = [
15
15
  ]
16
16
  dependencies = [
17
17
  "requests",
18
- "tqdm"
18
+ "tqdm",
19
+ "websockets",
20
+ "certifi"
19
21
  ]
20
22
 
21
23
  [project.urls]
@@ -1,38 +0,0 @@
1
- import time
2
- import json
3
- import asyncio
4
-
5
-
6
- def begin_task_execution(async_func):
7
- def wrap(*args, **kwargs):
8
- return asyncio.run(async_func(*args, **kwargs))
9
-
10
- return wrap
11
-
12
-
13
- def load_json_buffer(string):
14
- try:
15
- return json.loads(string)
16
- except json.JSONDecodeError:
17
- return None
18
-
19
-
20
- def time_api(func):
21
- """
22
- Decorator to measure the execution time of a function.
23
- """
24
-
25
- def time_wrapper(*args, **kwargs):
26
- """
27
- Passed function reference is utilized to run the function with its
28
- original arguments while maintaining timing and logging functions.
29
- """
30
- start_time = time.time()
31
- result = func(*args, **kwargs)
32
- end_time = time.time()
33
- elapsed_time = end_time - start_time
34
- print(f"{func.__name__} executed in {elapsed_time:.4f} seconds.")
35
- return result
36
-
37
- # We return the augmented function's reference.
38
- return time_wrapper
@@ -1,91 +0,0 @@
1
- import asyncio
2
- import json
3
- from typing import Dict
4
-
5
- import websockets
6
-
7
- from .model_config import ModelConfig
8
- from .utils import load_json_buffer, begin_task_execution
9
-
10
-
11
- async def interact_with_websocket(uri: str, query_payload: Dict[str, str], qid: int):
12
- end_of_stream = "<EOS>"
13
- payload = {qid: {"response": ""}}
14
-
15
- async with websockets.connect(uri) as websocket:
16
- # Send the user-provided message to the WebSocket server
17
- await websocket.send(json.dumps(query_payload))
18
-
19
- # Loop to receive messages until the server is finished
20
- while True:
21
- try:
22
- response = await websocket.recv()
23
- parsed_response = load_json_buffer(response)
24
-
25
- if isinstance(parsed_response, dict):
26
- # Case: json/text response and user is denied entry
27
- if 'message' in parsed_response.keys():
28
- if parsed_response['message'] == 'Forbidden':
29
- # Embed entire Forbidden message payload and leave response blank.
30
- payload[qid].update(parsed_response)
31
- break
32
- else:
33
- raise RuntimeWarning(f"Unknown ASU LLM endpoint edge case detected: `message` as key in "
34
- f"parsed response: {parsed_response}.")
35
-
36
- # Case: json response and user is not denied entry
37
- elif 'response' in parsed_response.keys():
38
- payload[qid]["response"] += (parsed_response["response"].replace(end_of_stream, ""))
39
- if 'metadata' in parsed_response.keys() or end_of_stream in parsed_response['response']:
40
- payload[qid]["metadata"] = parsed_response["metadata"]
41
- break
42
- # Edge case: user gets a message to contact the administrator after an error
43
- elif 'error' in parsed_response.keys():
44
- payload[qid]["error"] = parsed_response["error"]
45
- # Unknown edge case: json getting returned that is not a `message`, `response`, `error`
46
- # (not observed yet)
47
- else:
48
- raise RuntimeWarning(f"Unknown ASU LLM endpoint edge case detected: JSON parsed but data "
49
- f"does not contain a message or response field.")
50
-
51
- # Intended case: user asked for a response type of `text` and was not denied entry.
52
- else:
53
- payload[qid]["response"] += response.replace(end_of_stream, "")
54
- if end_of_stream in response:
55
- break
56
-
57
- except websockets.ConnectionClosed:
58
- print(f"Question {qid} connection closed by the server")
59
- break
60
- print(f"......Question {qid} response sent by the WebSocket server.")
61
- return payload
62
-
63
-
64
- async def limited_task(semaphore: asyncio.Semaphore, task):
65
- async with semaphore:
66
- return await task
67
-
68
-
69
- @begin_task_execution
70
- async def batch_query_llm_socket(model: ModelConfig, queries: Dict[int, str], max_concurrent_tasks: int = 3):
71
- tasks = []
72
- for qid, message in queries.items():
73
- tmp_payload = model.compute_payload(message)
74
- tasks.append(interact_with_websocket(model.api_url, tmp_payload, qid))
75
- semaphore = asyncio.Semaphore(max_concurrent_tasks)
76
- limited_tasks = [limited_task(semaphore, task) for task in tasks]
77
-
78
- # Gather all tasks to run them concurrently and collect results
79
- results = await asyncio.gather(*limited_tasks)
80
- final_results = {}
81
- # Desired format: {1: {"response": ..., "metadata": ...}, 2: {...}}
82
- for result_dict in results:
83
- final_results.update(result_dict)
84
- return final_results
85
-
86
-
87
- @begin_task_execution
88
- async def query_llm_socket(model: ModelConfig, query: str):
89
- tmp_payload = model.compute_payload(query=query)
90
- result = await interact_with_websocket(model.api_url, tmp_payload, qid=0)
91
- return result
@@ -1,2 +0,0 @@
1
- requests
2
- tqdm
File without changes
File without changes
File without changes
File without changes