symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. symai/__init__.py +269 -173
  2. symai/backend/base.py +123 -110
  3. symai/backend/engines/drawing/engine_bfl.py +45 -44
  4. symai/backend/engines/drawing/engine_gpt_image.py +112 -97
  5. symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
  6. symai/backend/engines/embedding/engine_openai.py +25 -21
  7. symai/backend/engines/execute/engine_python.py +19 -18
  8. symai/backend/engines/files/engine_io.py +104 -95
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
  11. symai/backend/engines/index/engine_pinecone.py +124 -97
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +84 -56
  14. symai/backend/engines/lean/engine_lean4.py +96 -52
  15. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
  21. symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
  26. symai/backend/engines/ocr/engine_apilayer.py +23 -27
  27. symai/backend/engines/output/engine_stdout.py +10 -13
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
  29. symai/backend/engines/search/engine_openai.py +100 -88
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +44 -45
  32. symai/backend/engines/search/engine_serpapi.py +37 -34
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
  35. symai/backend/engines/text_to_speech/engine_openai.py +20 -26
  36. symai/backend/engines/text_vision/engine_clip.py +39 -37
  37. symai/backend/engines/userinput/engine_console.py +5 -6
  38. symai/backend/mixin/__init__.py +13 -0
  39. symai/backend/mixin/anthropic.py +48 -38
  40. symai/backend/mixin/deepseek.py +6 -5
  41. symai/backend/mixin/google.py +7 -4
  42. symai/backend/mixin/groq.py +2 -4
  43. symai/backend/mixin/openai.py +140 -110
  44. symai/backend/settings.py +87 -20
  45. symai/chat.py +216 -123
  46. symai/collect/__init__.py +7 -1
  47. symai/collect/dynamic.py +80 -70
  48. symai/collect/pipeline.py +67 -51
  49. symai/collect/stats.py +161 -109
  50. symai/components.py +707 -360
  51. symai/constraints.py +24 -12
  52. symai/core.py +1857 -1233
  53. symai/core_ext.py +83 -80
  54. symai/endpoints/api.py +166 -104
  55. symai/extended/.DS_Store +0 -0
  56. symai/extended/__init__.py +46 -12
  57. symai/extended/api_builder.py +29 -21
  58. symai/extended/arxiv_pdf_parser.py +23 -14
  59. symai/extended/bibtex_parser.py +9 -6
  60. symai/extended/conversation.py +156 -126
  61. symai/extended/document.py +50 -30
  62. symai/extended/file_merger.py +57 -14
  63. symai/extended/graph.py +51 -32
  64. symai/extended/html_style_template.py +18 -14
  65. symai/extended/interfaces/blip_2.py +2 -3
  66. symai/extended/interfaces/clip.py +4 -3
  67. symai/extended/interfaces/console.py +9 -1
  68. symai/extended/interfaces/dall_e.py +4 -2
  69. symai/extended/interfaces/file.py +2 -0
  70. symai/extended/interfaces/flux.py +4 -2
  71. symai/extended/interfaces/gpt_image.py +16 -7
  72. symai/extended/interfaces/input.py +2 -1
  73. symai/extended/interfaces/llava.py +1 -2
  74. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
  75. symai/extended/interfaces/naive_vectordb.py +9 -10
  76. symai/extended/interfaces/ocr.py +5 -3
  77. symai/extended/interfaces/openai_search.py +2 -0
  78. symai/extended/interfaces/parallel.py +30 -0
  79. symai/extended/interfaces/perplexity.py +2 -0
  80. symai/extended/interfaces/pinecone.py +12 -9
  81. symai/extended/interfaces/python.py +2 -0
  82. symai/extended/interfaces/serpapi.py +3 -1
  83. symai/extended/interfaces/terminal.py +2 -4
  84. symai/extended/interfaces/tts.py +3 -2
  85. symai/extended/interfaces/whisper.py +3 -2
  86. symai/extended/interfaces/wolframalpha.py +2 -1
  87. symai/extended/metrics/__init__.py +11 -1
  88. symai/extended/metrics/similarity.py +14 -13
  89. symai/extended/os_command.py +39 -29
  90. symai/extended/packages/__init__.py +29 -3
  91. symai/extended/packages/symdev.py +51 -43
  92. symai/extended/packages/sympkg.py +41 -35
  93. symai/extended/packages/symrun.py +63 -50
  94. symai/extended/repo_cloner.py +14 -12
  95. symai/extended/seo_query_optimizer.py +15 -13
  96. symai/extended/solver.py +116 -91
  97. symai/extended/summarizer.py +12 -10
  98. symai/extended/taypan_interpreter.py +17 -18
  99. symai/extended/vectordb.py +122 -92
  100. symai/formatter/__init__.py +9 -1
  101. symai/formatter/formatter.py +51 -47
  102. symai/formatter/regex.py +70 -69
  103. symai/functional.py +325 -176
  104. symai/imports.py +190 -147
  105. symai/interfaces.py +57 -28
  106. symai/memory.py +45 -35
  107. symai/menu/screen.py +28 -19
  108. symai/misc/console.py +66 -56
  109. symai/misc/loader.py +8 -5
  110. symai/models/__init__.py +17 -1
  111. symai/models/base.py +395 -236
  112. symai/models/errors.py +1 -2
  113. symai/ops/__init__.py +32 -22
  114. symai/ops/measures.py +24 -25
  115. symai/ops/primitives.py +1149 -731
  116. symai/post_processors.py +58 -50
  117. symai/pre_processors.py +86 -82
  118. symai/processor.py +21 -13
  119. symai/prompts.py +764 -685
  120. symai/server/huggingface_server.py +135 -49
  121. symai/server/llama_cpp_server.py +21 -11
  122. symai/server/qdrant_server.py +206 -0
  123. symai/shell.py +100 -42
  124. symai/shellsv.py +700 -492
  125. symai/strategy.py +630 -346
  126. symai/symbol.py +368 -322
  127. symai/utils.py +100 -78
  128. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
  129. symbolicai-1.1.0.dist-info/RECORD +168 -0
  130. symbolicai-0.21.0.dist-info/RECORD +0 -162
  131. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  132. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  133. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  134. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  from .. import core
