symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. symai/__init__.py +269 -173
  2. symai/backend/base.py +123 -110
  3. symai/backend/engines/drawing/engine_bfl.py +45 -44
  4. symai/backend/engines/drawing/engine_gpt_image.py +112 -97
  5. symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
  6. symai/backend/engines/embedding/engine_openai.py +25 -21
  7. symai/backend/engines/execute/engine_python.py +19 -18
  8. symai/backend/engines/files/engine_io.py +104 -95
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
  11. symai/backend/engines/index/engine_pinecone.py +124 -97
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +84 -56
  14. symai/backend/engines/lean/engine_lean4.py +96 -52
  15. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
  21. symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
  26. symai/backend/engines/ocr/engine_apilayer.py +23 -27
  27. symai/backend/engines/output/engine_stdout.py +10 -13
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
  29. symai/backend/engines/search/engine_openai.py +100 -88
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +44 -45
  32. symai/backend/engines/search/engine_serpapi.py +37 -34
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
  35. symai/backend/engines/text_to_speech/engine_openai.py +20 -26
  36. symai/backend/engines/text_vision/engine_clip.py +39 -37
  37. symai/backend/engines/userinput/engine_console.py +5 -6
  38. symai/backend/mixin/__init__.py +13 -0
  39. symai/backend/mixin/anthropic.py +48 -38
  40. symai/backend/mixin/deepseek.py +6 -5
  41. symai/backend/mixin/google.py +7 -4
  42. symai/backend/mixin/groq.py +2 -4
  43. symai/backend/mixin/openai.py +140 -110
  44. symai/backend/settings.py +87 -20
  45. symai/chat.py +216 -123
  46. symai/collect/__init__.py +7 -1
  47. symai/collect/dynamic.py +80 -70
  48. symai/collect/pipeline.py +67 -51
  49. symai/collect/stats.py +161 -109
  50. symai/components.py +707 -360
  51. symai/constraints.py +24 -12
  52. symai/core.py +1857 -1233
  53. symai/core_ext.py +83 -80
  54. symai/endpoints/api.py +166 -104
  55. symai/extended/.DS_Store +0 -0
  56. symai/extended/__init__.py +46 -12
  57. symai/extended/api_builder.py +29 -21
  58. symai/extended/arxiv_pdf_parser.py +23 -14
  59. symai/extended/bibtex_parser.py +9 -6
  60. symai/extended/conversation.py +156 -126
  61. symai/extended/document.py +50 -30
  62. symai/extended/file_merger.py +57 -14
  63. symai/extended/graph.py +51 -32
  64. symai/extended/html_style_template.py +18 -14
  65. symai/extended/interfaces/blip_2.py +2 -3
  66. symai/extended/interfaces/clip.py +4 -3
  67. symai/extended/interfaces/console.py +9 -1
  68. symai/extended/interfaces/dall_e.py +4 -2
  69. symai/extended/interfaces/file.py +2 -0
  70. symai/extended/interfaces/flux.py +4 -2
  71. symai/extended/interfaces/gpt_image.py +16 -7
  72. symai/extended/interfaces/input.py +2 -1
  73. symai/extended/interfaces/llava.py +1 -2
  74. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
  75. symai/extended/interfaces/naive_vectordb.py +9 -10
  76. symai/extended/interfaces/ocr.py +5 -3
  77. symai/extended/interfaces/openai_search.py +2 -0
  78. symai/extended/interfaces/parallel.py +30 -0
  79. symai/extended/interfaces/perplexity.py +2 -0
  80. symai/extended/interfaces/pinecone.py +12 -9
  81. symai/extended/interfaces/python.py +2 -0
  82. symai/extended/interfaces/serpapi.py +3 -1
  83. symai/extended/interfaces/terminal.py +2 -4
  84. symai/extended/interfaces/tts.py +3 -2
  85. symai/extended/interfaces/whisper.py +3 -2
  86. symai/extended/interfaces/wolframalpha.py +2 -1
  87. symai/extended/metrics/__init__.py +11 -1
  88. symai/extended/metrics/similarity.py +14 -13
  89. symai/extended/os_command.py +39 -29
  90. symai/extended/packages/__init__.py +29 -3
  91. symai/extended/packages/symdev.py +51 -43
  92. symai/extended/packages/sympkg.py +41 -35
  93. symai/extended/packages/symrun.py +63 -50
  94. symai/extended/repo_cloner.py +14 -12
  95. symai/extended/seo_query_optimizer.py +15 -13
  96. symai/extended/solver.py +116 -91
  97. symai/extended/summarizer.py +12 -10
  98. symai/extended/taypan_interpreter.py +17 -18
  99. symai/extended/vectordb.py +122 -92
  100. symai/formatter/__init__.py +9 -1
  101. symai/formatter/formatter.py +51 -47
  102. symai/formatter/regex.py +70 -69
  103. symai/functional.py +325 -176
  104. symai/imports.py +190 -147
  105. symai/interfaces.py +57 -28
  106. symai/memory.py +45 -35
  107. symai/menu/screen.py +28 -19
  108. symai/misc/console.py +66 -56
  109. symai/misc/loader.py +8 -5
  110. symai/models/__init__.py +17 -1
  111. symai/models/base.py +395 -236
  112. symai/models/errors.py +1 -2
  113. symai/ops/__init__.py +32 -22
  114. symai/ops/measures.py +24 -25
  115. symai/ops/primitives.py +1149 -731
  116. symai/post_processors.py +58 -50
  117. symai/pre_processors.py +86 -82
  118. symai/processor.py +21 -13
  119. symai/prompts.py +764 -685
  120. symai/server/huggingface_server.py +135 -49
  121. symai/server/llama_cpp_server.py +21 -11
  122. symai/server/qdrant_server.py +206 -0
  123. symai/shell.py +100 -42
  124. symai/shellsv.py +700 -492
  125. symai/strategy.py +630 -346
  126. symai/symbol.py +368 -322
  127. symai/utils.py +100 -78
  128. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
  129. symbolicai-1.1.0.dist-info/RECORD +168 -0
  130. symbolicai-0.21.0.dist-info/RECORD +0 -162
  131. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  132. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  133. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  134. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/shellsv.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import glob
3
2
  import json
4
3
  import logging
5
4
  import os
@@ -10,32 +9,37 @@ import subprocess
10
9
  import sys
11
10
  import time
12
11
  import traceback
12
+ from collections.abc import Iterable
13
13
  from pathlib import Path
14
- from typing import Iterable, Tuple
14
+ from types import SimpleNamespace
15
15
 
16
- #@TODO: refactor to use rich instead of prompt_toolkit
16
+ # @TODO: refactor to use rich instead of prompt_toolkit
17
17
  from prompt_toolkit import HTML, PromptSession, print_formatted_text
18
18
  from prompt_toolkit.completion import Completer, Completion, WordCompleter
19
19
  from prompt_toolkit.history import History
20
20
  from prompt_toolkit.key_binding import KeyBindings
21
21
  from prompt_toolkit.keys import Keys
22
- from prompt_toolkit.lexers import PygmentsLexer
23
22
  from prompt_toolkit.patch_stdout import patch_stdout
24
23
  from prompt_toolkit.shortcuts import CompleteStyle, ProgressBar
25
24
  from prompt_toolkit.styles import Style
26
- from pygments.lexers.shell import BashLexer
27
25
 
28
26
  from .backend.settings import HOME_PATH, SYMSH_CONFIG
29
27
  from .components import FileReader, Function, Indexer
30
- from .extended import (ArxivPdfParser, Conversation, DocumentRetriever,
31
- FileMerger, RepositoryCloner,
32
- RetrievalAugmentedConversation)
28
+ from .extended import (
29
+ ArxivPdfParser,
30
+ Conversation,
31
+ DocumentRetriever,
32
+ FileMerger,
33
+ RepositoryCloner,
34
+ RetrievalAugmentedConversation,
35
+ )
33
36
  from .imports import Import
34
37
  from .interfaces import Interface
35
38
  from .menu.screen import show_intro_menu
36
39
  from .misc.console import ConsoleStyle
37
40
  from .misc.loader import Loader
38
41
  from .symbol import Symbol
42
+ from .utils import UserMessage
39
43
 
40
44
  logging.getLogger("prompt_toolkit").setLevel(logging.ERROR)
41
45
  logging.getLogger("asyncio").setLevel(logging.ERROR)
@@ -43,30 +47,37 @@ logging.getLogger("subprocess").setLevel(logging.ERROR)
43
47
 
44
48
  # load json config from home directory root
45
49
  home_path = HOME_PATH
46
- config_path = os.path.join(home_path, 'symsh.config.json')
50
+ config_path = home_path / "symsh.config.json"
47
51
  # migrate config from old path
48
- if 'colors' not in SYMSH_CONFIG:
52
+ if "colors" not in SYMSH_CONFIG:
49
53
  __new_config__ = {"colors": SYMSH_CONFIG}
50
54
  # add command in config
51
55
  SYMSH_CONFIG = __new_config__
52
56
  # save config
53
- with open(config_path, 'w') as f:
57
+ with config_path.open("w") as f:
54
58
  json.dump(__new_config__, f, indent=4)
55
59
 
56
60
  # make sure map-nt-cmd is in config
57
- if 'map-nt-cmd' not in SYMSH_CONFIG:
61
+ if "map-nt-cmd" not in SYMSH_CONFIG:
58
62
  # add command in config
59
- SYMSH_CONFIG['map-nt-cmd'] = True
63
+ SYMSH_CONFIG["map-nt-cmd"] = True
60
64
  # save config
61
- with open(config_path, 'w') as f:
65
+ with config_path.open("w") as f:
62
66
  json.dump(SYMSH_CONFIG, f, indent=4)
