proscenium 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
proscenium/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,33 @@
1
+ from rich import print
2
+
3
+ from pymilvus import MilvusClient
4
+ from pymilvus import model
5
+
6
+ from proscenium.verbs.read import load_file
7
+ from proscenium.verbs.chunk import documents_to_chunks_by_characters
8
+ from proscenium.verbs.vector_database import create_collection
9
+ from proscenium.verbs.vector_database import add_chunks_to_vector_db
10
+ from proscenium.verbs.display.milvus import collection_panel
11
+
12
+
13
+ def build_vector_db(
14
+ data_files: list[str],
15
+ vector_db_client: MilvusClient,
16
+ embedding_fn: model.dense.SentenceTransformerEmbeddingFunction,
17
+ collection_name: str,
18
+ ):
19
+
20
+ create_collection(vector_db_client, embedding_fn, collection_name, overwrite=True)
21
+
22
+ for data_file in data_files:
23
+
24
+ documents = load_file(data_file)
25
+ chunks = documents_to_chunks_by_characters(documents)
26
+ print("Data file", data_file, "has", len(chunks), "chunks")
27
+
28
+ info = add_chunks_to_vector_db(
29
+ vector_db_client, embedding_fn, chunks, collection_name
30
+ )
31
+ print(info["insert_count"], "chunks inserted")
32
+
33
+ print(collection_panel(vector_db_client, collection_name))
@@ -0,0 +1,80 @@
1
+ from typing import List
2
+ from typing import Callable
3
+ from typing import Any
4
+
5
+ import time
6
+ from pydantic import BaseModel
7
+
8
+ from rich import print
9
+ from rich.panel import Panel
10
+ from rich.progress import Progress
11
+
12
+ from langchain_core.documents.base import Document
13
+
14
+ from proscenium.verbs.chunk import documents_to_chunks_by_tokens
15
+ from proscenium.verbs.extract import extract_to_pydantic_model
16
+
17
+
18
+ def extract_from_document_chunks(
19
+ doc: Document,
20
+ doc_as_rich: Callable[[Document], Panel],
21
+ chunk_extraction_model_id: str,
22
+ chunk_extraction_template: str,
23
+ chunk_extract_clazz: type[BaseModel],
24
+ delay: float,
25
+ verbose: bool = False,
26
+ ) -> List[BaseModel]:
27
+
28
+ print(doc_as_rich(doc))
29
+ print()
30
+
31
+ extract_models = []
32
+
33
+ chunks = documents_to_chunks_by_tokens([doc], chunk_size=1000, chunk_overlap=0)
34
+ for i, chunk in enumerate(chunks):
35
+
36
+ ce = extract_to_pydantic_model(
37
+ chunk_extraction_model_id,
38
+ chunk_extraction_template,
39
+ chunk_extract_clazz,
40
+ chunk.page_content,
41
+ )
42
+
43
+ if verbose:
44
+ print("Extract model in chunk", i + 1, "of", len(chunks))
45
+ print(Panel(str(ce)))
46
+
47
+ extract_models.append(ce)
48
+ time.sleep(delay)
49
+
50
+ return extract_models
51
+
52
+
53
+ def enrich_documents(
54
+ retrieve_documents: Callable[[], List[Document]],
55
+ extract_from_doc_chunks: Callable[[Document], List[BaseModel]],
56
+ doc_enrichments: Callable[[Document, list[BaseModel]], BaseModel],
57
+ enrichments_jsonl_file: str,
58
+ verbose: bool = False,
59
+ ) -> None:
60
+
61
+ docs = retrieve_documents()
62
+
63
+ with Progress() as progress:
64
+
65
+ task_enrich = progress.add_task(
66
+ "[green]Enriching documents...", total=len(docs)
67
+ )
68
+
69
+ with open(enrichments_jsonl_file, "wt") as f:
70
+
71
+ for doc in docs:
72
+
73
+ chunk_extract_models = extract_from_doc_chunks(doc, verbose)
74
+ enrichments = doc_enrichments(doc, chunk_extract_models)
75
+ enrichments_json = enrichments.model_dump_json()
76
+ f.write(enrichments_json + "\n")
77
+
78
+ progress.update(task_enrich, advance=1)
79
+
80
+ print("Wrote document enrichments to", enrichments_jsonl_file)
@@ -0,0 +1,89 @@
1
+ from typing import Optional
2
+ from rich import print
3
+
4
+ from langchain_core.documents.base import Document
5
+ from neo4j import Driver
6
+
7
+ from pymilvus import MilvusClient
8
+
9
+ from proscenium.verbs.vector_database import vector_db
10
+ from proscenium.verbs.vector_database import create_collection
11
+ from proscenium.verbs.vector_database import closest_chunks
12
+ from proscenium.verbs.vector_database import add_chunks_to_vector_db
13
+ from proscenium.verbs.vector_database import embedding_function
14
+ from proscenium.verbs.display.milvus import collection_panel
15
+
16
+
17
+ class Resolver:
18
+
19
+ def __init__(
20
+ self,
21
+ cypher: str,
22
+ field_name: str,
23
+ collection_name: str,
24
+ embedding_model_id: str,
25
+ ):
26
+ self.cypher = cypher
27
+ self.field_name = field_name
28
+ self.collection_name = collection_name
29
+ self.embedding_model_id = embedding_model_id
30
+
31
+
32
+ def load_entity_resolver(
33
+ driver: Driver,
34
+ resolvers: list[Resolver],
35
+ milvus_uri: str,
36
+ ) -> None:
37
+
38
+ vector_db_client = vector_db(milvus_uri, overwrite=True)
39
+ print("Vector db stored at", milvus_uri)
40
+
41
+ for resolver in resolvers:
42
+
43
+ embedding_fn = embedding_function(resolver.embedding_model_id)
44
+ print("Embedding model", resolver.embedding_model_id)
45
+
46
+ values = []
47
+ with driver.session() as session:
48
+ result = session.run(resolver.cypher)
49
+ new_values = [Document(record[resolver.field_name]) for record in result]
50
+ values.extend(new_values)
51
+
52
+ print("Loading entity resolver into vector db", resolver.collection_name)
53
+ create_collection(
54
+ vector_db_client, embedding_fn, resolver.collection_name, overwrite=True
55
+ )
56
+ info = add_chunks_to_vector_db(
57
+ vector_db_client, embedding_fn, values, resolver.collection_name
58
+ )
59
+ print(info["insert_count"], "chunks inserted")
60
+ print(collection_panel(vector_db_client, resolver.collection_name))
61
+
62
+ vector_db_client.close()
63
+
64
+
65
+ def find_matching_objects(
66
+ vector_db_client: MilvusClient,
67
+ approximate: str,
68
+ resolver: Resolver,
69
+ ) -> Optional[str]:
70
+
71
+ print("Loading collection", resolver.collection_name)
72
+ vector_db_client.load_collection(resolver.collection_name)
73
+
74
+ print("Finding entity matches for", approximate, "using", resolver.collection_name)
75
+
76
+ hits = closest_chunks(
77
+ vector_db_client,
78
+ resolver.embedding_fn,
79
+ approximate,
80
+ resolver.collection_name,
81
+ k=5,
82
+ )
83
+ # TODO apply distance threshold
84
+ for match in [head["entity"]["text"] for head in hits[:1]]:
85
+ print("Closest match:", match)
86
+ return match
87
+
88
+ print("No match found")
89
+ return None
@@ -0,0 +1,43 @@
1
+ from typing import Callable
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from rich import print
6
+ from rich.panel import Panel
7
+
8
+ from neo4j import Driver
9
+
10
+ from proscenium.verbs.complete import complete_simple
11
+
12
+
13
+ def query_to_prompts(
14
+ question: str,
15
+ query_extraction_model_id: str,
16
+ milvus_uri: str,
17
+ driver: Driver,
18
+ query_extract: Callable[
19
+ [str, str, bool], BaseModel
20
+ ], # (query_str, query_extraction_model_id) -> QueryExtractions
21
+ extract_to_context: Callable[
22
+ [BaseModel, str, Driver, str, bool], BaseModel
23
+ ], # (QueryExtractions, query_str, Driver, milvus_uri) -> Context
24
+ context_to_prompts: Callable[
25
+ [BaseModel, bool], tuple[str, str]
26
+ ], # Context -> (system_prompt, user_prompt)
27
+ verbose: bool = False,
28
+ ) -> str:
29
+
30
+ print("Extracting information from the question")
31
+ extract = query_extract(question, query_extraction_model_id, verbose)
32
+ if extract is None:
33
+ print("Unable to extract information from that question")
34
+ return None
35
+ print("Extract:", extract)
36
+
37
+ print("Forming context from the extracted information")
38
+ context = extract_to_context(extract, question, driver, milvus_uri, verbose)
39
+ print("Context:", context)
40
+
41
+ prompts = context_to_prompts(context, verbose)
42
+
43
+ return prompts
@@ -0,0 +1,39 @@
1
+ from typing import Callable
2
+ from typing import Any
3
+
4
+ import json
5
+ from pydantic import BaseModel
6
+
7
+ from rich import print
8
+ from rich.progress import Progress
9
+
10
+ from neo4j import Driver
11
+
12
+
13
+ def load_knowledge_graph(
14
+ driver: Driver,
15
+ enrichments_jsonl_file: str,
16
+ enrichments_clazz: type[BaseModel],
17
+ doc_enrichments_to_graph: Callable[[Any, BaseModel], None],
18
+ ) -> None:
19
+
20
+ print("Parsing enrichments from", enrichments_jsonl_file)
21
+
22
+ enrichmentss = []
23
+ with open(enrichments_jsonl_file, "r") as f:
24
+ for line in f:
25
+ e = enrichments_clazz.model_construct(**json.loads(line))
26
+ enrichmentss.append(e)
27
+
28
+ with Progress() as progress:
29
+
30
+ task_load = progress.add_task(
31
+ f"Loading {len(enrichmentss)} enriched documents into graph...",
32
+ total=len(enrichmentss),
33
+ )
34
+
35
+ with driver.session() as session:
36
+ session.run("MATCH (n) DETACH DELETE n") # empty graph
37
+ for e in enrichmentss:
38
+ session.execute_write(doc_enrichments_to_graph, e)
39
+ progress.update(task_load, advance=1)
@@ -0,0 +1,63 @@
1
+ from typing import List, Dict
2
+
3
+ from rich import print
4
+ from rich.panel import Panel
5
+
6
+ from pymilvus import MilvusClient
7
+ from pymilvus import model
8
+
9
+ from proscenium.verbs.complete import complete_simple
10
+ from proscenium.verbs.display.milvus import chunk_hits_table
11
+ from proscenium.verbs.vector_database import closest_chunks
12
+
13
+
14
+ rag_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
15
+
16
+ rag_prompt_template = """
17
+ The document chunks that are most similar to the query are:
18
+
19
+ {context}
20
+
21
+ Question:
22
+
23
+ {query}
24
+
25
+ Answer:
26
+ """
27
+
28
+
29
+ def rag_prompt(chunks: List[Dict], query: str) -> str:
30
+
31
+ context = "\n\n".join(
32
+ [
33
+ f"CHUNK {chunk['id']}. {chunk['entity']['text']}"
34
+ for i, chunk in enumerate(chunks)
35
+ ]
36
+ )
37
+
38
+ return rag_prompt_template.format(context=context, query=query)
39
+
40
+
41
+ def answer_question(
42
+ query: str,
43
+ model_id: str,
44
+ vector_db_client: MilvusClient,
45
+ embedding_fn: model.dense.SentenceTransformerEmbeddingFunction,
46
+ collection_name: str,
47
+ verbose: bool = False,
48
+ ) -> str:
49
+
50
+ print(Panel(query, title="User"))
51
+
52
+ chunks = closest_chunks(vector_db_client, embedding_fn, query, collection_name)
53
+ if verbose:
54
+ print("Found", len(chunks), "closest chunks")
55
+ print(chunk_hits_table(chunks))
56
+
57
+ prompt = rag_prompt(chunks, query)
58
+ if verbose:
59
+ print("RAG prompt created. Calling inference at", model_id, "\n\n")
60
+
61
+ answer = complete_simple(model_id, rag_system_prompt, prompt, rich_output=verbose)
62
+
63
+ return answer
@@ -0,0 +1,103 @@
1
+ from typing import List
2
+
3
+ from rich import print
4
+ from rich.panel import Panel
5
+ from rich.text import Text
6
+ from thespian.actors import Actor
7
+
8
+ from gofannon.base import BaseTool
9
+
10
+ from proscenium.verbs.complete import (
11
+ complete_for_tool_applications,
12
+ evaluate_tool_calls,
13
+ complete_with_tool_results,
14
+ )
15
+ from proscenium.verbs.invoke import process_tools
16
+
17
+
18
+ def tool_applier_actor_class(
19
+ tools: List[BaseTool],
20
+ system_message: str,
21
+ model_id: str,
22
+ temperature: float = 0.75,
23
+ rich_output: bool = False,
24
+ ):
25
+
26
+ tool_map, tool_desc_list = process_tools(tools)
27
+
28
+ class ToolApplier(Actor):
29
+
30
+ def receiveMessage(self, message, sender):
31
+
32
+ response = apply_tools(
33
+ model_id=model_id,
34
+ system_message=system_message,
35
+ message=message,
36
+ tool_desc_list=tool_desc_list,
37
+ tool_map=tool_map,
38
+ temperature=temperature,
39
+ rich_output=rich_output,
40
+ )
41
+
42
+ self.send(sender, response)
43
+
44
+ return ToolApplier
45
+
46
+
47
+ def apply_tools(
48
+ model_id: str,
49
+ system_message: str,
50
+ message: str,
51
+ tool_desc_list: list,
52
+ tool_map: dict,
53
+ temperature: float = 0.75,
54
+ rich_output: bool = False,
55
+ ) -> str:
56
+
57
+ messages = [
58
+ {"role": "system", "content": system_message},
59
+ {"role": "user", "content": message},
60
+ ]
61
+
62
+ response = complete_for_tool_applications(
63
+ model_id, messages, tool_desc_list, temperature, rich_output
64
+ )
65
+
66
+ tool_call_message = response.choices[0].message
67
+
68
+ if tool_call_message.tool_calls is None or len(tool_call_message.tool_calls) == 0:
69
+
70
+ if rich_output:
71
+ print(
72
+ Panel(
73
+ Text(str(tool_call_message.content)),
74
+ title="Tool Application Response",
75
+ )
76
+ )
77
+
78
+ print("No tool applications detected")
79
+
80
+ return tool_call_message.content
81
+
82
+ else:
83
+
84
+ if rich_output:
85
+ print(
86
+ Panel(Text(str(tool_call_message)), title="Tool Application Response")
87
+ )
88
+
89
+ tool_evaluation_messages = evaluate_tool_calls(
90
+ tool_call_message, tool_map, rich_output
91
+ )
92
+
93
+ result = complete_with_tool_results(
94
+ model_id,
95
+ messages,
96
+ tool_call_message,
97
+ tool_evaluation_messages,
98
+ tool_desc_list,
99
+ temperature,
100
+ rich_output,
101
+ )
102
+
103
+ return result
File without changes
@@ -0,0 +1,40 @@
1
+ import logging
2
+ import os
3
+ from typing import List
4
+ from typing import Iterable
5
+
6
+ from langchain_core.documents.base import Document
7
+
8
+ from langchain.text_splitter import CharacterTextSplitter
9
+ from langchain.text_splitter import TokenTextSplitter
10
+
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+ logging.getLogger("langchain_text_splitters.base").setLevel(logging.ERROR)
13
+
14
+ # Each text chunk inherits the metadata from the document.
15
+
16
+
17
+ def documents_to_chunks_by_characters(
18
+ documents: Iterable[Document], chunk_size: int = 1000, chunk_overlap: int = 0
19
+ ) -> List[Document]:
20
+
21
+ text_splitter = CharacterTextSplitter(
22
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
23
+ )
24
+
25
+ chunks = text_splitter.split_documents(documents)
26
+
27
+ return chunks
28
+
29
+
30
+ def documents_to_chunks_by_tokens(
31
+ documents: Iterable[Document], chunk_size: int = 1000, chunk_overlap: int = 0
32
+ ) -> List[Document]:
33
+
34
+ text_splitter = TokenTextSplitter(
35
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
36
+ )
37
+
38
+ chunks = text_splitter.split_documents(documents)
39
+
40
+ return chunks