proscenium 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,214 @@
1
+ """
2
+ This module uses the [`aisuite`](https://github.com/andrewyng/aisuite) library
3
+ to interact with various LLM inference providers.
4
+
5
+ It provides functions to complete a simple chat prompt, evaluate a tool call,
6
+ and apply a list of tool calls to a chat prompt.
7
+
8
+ Providers tested with Proscenium include:
9
+
10
+ # AWS
11
+
12
+ Environment: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`
13
+
14
+ Valid model ids:
15
+ - `aws:meta.llama3-1-8b-instruct-v1:0`
16
+
17
+ # Anthropic
18
+
19
+ Environment: `ANTHROPIC_API_KEY`
20
+
21
+ Valid model ids:
22
+ - `anthropic:claude-3-5-sonnet-20240620`
23
+
24
+ # OpenAI
25
+
26
+ Environment: `OPENAI_API_KEY`
27
+
28
+ Valid model ids:
29
+ - `openai:gpt-4o`
30
+
31
+ # Ollama
32
+
33
+ Command line, eg `ollama run llama3.2 --keepalive 2h`
34
+
35
+ Valid model ids:
36
+ - `ollama:llama3.2`
37
+ - `ollama:granite3.1-dense:2b`
38
+ """
39
+
40
+ from typing import Any
41
+
42
+ import json
43
+ from rich import print
44
+ from rich.console import Group
45
+ from rich.panel import Panel
46
+ from rich.table import Table
47
+ from rich.text import Text
48
+
49
+ from aisuite import Client
50
+ from aisuite.framework.message import ChatCompletionMessageToolCall
51
+
52
+ from proscenium.verbs.display.tools import complete_with_tools_panel
53
+
54
+ provider_configs = {
55
+ # TODO expose this
56
+ "ollama": {"timeout": 180},
57
+ }
58
+
59
+ client = Client(provider_configs=provider_configs)
60
+
61
+
62
+ def complete_simple(
63
+ model_id: str, system_prompt: str, user_prompt: str, **kwargs
64
+ ) -> str:
65
+
66
+ rich_output = kwargs.pop("rich_output", False)
67
+
68
+ messages = [
69
+ {"role": "system", "content": system_prompt},
70
+ {"role": "user", "content": user_prompt},
71
+ ]
72
+
73
+ if rich_output:
74
+
75
+ kwargs_text = "\n".join([str(k) + ": " + str(v) for k, v in kwargs.items()])
76
+
77
+ params_text = Text(
78
+ f"""
79
+ model_id: {model_id}
80
+ {kwargs_text}
81
+ """
82
+ )
83
+
84
+ messages_table = Table(title="Messages", show_lines=True)
85
+ messages_table.add_column("Role", justify="left")
86
+ messages_table.add_column("Content", justify="left") # style="green"
87
+ for message in messages:
88
+ messages_table.add_row(message["role"], message["content"])
89
+
90
+ call_panel = Panel(
91
+ Group(params_text, messages_table), title="complete_simple call"
92
+ )
93
+ print(call_panel)
94
+
95
+ response = client.chat.completions.create(
96
+ model=model_id, messages=messages, **kwargs
97
+ )
98
+ response = response.choices[0].message.content
99
+
100
+ if rich_output:
101
+ print(Panel(response, title="Response"))
102
+
103
+ return response
104
+
105
+
106
+ def evaluate_tool_call(
107
+ tool_map: dict, tool_call: ChatCompletionMessageToolCall, rich_output: bool = False
108
+ ) -> Any:
109
+
110
+ function_name = tool_call.function.name
111
+ # TODO validate the arguments?
112
+ function_args = json.loads(tool_call.function.arguments)
113
+
114
+ if rich_output:
115
+ print(f"Evaluating tool call: {function_name} with args {function_args}")
116
+
117
+ function_response = tool_map[function_name](**function_args)
118
+
119
+ if rich_output:
120
+ print(f" Response: {function_response}")
121
+
122
+ return function_response
123
+
124
+
125
+ def tool_response_message(
126
+ tool_call: ChatCompletionMessageToolCall, tool_result: Any
127
+ ) -> dict:
128
+
129
+ return {
130
+ "role": "tool",
131
+ "tool_call_id": tool_call.id,
132
+ "name": tool_call.function.name,
133
+ "content": json.dumps(tool_result),
134
+ }
135
+
136
+
137
+ def evaluate_tool_calls(
138
+ tool_call_message, tool_map: dict, rich_output: bool = False
139
+ ) -> list[dict]:
140
+
141
+ tool_call: ChatCompletionMessageToolCall
142
+
143
+ if rich_output:
144
+ print("Evaluating tool calls")
145
+
146
+ new_messages: list[dict] = []
147
+
148
+ for tool_call in tool_call_message.tool_calls:
149
+ function_response = evaluate_tool_call(tool_map, tool_call, rich_output)
150
+ new_messages.append(tool_response_message(tool_call, function_response))
151
+
152
+ if rich_output:
153
+ print("Tool calls evaluated")
154
+
155
+ return new_messages
156
+
157
+
158
+ def complete_for_tool_applications(
159
+ model_id: str,
160
+ messages: list,
161
+ tool_desc_list: list,
162
+ temperature: float,
163
+ rich_output: bool = False,
164
+ ):
165
+
166
+ if rich_output:
167
+ panel = complete_with_tools_panel(
168
+ "complete for tool applications",
169
+ model_id,
170
+ tool_desc_list,
171
+ messages,
172
+ temperature,
173
+ )
174
+ print(panel)
175
+
176
+ response = client.chat.completions.create(
177
+ model=model_id,
178
+ messages=messages,
179
+ temperature=temperature,
180
+ tools=tool_desc_list, # tool_choice="auto",
181
+ )
182
+
183
+ return response
184
+
185
+
186
+ def complete_with_tool_results(
187
+ model_id: str,
188
+ messages: list,
189
+ tool_call_message: dict,
190
+ tool_evaluation_messages: list[dict],
191
+ tool_desc_list: list,
192
+ temperature: float,
193
+ rich_output: bool = False,
194
+ ):
195
+
196
+ messages.append(tool_call_message)
197
+ messages.extend(tool_evaluation_messages)
198
+
199
+ if rich_output:
200
+ panel = complete_with_tools_panel(
201
+ "complete call with tool results",
202
+ model_id,
203
+ tool_desc_list,
204
+ messages,
205
+ temperature,
206
+ )
207
+ print(panel)
208
+
209
+ response = client.chat.completions.create(
210
+ model=model_id,
211
+ messages=messages,
212
+ )
213
+
214
+ return response.choices[0].message.content
@@ -0,0 +1,9 @@
1
+ from rich.text import Text
2
+
3
+
4
+ def header() -> Text:
5
+ text = Text()
6
+ text.append("Proscenium 🎭\n", style="bold")
7
+ text.append("The AI Alliance\n", style="bold")
8
+ # TODO version, timestamp, ...
9
+ return text
@@ -0,0 +1,35 @@
1
+ from rich.table import Table
2
+
3
+
4
+ def messages_table(messages: list) -> Table:
5
+
6
+ table = Table(title="Messages in Chat Context", show_lines=True)
7
+ table.add_column("Role", justify="left")
8
+ table.add_column("Content", justify="left")
9
+ for message in messages:
10
+ if type(message) is dict:
11
+ role = message["role"]
12
+ content = ""
13
+ if role == "tool":
14
+ content = f"""tool call id = {message['tool_call_id']}
15
+ fn name = {message['name']}
16
+ result = {message['content']}"""
17
+ elif role == "assistant":
18
+ content = f"""{str(message)}"""
19
+ else:
20
+ content = message["content"]
21
+ table.add_row(role, content)
22
+ else:
23
+ role = message.role
24
+ content = ""
25
+ if role == "tool":
26
+ content = f"""tool call id = {message.tool_call_id}
27
+ fn name = {message.name}
28
+ result = {message['content']}"""
29
+ elif role == "assistant":
30
+ content = f"""{str(message)}"""
31
+ else:
32
+ content = message.content
33
+ table.add_row(role, content)
34
+
35
+ return table
File without changes
@@ -0,0 +1,68 @@
1
+ from rich.table import Table
2
+ from rich.panel import Panel
3
+ from rich.text import Text
4
+ from rich.console import Group
5
+ from pymilvus import MilvusClient
6
+
7
+
8
+ def chunk_hits_table(chunks: list[dict]) -> Table:
9
+
10
+ table = Table(title="Closest Chunks", show_lines=True)
11
+ table.add_column("id", justify="right")
12
+ table.add_column("distance")
13
+ table.add_column("entity.text", justify="right")
14
+ for chunk in chunks:
15
+ table.add_row(str(chunk["id"]), str(chunk["distance"]), chunk["entity"]["text"])
16
+ return table
17
+
18
+
19
+ def collection_panel(client: MilvusClient, collection_name: str) -> Panel:
20
+
21
+ stats = client.get_collection_stats(collection_name)
22
+ desc = client.describe_collection(collection_name)
23
+
24
+ params_text = Text(
25
+ f"""
26
+ Collection Name: {desc['collection_name']}
27
+ Auto ID: {desc['auto_id']}
28
+ Num Shards: {desc['num_shards']}
29
+ Description: {desc['description']}
30
+ Functions: {desc['functions']}
31
+ Aliases: {desc['aliases']}
32
+ Collection ID: {desc['collection_id']}
33
+ Consistency Level: {desc['consistency_level']}
34
+ Properties: {desc['properties']}
35
+ Num Partitions: {desc['num_partitions']}
36
+ Enable Dynamic Field: {desc['enable_dynamic_field']}"""
37
+ )
38
+
39
+ params_panel = Panel(params_text, title="Params")
40
+
41
+ fields_table = Table(title="Fields", show_lines=True)
42
+ fields_table.add_column("id", justify="left")
43
+ fields_table.add_column("name", justify="left")
44
+ fields_table.add_column("description", justify="left")
45
+ fields_table.add_column("type", justify="left")
46
+ fields_table.add_column("params", justify="left")
47
+ fields_table.add_column("auto_id", justify="left")
48
+ fields_table.add_column("is_primary", justify="left")
49
+ for field in desc["fields"]:
50
+ fields_table.add_row(
51
+ str(field["field_id"]), # int
52
+ field["name"],
53
+ field["description"],
54
+ field["type"].name, # Milvus DataType
55
+ "\n".join([f"{k}: {v}" for k, v in field["params"].items()]),
56
+ str(field.get("auto_id", "-")), # bool
57
+ str(field.get("is_primary", "-")),
58
+ ) # bool
59
+
60
+ stats_text = Text("\n".join([f"{k}: {v}" for k, v in stats.items()]))
61
+ stats_panel = Panel(stats_text, title="Stats")
62
+
63
+ panel = Panel(
64
+ Group(params_panel, fields_table, stats_panel),
65
+ title=f"Collection {collection_name}",
66
+ )
67
+
68
+ return panel
@@ -0,0 +1,25 @@
1
+ from typing import List
2
+ from rich.table import Table
3
+
4
+
5
+ def triples_table(triples: List[tuple[str, str, str]], title: str) -> Table:
6
+
7
+ table = Table(title=title, show_lines=False)
8
+ table.add_column("Subject", justify="left")
9
+ table.add_column("Predicate", justify="left")
10
+ table.add_column("Object", justify="left")
11
+ for triple in triples:
12
+ table.add_row(*triple)
13
+
14
+ return table
15
+
16
+
17
+ def pairs_table(subject_predicate_pairs: List[tuple[str, str]], title: str) -> Table:
18
+
19
+ table = Table(title=title, show_lines=False)
20
+ table.add_column("Subject", justify="left")
21
+ table.add_column("Predicate", justify="left")
22
+ for pair in subject_predicate_pairs:
23
+ table.add_row(*pair)
24
+
25
+ return table
@@ -0,0 +1,64 @@
1
+ from rich.console import Group
2
+ from rich.table import Table
3
+ from rich.text import Text
4
+ from rich.panel import Panel
5
+
6
+ from .chat import messages_table
7
+
8
+
9
+ def parameters_table(parameters: list[dict]) -> Table:
10
+
11
+ table = Table(title="Parameters", show_lines=False, box=None)
12
+ table.add_column("name", justify="right")
13
+ table.add_column("type", justify="left")
14
+ table.add_column("description", justify="left")
15
+
16
+ for name, props in parameters["properties"].items():
17
+ table.add_row(name, props["type"], props["description"])
18
+
19
+ # TODO denote required params
20
+
21
+ return table
22
+
23
+
24
+ def function_description_panel(fd: dict) -> Panel:
25
+
26
+ fn = fd["function"]
27
+
28
+ text = Text(f"{fd['type']} {fn['name']}: {fn['description']}\n")
29
+
30
+ pt = parameters_table(fn["parameters"])
31
+
32
+ panel = Panel(Group(text, pt))
33
+
34
+ return panel
35
+
36
+
37
+ def function_descriptions_panel(function_descriptions: list[dict]) -> Panel:
38
+
39
+ sub_panels = [function_description_panel(fd) for fd in function_descriptions]
40
+
41
+ panel = Panel(Group(*sub_panels), title="Function Descriptions")
42
+
43
+ return panel
44
+
45
+
46
+ def complete_with_tools_panel(
47
+ title: str, model_id: str, tool_desc_list: list, messages: list, temperature: float
48
+ ) -> Panel:
49
+
50
+ text = Text(
51
+ f"""
52
+ model_id: {model_id}
53
+ temperature: {temperature}
54
+ """
55
+ )
56
+
57
+ panel = Panel(
58
+ Group(
59
+ text, function_descriptions_panel(tool_desc_list), messages_table(messages)
60
+ ),
61
+ title=title,
62
+ )
63
+
64
+ return panel
@@ -0,0 +1,10 @@
1
+ from rich.text import Text
2
+
3
+
4
+ def header() -> Text:
5
+ text = Text(
6
+ """[bold]Proscenium[/bold] :performing_arts:
7
+ [bold]The AI Alliance[/bold]"""
8
+ )
9
+ # TODO version, timestamp, ...
10
+ return text
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from string import Formatter
3
+
4
+ import json
5
+ from pydantic import BaseModel
6
+
7
+ from proscenium.verbs.complete import complete_simple
8
+
9
+ extraction_system_prompt = "You are an entity extractor"
10
+
11
+
12
+ class PartialFormatter(Formatter):
13
+ def get_value(self, key, args, kwargs):
14
+ try:
15
+ return super().get_value(key, args, kwargs)
16
+ except KeyError:
17
+ return "{" + key + "}"
18
+
19
+
20
+ partial_formatter = PartialFormatter()
21
+
22
+ raw_extraction_template = """\
23
+ Below is a description of a data class for storing information extracted from text:
24
+
25
+ {extraction_description}
26
+
27
+ Find the information in the following text, and provide them in the specified JSON response format.
28
+ Only answer in JSON.:
29
+
30
+ {text}
31
+ """
32
+
33
+
34
+ def extract_to_pydantic_model(
35
+ extraction_model_id: str,
36
+ extraction_template: str,
37
+ clazz: type[BaseModel],
38
+ text: str,
39
+ verbose: bool = False,
40
+ ) -> BaseModel:
41
+
42
+ extract_str = complete_simple(
43
+ extraction_model_id,
44
+ extraction_system_prompt,
45
+ extraction_template.format(text=text),
46
+ response_format={
47
+ "type": "json_object",
48
+ "schema": clazz.model_json_schema(),
49
+ },
50
+ rich_output=verbose,
51
+ )
52
+
53
+ logging.info("complete_to_pydantic_model: extract_str = <<<%s>>>", extract_str)
54
+
55
+ try:
56
+ extract_dict = json.loads(extract_str)
57
+ return clazz.model_construct(**extract_dict)
58
+ except Exception as e:
59
+ logging.error("complete_to_pydantic_model: Exception: %s", e)
60
+
61
+ return None
@@ -0,0 +1,8 @@
1
+ from gofannon.base import BaseTool
2
+
3
+
4
+ def process_tools(tools: list[BaseTool]) -> tuple[dict, list]:
5
+ applied_tools = [F() for F in tools]
6
+ tool_map = {f.name: f.fn for f in applied_tools}
7
+ tool_desc_list = [f.definition for f in applied_tools]
8
+ return tool_map, tool_desc_list
@@ -0,0 +1,9 @@
1
+ from neo4j import GraphDatabase
2
+ from neo4j import Driver
3
+
4
+
5
+ def knowledge_graph_client(uri: str, username: str, password: str) -> Driver:
6
+
7
+ driver = GraphDatabase.driver(uri, auth=(username, password))
8
+
9
+ return driver
@@ -0,0 +1,55 @@
1
+ from typing import List
2
+
3
+ import os
4
+ import logging
5
+
6
+ from langchain_core.documents.base import Document
7
+
8
+ from langchain_community.document_loaders import TextLoader
9
+ from langchain_community.document_loaders.hugging_face_dataset import (
10
+ HuggingFaceDatasetLoader,
11
+ )
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+ logging.getLogger("langchain_text_splitters.base").setLevel(logging.ERROR)
15
+
16
+
17
+ def load_file(filename: str) -> List[Document]:
18
+
19
+ loader = TextLoader(filename)
20
+ documents = loader.load()
21
+
22
+ return documents
23
+
24
+
25
+ def load_hugging_face_dataset(
26
+ dataset_name: str, page_content_column: str = "text"
27
+ ) -> List[Document]:
28
+
29
+ loader = HuggingFaceDatasetLoader(
30
+ dataset_name, page_content_column=page_content_column
31
+ )
32
+ documents = loader.load()
33
+
34
+ return documents
35
+
36
+
37
+ import httpx
38
+ from pydantic.networks import HttpUrl
39
+ from pathlib import Path
40
+
41
+
42
+ async def url_to_file(url: HttpUrl, data_file: Path, overwrite: bool = False):
43
+
44
+ if data_file.exists() and not overwrite:
45
+ # print(f"File {data_file} exists. Use overwrite=True to replace.")
46
+ return
47
+
48
+ async with httpx.AsyncClient() as client:
49
+
50
+ # print(f"Downloading {url} to {data_file}...")
51
+ response = await client.get(url)
52
+ response.raise_for_status()
53
+
54
+ with open(data_file, "wb") as file:
55
+ file.write(response.content)
@@ -0,0 +1,8 @@
1
+ def format_chat_history(chat_history) -> str:
2
+ delimiter = "-" * 80 + "\n"
3
+ return delimiter.join(
4
+ [
5
+ f"{msg['sender']} to {msg['receiver']}:\n\n{msg['content']}\n\n"
6
+ for msg in chat_history
7
+ ]
8
+ )