2
- from ..pre_processors import PreProcessor
2
+ from ..components import Execute
3
3
  from ..post_processors import CodeExtractPostProcessor
4
+ from ..pre_processors import PreProcessor
4
5
  from ..symbol import Expression, Symbol
5
- from ..components import Execute
6
-
6
+ from ..utils import UserMessage
7
7
 
8
8
  API_BUILDER_DESCRIPTION = """[Description]
9
9
  You are an API coding tool for Python that creates API calls to any web URL based on user requests.
@@ -64,7 +64,7 @@ res = run(value) # [MANAGED] must contain this line, do not change
64
64
 
65
65
  class APIBuilderPreProcessor(PreProcessor):
66
66
  def __call__(self, argument):
67
- return '$> {} =>'.format(str(argument.args[0]))
67
+ return f"$> {argument.args[0]!s} =>"
68
68
 
69
69
 
70
70
  class APIBuilder(Expression):
@@ -77,9 +77,12 @@ class APIBuilder(Expression):
77
77
  self.sym_return_type = APIBuilder
78
78
 
79
79
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
80
- @core.zero_shot(prompt="Build the API call code:\n",
81
- pre_processors=[APIBuilderPreProcessor()],
82
- post_processors=[CodeExtractPostProcessor()], **kwargs)
80
+ @core.zero_shot(
81
+ prompt="Build the API call code:\n",
82
+ pre_processors=[APIBuilderPreProcessor()],
83
+ post_processors=[CodeExtractPostProcessor()],
84
+ **kwargs,
85
+ )
83
86
  def _func(_, text) -> str:
84
87
  pass
85
88
 
@@ -96,18 +99,20 @@ class StackTraceRetryExecutor(Expression):
96
99
  def forward(self, code: Symbol, request: Symbol, **kwargs) -> Symbol:
97
100
  code = str(code)
98
101
  # Set value that gets passed on to the 'run' function in the generated code
99
- value = request.value # do not remove this line
102
+ value = request.value # do not remove this line
100
103
  # Create the 'run' function
101
104
  self._runnable = self.executor(code, locals=locals().copy(), globals=globals().copy())
102
- result = self._runnable['locals']['run'](value)
105
+ result = self._runnable["locals"]["run"](value)
103
106
  retry = 0
104
107
  # Retry if there is a 'Traceback' in the result
105
- while 'Traceback' in result and retry <= self.max_retries:
106
- self._runnable = self.executor(code, payload=result, locals=locals().copy(), globals=globals().copy(), **kwargs)
107
- result = self._runnable['locals']['run'](value)
108
+ while "Traceback" in result and retry <= self.max_retries:
109
+ self._runnable = self.executor(
110
+ code, payload=result, locals=locals().copy(), globals=globals().copy(), **kwargs
111
+ )
112
+ result = self._runnable["locals"]["run"](value)
108
113
  retry += 1
109
- if 'locals_res' in self._runnable:
110
- result = self._runnable['locals_res']
114
+ if "locals_res" in self._runnable:
115
+ result = self._runnable["locals_res"]
111
116
  return result
112
117
 
113
118
 
@@ -126,14 +131,17 @@ class APIExecutor(Expression):
126
131
  def _runnable(self):
127
132
  return self.executor._runnable
128
133
 
129
- def forward(self, request: Symbol, **kwargs) -> Symbol:
134
+ def forward(self, request: Symbol, **_kwargs) -> Symbol:
130
135
  self._request = self._to_symbol(request)
131
- if self._verbose: print('[REQUEST]', self._request)
136
+ if self._verbose:
137
+ UserMessage(f"[REQUEST] {self._request}")
132
138
  # Generate the code to implement the API call
133
- self._code = self.builder(self._request)
134
- if self._verbose: print('[GENERATED_CODE]', self._code)
139
+ self._code = self.builder(self._request)
140
+ if self._verbose:
141
+ UserMessage(f"[GENERATED_CODE] {self._code}")
135
142
  # Execute the code to define the 'run' function
136
- self._result = self.executor(self._code, request=self._request)
137
- if self._verbose: print('[RESULT]:', self._result)
138
- self._value = self._result
143
+ self._result = self.executor(self._code, request=self._request)
144
+ if self._verbose:
145
+ UserMessage(f"[RESULT]: {self._result}")
146
+ self._value = self._result
139
147
  return self
@@ -1,17 +1,20 @@
1
- import os
2
1
  import re
3
2
  import shutil
4
- import requests
5
-
6
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from pathlib import Path
5
+
6
+ import requests
7
7
 
8
+ from ..backend.settings import HOME_PATH
8
9
  from ..symbol import Expression, Symbol
10
+ from ..utils import UserMessage
9
11
  from .file_merger import FileMerger
10
- from ..backend.settings import HOME_PATH
11
12
 
12
13
 
13
14
  class ArxivPdfParser(Expression):
14
- def __init__(self, url_pattern: str = r'https://arxiv.org/(?:pdf|abs)/(\d+.\d+)(?:\.pdf)?', **kwargs):
15
+ def __init__(
16
+ self, url_pattern: str = r"https://arxiv.org/(?:pdf|abs)/(\d+.\d+)(?:\.pdf)?", **kwargs
17
+ ):
15
18
  super().__init__(**kwargs)
16
19
  self.url_pattern = url_pattern
17
20
  self.merger = FileMerger()
@@ -21,28 +24,34 @@ class ArxivPdfParser(Expression):
21
24
  urls = re.findall(self.url_pattern, str(data))
22
25
 
23
26
  # Convert all urls to pdf urls
24
- pdf_urls = [f"https://arxiv.org/pdf/" + (f"{url.split('/')[-1]}.pdf" if 'pdf' not in url else {url.split('/')[-1]}) for url in urls]
27
+ pdf_urls = [
28
+ "https://arxiv.org/pdf/"
29
+ + (f"{url.split('/')[-1]}.pdf" if "pdf" not in url else {url.split("/")[-1]})
30
+ for url in urls
31
+ ]
25
32
 
26
33
  # Create temporary folder in the home directory
27
- output_path = os.path.join(HOME_PATH, "temp/downloads")
28
- os.makedirs(output_path, exist_ok=True)
34
+ output_path = HOME_PATH / "temp" / "downloads"
35
+ output_path.mkdir(parents=True, exist_ok=True)
29
36
 
30
37
  pdf_files = []
31
38
  with ThreadPoolExecutor() as executor:
32
39
  # Download all pdfs in parallel
33
- future_to_url = {executor.submit(self.download_pdf, url, output_path): url for url in pdf_urls}
40
+ future_to_url = {
41
+ executor.submit(self.download_pdf, url, output_path): url for url in pdf_urls
42
+ }
34
43
  for future in as_completed(future_to_url):
35
44
  url = future_to_url[future]
36
45
  try:
37
46
  pdf_files.append(future.result())
38
47
  except Exception as exc:
39
- print('%r generated an exception: %s' % (url, exc))
48
+ UserMessage(f"{url!r} generated an exception: {exc}")
40
49
 
41
50
  if len(pdf_files) == 0:
42
51
  return None
43
52
 
44
53
  # Merge all pdfs into one file
45
- merged_file = self.merger(output_path, **kwargs)
54
+ merged_file = self.merger(str(output_path), **kwargs)
46
55
 
47
56
  # Return the merged file as a Symbol
48
57
  return_file = self._to_symbol(merged_file)
@@ -55,7 +64,7 @@ class ArxivPdfParser(Expression):
55
64
  def download_pdf(self, url, output_path):
56
65
  # Download pdfs
57
66
  response = requests.get(url)
58
- file = os.path.join(output_path, f'{url.split("/")[-1]}')
59
- with open(file, 'wb') as f:
67
+ file_path = Path(output_path) / f"{url.split('/')[-1]}"
68
+ with file_path.open("wb") as f:
60
69
  f.write(response.content)
61
- return file
70
+ return str(file_path)
@@ -1,8 +1,7 @@
1
1
  from .. import core
2
+ from ..post_processors import CodeExtractPostProcessor
2
3
  from ..pre_processors import PreProcessor
3
4
  from ..symbol import Expression, Symbol
4
- from ..post_processors import CodeExtractPostProcessor
5
-
6
5
 
7
6
  BIB_DESCRIPTION = """[Description]