63
67
 
64
- print = print_formatted_text
65
- FunctionType = Function
66
- ConversationType = Conversation
67
- RetrievalConversationType = RetrievalAugmentedConversation
68
- use_styles = False
69
- map_nt_cmd_enabled = SYMSH_CONFIG['map-nt-cmd']
68
+ print = print_formatted_text # noqa
69
+ map_nt_cmd_enabled = SYMSH_CONFIG["map-nt-cmd"]
70
+
71
+ _shell_state = SimpleNamespace(
72
+ function_type=Function,
73
+ conversation_type=Conversation,
74
+ retrieval_conversation_type=RetrievalAugmentedConversation,
75
+ use_styles=False,
76
+ stateful_conversation=None,
77
+ previous_kwargs=None,
78
+ previous_prefix=None,
79
+ exec_prefix="default",
80
+ )
70
81
 
71
82
  SHELL_CONTEXT = """[Description]
72
83
  This shell program is the command interpreter on the Linux systems, MacOS and Windows PowerShell.
@@ -80,8 +91,6 @@ If additional instructions are provided the follow the user query to produce the
80
91
  A well related and helpful answer with suggested improvements is preferred over "I don't know" or "I don't understand" answers or stating the obvious.
81
92
  """
82
93
 
83
- stateful_conversation = None
84
- previous_kwargs = None
85
94
 
86
95
  def supports_ansi_escape():
87
96
  try:
@@ -90,15 +99,24 @@ def supports_ansi_escape():
90
99
  except OSError:
91
100
  return False
92
101
 
102
+
93
103
  class PathCompleter(Completer):
94
- def get_completions(self, document, complete_event):
104
+ def get_completions(self, document, _complete_event):
95
105
  complete_word = document.get_word_before_cursor(WORD=True)
96
106
  sep = os.path.sep
97
- if complete_word.startswith(f'~{sep}'):
107
+ if complete_word.startswith(f"~{sep}"):
98
108
  complete_word = FileReader.expand_user_path(complete_word)
99
109
 
100
110
  # list all files and directories in current directory
101
- files = list(glob.glob(complete_word + '*'))
111
+ complete_path = Path(complete_word)
112
+ if complete_word.endswith(sep):
113
+ parent = complete_path
114
+ pattern = "*"
115
+ else:
116
+ baseline = Path()
117
+ parent = complete_path.parent if complete_path.parent != baseline else baseline
118
+ pattern = f"{complete_path.name}*" if complete_path.name else "*"
119
+ files = [str(path) for path in parent.glob(pattern)]
102
120
  if len(files) == 0:
103
121
  return None
104
122
 
@@ -106,46 +124,56 @@ class PathCompleter(Completer):
106
124
  files_ = []
107
125
 
108
126
  for file in files:
127
+ path_obj = Path(file)
109
128
  # split the command into words by space (ignore escaped spaces)
110
- command_words = document.text.split(' ')
129
+ command_words = document.text.split(" ")
111
130
  if len(command_words) > 1:
112
131
  # Calculate start position of the completion
113
- start_position = len(document.text) - len(' '.join(command_words[:-1])) - 1
132
+ start_position = len(document.text) - len(" ".join(command_words[:-1])) - 1
114
133
  start_position = max(0, start_position)
115
134
  else:
116
135
  start_position = len(document.text)
117
136
  # if there is a space in the file name, then escape it
118
- if ' ' in file:
119
- file = file.replace(' ', '\\ ')
120
- if (document.text.startswith('cd') or document.text.startswith('mkdir')) and os.path.isfile(file):
137
+ display_name = file.replace(" ", "\\ ") if " " in file else file
138
+ if (
139
+ document.text.startswith("cd") or document.text.startswith("mkdir")
140
+ ) and path_obj.is_file():
121
141
  continue
122
- if os.path.isdir(file):
123
- dirs_.append(file)
142
+ if path_obj.is_dir():
143
+ dirs_.append(display_name)
124
144
  else:
125
- files_.append(file)
145
+ files_.append(display_name)
126
146
 
127
147
  for d in dirs_:
128
148
  # if starts with home directory, then replace it with ~
129
- d = FileReader.expand_user_path(d)
130
- yield Completion(d, start_position=-start_position,
131
- style='class:path-completion',
132
- selected_style='class:path-completion-selected')
149
+ directory_completion = FileReader.expand_user_path(d)
150
+ yield Completion(
151
+ directory_completion,
152
+ start_position=-start_position,
153
+ style="class:path-completion",
154
+ selected_style="class:path-completion-selected",
155
+ )
133
156
 
134
157
  for f in files_:
135
158
  # if starts with home directory, then replace it with ~
136
- f = FileReader.expand_user_path(f)
137
- yield Completion(f, start_position=-start_position,
138
- style='class:file-completion',
139
- selected_style='class:file-completion-selected')
159
+ file_completion = FileReader.expand_user_path(f)
160
+ yield Completion(
161
+ file_completion,
162
+ start_position=-start_position,
163
+ style="class:file-completion",
164
+ selected_style="class:file-completion-selected",
165
+ )
166
+
140
167
 
141
168
  class HistoryCompleter(WordCompleter):
142
169
  def get_completions(self, document, complete_event):
143
170
  completions = super().get_completions(document, complete_event)
144
171
  for completion in completions:
145
- completion.style = 'class:history-completion'
146
- completion.selected_style = 'class:history-completion-selected'
172
+ completion.style = "class:history-completion"
173
+ completion.selected_style = "class:history-completion-selected"
147
174
  yield completion
148
175
 
176
+
149
177
  class MergedCompleter(Completer):
150
178
  def __init__(self, path_completer, history_completer):
151
179
  self.path_completer = path_completer
@@ -154,53 +182,57 @@ class MergedCompleter(Completer):
154
182
  def get_completions(self, document, complete_event):
155
183
  text = document.text
156
184
 
157
- if text.startswith('cd ') or\
158
- text.startswith('ls ') or\
159
- text.startswith('touch ') or\
160
- text.startswith('cat ') or\
161
- text.startswith('mkdir ') or\
162
- text.startswith('open ') or\
163
- text.startswith('rm ') or\
164
- text.startswith('git ') or\
165
- text.startswith('vi ') or\
166
- text.startswith('nano ') or\
167
- text.startswith('*') or\
168
- text.startswith(r'.\\') or\
169
- text.startswith(r'~\\') or\
170
- text.startswith(r'\\') or\
171
- text.startswith('.\\') or\
172
- text.startswith('~\\') or\
173
- text.startswith('\\') or\
174
- text.startswith('./') or\
175
- text.startswith('~/') or\
176
- text.startswith('/'):
185
+ if (
186
+ text.startswith("cd ")
187
+ or text.startswith("ls ")
188
+ or text.startswith("touch ")
189
+ or text.startswith("cat ")
190
+ or text.startswith("mkdir ")
191
+ or text.startswith("open ")
192
+ or text.startswith("rm ")
193
+ or text.startswith("git ")
194
+ or text.startswith("vi ")
195
+ or text.startswith("nano ")
196
+ or text.startswith("*")
197
+ or text.startswith(r".\\")
198
+ or text.startswith(r"~\\")
199
+ or text.startswith(r"\\")
200
+ or text.startswith(".\\")
201
+ or text.startswith("~\\")
202
+ or text.startswith("\\")
203
+ or text.startswith("./")
204
+ or text.startswith("~/")
205
+ or text.startswith("/")
206
+ ):
177
207
  yield from self.path_completer.get_completions(document, complete_event)
178
208
  yield from self.history_completer.get_completions(document, complete_event)
179
209
  else:
180
210
  yield from self.history_completer.get_completions(document, complete_event)
181
211
  yield from self.path_completer.get_completions(document, complete_event)
182
212
 
213
+
183
214
  # Create custom keybindings
184
215
  bindings = KeyBindings()
185
- previous_prefix = None
186
- exec_prefix = 'default'
187
216
  # Get a copy of the current environment
188
217
  default_env = os.environ.copy()
189
218
 
219
+
190
220
  def get_exec_prefix():
191
- return sys.exec_prefix if exec_prefix == 'default' else exec_prefix
221
+ exec_prefix = _shell_state.exec_prefix
222
+ return sys.exec_prefix if exec_prefix == "default" else exec_prefix
223
+
192
224
 
193
225
  def get_conda_env():
194
226
  # what conda env am I in (e.g., where is my Python process from)?
195
227
  ENVBIN = get_exec_prefix()
196
- env_name = os.path.basename(ENVBIN)
197
- return env_name
228
+ return Path(ENVBIN).name
229
+
198
230
 
199
231
  # bind to 'Ctrl' + 'Space'
200
232
  @bindings.add(Keys.ControlSpace)
201
233
  def _(event):
202
234
  current_user_input = event.current_buffer.document.text
203
- func = FunctionType(SHELL_CONTEXT)
235
+ func = _shell_state.function_type(SHELL_CONTEXT)
204
236
 
205
237
  bottom_toolbar = HTML(' <b>[f]</b> Print "f" <b>[x]</b> Abort.')
206
238
 
@@ -208,61 +240,64 @@ def _(event):
208
240
  kb = KeyBindings()
209
241
 
210
242
  cancel = [False]
211
- @kb.add('f')
212
- def _(event):
213
- print('You pressed `f`.')
214
243
 
215
- @kb.add('x')
216
- def _(event):
217
- " Send Abort (control-c) signal. "
244
+ @kb.add("f")
245
+ def _(_event):
246
+ UserMessage("You pressed `f`.", style="alert")
247
+
248
+ @kb.add("x")
249
+ def _(_event):
250
+ "Send Abort (control-c) signal."
218
251
  cancel[0] = True
