proscenium 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proscenium/__init__.py +0 -0
- proscenium/scripts/__init__.py +0 -0
- proscenium/scripts/chunk_space.py +33 -0
- proscenium/scripts/document_enricher.py +80 -0
- proscenium/scripts/entity_resolver.py +89 -0
- proscenium/scripts/graph_rag.py +43 -0
- proscenium/scripts/knowledge_graph.py +39 -0
- proscenium/scripts/rag.py +63 -0
- proscenium/scripts/tools.py +103 -0
- proscenium/verbs/__init__.py +0 -0
- proscenium/verbs/chunk.py +40 -0
- proscenium/verbs/complete.py +214 -0
- proscenium/verbs/display/__init__.py +9 -0
- proscenium/verbs/display/chat.py +35 -0
- proscenium/verbs/display/huggingface.py +0 -0
- proscenium/verbs/display/milvus.py +68 -0
- proscenium/verbs/display/neo4j.py +25 -0
- proscenium/verbs/display/tools.py +64 -0
- proscenium/verbs/display.py +10 -0
- proscenium/verbs/extract.py +61 -0
- proscenium/verbs/invoke.py +8 -0
- proscenium/verbs/know.py +9 -0
- proscenium/verbs/read.py +55 -0
- proscenium/verbs/remember.py +8 -0
- proscenium/verbs/vector_database.py +146 -0
- proscenium/verbs/write.py +11 -0
- proscenium-0.0.1.dist-info/LICENSE +201 -0
- proscenium-0.0.1.dist-info/METADATA +52 -0
- proscenium-0.0.1.dist-info/RECORD +30 -0
- proscenium-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
"""
|
2
|
+
This module uses the [`aisuite`](https://github.com/andrewyng/aisuite) library
|
3
|
+
to interact with various LLM inference providers.
|
4
|
+
|
5
|
+
It provides functions to complete a simple chat prompt, evaluate a tool call,
|
6
|
+
and apply a list of tool calls to a chat prompt.
|
7
|
+
|
8
|
+
Providers tested with Proscenium include:
|
9
|
+
|
10
|
+
# AWS
|
11
|
+
|
12
|
+
Environment: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`
|
13
|
+
|
14
|
+
Valid model ids:
|
15
|
+
- `aws:meta.llama3-1-8b-instruct-v1:0`
|
16
|
+
|
17
|
+
# Anthropic
|
18
|
+
|
19
|
+
Environment: `ANTHROPIC_API_KEY`
|
20
|
+
|
21
|
+
Valid model ids:
|
22
|
+
- `anthropic:claude-3-5-sonnet-20240620`
|
23
|
+
|
24
|
+
# OpenAI
|
25
|
+
|
26
|
+
Environment: `OPENAI_API_KEY`
|
27
|
+
|
28
|
+
Valid model ids:
|
29
|
+
- `openai:gpt-4o`
|
30
|
+
|
31
|
+
# Ollama
|
32
|
+
|
33
|
+
Command line, eg `ollama run llama3.2 --keepalive 2h`
|
34
|
+
|
35
|
+
Valid model ids:
|
36
|
+
- `ollama:llama3.2`
|
37
|
+
- `ollama:granite3.1-dense:2b`
|
38
|
+
"""
|
39
|
+
|
40
|
+
from typing import Any
|
41
|
+
|
42
|
+
import json
|
43
|
+
from rich import print
|
44
|
+
from rich.console import Group
|
45
|
+
from rich.panel import Panel
|
46
|
+
from rich.table import Table
|
47
|
+
from rich.text import Text
|
48
|
+
|
49
|
+
from aisuite import Client
|
50
|
+
from aisuite.framework.message import ChatCompletionMessageToolCall
|
51
|
+
|
52
|
+
from proscenium.verbs.display.tools import complete_with_tools_panel
|
53
|
+
|
54
|
+
provider_configs = {
|
55
|
+
# TODO expose this
|
56
|
+
"ollama": {"timeout": 180},
|
57
|
+
}
|
58
|
+
|
59
|
+
client = Client(provider_configs=provider_configs)
|
60
|
+
|
61
|
+
|
62
|
+
def complete_simple(
|
63
|
+
model_id: str, system_prompt: str, user_prompt: str, **kwargs
|
64
|
+
) -> str:
|
65
|
+
|
66
|
+
rich_output = kwargs.pop("rich_output", False)
|
67
|
+
|
68
|
+
messages = [
|
69
|
+
{"role": "system", "content": system_prompt},
|
70
|
+
{"role": "user", "content": user_prompt},
|
71
|
+
]
|
72
|
+
|
73
|
+
if rich_output:
|
74
|
+
|
75
|
+
kwargs_text = "\n".join([str(k) + ": " + str(v) for k, v in kwargs.items()])
|
76
|
+
|
77
|
+
params_text = Text(
|
78
|
+
f"""
|
79
|
+
model_id: {model_id}
|
80
|
+
{kwargs_text}
|
81
|
+
"""
|
82
|
+
)
|
83
|
+
|
84
|
+
messages_table = Table(title="Messages", show_lines=True)
|
85
|
+
messages_table.add_column("Role", justify="left")
|
86
|
+
messages_table.add_column("Content", justify="left") # style="green"
|
87
|
+
for message in messages:
|
88
|
+
messages_table.add_row(message["role"], message["content"])
|
89
|
+
|
90
|
+
call_panel = Panel(
|
91
|
+
Group(params_text, messages_table), title="complete_simple call"
|
92
|
+
)
|
93
|
+
print(call_panel)
|
94
|
+
|
95
|
+
response = client.chat.completions.create(
|
96
|
+
model=model_id, messages=messages, **kwargs
|
97
|
+
)
|
98
|
+
response = response.choices[0].message.content
|
99
|
+
|
100
|
+
if rich_output:
|
101
|
+
print(Panel(response, title="Response"))
|
102
|
+
|
103
|
+
return response
|
104
|
+
|
105
|
+
|
106
|
+
def evaluate_tool_call(
|
107
|
+
tool_map: dict, tool_call: ChatCompletionMessageToolCall, rich_output: bool = False
|
108
|
+
) -> Any:
|
109
|
+
|
110
|
+
function_name = tool_call.function.name
|
111
|
+
# TODO validate the arguments?
|
112
|
+
function_args = json.loads(tool_call.function.arguments)
|
113
|
+
|
114
|
+
if rich_output:
|
115
|
+
print(f"Evaluating tool call: {function_name} with args {function_args}")
|
116
|
+
|
117
|
+
function_response = tool_map[function_name](**function_args)
|
118
|
+
|
119
|
+
if rich_output:
|
120
|
+
print(f" Response: {function_response}")
|
121
|
+
|
122
|
+
return function_response
|
123
|
+
|
124
|
+
|
125
|
+
def tool_response_message(
|
126
|
+
tool_call: ChatCompletionMessageToolCall, tool_result: Any
|
127
|
+
) -> dict:
|
128
|
+
|
129
|
+
return {
|
130
|
+
"role": "tool",
|
131
|
+
"tool_call_id": tool_call.id,
|
132
|
+
"name": tool_call.function.name,
|
133
|
+
"content": json.dumps(tool_result),
|
134
|
+
}
|
135
|
+
|
136
|
+
|
137
|
+
def evaluate_tool_calls(
|
138
|
+
tool_call_message, tool_map: dict, rich_output: bool = False
|
139
|
+
) -> list[dict]:
|
140
|
+
|
141
|
+
tool_call: ChatCompletionMessageToolCall
|
142
|
+
|
143
|
+
if rich_output:
|
144
|
+
print("Evaluating tool calls")
|
145
|
+
|
146
|
+
new_messages: list[dict] = []
|
147
|
+
|
148
|
+
for tool_call in tool_call_message.tool_calls:
|
149
|
+
function_response = evaluate_tool_call(tool_map, tool_call, rich_output)
|
150
|
+
new_messages.append(tool_response_message(tool_call, function_response))
|
151
|
+
|
152
|
+
if rich_output:
|
153
|
+
print("Tool calls evaluated")
|
154
|
+
|
155
|
+
return new_messages
|
156
|
+
|
157
|
+
|
158
|
+
def complete_for_tool_applications(
|
159
|
+
model_id: str,
|
160
|
+
messages: list,
|
161
|
+
tool_desc_list: list,
|
162
|
+
temperature: float,
|
163
|
+
rich_output: bool = False,
|
164
|
+
):
|
165
|
+
|
166
|
+
if rich_output:
|
167
|
+
panel = complete_with_tools_panel(
|
168
|
+
"complete for tool applications",
|
169
|
+
model_id,
|
170
|
+
tool_desc_list,
|
171
|
+
messages,
|
172
|
+
temperature,
|
173
|
+
)
|
174
|
+
print(panel)
|
175
|
+
|
176
|
+
response = client.chat.completions.create(
|
177
|
+
model=model_id,
|
178
|
+
messages=messages,
|
179
|
+
temperature=temperature,
|
180
|
+
tools=tool_desc_list, # tool_choice="auto",
|
181
|
+
)
|
182
|
+
|
183
|
+
return response
|
184
|
+
|
185
|
+
|
186
|
+
def complete_with_tool_results(
|
187
|
+
model_id: str,
|
188
|
+
messages: list,
|
189
|
+
tool_call_message: dict,
|
190
|
+
tool_evaluation_messages: list[dict],
|
191
|
+
tool_desc_list: list,
|
192
|
+
temperature: float,
|
193
|
+
rich_output: bool = False,
|
194
|
+
):
|
195
|
+
|
196
|
+
messages.append(tool_call_message)
|
197
|
+
messages.extend(tool_evaluation_messages)
|
198
|
+
|
199
|
+
if rich_output:
|
200
|
+
panel = complete_with_tools_panel(
|
201
|
+
"complete call with tool results",
|
202
|
+
model_id,
|
203
|
+
tool_desc_list,
|
204
|
+
messages,
|
205
|
+
temperature,
|
206
|
+
)
|
207
|
+
print(panel)
|
208
|
+
|
209
|
+
response = client.chat.completions.create(
|
210
|
+
model=model_id,
|
211
|
+
messages=messages,
|
212
|
+
)
|
213
|
+
|
214
|
+
return response.choices[0].message.content
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from rich.table import Table
|
2
|
+
|
3
|
+
|
4
|
+
def messages_table(messages: list) -> Table:
|
5
|
+
|
6
|
+
table = Table(title="Messages in Chat Context", show_lines=True)
|
7
|
+
table.add_column("Role", justify="left")
|
8
|
+
table.add_column("Content", justify="left")
|
9
|
+
for message in messages:
|
10
|
+
if type(message) is dict:
|
11
|
+
role = message["role"]
|
12
|
+
content = ""
|
13
|
+
if role == "tool":
|
14
|
+
content = f"""tool call id = {message['tool_call_id']}
|
15
|
+
fn name = {message['name']}
|
16
|
+
result = {message['content']}"""
|
17
|
+
elif role == "assistant":
|
18
|
+
content = f"""{str(message)}"""
|
19
|
+
else:
|
20
|
+
content = message["content"]
|
21
|
+
table.add_row(role, content)
|
22
|
+
else:
|
23
|
+
role = message.role
|
24
|
+
content = ""
|
25
|
+
if role == "tool":
|
26
|
+
content = f"""tool call id = {message.tool_call_id}
|
27
|
+
fn name = {message.name}
|
28
|
+
result = {message['content']}"""
|
29
|
+
elif role == "assistant":
|
30
|
+
content = f"""{str(message)}"""
|
31
|
+
else:
|
32
|
+
content = message.content
|
33
|
+
table.add_row(role, content)
|
34
|
+
|
35
|
+
return table
|
File without changes
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from rich.table import Table
|
2
|
+
from rich.panel import Panel
|
3
|
+
from rich.text import Text
|
4
|
+
from rich.console import Group
|
5
|
+
from pymilvus import MilvusClient
|
6
|
+
|
7
|
+
|
8
|
+
def chunk_hits_table(chunks: list[dict]) -> Table:
|
9
|
+
|
10
|
+
table = Table(title="Closest Chunks", show_lines=True)
|
11
|
+
table.add_column("id", justify="right")
|
12
|
+
table.add_column("distance")
|
13
|
+
table.add_column("entity.text", justify="right")
|
14
|
+
for chunk in chunks:
|
15
|
+
table.add_row(str(chunk["id"]), str(chunk["distance"]), chunk["entity"]["text"])
|
16
|
+
return table
|
17
|
+
|
18
|
+
|
19
|
+
def collection_panel(client: MilvusClient, collection_name: str) -> Panel:
|
20
|
+
|
21
|
+
stats = client.get_collection_stats(collection_name)
|
22
|
+
desc = client.describe_collection(collection_name)
|
23
|
+
|
24
|
+
params_text = Text(
|
25
|
+
f"""
|
26
|
+
Collection Name: {desc['collection_name']}
|
27
|
+
Auto ID: {desc['auto_id']}
|
28
|
+
Num Shards: {desc['num_shards']}
|
29
|
+
Description: {desc['description']}
|
30
|
+
Functions: {desc['functions']}
|
31
|
+
Aliases: {desc['aliases']}
|
32
|
+
Collection ID: {desc['collection_id']}
|
33
|
+
Consistency Level: {desc['consistency_level']}
|
34
|
+
Properties: {desc['properties']}
|
35
|
+
Num Partitions: {desc['num_partitions']}
|
36
|
+
Enable Dynamic Field: {desc['enable_dynamic_field']}"""
|
37
|
+
)
|
38
|
+
|
39
|
+
params_panel = Panel(params_text, title="Params")
|
40
|
+
|
41
|
+
fields_table = Table(title="Fields", show_lines=True)
|
42
|
+
fields_table.add_column("id", justify="left")
|
43
|
+
fields_table.add_column("name", justify="left")
|
44
|
+
fields_table.add_column("description", justify="left")
|
45
|
+
fields_table.add_column("type", justify="left")
|
46
|
+
fields_table.add_column("params", justify="left")
|
47
|
+
fields_table.add_column("auto_id", justify="left")
|
48
|
+
fields_table.add_column("is_primary", justify="left")
|
49
|
+
for field in desc["fields"]:
|
50
|
+
fields_table.add_row(
|
51
|
+
str(field["field_id"]), # int
|
52
|
+
field["name"],
|
53
|
+
field["description"],
|
54
|
+
field["type"].name, # Milvus DataType
|
55
|
+
"\n".join([f"{k}: {v}" for k, v in field["params"].items()]),
|
56
|
+
str(field.get("auto_id", "-")), # bool
|
57
|
+
str(field.get("is_primary", "-")),
|
58
|
+
) # bool
|
59
|
+
|
60
|
+
stats_text = Text("\n".join([f"{k}: {v}" for k, v in stats.items()]))
|
61
|
+
stats_panel = Panel(stats_text, title="Stats")
|
62
|
+
|
63
|
+
panel = Panel(
|
64
|
+
Group(params_panel, fields_table, stats_panel),
|
65
|
+
title=f"Collection {collection_name}",
|
66
|
+
)
|
67
|
+
|
68
|
+
return panel
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import List
|
2
|
+
from rich.table import Table
|
3
|
+
|
4
|
+
|
5
|
+
def triples_table(triples: List[tuple[str, str, str]], title: str) -> Table:
|
6
|
+
|
7
|
+
table = Table(title=title, show_lines=False)
|
8
|
+
table.add_column("Subject", justify="left")
|
9
|
+
table.add_column("Predicate", justify="left")
|
10
|
+
table.add_column("Object", justify="left")
|
11
|
+
for triple in triples:
|
12
|
+
table.add_row(*triple)
|
13
|
+
|
14
|
+
return table
|
15
|
+
|
16
|
+
|
17
|
+
def pairs_table(subject_predicate_pairs: List[tuple[str, str]], title: str) -> Table:
|
18
|
+
|
19
|
+
table = Table(title=title, show_lines=False)
|
20
|
+
table.add_column("Subject", justify="left")
|
21
|
+
table.add_column("Predicate", justify="left")
|
22
|
+
for pair in subject_predicate_pairs:
|
23
|
+
table.add_row(*pair)
|
24
|
+
|
25
|
+
return table
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from rich.console import Group
|
2
|
+
from rich.table import Table
|
3
|
+
from rich.text import Text
|
4
|
+
from rich.panel import Panel
|
5
|
+
|
6
|
+
from .chat import messages_table
|
7
|
+
|
8
|
+
|
9
|
+
def parameters_table(parameters: list[dict]) -> Table:
|
10
|
+
|
11
|
+
table = Table(title="Parameters", show_lines=False, box=None)
|
12
|
+
table.add_column("name", justify="right")
|
13
|
+
table.add_column("type", justify="left")
|
14
|
+
table.add_column("description", justify="left")
|
15
|
+
|
16
|
+
for name, props in parameters["properties"].items():
|
17
|
+
table.add_row(name, props["type"], props["description"])
|
18
|
+
|
19
|
+
# TODO denote required params
|
20
|
+
|
21
|
+
return table
|
22
|
+
|
23
|
+
|
24
|
+
def function_description_panel(fd: dict) -> Panel:
|
25
|
+
|
26
|
+
fn = fd["function"]
|
27
|
+
|
28
|
+
text = Text(f"{fd['type']} {fn['name']}: {fn['description']}\n")
|
29
|
+
|
30
|
+
pt = parameters_table(fn["parameters"])
|
31
|
+
|
32
|
+
panel = Panel(Group(text, pt))
|
33
|
+
|
34
|
+
return panel
|
35
|
+
|
36
|
+
|
37
|
+
def function_descriptions_panel(function_descriptions: list[dict]) -> Panel:
|
38
|
+
|
39
|
+
sub_panels = [function_description_panel(fd) for fd in function_descriptions]
|
40
|
+
|
41
|
+
panel = Panel(Group(*sub_panels), title="Function Descriptions")
|
42
|
+
|
43
|
+
return panel
|
44
|
+
|
45
|
+
|
46
|
+
def complete_with_tools_panel(
|
47
|
+
title: str, model_id: str, tool_desc_list: list, messages: list, temperature: float
|
48
|
+
) -> Panel:
|
49
|
+
|
50
|
+
text = Text(
|
51
|
+
f"""
|
52
|
+
model_id: {model_id}
|
53
|
+
temperature: {temperature}
|
54
|
+
"""
|
55
|
+
)
|
56
|
+
|
57
|
+
panel = Panel(
|
58
|
+
Group(
|
59
|
+
text, function_descriptions_panel(tool_desc_list), messages_table(messages)
|
60
|
+
),
|
61
|
+
title=title,
|
62
|
+
)
|
63
|
+
|
64
|
+
return panel
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import logging
|
2
|
+
from string import Formatter
|
3
|
+
|
4
|
+
import json
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from proscenium.verbs.complete import complete_simple
|
8
|
+
|
9
|
+
extraction_system_prompt = "You are an entity extractor"
|
10
|
+
|
11
|
+
|
12
|
+
class PartialFormatter(Formatter):
|
13
|
+
def get_value(self, key, args, kwargs):
|
14
|
+
try:
|
15
|
+
return super().get_value(key, args, kwargs)
|
16
|
+
except KeyError:
|
17
|
+
return "{" + key + "}"
|
18
|
+
|
19
|
+
|
20
|
+
partial_formatter = PartialFormatter()
|
21
|
+
|
22
|
+
raw_extraction_template = """\
|
23
|
+
Below is a description of a data class for storing information extracted from text:
|
24
|
+
|
25
|
+
{extraction_description}
|
26
|
+
|
27
|
+
Find the information in the following text, and provide them in the specified JSON response format.
|
28
|
+
Only answer in JSON.:
|
29
|
+
|
30
|
+
{text}
|
31
|
+
"""
|
32
|
+
|
33
|
+
|
34
|
+
def extract_to_pydantic_model(
|
35
|
+
extraction_model_id: str,
|
36
|
+
extraction_template: str,
|
37
|
+
clazz: type[BaseModel],
|
38
|
+
text: str,
|
39
|
+
verbose: bool = False,
|
40
|
+
) -> BaseModel:
|
41
|
+
|
42
|
+
extract_str = complete_simple(
|
43
|
+
extraction_model_id,
|
44
|
+
extraction_system_prompt,
|
45
|
+
extraction_template.format(text=text),
|
46
|
+
response_format={
|
47
|
+
"type": "json_object",
|
48
|
+
"schema": clazz.model_json_schema(),
|
49
|
+
},
|
50
|
+
rich_output=verbose,
|
51
|
+
)
|
52
|
+
|
53
|
+
logging.info("complete_to_pydantic_model: extract_str = <<<%s>>>", extract_str)
|
54
|
+
|
55
|
+
try:
|
56
|
+
extract_dict = json.loads(extract_str)
|
57
|
+
return clazz.model_construct(**extract_dict)
|
58
|
+
except Exception as e:
|
59
|
+
logging.error("complete_to_pydantic_model: Exception: %s", e)
|
60
|
+
|
61
|
+
return None
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from gofannon.base import BaseTool
|
2
|
+
|
3
|
+
|
4
|
+
def process_tools(tools: list[BaseTool]) -> tuple[dict, list]:
|
5
|
+
applied_tools = [F() for F in tools]
|
6
|
+
tool_map = {f.name: f.fn for f in applied_tools}
|
7
|
+
tool_desc_list = [f.definition for f in applied_tools]
|
8
|
+
return tool_map, tool_desc_list
|
proscenium/verbs/know.py
ADDED
proscenium/verbs/read.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
import os
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from langchain_core.documents.base import Document
|
7
|
+
|
8
|
+
from langchain_community.document_loaders import TextLoader
|
9
|
+
from langchain_community.document_loaders.hugging_face_dataset import (
|
10
|
+
HuggingFaceDatasetLoader,
|
11
|
+
)
|
12
|
+
|
13
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
14
|
+
logging.getLogger("langchain_text_splitters.base").setLevel(logging.ERROR)
|
15
|
+
|
16
|
+
|
17
|
+
def load_file(filename: str) -> List[Document]:
|
18
|
+
|
19
|
+
loader = TextLoader(filename)
|
20
|
+
documents = loader.load()
|
21
|
+
|
22
|
+
return documents
|
23
|
+
|
24
|
+
|
25
|
+
def load_hugging_face_dataset(
|
26
|
+
dataset_name: str, page_content_column: str = "text"
|
27
|
+
) -> List[Document]:
|
28
|
+
|
29
|
+
loader = HuggingFaceDatasetLoader(
|
30
|
+
dataset_name, page_content_column=page_content_column
|
31
|
+
)
|
32
|
+
documents = loader.load()
|
33
|
+
|
34
|
+
return documents
|
35
|
+
|
36
|
+
|
37
|
+
import httpx
|
38
|
+
from pydantic.networks import HttpUrl
|
39
|
+
from pathlib import Path
|
40
|
+
|
41
|
+
|
42
|
+
async def url_to_file(url: HttpUrl, data_file: Path, overwrite: bool = False):
|
43
|
+
|
44
|
+
if data_file.exists() and not overwrite:
|
45
|
+
# print(f"File {data_file} exists. Use overwrite=True to replace.")
|
46
|
+
return
|
47
|
+
|
48
|
+
async with httpx.AsyncClient() as client:
|
49
|
+
|
50
|
+
# print(f"Downloading {url} to {data_file}...")
|
51
|
+
response = await client.get(url)
|
52
|
+
response.raise_for_status()
|
53
|
+
|
54
|
+
with open(data_file, "wb") as file:
|
55
|
+
file.write(response.content)
|