8
7
  You take in a text with references to papers and return a list of biblatex entries.
@@ -69,7 +68,7 @@ Multimodal Few-Shot Learning with Frozen Language Models Maria Tsimpoukelli
69
68
 
70
69
  class BibTexPreProcessor(PreProcessor):
71
70
  def __call__(self, argument):
72
- return '>>>\n{}\n\n<<<\n'.format(str(argument.args[0]))
71
+ return f">>>\n{argument.args[0]!s}\n\n<<<\n"
73
72
 
74
73
 
75
74
  class BibTexParser(Expression):
@@ -82,9 +81,13 @@ class BibTexParser(Expression):
82
81
  self.sym_return_type = BibTexParser
83
82
 
84
83
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
85
- @core.zero_shot(prompt="Create bibtex entries:\n",
86
- pre_processors=[BibTexPreProcessor()],
87
- post_processors=[CodeExtractPostProcessor()], **kwargs)
84
+ @core.zero_shot(
85
+ prompt="Create bibtex entries:\n",
86
+ pre_processors=[BibTexPreProcessor()],
87
+ post_processors=[CodeExtractPostProcessor()],
88
+ **kwargs,
89
+ )
88
90
  def _func(_, text) -> str:
89
91
  pass
92
+
90
93
  return _func(self, sym)