219
252
  os.kill(os.getpid(), signal.SIGINT)
220
253
 
221
254
  # Use `patch_stdout`, to make sure that prints go above the
222
255
  # application.
223
- with patch_stdout():
224
- with ProgressBar(key_bindings=kb, bottom_toolbar=bottom_toolbar) as pb:
225
- # TODO: hack to simulate progress bar of indeterminate length of an synchronous function
226
- for i in pb(range(100)):
227
- if i > 50 and i < 70:
228
- time.sleep(.01)
256
+ with patch_stdout(), ProgressBar(key_bindings=kb, bottom_toolbar=bottom_toolbar) as pb:
257
+ # TODO: hack to simulate progress bar of indeterminate length of an synchronous function
258
+ for i in pb(range(100)):
259
+ if i > 50 and i < 70:
260
+ time.sleep(0.01)
229
261
 
230
- if i == 60:
231
- res = func(current_user_input) # hack to see progress bar
262
+ if i == 60:
263
+ res = func(current_user_input) # hack to see progress bar
232
264
 
233
- # Stop when the cancel flag has been set.
234
- if cancel[0]:
235
- break
265
+ # Stop when the cancel flag has been set.
266
+ if cancel[0]:
267
+ break
236
268
 
237
- with ConsoleStyle('code') as console:
269
+ with ConsoleStyle("code") as console:
238
270
  console.print(res)
239
271
 
272
+
240
273
  @bindings.add(Keys.PageUp)
241
274
  def _(event):
242
275
  # Moving up for 5 lines
243
276
  for _ in range(5):
244
277
  event.current_buffer.auto_up()
245
278
 
279
+
246
280
  @bindings.add(Keys.PageDown)
247
281
  def _(event):
248
282
  # Moving down for 5 lines
249
283
  for _ in range(5):
250
284
  event.current_buffer.auto_down()
251
285
 
286
+
252
287
  class FileHistory(History):
253
- '''
288
+ """
254
289
  :class:`.History` class that stores all strings in a file.
255
- '''
290
+ """
256
291
 
257
292
  def __init__(self, filename: str) -> None:
258
- self.filename = filename
293
+ self.filename = Path(filename)
259
294
  super().__init__()
260
295
 
261
296
  def load_history_strings(self) -> Iterable[str]:
262
297
  lines: list[str] = []
263
298
 
264
- if os.path.exists(self.filename):
265
- with open(self.filename, "r") as f:
299
+ if self.filename.exists():
300
+ with self.filename.open() as f:
266
301
  lines = f.readlines()
267
302
  # Remove comments and empty lines.
268
303
  lines = [line for line in lines if line.strip() and not line.startswith("#")]
@@ -276,33 +311,40 @@ class FileHistory(History):
276
311
 
277
312
  def store_string(self, string: str) -> None:
278
313
  # Save to file.
279
- with open(self.filename, "ab") as f:
314
+ with self.filename.open("ab") as f:
280
315
 
281
316
  def write(t: str) -> None:
282
317
  f.write(t.encode("utf-8"))
283
318
 
284
319
  for line in string.split("\n"):
285
- write("%s\n" % line)
320
+ write(f"{line}\n")
321
+
286
322
 
287
323
  # Defining commands history
288
- def load_history(home_path=HOME_PATH, history_file='.bash_history'):
289
- history_file_path = os.path.join(home_path, history_file)
324
+ def load_history(home_path=HOME_PATH, history_file=".bash_history"):
325
+ history_file_path = home_path / history_file
290
326
  history = FileHistory(history_file_path)
291
327
  return history, list(history.load_history_strings())
292
328
 
329
+
293
330
  # Function to check if current directory is a git directory
294
331
  def get_git_branch():
295
332
  try:
296
- git_process = subprocess.Popen(['git', 'rev-parse', '--abbrev-ref', 'HEAD'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
297
- stdout, stderr = git_process.communicate()
333
+ git_process = subprocess.Popen(
334
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
335
+ stdout=subprocess.PIPE,
336
+ stderr=subprocess.PIPE,
337
+ )
338
+ stdout, _stderr = git_process.communicate()
298
339
  if git_process.returncode == 0:
299
- return stdout.strip().decode('utf-8')
340
+ return stdout.strip().decode("utf-8")
300
341
  except FileNotFoundError:
301
342
  pass
302
343
  return None
303
344
 
304
- def disambiguate(cmds: str) -> Tuple[str, int]:
305
- '''
345
+
346
+ def disambiguate(cmds: str) -> tuple[str, int]:
347
+ """
306
348
  Ok, so, possible options for now:
307
349
  1. query | cmd
308
350
  2. query | file [file ...]
@@ -310,174 +352,238 @@ def disambiguate(cmds: str) -> Tuple[str, int]:
310
352
  3. query | cmd | file
311
353
  4. query | cmd cmd ...
312
354
  5. query | file | cmd
