proscenium 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. proscenium/__init__.py +3 -0
  2. proscenium/admin/__init__.py +37 -0
  3. proscenium/bin/bot.py +142 -0
  4. proscenium/core/__init__.py +152 -0
  5. proscenium/interfaces/__init__.py +3 -0
  6. proscenium/interfaces/slack.py +265 -0
  7. proscenium/patterns/__init__.py +3 -0
  8. proscenium/patterns/chunk_space.py +51 -0
  9. proscenium/{scripts → patterns}/document_enricher.py +15 -11
  10. proscenium/{scripts → patterns}/entity_resolver.py +24 -18
  11. proscenium/patterns/graph_rag.py +61 -0
  12. proscenium/{scripts → patterns}/knowledge_graph.py +4 -2
  13. proscenium/{scripts → patterns}/rag.py +6 -12
  14. proscenium/{scripts → patterns}/tools.py +13 -45
  15. proscenium/verbs/__init__.py +3 -0
  16. proscenium/verbs/chunk.py +2 -0
  17. proscenium/verbs/complete.py +24 -28
  18. proscenium/verbs/display/__init__.py +1 -1
  19. proscenium/verbs/display.py +3 -0
  20. proscenium/verbs/extract.py +8 -4
  21. proscenium/verbs/invoke.py +3 -0
  22. proscenium/verbs/read.py +6 -8
  23. proscenium/verbs/remember.py +5 -0
  24. proscenium/verbs/vector_database.py +13 -20
  25. proscenium/verbs/write.py +3 -0
  26. {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/METADATA +5 -8
  27. proscenium-0.0.3.dist-info/RECORD +34 -0
  28. {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/WHEEL +1 -1
  29. proscenium-0.0.3.dist-info/entry_points.txt +3 -0
  30. proscenium/scripts/__init__.py +0 -0
  31. proscenium/scripts/chunk_space.py +0 -33
  32. proscenium/scripts/graph_rag.py +0 -43
  33. proscenium/verbs/display/huggingface.py +0 -0
  34. proscenium/verbs/know.py +0 -9
  35. proscenium-0.0.1.dist-info/RECORD +0 -30
  36. {proscenium-0.0.1.dist-info → proscenium-0.0.3.dist-info}/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
1
  from typing import List
2
2
  from typing import Callable
3
- from typing import Any
3
+ from typing import Optional
4
4
 
5
5
  import time
6
+ import logging
6
7
  from pydantic import BaseModel
7
8
 
8
- from rich import print
9
9
  from rich.panel import Panel
10
+ from rich.console import Console
10
11
  from rich.progress import Progress
11
12
 
12
13
  from langchain_core.documents.base import Document
@@ -14,6 +15,8 @@ from langchain_core.documents.base import Document
14
15
  from proscenium.verbs.chunk import documents_to_chunks_by_tokens
15
16
  from proscenium.verbs.extract import extract_to_pydantic_model
16
17
 
18
+ log = logging.getLogger(__name__)
19
+
17
20
 
18
21
  def extract_from_document_chunks(
19
22
  doc: Document,
@@ -22,11 +25,12 @@ def extract_from_document_chunks(
22
25
  chunk_extraction_template: str,
23
26
  chunk_extract_clazz: type[BaseModel],
24
27
  delay: float,
25
- verbose: bool = False,
28
+ console: Optional[Console] = None,
26
29
  ) -> List[BaseModel]:
27
30
 
28
- print(doc_as_rich(doc))
29
- print()
31
+ if console is not None:
32
+ console.print(doc_as_rich(doc))
33
+ console.print()
30
34
 
31
35
  extract_models = []
32
36
 
@@ -40,9 +44,9 @@ def extract_from_document_chunks(
40
44
  chunk.page_content,
41
45
  )
42
46
 
43
- if verbose:
44
- print("Extract model in chunk", i + 1, "of", len(chunks))
45
- print(Panel(str(ce)))
47
+ log.info("Extract model in chunk %s of %s", i + 1, len(chunks))
48
+ if console is not None:
49
+ console.print(Panel(str(ce)))
46
50
 
47
51
  extract_models.append(ce)
48
52
  time.sleep(delay)
@@ -55,7 +59,7 @@ def enrich_documents(
55
59
  extract_from_doc_chunks: Callable[[Document], List[BaseModel]],
56
60
  doc_enrichments: Callable[[Document, list[BaseModel]], BaseModel],
57
61
  enrichments_jsonl_file: str,
58
- verbose: bool = False,
62
+ console: Optional[Console] = None,
59
63
  ) -> None:
60
64
 
61
65
  docs = retrieve_documents()
@@ -70,11 +74,11 @@ def enrich_documents(
70
74
 
71
75
  for doc in docs:
72
76
 
73
- chunk_extract_models = extract_from_doc_chunks(doc, verbose)
77
+ chunk_extract_models = extract_from_doc_chunks(doc)
74
78
  enrichments = doc_enrichments(doc, chunk_extract_models)
75
79
  enrichments_json = enrichments.model_dump_json()
76
80
  f.write(enrichments_json + "\n")
77
81
 
78
82
  progress.update(task_enrich, advance=1)
79
83
 
80
- print("Wrote document enrichments to", enrichments_jsonl_file)
84
+ log.info("Wrote document enrichments to %s", enrichments_jsonl_file)
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
- from rich import print
2
+ import logging
3
3
 
4
+ from rich.console import Console
4
5
  from langchain_core.documents.base import Document
5
6
  from neo4j import Driver
6
7
 
@@ -13,6 +14,8 @@ from proscenium.verbs.vector_database import add_chunks_to_vector_db
13
14
  from proscenium.verbs.vector_database import embedding_function
14
15
  from proscenium.verbs.display.milvus import collection_panel
15
16
 
17
+ log = logging.getLogger(__name__)
18
+
16
19
 
17
20
  class Resolver:
18
21
 
@@ -21,27 +24,27 @@ class Resolver:
21
24
  cypher: str,
22
25
  field_name: str,
23
26
  collection_name: str,
24
- embedding_model_id: str,
25
27
  ):
26
28
  self.cypher = cypher
27
29
  self.field_name = field_name
28
30
  self.collection_name = collection_name
29
- self.embedding_model_id = embedding_model_id
30
31
 
31
32
 
32
33
  def load_entity_resolver(
33
34
  driver: Driver,
34
35
  resolvers: list[Resolver],
36
+ embedding_model_id: str,
35
37
  milvus_uri: str,
38
+ console: Optional[Console] = None,
36
39
  ) -> None:
37
40
 
38
- vector_db_client = vector_db(milvus_uri, overwrite=True)
39
- print("Vector db stored at", milvus_uri)
41
+ vector_db_client = vector_db(milvus_uri)
42
+ log.info("Vector db stored at %s", milvus_uri)
40
43
 
41
- for resolver in resolvers:
44
+ embedding_fn = embedding_function(embedding_model_id)
45
+ log.info("Embedding model %s", embedding_model_id)
42
46
 
43
- embedding_fn = embedding_function(resolver.embedding_model_id)
44
- print("Embedding model", resolver.embedding_model_id)
47
+ for resolver in resolvers:
45
48
 
46
49
  values = []
47
50
  with driver.session() as session:
@@ -49,15 +52,16 @@ def load_entity_resolver(
49
52
  new_values = [Document(record[resolver.field_name]) for record in result]
50
53
  values.extend(new_values)
51
54
 
52
- print("Loading entity resolver into vector db", resolver.collection_name)
53
- create_collection(
54
- vector_db_client, embedding_fn, resolver.collection_name, overwrite=True
55
- )
55
+ log.info("Loading entity resolver into vector db %s", resolver.collection_name)
56
+ create_collection(vector_db_client, embedding_fn, resolver.collection_name)
57
+
56
58
  info = add_chunks_to_vector_db(
57
59
  vector_db_client, embedding_fn, values, resolver.collection_name
58
60
  )
59
- print(info["insert_count"], "chunks inserted")
60
- print(collection_panel(vector_db_client, resolver.collection_name))
61
+ log.info("%s chunks inserted ", info["insert_count"])
62
+
63
+ if console is not None:
64
+ console.print(collection_panel(vector_db_client, resolver.collection_name))
61
65
 
62
66
  vector_db_client.close()
63
67
 
@@ -68,10 +72,12 @@ def find_matching_objects(
68
72
  resolver: Resolver,
69
73
  ) -> Optional[str]:
70
74
 
71
- print("Loading collection", resolver.collection_name)
75
+ log.info("Loading collection", resolver.collection_name)
72
76
  vector_db_client.load_collection(resolver.collection_name)
73
77
 
74
- print("Finding entity matches for", approximate, "using", resolver.collection_name)
78
+ log.info(
79
+ "Finding entity matches for", approximate, "using", resolver.collection_name
80
+ )
75
81
 
76
82
  hits = closest_chunks(
77
83
  vector_db_client,
@@ -82,8 +88,8 @@ def find_matching_objects(
82
88
  )
83
89
  # TODO apply distance threshold
84
90
  for match in [head["entity"]["text"] for head in hits[:1]]:
85
- print("Closest match:", match)
91
+ log.info("Closest match:", match)
86
92
  return match
87
93
 
88
- print("No match found")
94
+ log.info("No match found")
89
95
  return None
@@ -0,0 +1,61 @@
1
+ from typing import Callable
2
+ from typing import Optional
3
+
4
+ import logging
5
+
6
+ from rich.console import Console
7
+
8
+ from pydantic import BaseModel
9
+ from uuid import uuid4, UUID
10
+ from neo4j import Driver
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def query_to_prompts(
16
+ query: str,
17
+ query_extraction_model_id: str,
18
+ milvus_uri: str,
19
+ driver: Driver,
20
+ query_extract: Callable[
21
+ [str, str], BaseModel
22
+ ], # (query_str, query_extraction_model_id) -> QueryExtractions
23
+ query_extract_to_graph: Callable[
24
+ [str, UUID, BaseModel], None
25
+ ], # query, query_id, extract
26
+ query_extract_to_context: Callable[
27
+ [BaseModel, str, Driver, str, Optional[Console]], BaseModel
28
+ ], # (QueryExtractions, query_str, Driver, milvus_uri) -> Context
29
+ context_to_prompts: Callable[
30
+ [BaseModel], tuple[str, str]
31
+ ], # Context -> (system_prompt, user_prompt)
32
+ console: Optional[Console] = None,
33
+ ) -> Optional[tuple[str, str]]:
34
+
35
+ query_id = uuid4()
36
+
37
+ log.info("Extracting information from the question")
38
+
39
+ extract = query_extract(query, query_extraction_model_id)
40
+ if extract is None:
41
+ log.info("Unable to extract information from that question")
42
+ return None
43
+
44
+ log.info("Extract: %s", extract)
45
+
46
+ log.info("Storing the extracted information in the graph")
47
+ query_extract_to_graph(query, query_id, extract, driver)
48
+
49
+ log.info("Forming context from the extracted information")
50
+ context = query_extract_to_context(
51
+ extract, query, driver, milvus_uri, console=console
52
+ )
53
+ if context is None:
54
+ log.info("Unable to form context from the extracted information")
55
+ return None
56
+
57
+ log.info("Context: %s", context)
58
+
59
+ prompts = context_to_prompts(context)
60
+
61
+ return prompts
@@ -1,14 +1,16 @@
1
1
  from typing import Callable
2
2
  from typing import Any
3
3
 
4
+ import logging
4
5
  import json
5
6
  from pydantic import BaseModel
6
7
 
7
- from rich import print
8
8
  from rich.progress import Progress
9
9
 
10
10
  from neo4j import Driver
11
11
 
12
+ log = logging.getLogger(__name__)
13
+
12
14
 
13
15
  def load_knowledge_graph(
14
16
  driver: Driver,
@@ -17,7 +19,7 @@ def load_knowledge_graph(
17
19
  doc_enrichments_to_graph: Callable[[Any, BaseModel], None],
18
20
  ) -> None:
19
21
 
20
- print("Parsing enrichments from", enrichments_jsonl_file)
22
+ log.info("Parsing enrichments from %s", enrichments_jsonl_file)
21
23
 
22
24
  enrichmentss = []
23
25
  with open(enrichments_jsonl_file, "r") as f:
@@ -1,7 +1,5 @@
1
1
  from typing import List, Dict
2
-
3
- from rich import print
4
- from rich.panel import Panel
2
+ import logging
5
3
 
6
4
  from pymilvus import MilvusClient
7
5
  from pymilvus import model
@@ -10,6 +8,7 @@ from proscenium.verbs.complete import complete_simple
10
8
  from proscenium.verbs.display.milvus import chunk_hits_table
11
9
  from proscenium.verbs.vector_database import closest_chunks
12
10
 
11
+ log = logging.getLogger(__name__)
13
12
 
14
13
  rag_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
15
14
 
@@ -44,20 +43,15 @@ def answer_question(
44
43
  vector_db_client: MilvusClient,
45
44
  embedding_fn: model.dense.SentenceTransformerEmbeddingFunction,
46
45
  collection_name: str,
47
- verbose: bool = False,
48
46
  ) -> str:
49
47
 
50
- print(Panel(query, title="User"))
51
-
52
48
  chunks = closest_chunks(vector_db_client, embedding_fn, query, collection_name)
53
- if verbose:
54
- print("Found", len(chunks), "closest chunks")
55
- print(chunk_hits_table(chunks))
49
+ log.info("Found %s closest chunks", len(chunks))
50
+ log.info(chunk_hits_table(chunks))
56
51
 
57
52
  prompt = rag_prompt(chunks, query)
58
- if verbose:
59
- print("RAG prompt created. Calling inference at", model_id, "\n\n")
53
+ log.info("RAG prompt created. Calling inference at %s", model_id)
60
54
 
61
- answer = complete_simple(model_id, rag_system_prompt, prompt, rich_output=verbose)
55
+ answer = complete_simple(model_id, rag_system_prompt, prompt)
62
56
 
63
57
  return answer
@@ -1,47 +1,17 @@
1
- from typing import List
1
+ from typing import Optional
2
+ import logging
2
3
 
3
- from rich import print
4
+ from rich.console import Console
4
5
  from rich.panel import Panel
5
6
  from rich.text import Text
6
- from thespian.actors import Actor
7
-
8
- from gofannon.base import BaseTool
9
7
 
10
8
  from proscenium.verbs.complete import (
11
9
  complete_for_tool_applications,
12
10
  evaluate_tool_calls,
13
11
  complete_with_tool_results,
14
12
  )
15
- from proscenium.verbs.invoke import process_tools
16
-
17
-
18
- def tool_applier_actor_class(
19
- tools: List[BaseTool],
20
- system_message: str,
21
- model_id: str,
22
- temperature: float = 0.75,
23
- rich_output: bool = False,
24
- ):
25
-
26
- tool_map, tool_desc_list = process_tools(tools)
27
13
 
28
- class ToolApplier(Actor):
29
-
30
- def receiveMessage(self, message, sender):
31
-
32
- response = apply_tools(
33
- model_id=model_id,
34
- system_message=system_message,
35
- message=message,
36
- tool_desc_list=tool_desc_list,
37
- tool_map=tool_map,
38
- temperature=temperature,
39
- rich_output=rich_output,
40
- )
41
-
42
- self.send(sender, response)
43
-
44
- return ToolApplier
14
+ log = logging.getLogger(__name__)
45
15
 
46
16
 
47
17
  def apply_tools(
@@ -51,7 +21,7 @@ def apply_tools(
51
21
  tool_desc_list: list,
52
22
  tool_map: dict,
53
23
  temperature: float = 0.75,
54
- rich_output: bool = False,
24
+ console: Optional[Console] = None,
55
25
  ) -> str:
56
26
 
57
27
  messages = [
@@ -60,35 +30,33 @@ def apply_tools(
60
30
  ]
61
31
 
62
32
  response = complete_for_tool_applications(
63
- model_id, messages, tool_desc_list, temperature, rich_output
33
+ model_id, messages, tool_desc_list, temperature, console
64
34
  )
65
35
 
66
36
  tool_call_message = response.choices[0].message
67
37
 
68
38
  if tool_call_message.tool_calls is None or len(tool_call_message.tool_calls) == 0:
69
39
 
70
- if rich_output:
71
- print(
40
+ if console is not None:
41
+ console.print(
72
42
  Panel(
73
43
  Text(str(tool_call_message.content)),
74
44
  title="Tool Application Response",
75
45
  )
76
46
  )
77
47
 
78
- print("No tool applications detected")
48
+ log.info("No tool applications detected")
79
49
 
80
50
  return tool_call_message.content
81
51
 
82
52
  else:
83
53
 
84
- if rich_output:
85
- print(
54
+ if console is not None:
55
+ console.print(
86
56
  Panel(Text(str(tool_call_message)), title="Tool Application Response")
87
57
  )
88
58
 
89
- tool_evaluation_messages = evaluate_tool_calls(
90
- tool_call_message, tool_map, rich_output
91
- )
59
+ tool_evaluation_messages = evaluate_tool_calls(tool_call_message, tool_map)
92
60
 
93
61
  result = complete_with_tool_results(
94
62
  model_id,
@@ -97,7 +65,7 @@ def apply_tools(
97
65
  tool_evaluation_messages,
98
66
  tool_desc_list,
99
67
  temperature,
100
- rich_output,
68
+ console,
101
69
  )
102
70
 
103
71
  return result
@@ -0,0 +1,3 @@
1
+ import logging
2
+
3
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
proscenium/verbs/chunk.py CHANGED
@@ -8,6 +8,8 @@ from langchain_core.documents.base import Document
8
8
  from langchain.text_splitter import CharacterTextSplitter
9
9
  from langchain.text_splitter import TokenTextSplitter
10
10
 
11
+ log = logging.getLogger(__name__)
12
+
11
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
14
  logging.getLogger("langchain_text_splitters.base").setLevel(logging.ERROR)
13
15
 
@@ -37,10 +37,12 @@ Valid model ids:
37
37
  - `ollama:granite3.1-dense:2b`
38
38
  """
39
39
 
40
+ from typing import Optional
40
41
  from typing import Any
41
-
42
+ import logging
42
43
  import json
43
- from rich import print
44
+
45
+ from rich.console import Console
44
46
  from rich.console import Group
45
47
  from rich.panel import Panel
46
48
  from rich.table import Table
@@ -51,6 +53,8 @@ from aisuite.framework.message import ChatCompletionMessageToolCall
51
53
 
52
54
  from proscenium.verbs.display.tools import complete_with_tools_panel
53
55
 
56
+ log = logging.getLogger(__name__)
57
+
54
58
  provider_configs = {
55
59
  # TODO expose this
56
60
  "ollama": {"timeout": 180},
@@ -63,14 +67,14 @@ def complete_simple(
63
67
  model_id: str, system_prompt: str, user_prompt: str, **kwargs
64
68
  ) -> str:
65
69
 
66
- rich_output = kwargs.pop("rich_output", False)
70
+ console = kwargs.pop("console", None)
67
71
 
68
72
  messages = [
69
73
  {"role": "system", "content": system_prompt},
70
74
  {"role": "user", "content": user_prompt},
71
75
  ]
72
76
 
73
- if rich_output:
77
+ if console is not None:
74
78
 
75
79
  kwargs_text = "\n".join([str(k) + ": " + str(v) for k, v in kwargs.items()])
76
80
 
@@ -90,34 +94,30 @@ model_id: {model_id}
90
94
  call_panel = Panel(
91
95
  Group(params_text, messages_table), title="complete_simple call"
92
96
  )
93
- print(call_panel)
97
+ console.print(call_panel)
94
98
 
95
99
  response = client.chat.completions.create(
96
100
  model=model_id, messages=messages, **kwargs
97
101
  )
98
102
  response = response.choices[0].message.content
99
103
 
100
- if rich_output:
101
- print(Panel(response, title="Response"))
104
+ if console is not None:
105
+ console.print(Panel(response, title="Response"))
102
106
 
103
107
  return response
104
108
 
105
109
 
106
- def evaluate_tool_call(
107
- tool_map: dict, tool_call: ChatCompletionMessageToolCall, rich_output: bool = False
108
- ) -> Any:
110
+ def evaluate_tool_call(tool_map: dict, tool_call: ChatCompletionMessageToolCall) -> Any:
109
111
 
110
112
  function_name = tool_call.function.name
111
113
  # TODO validate the arguments?
112
114
  function_args = json.loads(tool_call.function.arguments)
113
115
 
114
- if rich_output:
115
- print(f"Evaluating tool call: {function_name} with args {function_args}")
116
+ log.info(f"Evaluating tool call: {function_name} with args {function_args}")
116
117
 
117
118
  function_response = tool_map[function_name](**function_args)
118
119
 
119
- if rich_output:
120
- print(f" Response: {function_response}")
120
+ log.info(f" Response: {function_response}")
121
121
 
122
122
  return function_response
123
123
 
@@ -134,23 +134,19 @@ def tool_response_message(
134
134
  }
135
135
 
136
136
 
137
- def evaluate_tool_calls(
138
- tool_call_message, tool_map: dict, rich_output: bool = False
139
- ) -> list[dict]:
137
+ def evaluate_tool_calls(tool_call_message, tool_map: dict) -> list[dict]:
140
138
 
141
139
  tool_call: ChatCompletionMessageToolCall
142
140
 
143
- if rich_output:
144
- print("Evaluating tool calls")
141
+ log.info("Evaluating tool calls")
145
142
 
146
143
  new_messages: list[dict] = []
147
144
 
148
145
  for tool_call in tool_call_message.tool_calls:
149
- function_response = evaluate_tool_call(tool_map, tool_call, rich_output)
146
+ function_response = evaluate_tool_call(tool_map, tool_call)
150
147
  new_messages.append(tool_response_message(tool_call, function_response))
151
148
 
152
- if rich_output:
153
- print("Tool calls evaluated")
149
+ log.info("Tool calls evaluated")
154
150
 
155
151
  return new_messages
156
152
 
@@ -160,10 +156,10 @@ def complete_for_tool_applications(
160
156
  messages: list,
161
157
  tool_desc_list: list,
162
158
  temperature: float,
163
- rich_output: bool = False,
159
+ console: Optional[Console] = None,
164
160
  ):
165
161
 
166
- if rich_output:
162
+ if console is not None:
167
163
  panel = complete_with_tools_panel(
168
164
  "complete for tool applications",
169
165
  model_id,
@@ -171,7 +167,7 @@ def complete_for_tool_applications(
171
167
  messages,
172
168
  temperature,
173
169
  )
174
- print(panel)
170
+ console.print(panel)
175
171
 
176
172
  response = client.chat.completions.create(
177
173
  model=model_id,
@@ -190,13 +186,13 @@ def complete_with_tool_results(
190
186
  tool_evaluation_messages: list[dict],
191
187
  tool_desc_list: list,
192
188
  temperature: float,
193
- rich_output: bool = False,
189
+ console: Optional[Console] = None,
194
190
  ):
195
191
 
196
192
  messages.append(tool_call_message)
197
193
  messages.extend(tool_evaluation_messages)
198
194
 
199
- if rich_output:
195
+ if console is not None:
200
196
  panel = complete_with_tools_panel(
201
197
  "complete call with tool results",
202
198
  model_id,
@@ -204,7 +200,7 @@ def complete_with_tool_results(
204
200
  messages,
205
201
  temperature,
206
202
  )
207
- print(panel)
203
+ console.print(panel)
208
204
 
209
205
  response = client.chat.completions.create(
210
206
  model=model_id,
@@ -4,6 +4,6 @@ from rich.text import Text
4
4
  def header() -> Text:
5
5
  text = Text()
6
6
  text.append("Proscenium 🎭\n", style="bold")
7
- text.append("The AI Alliance\n", style="bold")
7
+ text.append("https://the-ai-alliance.github.io/proscenium/\n")
8
8
  # TODO version, timestamp, ...
9
9
  return text
@@ -1,5 +1,8 @@
1
+ import logging
1
2
  from rich.text import Text
2
3
 
4
+ log = logging.getLogger(__name__)
5
+
3
6
 
4
7
  def header() -> Text:
5
8
  text = Text(
@@ -1,4 +1,6 @@
1
+ from typing import Optional
1
2
  import logging
3
+ from rich.console import Console
2
4
  from string import Formatter
3
5
 
4
6
  import json
@@ -6,6 +8,8 @@ from pydantic import BaseModel
6
8
 
7
9
  from proscenium.verbs.complete import complete_simple
8
10
 
11
+ log = logging.getLogger(__name__)
12
+
9
13
  extraction_system_prompt = "You are an entity extractor"
10
14
 
11
15
 
@@ -36,7 +40,7 @@ def extract_to_pydantic_model(
36
40
  extraction_template: str,
37
41
  clazz: type[BaseModel],
38
42
  text: str,
39
- verbose: bool = False,
43
+ console: Optional[Console] = None,
40
44
  ) -> BaseModel:
41
45
 
42
46
  extract_str = complete_simple(
@@ -47,15 +51,15 @@ def extract_to_pydantic_model(
47
51
  "type": "json_object",
48
52
  "schema": clazz.model_json_schema(),
49
53
  },
50
- rich_output=verbose,
54
+ console=console,
51
55
  )
52
56
 
53
- logging.info("complete_to_pydantic_model: extract_str = <<<%s>>>", extract_str)
57
+ log.info("complete_to_pydantic_model: extract_str = <<<%s>>>", extract_str)
54
58
 
55
59
  try:
56
60
  extract_dict = json.loads(extract_str)
57
61
  return clazz.model_construct(**extract_dict)
58
62
  except Exception as e:
59
- logging.error("complete_to_pydantic_model: Exception: %s", e)
63
+ log.error("complete_to_pydantic_model: Exception: %s", e)
60
64
 
61
65
  return None
@@ -1,5 +1,8 @@
1
+ import logging
1
2
  from gofannon.base import BaseTool
2
3
 
4
+ log = logging.getLogger(__name__)
5
+
3
6
 
4
7
  def process_tools(tools: list[BaseTool]) -> tuple[dict, list]:
5
8
  applied_tools = [F() for F in tools]