@@ -1,38 +1,43 @@
1
- import os
2
1
  import pickle
2
+ from collections.abc import Callable
3
3
  from datetime import datetime
4
4
  from pathlib import Path
5
- from typing import Any, Callable, List, Optional
5
+ from typing import Any
6
6
 
7
- from ..components import FileReader, Indexer
7
+ from ..components import FileReader
8
8
  from ..formatter import TextContainerFormatter
9
9
  from ..interfaces import Interface
10
10
  from ..memory import SlidingWindowStringConcatMemory
11
11
  from ..symbol import Symbol
12
- from ..utils import CustomUserWarning, deprecated
12
+ from ..utils import UserMessage, deprecated
13
13
  from .document import DocumentRetriever
14
14
  from .seo_query_optimizer import SEOQueryOptimizer
15
15
 
16
+ _DEFAULT_TEXT_CONTAINER_FORMATTER = TextContainerFormatter(text_split=4)
17
+
16
18
 
17
19
  class CodeFormatter:
18
- def __call__(self, value: str, *args: Any, **kwds: Any) -> Any:
20
+ def __call__(self, value: str, *_args: Any, **_kwds: Any) -> Any:
19
21
  # extract code from chat conversations or ```<language>\n{code}\n``` blocks
20
- return Symbol(value).extract('Only extract code without ``` block markers or chat conversations')
22
+ return Symbol(value).extract(
23
+ "Only extract code without ``` block markers or chat conversations"
24
+ )
21
25
 
22
26
 
23
27
  class Conversation(SlidingWindowStringConcatMemory):
24
28
  def __init__(
25
- self,
26
- init: Optional[str] = None,
27
- file_link: Optional[List[str]] = None,
28
- url_link: Optional[List[str]] = None,
29
- index_name: Optional[str] = None,
30
- auto_print: bool = True,
31
- truncation_percentage: float = 0.8,
32
- truncation_type: str = 'head',
33
- with_metadata: bool = False,
34
- *args, **kwargs
35
- ):
29
+ self,
30
+ init: str | None = None,
31
+ file_link: list[str] | None = None,
32
+ url_link: list[str] | None = None,
33
+ index_name: str | None = None,
34
+ auto_print: bool = True,
35
+ truncation_percentage: float = 0.8,
36
+ truncation_type: str = "head",
37
+ with_metadata: bool = False,
38
+ *args,
39
+ **kwargs,
40
+ ):
36
41
  super().__init__(*args, **kwargs)