313
- '''
314
- has_at_least_one_cmd = any([shutil.which(cmd) is not None for cmd in cmds.split(' ')])
315
- maybe_cmd = cmds.split(' ')[0].strip() # get first command
355
+ """
356
+ has_at_least_one_cmd = any(shutil.which(cmd) is not None for cmd in cmds.split(" "))
357
+ maybe_cmd = cmds.split(" ")[0].strip() # get first command
316
358
  maybe_files = FileReader.extract_files(cmds)
317
359
  # if cmd follows file(s) or file(s) follows cmd throw error as not supported
318
360
  if maybe_files is not None and has_at_least_one_cmd:
319
- raise ValueError('Cannot disambiguate commands that have both files and commands or multiple commands. Please provide correct order of commands. '
320
- 'Supported are: '
321
- 'query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
322
- 'and '
323
- 'query | cmd (e.g. "what flags can I use with rg?" | rg --help)')
361
+ msg = (
362
+ "Cannot disambiguate commands that have both files and commands or multiple commands. Please provide "
363
+ 'correct order of commands. Supported are: query | file [file ...] (e.g. "what do these files have in '
364
+ 'common?" | file1 [file2 ...]) and query | cmd (e.g. "what flags can I use with rg?" | rg --help)'
365
+ )
366
+ UserMessage(msg, raise_with=ValueError)
324
367
  # now check order of commands and keep correct order
325
368
  if shutil.which(maybe_cmd) is not None:
326
- cmd_out = subprocess.run(cmds, capture_output=True, text=True, shell=True)
369
+ cmd_out = subprocess.run(cmds, check=False, capture_output=True, text=True, shell=True)
327
370
  if not cmd_out.stdout:
328
- raise ValueError(f'Command not found or failed. Error: {cmd_out.stderr}')
371
+ msg = f"Command not found or failed. Error: {cmd_out.stderr}"
372
+ UserMessage(msg, raise_with=ValueError)
329
373
  return cmd_out.stdout, 1
330
374
  if maybe_files is not None:
331
375
  return maybe_files, 2
376
+ return None
377
+
332
378
 
333
379
  # query language model
334
- def query_language_model(query: str, res=None, *args, **kwargs):
335
- global stateful_conversation, previous_kwargs
336
- home_path = HOME_PATH
337
- symai_path = os.path.join(home_path, '.conversation_state')
338
- plugin = SYMSH_CONFIG.get('plugin_prefix')
339
-
340
- # check and extract kwargs from query if any
341
- # format --kwargs key1=value1,key2=value2,key3=value3,...keyN=valueN
342
- if '--kwargs' in query or '-kw' in query:
343
- splitter = '--kwargs' if '--kwargs' in query else '-kw'
344
- # check if kwargs format is last in query otherwise raise error
345
- splits = query.split(f'{splitter}')
346
- if previous_kwargs is None and '=' not in splits[-1] and ',' not in splits[-1]:
347
- raise ValueError('Kwargs format must be last in query.')
348
- elif previous_kwargs is not None and '=' not in splits[-1] and ',' not in splits[-1]:
349
- # use previous kwargs
350
- cmd_kwargs = previous_kwargs
351
- else:
352
- # remove kwargs from query
353
- query = splits[0].strip()
354
- # extract kwargs
355
- kwargs_str = splits[-1].strip()
356
- cmd_kwargs = dict([kw.split('=') for kw in kwargs_str.split(',')])
357
- cmd_kwargs = {k.strip(): Symbol(v.strip()).ast() for k, v in cmd_kwargs.items()}
358
-
359
- previous_kwargs = cmd_kwargs
360
- # unpack cmd_kwargs to kwargs
361
- kwargs = {**kwargs, **cmd_kwargs}
362
-
363
- # Handle stateful conversations:
364
- # 1. If query starts with !" (new conversation), create new conversation state
365
- # 2. If query starts with ." (follow-up), either:
366
- # - Load existing conversation state if it exists
367
- # - Create new conversation state if none exists
368
- if (query.startswith('!"') or query.startswith("!'") or query.startswith('!`')):
369
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
370
- stateful_conversation = ConversationType(auto_print=False)
371
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
372
- # Special case: if query starts with !" and has a prefix, run the prefix command and store the output
373
- if plugin is not None:
374
- with Loader(desc="Inference ...", end=""):
375
- cmd = query[1:].strip('\'"')
376
- cmd = f"symrun {plugin} '{cmd}' --disable-pbar"
377
- cmd_out = run_shell_command(cmd, auto_query_on_error=True)
378
- stateful_conversation.store(cmd_out)
379
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
380
- return cmd_out
381
- elif query.startswith('."') or query.startswith(".'") or query.startswith('.`'):
382
- try:
383
- stateful_conversation = stateful_conversation.load_conversation_state(symai_path)
384
- except Exception:
385
- with ConsoleStyle('error') as console:
386
- console.print('No conversation state found. Please start a new conversation.')
387
- return
388
- if plugin is not None:
389
- with Loader(desc="Inference ...", end=""):
390
- query = query[1:].strip('\'"')
391
- answer = stateful_conversation(query).value
392
- cmd = f"symrun {plugin} '{answer}' --disable-pbar"
393
- cmd_out = run_shell_command(cmd, auto_query_on_error=True)
394
- stateful_conversation.store(cmd_out)
395
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
396
- return cmd_out
397
- cmd = None
398
- if '|' in query:
399
- cmds = query.split('|')
400
- if len(cmds) > 2:
401
- raise ValueError(('Cannot disambiguate commands that have more than 1 pipes. Please provide correct order of commands. '
402
- 'Supported are: '
403
- 'query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
404
- 'and '
405
- 'query | cmd (e.g. "what flags can I use with rg?" | rg --help)'))
406
- query = cmds[0]
407
- payload, order = disambiguate(cmds[1].strip())
408
- # check if we're in a stateful conversation
409
- is_stateful = query.startswith(('.', '!')) and any(query.startswith(f"{prefix}{quote}")
410
- for prefix in ['.', '!']
411
- for quote in ['"', "'", '`'])
412
-
413
- if is_stateful:
414
- func = stateful_conversation
415
- else:
416
- func = FunctionType(payload) if order == 1 else ConversationType(file_link=payload, auto_print=False)
417
-
418
- if is_stateful:
419
- if order == 1:
420
- func.store(payload)
421
- elif order == 2:
422
- for file in payload:
423
- func.store_file(file)
380
+ def _starts_with_prefix(query: str, prefix: str) -> bool:
381
+ return (
382
+ query.startswith(f'{prefix}"')
383
+ or query.startswith(f"{prefix}'")
384
+ or query.startswith(f"{prefix}`")
385
+ )
386
+
387
+
388
+ def _is_new_conversation_query(query: str) -> bool:
389
+ return _starts_with_prefix(query, "!")
390
+
391
+
392
+ def _is_followup_conversation_query(query: str) -> bool:
393
+ return _starts_with_prefix(query, ".")
394
+
395
+
396
+ def _is_stateful_query(query: str) -> bool:
397
+ return any(_starts_with_prefix(query, prefix) for prefix in [".", "!"])
398
+
399
+
400
+ def _extract_query_kwargs(query: str, previous_kwargs, existing_kwargs):
401
+ if "--kwargs" not in query and "-kw" not in query:
402
+ return query, existing_kwargs, previous_kwargs
403
+
404
+ splitter = "--kwargs" if "--kwargs" in query else "-kw"
405
+ splits = query.split(splitter)
406
+ suffix = splits[-1]
407
+ if previous_kwargs is None and "=" not in suffix and "," not in suffix:
408
+ msg = "Kwargs format must be last in query."
409
+ UserMessage(msg, raise_with=ValueError)
410
+ if previous_kwargs is not None and "=" not in suffix and "," not in suffix:
411
+ cmd_kwargs = previous_kwargs
424
412
  else:
425
- if query.startswith('."') or query.startswith(".'") or query.startswith('.`') or\
426
- query.startswith('!"') or query.startswith("!'") or query.startswith('!`'):
427
- func = stateful_conversation
428
- else:
429
- func = FunctionType(SHELL_CONTEXT)
413
+ query = splits[0].strip()
414
+ kwargs_str = suffix.strip()
415
+ cmd_kwargs = dict([kw.split("=") for kw in kwargs_str.split(",")])
416
+ cmd_kwargs = {k.strip(): Symbol(v.strip()).ast() for k, v in cmd_kwargs.items()}
417
+
418
+ previous_kwargs = cmd_kwargs
419
+ merged_kwargs = {**existing_kwargs, **cmd_kwargs}
420
+ return query, merged_kwargs, previous_kwargs
421
+
422
+
423
+ def _process_new_conversation(query, conversation_cls, symai_path, plugin, previous_kwargs, state):
424
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
425
+ conversation = conversation_cls(auto_print=False)
426
+ conversation_cls.save_conversation_state(conversation, symai_path)
427
+ state.stateful_conversation = conversation
428
+ if plugin is None:
429
+ return conversation, previous_kwargs, None, False
430
+ with Loader(desc="Inference ...", end=""):
431
+ cmd = query[1:].strip("'\"")
432
+ cmd = f"symrun {plugin} '{cmd}' --disable-pbar"
433
+ cmd_out = run_shell_command(cmd, auto_query_on_error=True)
434
+ conversation.store(cmd_out)
435
+ conversation_cls.save_conversation_state(conversation, symai_path)
436
+ state.stateful_conversation = conversation
437
+ state.previous_kwargs = previous_kwargs
438
+ return conversation, previous_kwargs, cmd_out, True
439
+
440
+
441
+ def _process_followup_conversation(
442
+ query, conversation, conversation_cls, symai_path, plugin, previous_kwargs, state
443
+ ):
444
+ try:
445
+ conversation = conversation.load_conversation_state(symai_path)
446
+ state.stateful_conversation = conversation
447
+ except Exception:
448
+ with ConsoleStyle("error") as console:
449
+ console.print("No conversation state found. Please start a new conversation.")
450
+ return conversation, previous_kwargs, None, True
451
+ if plugin is None:
452
+ return conversation, previous_kwargs, None, False
453
+ with Loader(desc="Inference ...", end=""):
454
+ trimmed_query = query[1:].strip("'\"")
455
+ answer = conversation(trimmed_query).value
456
+ cmd = f"symrun {plugin} '{answer}' --disable-pbar"
457
+ cmd_out = run_shell_command(cmd, auto_query_on_error=True)
458
+ conversation.store(cmd_out)
459
+ conversation_cls.save_conversation_state(conversation, symai_path)
460
+ state.stateful_conversation = conversation
461
+ state.previous_kwargs = previous_kwargs
462
+ return conversation, previous_kwargs, cmd_out, True
463
+
464
+
465
+ def _handle_piped_query(query, conversation, state):
466
+ cmds = query.split("|")
467
+ if len(cmds) > 2:
468
+ msg = (
469
+ "Cannot disambiguate commands that have more than 1 pipes. Please provide correct order of commands. "
470
+ 'Supported are: query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
471
+ 'and query | cmd (e.g. "what flags can I use with rg?" | rg --help)'
472
+ )
473
+ UserMessage(msg, raise_with=ValueError)
474
+ base_query = cmds[0]
475
+ payload, order = disambiguate(cmds[1].strip())
476
+ is_stateful = _is_stateful_query(base_query)
477
+ if is_stateful:
478
+ func = conversation
479
+ else:
480
+ func = (
481
+ state.function_type(payload)
482
+ if order == 1
483
+ else state.conversation_type(file_link=payload, auto_print=False)
484
+ )
485
+ if is_stateful:
486
+ if order == 1:
487
+ func.store(payload)
488
+ elif order == 2:
489
+ for file in payload:
490
+ func.store_file(file)
491
+ return func, base_query
492
+
493
+
494
+ def _select_function_for_query(query, conversation, state):
495
+ if "|" in query:
496
+ return _handle_piped_query(query, conversation, state)
497
+ if _is_stateful_query(query):
498
+ return conversation, query
499
+ return state.function_type(SHELL_CONTEXT), query
500
+
501
+
502
+ def _should_save_conversation(conversation, query):
503
+ if conversation is None:
504
+ return False
505
+ return _is_stateful_query(query)
430
506
 
507
+
508
+ def query_language_model(query: str, res=None, *args, **kwargs):
509
+ state = _shell_state
510
+ conversation = state.stateful_conversation
511
+ previous_kwargs = state.previous_kwargs
512
+ conversation_cls = state.conversation_type
513
+ home_path = HOME_PATH
514
+ symai_path = home_path / ".conversation_state"
515
+ plugin = SYMSH_CONFIG.get("plugin_prefix")
516
+
517
+ query, kwargs, previous_kwargs = _extract_query_kwargs(query, previous_kwargs, kwargs)
518
+
519
+ if _is_new_conversation_query(query):
520
+ conversation, previous_kwargs, result, should_return = _process_new_conversation(
521
+ query, conversation_cls, symai_path, plugin, previous_kwargs, state
522
+ )
523
+ if should_return:
524
+ return result
525
+ elif _is_followup_conversation_query(query):
526
+ conversation, previous_kwargs, result, should_return = _process_followup_conversation(
527
+ query, conversation, conversation_cls, symai_path, plugin, previous_kwargs, state
528
+ )
529
+ if should_return:
530
+ return result
531
+ func, query = _select_function_for_query(query, conversation, state)
431
532
  with Loader(desc="Inference ...", end=""):
533
+ query_to_execute = query
432
534
  if res is not None:
433
- query = f"[Context]\n{res}\n\n[Query]\n{query}"
434
- msg = func(query, *args, **kwargs)
535
+ query_to_execute = f"[Context]\n{res}\n\n[Query]\n{query}"
536
+ msg = func(query_to_execute, *args, **kwargs)
537
+ if res is not None:
538
+ query = query_to_execute
435
539
 
436
- if stateful_conversation is not None and (
437
- query.startswith('."') or query.startswith(".'") or query.startswith('.`') or
438
- query.startswith('!"') or query.startswith("!'") or query.startswith('!`')
439
- ):
440
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
540
+ if _should_save_conversation(conversation, query):
541
+ conversation_cls.save_conversation_state(conversation, symai_path)
441
542
 
543
+ state.stateful_conversation = conversation
544
+ state.previous_kwargs = previous_kwargs
442
545
  return msg
443
546
 
444
- def retrieval_augmented_indexing(query: str, index_name = None, *args, **kwargs):
445
- global stateful_conversation
547
+
548
+ def retrieval_augmented_indexing(query: str, index_name=None, *_args, **_kwargs):
549
+ state = _shell_state
446
550
  sep = os.path.sep
447
551
  path = query
552
+ home_path = HOME_PATH
448
553
 
449
554
  # check if path contains overwrite flag
450
555
  overwrite = False
451
- if path.startswith('!'):
556
+ if path.startswith("!"):
452
557
  overwrite = True
453
558
  path = path[1:]
454
559
 
455
560
  # check if request use of specific index
456
- use_index_name = False
457
- if path.startswith('index:'):
561
+ use_index_name = False
562
+ if path.startswith("index:"):
458
563
  use_index_name = True
459
564
  # continue conversation with specific index
460
- index_name = path.split('index:')[-1].strip()
565
+ index_name = path.split("index:")[-1].strip()
461
566
  else:
462
567
  parse_arxiv = False
463
568
 
464
569
  # check if path contains arxiv flag
465
- if path.startswith('arxiv:'):
570
+ if path.startswith("arxiv:"):
466
571
  parse_arxiv = True
467
572
 
468
573
  # check if path contains git flag
469
- if path.startswith('git@'):
574
+ if path.startswith("git@"):
470
575
  overwrite = True
471
- repo_path = os.path.join(home_path, 'temp')
472
- cloner = RepositoryCloner(repo_path=repo_path)
473
- url = path[4:]
474
- if 'http' not in url:
475
- url = 'https://' + url
476
- url = url.replace('.com:', '.com/')
477
- # if ends with '.git' then remove it
478
- if url.endswith('.git'):
479
- url = url[:-4]
480
- path = cloner(url)
576
+ repo_path = home_path / "temp"
577
+ with Loader(desc="Cloning repo ...", end=""):
578
+ cloner = RepositoryCloner(repo_path=str(repo_path))
579
+ url = path[4:]
580
+ if "http" not in url:
581
+ url = "https://" + url
582
+ url = url.replace(".com:", ".com/")
583
+ # if ends with '.git' then remove it
584
+ if url.endswith(".git"):
585
+ url = url[:-4]
586
+ path = cloner(url)
481
587
 
482
588
  # merge files
483
589
  merger = FileMerger()
@@ -487,112 +593,125 @@ def retrieval_augmented_indexing(query: str, index_name = None, *args, **kwargs)
487
593
  arxiv = ArxivPdfParser()
488
594
  pdf_file = arxiv(file)
489
595
  if pdf_file is not None:
490
- file = file |'\n'| pdf_file
596
+ file = file | "\n" | pdf_file
491
597
 
492
598
  index_name = path.split(sep)[-1] if index_name is None else index_name
493
599
  index_name = Indexer.replace_special_chars(index_name)
494
- print(f'Indexing {index_name} ...')
600
+ UserMessage(f"Indexing {index_name} ...", style="extensity")
495
601
 
496
602
  # creates index if not exists
497
603
  DocumentRetriever(index_name=index_name, file=file, overwrite=overwrite)
498
604
 
499
- home_path = HOME_PATH
500
- symai_path = os.path.join(home_path, '.conversation_state')
501
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
502
- stateful_conversation = RetrievalConversationType(auto_print=False, index_name=index_name)
605
+ symai_path = home_path / ".conversation_state"
606
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
607
+ stateful_conversation = state.retrieval_conversation_type(
608
+ auto_print=False, index_name=index_name
609
+ )
610
+ state.stateful_conversation = stateful_conversation
503
611
  Conversation.save_conversation_state(stateful_conversation, symai_path)
504
612
  if use_index_name:
505
- message = 'New session '
613
+ message = "New session "
506
614
  else:
507
- message = f'Repository {url} cloned and ' if query.startswith('git@') or query.startswith('git:') else f'Directory {path} '
508
- msg = f'{message}successfully indexed: {index_name}'
509
- return msg
615
+ message = (
616
+ f"Repository {url} cloned and "
617
+ if query.startswith("git@") or query.startswith("git:")
618
+ else f"Directory {path} "
619
+ )
620
+ return f"{message}successfully indexed: {index_name}"
621
+
510
622
 
511
- def search_engine(query: str, res=None, *args, **kwargs):
512
- search = Interface('serpapi')
623
+ def search_engine(query: str, res=None, *_args, **_kwargs):
624
+ search = Interface("serpapi")
513
625
  with Loader(desc="Searching ...", end=""):
514
- search_query = Symbol(query).extract('search engine optimized query')
626
+ search_query = Symbol(query).extract("search engine optimized query")
515
627
  res = search(search_query)
516
628
  with Loader(desc="Inference ...", end=""):
517
- func = FunctionType(query)
629
+ func = _shell_state.function_type(query)
518
630
  msg = func(res, payload=res)
519
631
  # write a temp dump file with the query and results
520
632
  home_path = HOME_PATH
521
- symai_path = os.path.join(home_path, '.search_dump')
522
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
523
- with open(symai_path, 'w') as f:
524
- f.write(f'[SEARCH_QUERY]:\n{search_query}\n[RESULTS]\n{res}\n[MESSAGE]\n{msg}')
633
+ symai_path = home_path / ".search_dump"
634
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
635
+ with symai_path.open("w") as f:
636
+ f.write(f"[SEARCH_QUERY]:\n{search_query}\n[RESULTS]\n{res}\n[MESSAGE]\n{msg}")
525
637
  return msg
526
638
 
639
+
527
640
  def set_default_module(cmd: str):
528
- if cmd.startswith('set-plugin'):
529
- module = cmd.split('set-plugin')[-1].strip()
530
- SYMSH_CONFIG['plugin_prefix'] = module
531
- with open(config_path, 'w') as f:
641
+ if cmd.startswith("set-plugin"):
642
+ module = cmd.split("set-plugin")[-1].strip()
643
+ SYMSH_CONFIG["plugin_prefix"] = module
644
+ with config_path.open("w") as f:
532
645
  json.dump(SYMSH_CONFIG, f, indent=4)
533
646
  msg = f"Default plugin set to '{module}'"
534
- elif cmd == 'unset-plugin':
535
- SYMSH_CONFIG['plugin_prefix'] = None
536
- with open(config_path, 'w') as f:
647
+ elif cmd == "unset-plugin":
648
+ SYMSH_CONFIG["plugin_prefix"] = None
649
+ with config_path.open("w") as f:
537
650
  json.dump(SYMSH_CONFIG, f, indent=4)
538
651
  msg = "Default plugin unset"
539
- elif cmd == 'get-plugin':
652
+ elif cmd == "get-plugin":
540
653
  msg = f"Default plugin is '{SYMSH_CONFIG['plugin_prefix']}'"
541
654
 
542
- with ConsoleStyle('success') as console:
655
+ with ConsoleStyle("success") as console:
543
656
  console.print(msg)
544
657
 
658
+
545
659
  def handle_error(cmd, res, message, auto_query_on_error):
546
- msg = Symbol(cmd) | f'\n{str(res)}'
547
- if 'command not found' in str(res) or 'not recognized as an internal or external command' in str(res):
548
- return res.stderr.decode('utf-8')
549
- else:
550
- stderr = res.stderr
551
- if stderr and auto_query_on_error:
552
- rsp = stderr.decode('utf-8')
553
- print(rsp)
554
- msg = msg | f"\n{rsp}"
555
- if 'usage:' in rsp:
556
- try:
557
- cmd = cmd.split('usage: ')[-1].split(' ')[0]
558
- # get man page result for command
559
- res = subprocess.run('man -P cat %s' % cmd,
560
- shell=True,
561
- stdout=subprocess.PIPE)
562
- stdout = res.stdout
563
- if stdout:
564
- rsp = stdout.decode('utf-8')[:500]
565
- msg = msg | f"\n{rsp}"
566
- except Exception:
567
- pass
568
-
569
- return query_language_model(msg)
570
- else:
571
- stdout = res.stdout
572
- if stdout:
573
- message = stderr.decode('utf-8')
574
- return message
660
+ msg = Symbol(cmd) | f"\n{res!s}"
661
+ if "command not found" in str(
662
+ res
663
+ ) or "not recognized as an internal or external command" in str(res):
664
+ return res.stderr.decode("utf-8")
665
+ stderr = res.stderr
666
+ if stderr and auto_query_on_error:
667
+ rsp = stderr.decode("utf-8")
668
+ UserMessage(rsp, style="alert")
669
+ msg = msg | f"\n{rsp}"
670
+ if "usage:" in rsp:
671
+ try:
672
+ cmd = cmd.split("usage: ")[-1].split(" ")[0]
673
+ # get man page result for command
674
+ res = subprocess.run(
675
+ f"man -P cat {cmd}", check=False, shell=True, stdout=subprocess.PIPE
676
+ )
677
+ stdout = res.stdout
678
+ if stdout:
679
+ rsp = stdout.decode("utf-8")[:500]
680
+ msg = msg | f"\n{rsp}"
681
+ except Exception:
682
+ pass
683
+
684
+ return query_language_model(msg)
685
+ stdout = res.stdout
686
+ if stdout:
687
+ message = stderr.decode("utf-8")
688
+ return message
689
+
575
690
 
576
691
  # run shell command
577
- def run_shell_command(cmd: str, prev=None, auto_query_on_error: bool=False, stdout=None, stderr=None):
692
+ def run_shell_command(
693
+ cmd: str, prev=None, auto_query_on_error: bool = False, stdout=None, stderr=None
694
+ ):
578
695
  if prev is not None:
579
- cmd = prev + ' && ' + cmd
696
+ cmd = prev + " && " + cmd
580
697
  message = None
581
698
  conda_env = get_exec_prefix()
582
699
  # copy default_env
583
700
  new_env = default_env.copy()
584
- if exec_prefix != 'default':
701
+ if _shell_state.exec_prefix != "default":
585
702
  # remove current env from PATH
586
703
  new_env["PATH"] = new_env["PATH"].replace(sys.exec_prefix, conda_env)
587
704
  # Execute the command
588
705
  try:
589
706
  stdout = subprocess.PIPE if auto_query_on_error else stdout
590
707
  stderr = subprocess.PIPE if auto_query_on_error else stderr
591
- res = subprocess.run(cmd, shell=True, stdout=stdout, stderr=stderr, env=new_env)
708
+ res = subprocess.run(
709
+ cmd, check=False, shell=True, stdout=stdout, stderr=stderr, env=new_env
710
+ )
592
711
  if res and stdout and res.stdout:
593
- message = res.stdout.decode('utf-8')
712
+ message = res.stdout.decode("utf-8")
594
713
  elif res and stderr and res.stderr:
595
- message = res.stderr.decode('utf-8')
714
+ message = res.stderr.decode("utf-8")
596
715
  except FileNotFoundError as e:
597
716
  return e
598
717
  except PermissionError as e:
@@ -602,49 +721,61 @@ def run_shell_command(cmd: str, prev=None, auto_query_on_error: bool=False, stdo
602
721
  if res.returncode == 0:
603
722
  return message
604
723
  # If command not found, then try to query language model
605
- else:
606
- return handle_error(cmd, res, message, auto_query_on_error)
724
+ return handle_error(cmd, res, message, auto_query_on_error)
725
+
607
726
 
608
727
  def is_llm_request(cmd: str):
609
- return cmd.startswith('"') or cmd.startswith('."') or cmd.startswith('!"') or cmd.startswith('?"') or\
610
- cmd.startswith("'") or cmd.startswith(".'") or cmd.startswith("!'") or cmd.startswith("?'") or\
611
- cmd.startswith('`') or cmd.startswith('.`') or cmd.startswith('!`') or cmd.startswith('?`') or\
612
- cmd.startswith('!(')
728
+ return (
729
+ cmd.startswith('"')
730
+ or cmd.startswith('."')
731
+ or cmd.startswith('!"')
732
+ or cmd.startswith('?"')
733
+ or cmd.startswith("'")
734
+ or cmd.startswith(".'")
735
+ or cmd.startswith("!'")
736
+ or cmd.startswith("?'")
737
+ or cmd.startswith("`")
738
+ or cmd.startswith(".`")
739
+ or cmd.startswith("!`")
740
+ or cmd.startswith("?`")
741
+ or cmd.startswith("!(")
742
+ )
743
+
613
744
 
614
745
  def map_nt_cmd(cmd: str, map_nt_cmd_enabled: bool = True):
615
- if os.name.lower() == 'nt' and map_nt_cmd_enabled and not is_llm_request(cmd):
746
+ if os.name.lower() == "nt" and map_nt_cmd_enabled and not is_llm_request(cmd):
616
747
  # Mapping command replacements with regex for commands with variants
617
748
  cmd_mappings = {
618
- r'\bls\b(-[a-zA-Z]*)?' : r'dir \1', # Maps 'ls' with or without arguments
619
- r'\bmv\b\s+(.*)' : r'move \1', # Maps 'mv' with any arguments
620
- r'\bcp\b\s+(.*)' : r'copy \1', # Maps 'cp' with any arguments
621
- r'\btouch\b\s+(.*)' : r'type nul > \1', # Maps 'touch filename' to 'type nul > filename'
622
- r'\brm\b\s+(-rf)?' : r'del \1', # Maps 'rm' and 'rm -rf'
623
- r'\bdiff\b\s+(.*)' : r'fc \1', # Maps 'diff' with any arguments
624
- r'\bgrep\b\s+(.*)' : r'find \1', # Maps 'grep' with any arguments
625
- r'\bpwd\b' : 'chdir', # pwd has no arguments
626
- r'\bdate\b' : 'time', # date has no arguments
627
- r'\bmkdir\b\s+(.*)' : r'md \1', # Maps 'mkdir' with any arguments
628
- r'\bwhich\b\s+(.*)' : r'where \1', # Maps 'which' with any arguments
629
- r'\b(vim|nano)\b\s+(.*)' : r'notepad \2', # Maps 'vim' or 'nano' with any arguments
630
- r'\b(mke2fs|mformat)\b\s+(.*)' : r'format \2', # Maps 'mke2fs' or 'mformat' with any arguments
631
- r'\b(rm\s+-rf|rmdir)\b' : 'rmdir /s /q', # Matches 'rm -rf' or 'rmdir'
632
- r'\bkill\b\s+(.*)' : r'taskkill \1', # Maps 'kill' with any arguments
633
- r'\bps\b\s*(.*)?' : r'tasklist \1', # Maps 'ps' with any or no arguments
634
- r'\bexport\b\s+(.*)' : r'set \1', # Maps 'export' with any arguments
635
- r'\b(chown|chmod)\b\s+(.*)' : r'attrib +r \2', # Maps 'chown' or 'chmod' with any arguments
636
- r'\btraceroute\b\s+(.*)' : r'tracert \1', # Maps 'traceroute' with any arguments
637
- r'\bcron\b\s+(.*)' : r'at \1', # Maps 'cron' with any arguments
638
- r'\bcat\b\s+(.*)' : r'type \1', # Maps 'cat' with any arguments
639
- r'\bdu\s+-s\b' : 'chkdsk', # du -s has no arguments, chkdsk is closest in functionality
640
- r'\bls\s+-R\b' : 'tree', # ls -R has no arguments
749
+ r"\bls\b(-[a-zA-Z]*)?": r"dir \1", # Maps 'ls' with or without arguments
750
+ r"\bmv\b\s+(.*)": r"move \1", # Maps 'mv' with any arguments
751
+ r"\bcp\b\s+(.*)": r"copy \1", # Maps 'cp' with any arguments
752
+ r"\btouch\b\s+(.*)": r"type nul > \1", # Maps 'touch filename' to 'type nul > filename'
753
+ r"\brm\b\s+(-rf)?": r"del \1", # Maps 'rm' and 'rm -rf'
754
+ r"\bdiff\b\s+(.*)": r"fc \1", # Maps 'diff' with any arguments
755
+ r"\bgrep\b\s+(.*)": r"find \1", # Maps 'grep' with any arguments
756
+ r"\bpwd\b": "chdir", # pwd has no arguments
757
+ r"\bdate\b": "time", # date has no arguments
758
+ r"\bmkdir\b\s+(.*)": r"md \1", # Maps 'mkdir' with any arguments
759
+ r"\bwhich\b\s+(.*)": r"where \1", # Maps 'which' with any arguments
760
+ r"\b(vim|nano)\b\s+(.*)": r"notepad \2", # Maps 'vim' or 'nano' with any arguments
761
+ r"\b(mke2fs|mformat)\b\s+(.*)": r"format \2", # Maps 'mke2fs' or 'mformat' with any arguments
762
+ r"\b(rm\s+-rf|rmdir)\b": "rmdir /s /q", # Matches 'rm -rf' or 'rmdir'
763
+ r"\bkill\b\s+(.*)": r"taskkill \1", # Maps 'kill' with any arguments
764
+ r"\bps\b\s*(.*)?": r"tasklist \1", # Maps 'ps' with any or no arguments
765
+ r"\bexport\b\s+(.*)": r"set \1", # Maps 'export' with any arguments
766
+ r"\b(chown|chmod)\b\s+(.*)": r"attrib +r \2", # Maps 'chown' or 'chmod' with any arguments
767
+ r"\btraceroute\b\s+(.*)": r"tracert \1", # Maps 'traceroute' with any arguments
768
+ r"\bcron\b\s+(.*)": r"at \1", # Maps 'cron' with any arguments
769
+ r"\bcat\b\s+(.*)": r"type \1", # Maps 'cat' with any arguments
770
+ r"\bdu\s+-s\b": "chkdsk", # du -s has no arguments, chkdsk is closest in functionality
771
+ r"\bls\s+-R\b": "tree", # ls -R has no arguments
641
772
  }
642
773
 
643
774
  # Remove 1:1 mappings
644
775
  direct_mappings = {
645
- 'clear': 'cls',
646
- 'man' : 'help',
647
- 'mem' : 'free',
776
+ "clear": "cls",
777
+ "man": "help",
778
+ "mem": "free",
648
779
  }
649
780
 
650
781
  cmd_mappings.update(direct_mappings)
@@ -655,201 +786,264 @@ def map_nt_cmd(cmd: str, map_nt_cmd_enabled: bool = True):
655
786
  original_cmd = cmd
656
787
  cmd = re.sub(linux_cmd, windows_cmd, cmd)
657
788
  if cmd != original_cmd:
658
- print(f'symsh >> command "{original_cmd}" mapped to "{cmd}"\n')
789
+ UserMessage(
790
+ f'symsh >> command "{original_cmd}" mapped to "{cmd}"\n', style="extensity"
791
+ )
659
792
 
660
793
  return cmd
661
794
 
662
- def process_command(cmd: str, res=None, auto_query_on_error: bool=False):
663
- global exec_prefix, previous_prefix
664
795
 
665
- # map commands to windows if needed
666
- cmd = map_nt_cmd(cmd)
667
- if cmd.startswith('set-plugin') or cmd == 'unset-plugin' or cmd == 'get-plugin':
796
+ def _handle_plugin_commands(cmd: str):
797
+ if cmd.startswith("set-plugin") or cmd == "unset-plugin" or cmd == "get-plugin":
668
798
  return set_default_module(cmd)
799
+ return None
669
800
 
670
- sep = os.path.sep
671
- # check for '&&' to also preserve pipes '|' in normal shell commands
672
- if '" && ' in cmd or "' && " in cmd or '` && ' in cmd:
673
- if is_llm_request(cmd):
674
- # Process each command (the ones involving the LLM) separately
675
- cmds = cmd.split(' && ')
676
- if not is_llm_request(cmds[0]):
677
- return ValueError('The first command must be a LLM request.')
678
- # Process the first command as an LLM request
679
- res = query_language_model(cmds[0], res=res)
680
- rest = ' && '.join(cmds[1:])
681
- if '$1' in cmds[1]:
682
- res = str(res).replace('\n', r'\\n')
683
- rest = rest.replace('$1', '"%s"' % res)
684
- res = None
685
- cmd = rest
686
- # If it's a normal shell command with pipes or &&, pass it whole
687
- res = run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
688
- return res
689
-
690
- if cmd.startswith('?"') or cmd.startswith("?'") or cmd.startswith('?`'):
691
- cmd = cmd[1:]
692
- return search_engine(cmd, res=res)
693
-
694
- elif is_llm_request(cmd) or '...' in cmd:
801
+
802
+ def _handle_chained_llm_commands(cmd: str, res, auto_query_on_error: bool):
803
+ if '" && ' not in cmd and "' && " not in cmd and "` && " not in cmd:
804
+ return None
805
+ if not is_llm_request(cmd):
806
+ return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
807
+ cmds = cmd.split(" && ")
808
+ if not is_llm_request(cmds[0]):
809
+ return ValueError("The first command must be a LLM request.")
810
+ first_res = query_language_model(cmds[0], res=res)
811
+ rest = " && ".join(cmds[1:])
812
+ if len(cmds) > 1 and "$1" in cmds[1]:
813
+ first_res_str = str(first_res).replace("\n", r"\\n")
814
+ rest = rest.replace("$1", f'"{first_res_str}"')
815
+ first_res = None
816
+ return run_shell_command(rest, prev=first_res, auto_query_on_error=auto_query_on_error)
817
+
818
+
819
+ def _handle_llm_or_search(cmd: str, res):
820
+ if cmd.startswith('?"') or cmd.startswith("?'") or cmd.startswith("?`"):
821
+ query = cmd[1:]
822
+ return search_engine(query, res=res)
823
+ if is_llm_request(cmd) or "..." in cmd:
695
824
  return query_language_model(cmd, res=res)
825
+ return None
826
+
827
+
828
+ def _handle_retrieval_commands(cmd: str):
829
+ if cmd.startswith("*"):
830
+ return retrieval_augmented_indexing(cmd[1:])
831
+ return None
832
+
696
833
 
697
- elif cmd.startswith('*'):
698
- cmd = cmd[1:]
699
- return retrieval_augmented_indexing(cmd)
700
-
701
- elif cmd.startswith('man symsh'):
702
- # read symsh.md file and print it
703
- # get symsh path
704
- pkg_path = os.path.dirname(os.path.abspath(__file__))
705
- symsh_path = os.path.join(pkg_path, 'symsh.md')
706
- with open(symsh_path, 'r', encoding="utf8") as f:
707
- return f.read()
708
-
709
- elif cmd.startswith('conda activate'):
710
- # check conda execution prefix and verify if environment exists
711
- env = sys.exec_prefix
712
- path_ = sep.join(env.split(sep)[:-1])
713
- env_base = os.path.join(sep, path_)
714
- req_env = cmd.split(' ')[2]
715
- # check if environment exists
716
- env_path = os.path.join(env_base, req_env)
717
- if not os.path.exists(env_path):
718
- return f'Environment {req_env} does not exist!'
719
- previous_prefix = exec_prefix
720
- exec_prefix = os.path.join(env_base, req_env)
721
- return exec_prefix
722
-
723
- elif cmd.startswith('conda deactivate'):
724
- if previous_prefix is not None:
725
- exec_prefix = previous_prefix
726
- if previous_prefix == 'default':
727
- previous_prefix = None
834
+ def _handle_man_command(cmd: str):
835
+ if cmd.startswith("man symsh"):
836
+ pkg_path = Path(__file__).resolve().parent
837
+ symsh_path = pkg_path / "symsh.md"
838
+ with symsh_path.open(encoding="utf8") as file_ptr:
839
+ return file_ptr.read()
840
+ return None
841
+
842
+
843
+ def _handle_conda_commands(cmd: str, state, res, auto_query_on_error: bool):
844
+ if cmd.startswith("conda activate"):
845
+ env = Path(sys.exec_prefix)
846
+ env_base = env.parent
847
+ req_env = cmd.split(" ")[2]
848
+ env_path = env_base / req_env
849
+ if not env_path.exists():
850
+ return f"Environment {req_env} does not exist!"
851
+ state.previous_prefix = state.exec_prefix
852
+ state.exec_prefix = str(env_path)
853
+ return state.exec_prefix
854
+ if cmd.startswith("conda deactivate"):
855
+ prev_prefix = state.previous_prefix
856
+ if prev_prefix is not None:
857
+ state.exec_prefix = prev_prefix
858
+ if prev_prefix == "default":
859
+ state.previous_prefix = None
728
860
  return get_exec_prefix()
861
+ if cmd.startswith("conda"):
862
+ env = Path(get_exec_prefix())
863
+ try:
864
+ env_base = env.parents[1]
865
+ except IndexError:
866
+ env_base = env.parent
867
+ cmd_rewritten = cmd.replace("conda", str(env_base / "condabin" / "conda"))
868
+ return run_shell_command(cmd_rewritten, prev=res, auto_query_on_error=auto_query_on_error)
869
+ return None
729
870
 
730
- elif cmd.startswith('conda'):
731
- env = get_exec_prefix()
732
- env_base = os.path.join(sep, *env.split(sep)[:-2])
733
- cmd = cmd.replace('conda', os.path.join(env_base, "condabin", "conda"))
734
- return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
735
871
 
736
- elif cmd.startswith('cd'):
872
+ def _handle_directory_navigation(cmd: str):
873
+ sep = os.path.sep
874
+ if cmd.startswith("cd"):
737
875
  try:
738
- # replace ~ with home directory
739
- cmd = FileReader.expand_user_path(cmd)
740
- # Change directory
741
- path = ' '.join(cmd.split(' ')[1:])
876
+ cmd_expanded = FileReader.expand_user_path(cmd)
877
+ path = " ".join(cmd_expanded.split(" ")[1:])
742
878
  if path.endswith(sep):
743
879
  path = path[:-1]
744
880
  return os.chdir(path)
745
- except FileNotFoundError as e:
746
- return e
747
- except PermissionError as e:
748
- return e
749
-
750
- elif os.path.isdir(cmd):
881
+ except FileNotFoundError as err:
882
+ return err
883
+ except PermissionError as err:
884
+ return err
885
+ cmd_path = FileReader.expand_user_path(cmd)
886
+ if Path(cmd).is_dir():
751
887
  try:
752
- # replace ~ with home directory
753
- cmd = FileReader.expand_user_path(cmd)
754
- # Change directory
755
- os.chdir(cmd)
756
- except FileNotFoundError as e:
757
- return e
758
- except PermissionError as e:
759
- return e
760
-
761
- elif cmd.startswith('ll'):
762
-
763
- if os.name == 'nt':
764
- cmd = cmd.replace('ll', 'dir')
765
- return run_shell_command(cmd, prev=res)
766
- else:
767
- cmd = cmd.replace('ll', 'ls -la')
768
- return run_shell_command(cmd, prev=res)
888
+ os.chdir(cmd_path)
889
+ except FileNotFoundError as err:
890
+ return err
891
+ except PermissionError as err:
892
+ return err
893
+ return None
894
+
895
+
896
+ def _handle_ll_alias(cmd: str, res):
897
+ if not cmd.startswith("ll"):
898
+ return None
899
+ if os.name == "nt":
900
+ rewritten = cmd.replace("ll", "dir")
901
+ return run_shell_command(rewritten, prev=res)
902
+ rewritten = cmd.replace("ll", "ls -la")
903
+ return run_shell_command(rewritten, prev=res)
904
+
905
+
906
+ def process_command(cmd: str, res=None, auto_query_on_error: bool = False):
907
+ state = _shell_state
908
+
909
+ # map commands to windows if needed
910
+ cmd = map_nt_cmd(cmd)
911
+ plugin_result = _handle_plugin_commands(cmd)
912
+ if plugin_result is not None:
913
+ return plugin_result
914
+
915
+ chained_result = _handle_chained_llm_commands(cmd, res, auto_query_on_error)
916
+ if chained_result is not None:
917
+ return chained_result
918
+
919
+ llm_or_search = _handle_llm_or_search(cmd, res)
920
+ if llm_or_search is not None:
921
+ return llm_or_search
922
+
923
+ retrieval_result = _handle_retrieval_commands(cmd)
924
+ if retrieval_result is not None:
925
+ return retrieval_result
926
+
927
+ man_result = _handle_man_command(cmd)
928
+ if man_result is not None:
929
+ return man_result
930
+
931
+ conda_result = _handle_conda_commands(cmd, state, res, auto_query_on_error)
932
+ if conda_result is not None:
933
+ return conda_result
934
+
935
+ directory_result = _handle_directory_navigation(cmd)
936
+ if directory_result is not None:
937
+ return directory_result
938
+
939
+ ll_result = _handle_ll_alias(cmd, res, auto_query_on_error)
940
+ if ll_result is not None:
941
+ return ll_result
942
+
943
+ return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
769
944
 
770
- else:
771
- return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
772
945
 
773
946
  def save_conversation():
774
947
  home_path = HOME_PATH
775
- symai_path = os.path.join(home_path, '.conversation_state')
776
- Conversation.save_conversation_state(stateful_conversation, symai_path)
948
+ symai_path = home_path / ".conversation_state"
949
+ Conversation.save_conversation_state(_shell_state.stateful_conversation, symai_path)
950
+
951
+
952
+ def _is_exit_command(cmd: str) -> bool:
953
+ return cmd in ["quit", "exit", "q"]
954
+
955
+
956
+ def _format_working_directory():
957
+ sep = os.path.sep
958
+ cur_working_dir = Path.cwd()
959
+ cur_working_dir_str = str(cur_working_dir)
960
+ if cur_working_dir_str.startswith(sep):
961
+ cur_working_dir_str = FileReader.expand_user_path(cur_working_dir_str)
962
+ paths = cur_working_dir_str.split(sep)
963
+ prev_paths = sep.join(paths[:-1])
964
+ last_path = paths[-1]
965
+ if len(paths) > 1:
966
+ return f"{prev_paths}{sep}<b>{last_path}</b>"
967
+ return f"<b>{last_path}</b>"
968
+
969
+
970
+ def _build_prompt(git_branch, conda_env, cur_working_dir_str):
971
+ if git_branch:
972
+ return HTML(
973
+ f"<ansiblue>{cur_working_dir_str}</ansiblue><ansiwhite> on git:[</ansiwhite>"
974
+ f"<ansigreen>{git_branch}</ansigreen><ansiwhite>]</ansiwhite> <ansiwhite>conda:[</ansiwhite>"
975
+ f"<ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ></ansicyan> "
976
+ )
977
+ return HTML(
978
+ f"<ansiblue>{cur_working_dir_str}</ansiblue> <ansiwhite>conda:[</ansiwhite>"
979
+ f"<ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ></ansicyan> "
980
+ )
981
+
982
+
983
+ def _handle_exit(state):
984
+ if state.stateful_conversation is not None:
985
+ save_conversation()
986
+ if not state.use_styles:
987
+ UserMessage("Goodbye!", style="extensity")
988
+ else:
989
+ func = _shell_state.function_type("Give short goodbye")
990
+ UserMessage(func("bye"), style="extensity")
991
+ os._exit(0)
992
+
777
993
 
778
994
  # Function to listen for user input and execute commands
779
- def listen(session: PromptSession, word_comp: WordCompleter, auto_query_on_error: bool=False, verbose: bool=False):
995
+ def listen(
996
+ session: PromptSession,
997
+ word_comp: WordCompleter,
998
+ auto_query_on_error: bool = False,
999
+ verbose: bool = False,
1000
+ ):
1001
+ state = _shell_state
780
1002
  with patch_stdout():
781
1003
  while True:
782
1004
  try:
783
1005
  git_branch = get_git_branch()
784
1006
  conda_env = get_conda_env()
785
- # get directory from the shell
786
- cur_working_dir = os.getcwd()
787
- sep = os.path.sep
788
- if cur_working_dir.startswith(sep):
789
- cur_working_dir = FileReader.expand_user_path(cur_working_dir)
790
- paths = cur_working_dir.split(sep)
791
- prev_paths = sep.join(paths[:-1])
792
- last_path = paths[-1]
793
-
794
- # Format the prompt
795
- if len(paths) > 1:
796
- cur_working_dir = f'{prev_paths}{sep}<b>{last_path}</b>'
797
- else:
798
- cur_working_dir = f'<b>{last_path}</b>'
799
-
800
- if git_branch:
801
- prompt = HTML(f"<ansiblue>{cur_working_dir}</ansiblue><ansiwhite> on git:[</ansiwhite><ansigreen>{git_branch}</ansigreen><ansiwhite>]</ansiwhite> <ansiwhite>conda:[</ansiwhite><ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ❯</ansicyan> ")
802
- else:
803
- prompt = HTML(f"<ansiblue>{cur_working_dir}</ansiblue> <ansiwhite>conda:[</ansiwhite><ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ❯</ansicyan> ")
804
-
805
- # Read user input
1007
+ cur_working_dir_str = _format_working_directory()
1008
+ prompt = _build_prompt(git_branch, conda_env, cur_working_dir_str)
806
1009
  cmd = session.prompt(prompt)
807
- if cmd.strip() == '':
1010
+ if cmd.strip() == "":
808
1011
  continue
809
1012
 
810
- if cmd == 'quit' or cmd == 'exit' or cmd == 'q':
811
- if stateful_conversation is not None:
812
- save_conversation()
813
- if not use_styles:
814
- print('Goodbye!')
815
- else:
816
- func = FunctionType('Give short goodbye')
817
- print(func('bye'))
818
- os._exit(0)
819
- else:
820
- msg = process_command(cmd, auto_query_on_error=auto_query_on_error)
821
- if msg is not None:
822
- with ConsoleStyle('code') as console:
823
- console.print(msg)
1013
+ if _is_exit_command(cmd):
1014
+ _handle_exit(state)
1015
+ msg = process_command(cmd, auto_query_on_error=auto_query_on_error)
1016
+ if msg is not None:
1017
+ with ConsoleStyle("code") as console:
1018
+ console.print(msg)
824
1019
 
825
1020
  # Append the command to the word completer list
826
1021
  word_comp.words.append(cmd)
827
1022
 
828
1023
  except KeyboardInterrupt:
829
- print()
830
- pass
831
-
1024
+ UserMessage("", style="alert")
832
1025
  except Exception as e:
833
- print(e)
1026
+ UserMessage(str(e), style="alert")
834
1027
  if verbose:
835
1028
  traceback.print_exc()
836
- pass
1029
+
837
1030
 
838
1031
  def create_session(history, merged_completer):
839
- colors = SYMSH_CONFIG['colors']
1032
+ colors = SYMSH_CONFIG["colors"]
840
1033
 
841
1034
  # Load style
842
1035
  style = Style.from_dict(colors)
843
1036
 
844
1037
  # Session for the auto-completion
845
- session = PromptSession(history=history,
846
- completer=merged_completer,
847
- complete_style=CompleteStyle.MULTI_COLUMN,
848
- reserve_space_for_menu=5,
849
- style=style,
850
- key_bindings=bindings)
1038
+ return PromptSession(
1039
+ history=history,
1040
+ completer=merged_completer,
1041
+ complete_style=CompleteStyle.MULTI_COLUMN,
1042
+ reserve_space_for_menu=5,
1043
+ style=style,
1044
+ key_bindings=bindings,
1045
+ )
851
1046
 
852
- return session
853
1047
 
854
1048
  def create_completer():
855
1049
  # Load history
@@ -863,32 +1057,46 @@ def create_completer():
863
1057
  merged_completer = MergedCompleter(custom_completer, word_comp)
864
1058
  return history, word_comp, merged_completer
865
1059
 
1060
+
866
1061
  def run(auto_query_on_error=False, conversation_style=None, verbose=False):
867
- global FunctionType, ConversationType, RetrievalConversationType, use_styles
868
- if conversation_style is not None and conversation_style != '':
869
- print('Loading style:', conversation_style)
1062
+ state = _shell_state
1063
+ if conversation_style is not None and conversation_style != "":
1064
+ UserMessage(f"Loading style: {conversation_style}", style="extensity")
870
1065
  styles_ = Import.load_module_class(conversation_style)
871
- FunctionType, ConversationType, RetrievalConversationType = styles_
872
- use_styles = True
873
-
874
- if SYMSH_CONFIG['show-splash-screen']:
1066
+ (
1067
+ state.function_type,
1068
+ state.conversation_type,
1069
+ state.retrieval_conversation_type,
1070
+ ) = styles_
1071
+ state.use_styles = True
1072
+
1073
+ if SYMSH_CONFIG["show-splash-screen"]:
875
1074
  show_intro_menu()
876
1075
  # set show splash screen to false
877
- SYMSH_CONFIG['show-splash-screen'] = False
1076
+ SYMSH_CONFIG["show-splash-screen"] = False
878
1077
  # save config
879
- _config_path = HOME_PATH / 'symsh.config.json'
880
- with open(_config_path, 'w') as f:
1078
+ _config_path = HOME_PATH / "symsh.config.json"
1079
+ with _config_path.open("w") as f:
881
1080
  json.dump(SYMSH_CONFIG, f, indent=4)
882
- if 'plugin_prefix' not in SYMSH_CONFIG:
883
- SYMSH_CONFIG['plugin_prefix'] = None
1081
+ if "plugin_prefix" not in SYMSH_CONFIG:
1082
+ SYMSH_CONFIG["plugin_prefix"] = None
884
1083
 
885
1084
  history, word_comp, merged_completer = create_completer()
886
1085
  session = create_session(history, merged_completer)
887
1086
  listen(session, word_comp, auto_query_on_error=auto_query_on_error, verbose=verbose)
888
1087
 
889
- if __name__ == '__main__':
890
- parser = argparse.ArgumentParser(description='SymSH: Symbolic Shell')
891
- parser.add_argument('--auto-query-on-error', action='store_true', help='Automatically query the language model on error.')
892
- parser.add_argument('--verbose', action='store_true', help='Print verbose errors.')
1088
+
1089
+ if __name__ == "__main__":
1090
+ parser = argparse.ArgumentParser(description="SymSH: Symbolic Shell")
1091
+ parser.add_argument(
1092
+ "--auto-query-on-error",
1093
+ action="store_true",
1094
+ help="Automatically query the language model on error.",
1095
+ )
1096
+ parser.add_argument("--verbose", action="store_true", help="Print verbose errors.")
893
1097
  args = parser.parse_args()
894
- run(auto_query_on_error=args.auto_query_on_error, conversation_style=args.conversation_style, verbose=args.verbose)
1098
+ run(
1099
+ auto_query_on_error=args.auto_query_on_error,
1100
+ conversation_style=args.conversation_style,
1101
+ verbose=args.verbose,
1102
+ )