proscenium 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proscenium/__init__.py +3 -0
- proscenium/admin/__init__.py +37 -0
- proscenium/bin/bot.py +142 -0
- proscenium/core/__init__.py +152 -0
- proscenium/interfaces/__init__.py +3 -0
- proscenium/interfaces/slack.py +265 -0
- proscenium/patterns/__init__.py +3 -0
- proscenium/patterns/chunk_space.py +51 -0
- proscenium/{scripts → patterns}/document_enricher.py +15 -11
- proscenium/{scripts → patterns}/entity_resolver.py +24 -18
- proscenium/patterns/graph_rag.py +61 -0
- proscenium/{scripts → patterns}/knowledge_graph.py +4 -2
- proscenium/{scripts → patterns}/rag.py +6 -12
- proscenium/{scripts → patterns}/tools.py +13 -45
- proscenium/verbs/__init__.py +3 -0
- proscenium/verbs/chunk.py +2 -0
- proscenium/verbs/complete.py +24 -28
- proscenium/verbs/display/__init__.py +1 -1
- proscenium/verbs/display.py +3 -0
- proscenium/verbs/extract.py +8 -4
- proscenium/verbs/invoke.py +3 -0
- proscenium/verbs/read.py +6 -8
- proscenium/verbs/remember.py +5 -0
- proscenium/verbs/vector_database.py +13 -20
- proscenium/verbs/write.py +3 -0
- {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/METADATA +5 -8
- proscenium-0.0.3.dist-info/RECORD +34 -0
- {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/WHEEL +1 -1
- proscenium-0.0.3.dist-info/entry_points.txt +3 -0
- proscenium/scripts/__init__.py +0 -0
- proscenium/scripts/chunk_space.py +0 -33
- proscenium/scripts/graph_rag.py +0 -43
- proscenium/verbs/display/huggingface.py +0 -0
- proscenium/verbs/know.py +0 -9
- proscenium-0.0.1.dist-info/RECORD +0 -30
- {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/LICENSE +0 -0
@@ -1,12 +1,13 @@
|
|
1
1
|
from typing import List
|
2
2
|
from typing import Callable
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
import time
|
6
|
+
import logging
|
6
7
|
from pydantic import BaseModel
|
7
8
|
|
8
|
-
from rich import print
|
9
9
|
from rich.panel import Panel
|
10
|
+
from rich.console import Console
|
10
11
|
from rich.progress import Progress
|
11
12
|
|
12
13
|
from langchain_core.documents.base import Document
|
@@ -14,6 +15,8 @@ from langchain_core.documents.base import Document
|
|
14
15
|
from proscenium.verbs.chunk import documents_to_chunks_by_tokens
|
15
16
|
from proscenium.verbs.extract import extract_to_pydantic_model
|
16
17
|
|
18
|
+
log = logging.getLogger(__name__)
|
19
|
+
|
17
20
|
|
18
21
|
def extract_from_document_chunks(
|
19
22
|
doc: Document,
|
@@ -22,11 +25,12 @@ def extract_from_document_chunks(
|
|
22
25
|
chunk_extraction_template: str,
|
23
26
|
chunk_extract_clazz: type[BaseModel],
|
24
27
|
delay: float,
|
25
|
-
|
28
|
+
console: Optional[Console] = None,
|
26
29
|
) -> List[BaseModel]:
|
27
30
|
|
28
|
-
|
29
|
-
|
31
|
+
if console is not None:
|
32
|
+
console.print(doc_as_rich(doc))
|
33
|
+
console.print()
|
30
34
|
|
31
35
|
extract_models = []
|
32
36
|
|
@@ -40,9 +44,9 @@ def extract_from_document_chunks(
|
|
40
44
|
chunk.page_content,
|
41
45
|
)
|
42
46
|
|
43
|
-
|
44
|
-
|
45
|
-
print(Panel(str(ce)))
|
47
|
+
log.info("Extract model in chunk %s of %s", i + 1, len(chunks))
|
48
|
+
if console is not None:
|
49
|
+
console.print(Panel(str(ce)))
|
46
50
|
|
47
51
|
extract_models.append(ce)
|
48
52
|
time.sleep(delay)
|
@@ -55,7 +59,7 @@ def enrich_documents(
|
|
55
59
|
extract_from_doc_chunks: Callable[[Document], List[BaseModel]],
|
56
60
|
doc_enrichments: Callable[[Document, list[BaseModel]], BaseModel],
|
57
61
|
enrichments_jsonl_file: str,
|
58
|
-
|
62
|
+
console: Optional[Console] = None,
|
59
63
|
) -> None:
|
60
64
|
|
61
65
|
docs = retrieve_documents()
|
@@ -70,11 +74,11 @@ def enrich_documents(
|
|
70
74
|
|
71
75
|
for doc in docs:
|
72
76
|
|
73
|
-
chunk_extract_models = extract_from_doc_chunks(doc
|
77
|
+
chunk_extract_models = extract_from_doc_chunks(doc)
|
74
78
|
enrichments = doc_enrichments(doc, chunk_extract_models)
|
75
79
|
enrichments_json = enrichments.model_dump_json()
|
76
80
|
f.write(enrichments_json + "\n")
|
77
81
|
|
78
82
|
progress.update(task_enrich, advance=1)
|
79
83
|
|
80
|
-
|
84
|
+
log.info("Wrote document enrichments to %s", enrichments_jsonl_file)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import Optional
|
2
|
-
|
2
|
+
import logging
|
3
3
|
|
4
|
+
from rich.console import Console
|
4
5
|
from langchain_core.documents.base import Document
|
5
6
|
from neo4j import Driver
|
6
7
|
|
@@ -13,6 +14,8 @@ from proscenium.verbs.vector_database import add_chunks_to_vector_db
|
|
13
14
|
from proscenium.verbs.vector_database import embedding_function
|
14
15
|
from proscenium.verbs.display.milvus import collection_panel
|
15
16
|
|
17
|
+
log = logging.getLogger(__name__)
|
18
|
+
|
16
19
|
|
17
20
|
class Resolver:
|
18
21
|
|
@@ -21,27 +24,27 @@ class Resolver:
|
|
21
24
|
cypher: str,
|
22
25
|
field_name: str,
|
23
26
|
collection_name: str,
|
24
|
-
embedding_model_id: str,
|
25
27
|
):
|
26
28
|
self.cypher = cypher
|
27
29
|
self.field_name = field_name
|
28
30
|
self.collection_name = collection_name
|
29
|
-
self.embedding_model_id = embedding_model_id
|
30
31
|
|
31
32
|
|
32
33
|
def load_entity_resolver(
|
33
34
|
driver: Driver,
|
34
35
|
resolvers: list[Resolver],
|
36
|
+
embedding_model_id: str,
|
35
37
|
milvus_uri: str,
|
38
|
+
console: Optional[Console] = None,
|
36
39
|
) -> None:
|
37
40
|
|
38
|
-
vector_db_client = vector_db(milvus_uri
|
39
|
-
|
41
|
+
vector_db_client = vector_db(milvus_uri)
|
42
|
+
log.info("Vector db stored at %s", milvus_uri)
|
40
43
|
|
41
|
-
|
44
|
+
embedding_fn = embedding_function(embedding_model_id)
|
45
|
+
log.info("Embedding model %s", embedding_model_id)
|
42
46
|
|
43
|
-
|
44
|
-
print("Embedding model", resolver.embedding_model_id)
|
47
|
+
for resolver in resolvers:
|
45
48
|
|
46
49
|
values = []
|
47
50
|
with driver.session() as session:
|
@@ -49,15 +52,16 @@ def load_entity_resolver(
|
|
49
52
|
new_values = [Document(record[resolver.field_name]) for record in result]
|
50
53
|
values.extend(new_values)
|
51
54
|
|
52
|
-
|
53
|
-
create_collection(
|
54
|
-
|
55
|
-
)
|
55
|
+
log.info("Loading entity resolver into vector db %s", resolver.collection_name)
|
56
|
+
create_collection(vector_db_client, embedding_fn, resolver.collection_name)
|
57
|
+
|
56
58
|
info = add_chunks_to_vector_db(
|
57
59
|
vector_db_client, embedding_fn, values, resolver.collection_name
|
58
60
|
)
|
59
|
-
|
60
|
-
|
61
|
+
log.info("%s chunks inserted ", info["insert_count"])
|
62
|
+
|
63
|
+
if console is not None:
|
64
|
+
console.print(collection_panel(vector_db_client, resolver.collection_name))
|
61
65
|
|
62
66
|
vector_db_client.close()
|
63
67
|
|
@@ -68,10 +72,12 @@ def find_matching_objects(
|
|
68
72
|
resolver: Resolver,
|
69
73
|
) -> Optional[str]:
|
70
74
|
|
71
|
-
|
75
|
+
log.info("Loading collection", resolver.collection_name)
|
72
76
|
vector_db_client.load_collection(resolver.collection_name)
|
73
77
|
|
74
|
-
|
78
|
+
log.info(
|
79
|
+
"Finding entity matches for", approximate, "using", resolver.collection_name
|
80
|
+
)
|
75
81
|
|
76
82
|
hits = closest_chunks(
|
77
83
|
vector_db_client,
|
@@ -82,8 +88,8 @@ def find_matching_objects(
|
|
82
88
|
)
|
83
89
|
# TODO apply distance threshold
|
84
90
|
for match in [head["entity"]["text"] for head in hits[:1]]:
|
85
|
-
|
91
|
+
log.info("Closest match:", match)
|
86
92
|
return match
|
87
93
|
|
88
|
-
|
94
|
+
log.info("No match found")
|
89
95
|
return None
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from rich.console import Console
|
7
|
+
|
8
|
+
from pydantic import BaseModel
|
9
|
+
from uuid import uuid4, UUID
|
10
|
+
from neo4j import Driver
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def query_to_prompts(
|
16
|
+
query: str,
|
17
|
+
query_extraction_model_id: str,
|
18
|
+
milvus_uri: str,
|
19
|
+
driver: Driver,
|
20
|
+
query_extract: Callable[
|
21
|
+
[str, str], BaseModel
|
22
|
+
], # (query_str, query_extraction_model_id) -> QueryExtractions
|
23
|
+
query_extract_to_graph: Callable[
|
24
|
+
[str, UUID, BaseModel], None
|
25
|
+
], # query, query_id, extract
|
26
|
+
query_extract_to_context: Callable[
|
27
|
+
[BaseModel, str, Driver, str, Optional[Console]], BaseModel
|
28
|
+
], # (QueryExtractions, query_str, Driver, milvus_uri) -> Context
|
29
|
+
context_to_prompts: Callable[
|
30
|
+
[BaseModel], tuple[str, str]
|
31
|
+
], # Context -> (system_prompt, user_prompt)
|
32
|
+
console: Optional[Console] = None,
|
33
|
+
) -> Optional[tuple[str, str]]:
|
34
|
+
|
35
|
+
query_id = uuid4()
|
36
|
+
|
37
|
+
log.info("Extracting information from the question")
|
38
|
+
|
39
|
+
extract = query_extract(query, query_extraction_model_id)
|
40
|
+
if extract is None:
|
41
|
+
log.info("Unable to extract information from that question")
|
42
|
+
return None
|
43
|
+
|
44
|
+
log.info("Extract: %s", extract)
|
45
|
+
|
46
|
+
log.info("Storing the extracted information in the graph")
|
47
|
+
query_extract_to_graph(query, query_id, extract, driver)
|
48
|
+
|
49
|
+
log.info("Forming context from the extracted information")
|
50
|
+
context = query_extract_to_context(
|
51
|
+
extract, query, driver, milvus_uri, console=console
|
52
|
+
)
|
53
|
+
if context is None:
|
54
|
+
log.info("Unable to form context from the extracted information")
|
55
|
+
return None
|
56
|
+
|
57
|
+
log.info("Context: %s", context)
|
58
|
+
|
59
|
+
prompts = context_to_prompts(context)
|
60
|
+
|
61
|
+
return prompts
|
@@ -1,14 +1,16 @@
|
|
1
1
|
from typing import Callable
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
+
import logging
|
4
5
|
import json
|
5
6
|
from pydantic import BaseModel
|
6
7
|
|
7
|
-
from rich import print
|
8
8
|
from rich.progress import Progress
|
9
9
|
|
10
10
|
from neo4j import Driver
|
11
11
|
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
12
14
|
|
13
15
|
def load_knowledge_graph(
|
14
16
|
driver: Driver,
|
@@ -17,7 +19,7 @@ def load_knowledge_graph(
|
|
17
19
|
doc_enrichments_to_graph: Callable[[Any, BaseModel], None],
|
18
20
|
) -> None:
|
19
21
|
|
20
|
-
|
22
|
+
log.info("Parsing enrichments from %s", enrichments_jsonl_file)
|
21
23
|
|
22
24
|
enrichmentss = []
|
23
25
|
with open(enrichments_jsonl_file, "r") as f:
|
@@ -1,7 +1,5 @@
|
|
1
1
|
from typing import List, Dict
|
2
|
-
|
3
|
-
from rich import print
|
4
|
-
from rich.panel import Panel
|
2
|
+
import logging
|
5
3
|
|
6
4
|
from pymilvus import MilvusClient
|
7
5
|
from pymilvus import model
|
@@ -10,6 +8,7 @@ from proscenium.verbs.complete import complete_simple
|
|
10
8
|
from proscenium.verbs.display.milvus import chunk_hits_table
|
11
9
|
from proscenium.verbs.vector_database import closest_chunks
|
12
10
|
|
11
|
+
log = logging.getLogger(__name__)
|
13
12
|
|
14
13
|
rag_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
15
14
|
|
@@ -44,20 +43,15 @@ def answer_question(
|
|
44
43
|
vector_db_client: MilvusClient,
|
45
44
|
embedding_fn: model.dense.SentenceTransformerEmbeddingFunction,
|
46
45
|
collection_name: str,
|
47
|
-
verbose: bool = False,
|
48
46
|
) -> str:
|
49
47
|
|
50
|
-
print(Panel(query, title="User"))
|
51
|
-
|
52
48
|
chunks = closest_chunks(vector_db_client, embedding_fn, query, collection_name)
|
53
|
-
|
54
|
-
|
55
|
-
print(chunk_hits_table(chunks))
|
49
|
+
log.info("Found %s closest chunks", len(chunks))
|
50
|
+
log.info(chunk_hits_table(chunks))
|
56
51
|
|
57
52
|
prompt = rag_prompt(chunks, query)
|
58
|
-
|
59
|
-
print("RAG prompt created. Calling inference at", model_id, "\n\n")
|
53
|
+
log.info("RAG prompt created. Calling inference at %s", model_id)
|
60
54
|
|
61
|
-
answer = complete_simple(model_id, rag_system_prompt, prompt
|
55
|
+
answer = complete_simple(model_id, rag_system_prompt, prompt)
|
62
56
|
|
63
57
|
return answer
|
@@ -1,47 +1,17 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional
|
2
|
+
import logging
|
2
3
|
|
3
|
-
from rich import
|
4
|
+
from rich.console import Console
|
4
5
|
from rich.panel import Panel
|
5
6
|
from rich.text import Text
|
6
|
-
from thespian.actors import Actor
|
7
|
-
|
8
|
-
from gofannon.base import BaseTool
|
9
7
|
|
10
8
|
from proscenium.verbs.complete import (
|
11
9
|
complete_for_tool_applications,
|
12
10
|
evaluate_tool_calls,
|
13
11
|
complete_with_tool_results,
|
14
12
|
)
|
15
|
-
from proscenium.verbs.invoke import process_tools
|
16
|
-
|
17
|
-
|
18
|
-
def tool_applier_actor_class(
|
19
|
-
tools: List[BaseTool],
|
20
|
-
system_message: str,
|
21
|
-
model_id: str,
|
22
|
-
temperature: float = 0.75,
|
23
|
-
rich_output: bool = False,
|
24
|
-
):
|
25
|
-
|
26
|
-
tool_map, tool_desc_list = process_tools(tools)
|
27
13
|
|
28
|
-
|
29
|
-
|
30
|
-
def receiveMessage(self, message, sender):
|
31
|
-
|
32
|
-
response = apply_tools(
|
33
|
-
model_id=model_id,
|
34
|
-
system_message=system_message,
|
35
|
-
message=message,
|
36
|
-
tool_desc_list=tool_desc_list,
|
37
|
-
tool_map=tool_map,
|
38
|
-
temperature=temperature,
|
39
|
-
rich_output=rich_output,
|
40
|
-
)
|
41
|
-
|
42
|
-
self.send(sender, response)
|
43
|
-
|
44
|
-
return ToolApplier
|
14
|
+
log = logging.getLogger(__name__)
|
45
15
|
|
46
16
|
|
47
17
|
def apply_tools(
|
@@ -51,7 +21,7 @@ def apply_tools(
|
|
51
21
|
tool_desc_list: list,
|
52
22
|
tool_map: dict,
|
53
23
|
temperature: float = 0.75,
|
54
|
-
|
24
|
+
console: Optional[Console] = None,
|
55
25
|
) -> str:
|
56
26
|
|
57
27
|
messages = [
|
@@ -60,35 +30,33 @@ def apply_tools(
|
|
60
30
|
]
|
61
31
|
|
62
32
|
response = complete_for_tool_applications(
|
63
|
-
model_id, messages, tool_desc_list, temperature,
|
33
|
+
model_id, messages, tool_desc_list, temperature, console
|
64
34
|
)
|
65
35
|
|
66
36
|
tool_call_message = response.choices[0].message
|
67
37
|
|
68
38
|
if tool_call_message.tool_calls is None or len(tool_call_message.tool_calls) == 0:
|
69
39
|
|
70
|
-
if
|
71
|
-
print(
|
40
|
+
if console is not None:
|
41
|
+
console.print(
|
72
42
|
Panel(
|
73
43
|
Text(str(tool_call_message.content)),
|
74
44
|
title="Tool Application Response",
|
75
45
|
)
|
76
46
|
)
|
77
47
|
|
78
|
-
|
48
|
+
log.info("No tool applications detected")
|
79
49
|
|
80
50
|
return tool_call_message.content
|
81
51
|
|
82
52
|
else:
|
83
53
|
|
84
|
-
if
|
85
|
-
print(
|
54
|
+
if console is not None:
|
55
|
+
console.print(
|
86
56
|
Panel(Text(str(tool_call_message)), title="Tool Application Response")
|
87
57
|
)
|
88
58
|
|
89
|
-
tool_evaluation_messages = evaluate_tool_calls(
|
90
|
-
tool_call_message, tool_map, rich_output
|
91
|
-
)
|
59
|
+
tool_evaluation_messages = evaluate_tool_calls(tool_call_message, tool_map)
|
92
60
|
|
93
61
|
result = complete_with_tool_results(
|
94
62
|
model_id,
|
@@ -97,7 +65,7 @@ def apply_tools(
|
|
97
65
|
tool_evaluation_messages,
|
98
66
|
tool_desc_list,
|
99
67
|
temperature,
|
100
|
-
|
68
|
+
console,
|
101
69
|
)
|
102
70
|
|
103
71
|
return result
|
proscenium/verbs/__init__.py
CHANGED
proscenium/verbs/chunk.py
CHANGED
@@ -8,6 +8,8 @@ from langchain_core.documents.base import Document
|
|
8
8
|
from langchain.text_splitter import CharacterTextSplitter
|
9
9
|
from langchain.text_splitter import TokenTextSplitter
|
10
10
|
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
11
13
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
12
14
|
logging.getLogger("langchain_text_splitters.base").setLevel(logging.ERROR)
|
13
15
|
|
proscenium/verbs/complete.py
CHANGED
@@ -37,10 +37,12 @@ Valid model ids:
|
|
37
37
|
- `ollama:granite3.1-dense:2b`
|
38
38
|
"""
|
39
39
|
|
40
|
+
from typing import Optional
|
40
41
|
from typing import Any
|
41
|
-
|
42
|
+
import logging
|
42
43
|
import json
|
43
|
-
|
44
|
+
|
45
|
+
from rich.console import Console
|
44
46
|
from rich.console import Group
|
45
47
|
from rich.panel import Panel
|
46
48
|
from rich.table import Table
|
@@ -51,6 +53,8 @@ from aisuite.framework.message import ChatCompletionMessageToolCall
|
|
51
53
|
|
52
54
|
from proscenium.verbs.display.tools import complete_with_tools_panel
|
53
55
|
|
56
|
+
log = logging.getLogger(__name__)
|
57
|
+
|
54
58
|
provider_configs = {
|
55
59
|
# TODO expose this
|
56
60
|
"ollama": {"timeout": 180},
|
@@ -63,14 +67,14 @@ def complete_simple(
|
|
63
67
|
model_id: str, system_prompt: str, user_prompt: str, **kwargs
|
64
68
|
) -> str:
|
65
69
|
|
66
|
-
|
70
|
+
console = kwargs.pop("console", None)
|
67
71
|
|
68
72
|
messages = [
|
69
73
|
{"role": "system", "content": system_prompt},
|
70
74
|
{"role": "user", "content": user_prompt},
|
71
75
|
]
|
72
76
|
|
73
|
-
if
|
77
|
+
if console is not None:
|
74
78
|
|
75
79
|
kwargs_text = "\n".join([str(k) + ": " + str(v) for k, v in kwargs.items()])
|
76
80
|
|
@@ -90,34 +94,30 @@ model_id: {model_id}
|
|
90
94
|
call_panel = Panel(
|
91
95
|
Group(params_text, messages_table), title="complete_simple call"
|
92
96
|
)
|
93
|
-
print(call_panel)
|
97
|
+
console.print(call_panel)
|
94
98
|
|
95
99
|
response = client.chat.completions.create(
|
96
100
|
model=model_id, messages=messages, **kwargs
|
97
101
|
)
|
98
102
|
response = response.choices[0].message.content
|
99
103
|
|
100
|
-
if
|
101
|
-
print(Panel(response, title="Response"))
|
104
|
+
if console is not None:
|
105
|
+
console.print(Panel(response, title="Response"))
|
102
106
|
|
103
107
|
return response
|
104
108
|
|
105
109
|
|
106
|
-
def evaluate_tool_call(
|
107
|
-
tool_map: dict, tool_call: ChatCompletionMessageToolCall, rich_output: bool = False
|
108
|
-
) -> Any:
|
110
|
+
def evaluate_tool_call(tool_map: dict, tool_call: ChatCompletionMessageToolCall) -> Any:
|
109
111
|
|
110
112
|
function_name = tool_call.function.name
|
111
113
|
# TODO validate the arguments?
|
112
114
|
function_args = json.loads(tool_call.function.arguments)
|
113
115
|
|
114
|
-
|
115
|
-
print(f"Evaluating tool call: {function_name} with args {function_args}")
|
116
|
+
log.info(f"Evaluating tool call: {function_name} with args {function_args}")
|
116
117
|
|
117
118
|
function_response = tool_map[function_name](**function_args)
|
118
119
|
|
119
|
-
|
120
|
-
print(f" Response: {function_response}")
|
120
|
+
log.info(f" Response: {function_response}")
|
121
121
|
|
122
122
|
return function_response
|
123
123
|
|
@@ -134,23 +134,19 @@ def tool_response_message(
|
|
134
134
|
}
|
135
135
|
|
136
136
|
|
137
|
-
def evaluate_tool_calls(
|
138
|
-
tool_call_message, tool_map: dict, rich_output: bool = False
|
139
|
-
) -> list[dict]:
|
137
|
+
def evaluate_tool_calls(tool_call_message, tool_map: dict) -> list[dict]:
|
140
138
|
|
141
139
|
tool_call: ChatCompletionMessageToolCall
|
142
140
|
|
143
|
-
|
144
|
-
print("Evaluating tool calls")
|
141
|
+
log.info("Evaluating tool calls")
|
145
142
|
|
146
143
|
new_messages: list[dict] = []
|
147
144
|
|
148
145
|
for tool_call in tool_call_message.tool_calls:
|
149
|
-
function_response = evaluate_tool_call(tool_map, tool_call
|
146
|
+
function_response = evaluate_tool_call(tool_map, tool_call)
|
150
147
|
new_messages.append(tool_response_message(tool_call, function_response))
|
151
148
|
|
152
|
-
|
153
|
-
print("Tool calls evaluated")
|
149
|
+
log.info("Tool calls evaluated")
|
154
150
|
|
155
151
|
return new_messages
|
156
152
|
|
@@ -160,10 +156,10 @@ def complete_for_tool_applications(
|
|
160
156
|
messages: list,
|
161
157
|
tool_desc_list: list,
|
162
158
|
temperature: float,
|
163
|
-
|
159
|
+
console: Optional[Console] = None,
|
164
160
|
):
|
165
161
|
|
166
|
-
if
|
162
|
+
if console is not None:
|
167
163
|
panel = complete_with_tools_panel(
|
168
164
|
"complete for tool applications",
|
169
165
|
model_id,
|
@@ -171,7 +167,7 @@ def complete_for_tool_applications(
|
|
171
167
|
messages,
|
172
168
|
temperature,
|
173
169
|
)
|
174
|
-
print(panel)
|
170
|
+
console.print(panel)
|
175
171
|
|
176
172
|
response = client.chat.completions.create(
|
177
173
|
model=model_id,
|
@@ -190,13 +186,13 @@ def complete_with_tool_results(
|
|
190
186
|
tool_evaluation_messages: list[dict],
|
191
187
|
tool_desc_list: list,
|
192
188
|
temperature: float,
|
193
|
-
|
189
|
+
console: Optional[Console] = None,
|
194
190
|
):
|
195
191
|
|
196
192
|
messages.append(tool_call_message)
|
197
193
|
messages.extend(tool_evaluation_messages)
|
198
194
|
|
199
|
-
if
|
195
|
+
if console is not None:
|
200
196
|
panel = complete_with_tools_panel(
|
201
197
|
"complete call with tool results",
|
202
198
|
model_id,
|
@@ -204,7 +200,7 @@ def complete_with_tool_results(
|
|
204
200
|
messages,
|
205
201
|
temperature,
|
206
202
|
)
|
207
|
-
print(panel)
|
203
|
+
console.print(panel)
|
208
204
|
|
209
205
|
response = client.chat.completions.create(
|
210
206
|
model=model_id,
|
@@ -4,6 +4,6 @@ from rich.text import Text
|
|
4
4
|
def header() -> Text:
|
5
5
|
text = Text()
|
6
6
|
text.append("Proscenium 🎭\n", style="bold")
|
7
|
-
text.append("
|
7
|
+
text.append("https://the-ai-alliance.github.io/proscenium/\n")
|
8
8
|
# TODO version, timestamp, ...
|
9
9
|
return text
|
proscenium/verbs/display.py
CHANGED
proscenium/verbs/extract.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
|
+
from typing import Optional
|
1
2
|
import logging
|
3
|
+
from rich.console import Console
|
2
4
|
from string import Formatter
|
3
5
|
|
4
6
|
import json
|
@@ -6,6 +8,8 @@ from pydantic import BaseModel
|
|
6
8
|
|
7
9
|
from proscenium.verbs.complete import complete_simple
|
8
10
|
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
9
13
|
extraction_system_prompt = "You are an entity extractor"
|
10
14
|
|
11
15
|
|
@@ -36,7 +40,7 @@ def extract_to_pydantic_model(
|
|
36
40
|
extraction_template: str,
|
37
41
|
clazz: type[BaseModel],
|
38
42
|
text: str,
|
39
|
-
|
43
|
+
console: Optional[Console] = None,
|
40
44
|
) -> BaseModel:
|
41
45
|
|
42
46
|
extract_str = complete_simple(
|
@@ -47,15 +51,15 @@ def extract_to_pydantic_model(
|
|
47
51
|
"type": "json_object",
|
48
52
|
"schema": clazz.model_json_schema(),
|
49
53
|
},
|
50
|
-
|
54
|
+
console=console,
|
51
55
|
)
|
52
56
|
|
53
|
-
|
57
|
+
log.info("complete_to_pydantic_model: extract_str = <<<%s>>>", extract_str)
|
54
58
|
|
55
59
|
try:
|
56
60
|
extract_dict = json.loads(extract_str)
|
57
61
|
return clazz.model_construct(**extract_dict)
|
58
62
|
except Exception as e:
|
59
|
-
|
63
|
+
log.error("complete_to_pydantic_model: Exception: %s", e)
|
60
64
|
|
61
65
|
return None
|