37
42
  self.truncation_percentage = truncation_percentage
38
43
  self.truncation_type = truncation_type
@@ -46,9 +51,9 @@ class Conversation(SlidingWindowStringConcatMemory):
46
51
  self.index_name = index_name
47
52
  self.seo_opt = SEOQueryOptimizer()
48
53
  self.reader = FileReader(with_metadata=with_metadata)
49
- self.scraper = Interface('naive_webscraping')
50
- self.user_tag = 'USER::'
51
- self.bot_tag = 'ASSISTANT::'
54
+ self.scraper = Interface("naive_scrape")
55
+ self.user_tag = "USER::"
56
+ self.bot_tag = "ASSISTANT::"
52
57
 
53
58
  if init is not None:
54
59
  self.store_system_message(init, *args, **kwargs)
@@ -61,14 +66,16 @@ class Conversation(SlidingWindowStringConcatMemory):
61
66
  self.indexer = None
62
67
  self.index = None
63
68
  if index_name is not None:
64
- CustomUserWarning("Index not supported for conversation class.", raise_with=NotImplementedError)
69
+ UserMessage(
70
+ "Index not supported for conversation class.", raise_with=NotImplementedError
71
+ )
65
72
 
66
73
  def __getstate__(self):
67
74
  state = super().__getstate__().copy()
68
- state.pop('seo_opt', None)
69
- state.pop('indexer', None)
70
- state.pop('index', None)
71
- state.pop('reader', None)
75
+ state.pop("seo_opt", None)
76
+ state.pop("indexer", None)
77
+ state.pop("index", None)
78
+ state.pop("reader", None)
72
79
  return state
73
80
 
74
81
  def __setstate__(self, state):
@@ -76,41 +83,44 @@ class Conversation(SlidingWindowStringConcatMemory):
76
83
  self.seo_opt = SEOQueryOptimizer()
77
84
  self.reader = FileReader()
78
85
  if self.index_name is not None:
79
- CustomUserWarning("Index not supported for conversation class.", raise_with=NotImplementedError)
86
+ UserMessage(
87
+ "Index not supported for conversation class.", raise_with=NotImplementedError
88
+ )
80
89
 
81
- def store_system_message(self, message: str, *args, **kwargs):
82
- val = f"[SYSTEM_INSTRUCTION::]: <<<\n{str(message)}\n>>>\n"
90
+ def store_system_message(self, message: str, *_args, **_kwargs):
91
+ val = f"[SYSTEM_INSTRUCTION::]: <<<\n{message!s}\n>>>\n"
83
92
  self.store(val)
84
93
 
85
- def store_file(self, file_path: str, *args, **kwargs):
94
+ def store_file(self, file_path: str, *_args, **_kwargs):
86
95
  content = self.reader(file_path)
87
- val = f"[DATA::{file_path}]: <<<\n{str(content)}\n>>>\n"
96
+ val = f"[DATA::{file_path}]: <<<\n{content!s}\n>>>\n"
88
97
  self.store(val)
89
98
 
90
- def store_url(self, url: str, *args, **kwargs):
99
+ def store_url(self, url: str, *_args, **_kwargs):
91
100
  content = self.scraper(url)
92
- val = f"[DATA::{url}]: <<<\n{str(content)}\n>>>\n"
101
+ val = f"[DATA::{url}]: <<<\n{content!s}\n>>>\n"
93
102
  self.store(val)
94
103
 
95
104
  @staticmethod
96
105
  def save_conversation_state(conversation: "Conversation", file_path: str) -> None:
97
106
  # Check if path exists and create it if it doesn't
98
- dir_path = os.path.dirname(file_path)
99
- os.makedirs(dir_path, exist_ok=True)
107
+ path_obj = Path(file_path)
108
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
100
109
  # Save the conversation object as a pickle file
101
- with open(file_path, 'wb') as handle:
110
+ with path_obj.open("wb") as handle:
102
111
  pickle.dump(conversation, handle, protocol=pickle.HIGHEST_PROTOCOL)
103
112
 
104
113
  def load_conversation_state(self, path: str) -> "Conversation":
105
114
  # Check if the file exists and it's not empty
106
- if os.path.exists(path):
107
- if os.path.getsize(path) <= 0:
108
- raise Exception("File is empty.")
115
+ path_obj = Path(path)
116
+ if path_obj.exists():
117
+ if path_obj.stat().st_size <= 0:
118
+ UserMessage("File is empty.", raise_with=Exception)
109
119
  # Load the conversation object from a pickle file
110
- with open(path, 'rb') as handle:
120
+ with path_obj.open("rb") as handle:
111
121
  conversation_state = pickle.load(handle)
112
122
  else:
113
- raise Exception("File does not exist or is empty.")
123
+ UserMessage("File does not exist or is empty.", raise_with=Exception)
114
124
 
115
125
  # Create a new instance of the `Conversation` class and restore
116
126
  # the state from the saved conversation
@@ -120,17 +130,19 @@ class Conversation(SlidingWindowStringConcatMemory):
120
130
  self._memory = conversation_state._memory
121
131
  self.truncation_percentage = conversation_state.truncation_percentage
122
132
  self.truncation_type = conversation_state.truncation_type
123
- self.auto_print = conversation_state.auto_print
133
+ self.auto_print = conversation_state.auto_print
124
134
  self.file_link = conversation_state.file_link
125
135
  self.url_link = conversation_state.url_link
126
- self.index_name = conversation_state.index_name
136
+ self.index_name = conversation_state.index_name
127
137
  self.seo_opt = SEOQueryOptimizer()
128
138
  self.reader = FileReader()
129
139
  if self.index_name is not None:
130
- CustomUserWarning("Index not supported for conversation class.", raise_with=NotImplementedError)
140
+ UserMessage(
141
+ "Index not supported for conversation class.", raise_with=NotImplementedError
142
+ )
131
143
  return self
132
144
 
133
- def commit(self, target_file: str = None, formatter: Optional[Callable] = None):
145
+ def commit(self, target_file: str | None = None, formatter: Callable | None = None):
134
146
  if target_file and isinstance(target_file, str):
135
147
  file_link = target_file
136
148
  else:
@@ -140,20 +152,22 @@ class Conversation(SlidingWindowStringConcatMemory):
140
152
  elif isinstance(file_link, list) and len(file_link) == 1:
141
153
  file_link = file_link[0]
142
154
  else:
143
- file_link = None # cannot commit to multiple files
144
- raise Exception('Cannot commit to multiple files.')
155
+ file_link = None # cannot commit to multiple files
156
+ UserMessage("Cannot commit to multiple files.", raise_with=Exception)
145
157
  if file_link:
146
158
  # if file extension is .py, then format code
147
159
  format_ = formatter
148
- formatter = CodeFormatter() if format_ is None and file_link.endswith('.py') else formatter
160
+ formatter = (
161
+ CodeFormatter() if format_ is None and file_link.endswith(".py") else formatter
162
+ )
149
163
  val = self.value
150
164
  if formatter:
151
165
  val = formatter(val)
152
166
  # if file does not exist, create it
153
- with open(file_link, 'w') as file:
167
+ with Path(file_link).open("w") as file:
154
168
  file.write(str(val))
155
169
  else:
156
- raise Exception('File link is not set or a set of files.')
170
+ UserMessage("File link is not set or a set of files.", raise_with=Exception)
157
171
 
158
172
  def save(self, path: str, replace: bool = False) -> Symbol:
159
173
  return Symbol(self._memory).save(path, replace=replace)
@@ -161,77 +175,87 @@ class Conversation(SlidingWindowStringConcatMemory):
161
175
  def build_tag(self, tag: str, query: str) -> str:
162
176
  # get timestamp in string format
163
177
  timestamp = datetime.now().strftime("%d/%m/%Y %H:%M:%S:%f")
164
- return str(f"[{tag}{timestamp}]: <<<\n{str(query)}\n>>>\n")
178
+ return str(f"[{tag}{timestamp}]: <<<\n{query!s}\n>>>\n")
165
179
 
166
180
  def forward(self, query: str, *args, **kwargs):
167
- # dynamic takes precedence over static
168
- dynamic_truncation_percentage = kwargs.get('truncation_percentage', self.truncation_percentage)
169
- dynamic_truncation_type = kwargs.get('truncation_type', self.truncation_type)
170
- kwargs = {**kwargs, 'truncation_percentage': dynamic_truncation_percentage, 'truncation_type': dynamic_truncation_type}
171
-
181
+ kwargs = self._apply_truncation_overrides(kwargs)
172
182
  query = self._to_symbol(query)
173
- memory = None
174
-
175
- if self.index is not None:
176
- memory_split = self._memory.split(self.marker)
177
- memory_shards = []
178
- for ms in memory_split:
179
- if ms.strip() == '':
180
- continue
181
- memory_shards.append(ms)
182
-
183
- length_memory_shards = len(memory_shards)
184
- if length_memory_shards <= 3:
185
- memory_shards = memory_shards
186
- elif length_memory_shards <= 5:
187
- memory_shards = memory_shards[:2] + memory_shards[-(length_memory_shards-2):]
188
- else:
189
- memory_shards = memory_shards[:2] + memory_shards[-3:]
190
-
191
- search_query = query | '\n' | '\n'.join(memory_shards)
192
- if kwargs.get('use_seo_opt'):
193
- search_query = self.seo_opt(f'[Query]:' | search_query)
194
- memory = self.index(search_query, *args, **kwargs)
195
-
196
- if 'raw_result' in kwargs:
197
- print(memory)
198
-
199
- payload = ''
200
- # if payload is set, then add it to the memory
201
- if 'payload' in kwargs:
202
- payload = f"[Conversation Payload]:\n{kwargs.pop('payload')}\n"
203
-
204
- index_memory = ''
205
- # if index is set, then add it to the memory
206
- if memory:
207
- index_memory = f'[Index Retrieval]:\n{str(memory)[:1500]}\n'
208
-
209
- payload = f'{index_memory}{payload}'
210
- # perform a recall function using the query
183
+ memory = self._retrieve_index_memory(query, args, kwargs)
184
+ payload = self._build_payload(kwargs, memory)
211
185
  res = self.recall(query, *args, payload=payload, **kwargs)
212
186
 
213
187
  # if user is requesting to preview the response, then return only the preview result
214
- if 'preview' in kwargs and kwargs['preview']:
188
+ if kwargs.get("preview"):
215
189
  if self.auto_print:
216
- print(res)
190
+ UserMessage(str(res), style="text")
217
191
  return res
218
192
 
219
193
  ### --- asses memory update --- ###
220
194
 
221
- # append the bot prompt to the memory
222
- prompt = self.build_tag(self.user_tag, query)
223
- self.store(prompt)
224
-
225
- self._value = res.value # save last response
226
- val = self.build_tag(self.bot_tag, res)
227
- self.store(val)
195
+ self._append_interaction_to_memory(query, res)
228
196
 
229
197
  # WARN: DO NOT PROCESS THE RES BY REMOVING `<<<` AND `>>>` TAGS
230
198
 
231
199
  if self.auto_print:
232
- print(res)
200
+ UserMessage(str(res), style="text")
233
201
  return res
234
202
 
203
+ def _apply_truncation_overrides(self, kwargs: dict[str, Any]) -> dict[str, Any]:
204
+ dynamic_truncation_percentage = kwargs.get(
205
+ "truncation_percentage", self.truncation_percentage
206
+ )
207
+ dynamic_truncation_type = kwargs.get("truncation_type", self.truncation_type)
208
+ return {
209
+ **kwargs,
210
+ "truncation_percentage": dynamic_truncation_percentage,
211
+ "truncation_type": dynamic_truncation_type,
212
+ }
213
+
214
+ def _retrieve_index_memory(self, query: Symbol, args: tuple[Any, ...], kwargs: dict[str, Any]):
215
+ if self.index is None:
216
+ return None
217
+
218
+ memory_split = self._memory.split(self.marker)
219
+ memory_shards = []
220
+ for shard in memory_split:
221
+ if shard.strip() == "":
222
+ continue
223
+ memory_shards.append(shard)
224
+
225
+ length_memory_shards = len(memory_shards)
226
+ if length_memory_shards > 5:
227
+ memory_shards = memory_shards[:2] + memory_shards[-3:]
228
+ elif length_memory_shards > 3:
229
+ retained = memory_shards[-(length_memory_shards - 2) :]
230
+ memory_shards = memory_shards[:2] + retained
231
+
232
+ search_query = query | "\n" | "\n".join(memory_shards)
233
+ if kwargs.get("use_seo_opt"):
234
+ search_query = self.seo_opt("[Query]:" | search_query)
235
+ memory = self.index(search_query, *args, **kwargs)
236
+
237
+ if "raw_result" in kwargs:
238
+ UserMessage(str(memory), style="text")
239
+ return memory
240
+
241
+ def _build_payload(self, kwargs: dict[str, Any], memory) -> str:
242
+ payload = ""
243
+ if "payload" in kwargs:
244
+ payload = f"[Conversation Payload]:\n{kwargs.pop('payload')}\n"
245
+
246
+ index_memory = ""
247
+ if memory:
248
+ index_memory = f"[Index Retrieval]:\n{str(memory)[:1500]}\n"
249
+ return f"{index_memory}{payload}"
250
+
251
+ def _append_interaction_to_memory(self, query: Symbol, res: Symbol) -> None:
252
+ prompt = self.build_tag(self.user_tag, query)
253
+ self.store(prompt)
254
+
255
+ self._value = res.value # save last response
256
+ val = self.build_tag(self.bot_tag, res)
257
+ self.store(val)
258
+
235
259
 
236
260
  RETRIEVAL_CONTEXT = """[Description]
237
261
  This is a conversation between a retrieval augmented indexing program and a user. The system combines document retrieval with conversational AI to provide context-aware responses. It can:
@@ -265,27 +289,33 @@ Responses should be:
265
289
  - Referenced to source when applicable
266
290
  """
267
291
 
292
+
268
293
  @deprecated("Use `Conversation` instead for now. This will be removed/fixed in the future.")
269
294
  class RetrievalAugmentedConversation(Conversation):
270
295
  def __init__(
271
- self,
272
- folder_path: Optional[str] = None,
273
- *,
274
- index_name: Optional[str] = None,
275
- max_depth: Optional[int] = 0,
276
- auto_print: bool = True,
277
- top_k: int = 5,
278
- formatter: Callable = TextContainerFormatter(text_split=4),
279
- overwrite: bool = False,
280
- truncation_percentage: float = 0.8,
281
- truncation_type: str = 'head',
282
- with_metadata: bool = False,
283
- raw_result: Optional[bool] = False,
284
- new_dim: Optional[int] = None,
285
- **kwargs
286
- ):
287
-
288
- super().__init__(auto_print=auto_print, truncation_percentage=truncation_percentage, truncation_type=truncation_type, with_metadata=with_metadata, *kwargs)
296
+ self,
297
+ folder_path: str | None = None,
298
+ *,
299
+ index_name: str | None = None,
300
+ max_depth: int | None = 0,
301
+ auto_print: bool = True,
302
+ top_k: int = 5,
303
+ formatter: Callable = _DEFAULT_TEXT_CONTAINER_FORMATTER,
304
+ overwrite: bool = False,
305
+ truncation_percentage: float = 0.8,
306
+ truncation_type: str = "head",
307
+ with_metadata: bool = False,
308
+ raw_result: bool | None = False,
309
+ new_dim: int | None = None,
310
+ **kwargs,
311
+ ):
312
+ super().__init__(
313
+ auto_print=auto_print,
314
+ truncation_percentage=truncation_percentage,
315
+ truncation_type=truncation_type,
316
+ with_metadata=with_metadata,
317
+ **kwargs,
318
+ )
289
319
 
290
320
  self.retriever = DocumentRetriever(
291
321
  source=folder_path,
@@ -297,7 +327,7 @@ class RetrievalAugmentedConversation(Conversation):
297
327
  with_metadata=with_metadata,
298
328
  raw_result=raw_result,
299
329
  new_dim=new_dim,
300
- **kwargs
330
+ **kwargs,
301
331
  )
302
332
 
303
333
  self.index = self.retriever.index
@@ -322,14 +352,14 @@ class RetrievalAugmentedConversation(Conversation):
322
352
 
323
353
  memory = self.index(query, *args, **kwargs)
324
354
 
325
- if 'raw_result' in kwargs:
326
- print(memory)
355
+ if "raw_result" in kwargs:
356
+ UserMessage(str(memory), style="text")
327
357
  return memory
328
358
 
329
359
  prompt = self.build_tag(self.user_tag, query)
330
360
  self.store(prompt)
331
361
 
332
- payload = f'[Index Retrieval]:\n{str(memory)[:1500]}\n'
362
+ payload = f"[Index Retrieval]:\n{str(memory)[:1500]}\n"
333
363
 
334
364
  res = self.recall(query, *args, payload=payload, **kwargs)
335
365
 
@@ -338,5 +368,5 @@ class RetrievalAugmentedConversation(Conversation):
338
368
  self.store(val)
339
369
 
340
370
  if self.auto_print:
341
- print(res)
371
+ UserMessage(str(res), style="text")
342
372
  return res