symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. symai/__init__.py +96 -64
  2. symai/backend/base.py +93 -80
  3. symai/backend/engines/drawing/engine_bfl.py +12 -11
  4. symai/backend/engines/drawing/engine_gpt_image.py +108 -87
  5. symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
  6. symai/backend/engines/embedding/engine_openai.py +3 -5
  7. symai/backend/engines/execute/engine_python.py +6 -5
  8. symai/backend/engines/files/engine_io.py +74 -67
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
  11. symai/backend/engines/index/engine_pinecone.py +23 -24
  12. symai/backend/engines/index/engine_vectordb.py +16 -14
  13. symai/backend/engines/lean/engine_lean4.py +38 -34
  14. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  15. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
  17. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
  18. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
  19. symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
  20. symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
  21. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
  22. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
  23. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
  24. symai/backend/engines/ocr/engine_apilayer.py +6 -8
  25. symai/backend/engines/output/engine_stdout.py +1 -4
  26. symai/backend/engines/search/engine_openai.py +7 -7
  27. symai/backend/engines/search/engine_perplexity.py +5 -5
  28. symai/backend/engines/search/engine_serpapi.py +12 -14
  29. symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
  30. symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
  31. symai/backend/engines/text_to_speech/engine_openai.py +5 -7
  32. symai/backend/engines/text_vision/engine_clip.py +7 -11
  33. symai/backend/engines/userinput/engine_console.py +3 -3
  34. symai/backend/engines/webscraping/engine_requests.py +81 -48
  35. symai/backend/mixin/__init__.py +13 -0
  36. symai/backend/mixin/anthropic.py +4 -2
  37. symai/backend/mixin/deepseek.py +2 -0
  38. symai/backend/mixin/google.py +2 -0
  39. symai/backend/mixin/openai.py +11 -3
  40. symai/backend/settings.py +83 -16
  41. symai/chat.py +101 -78
  42. symai/collect/__init__.py +7 -1
  43. symai/collect/dynamic.py +77 -69
  44. symai/collect/pipeline.py +35 -27
  45. symai/collect/stats.py +75 -63
  46. symai/components.py +198 -169
  47. symai/constraints.py +15 -12
  48. symai/core.py +698 -359
  49. symai/core_ext.py +32 -34
  50. symai/endpoints/api.py +80 -73
  51. symai/extended/.DS_Store +0 -0
  52. symai/extended/__init__.py +46 -12
  53. symai/extended/api_builder.py +11 -8
  54. symai/extended/arxiv_pdf_parser.py +13 -12
  55. symai/extended/bibtex_parser.py +2 -3
  56. symai/extended/conversation.py +101 -90
  57. symai/extended/document.py +17 -10
  58. symai/extended/file_merger.py +18 -13
  59. symai/extended/graph.py +18 -13
  60. symai/extended/html_style_template.py +2 -4
  61. symai/extended/interfaces/blip_2.py +1 -2
  62. symai/extended/interfaces/clip.py +1 -2
  63. symai/extended/interfaces/console.py +7 -1
  64. symai/extended/interfaces/dall_e.py +1 -1
  65. symai/extended/interfaces/flux.py +1 -1
  66. symai/extended/interfaces/gpt_image.py +1 -1
  67. symai/extended/interfaces/input.py +1 -1
  68. symai/extended/interfaces/llava.py +0 -1
  69. symai/extended/interfaces/naive_vectordb.py +7 -8
  70. symai/extended/interfaces/naive_webscraping.py +1 -1
  71. symai/extended/interfaces/ocr.py +1 -1
  72. symai/extended/interfaces/pinecone.py +6 -5
  73. symai/extended/interfaces/serpapi.py +1 -1
  74. symai/extended/interfaces/terminal.py +2 -3
  75. symai/extended/interfaces/tts.py +1 -1
  76. symai/extended/interfaces/whisper.py +1 -1
  77. symai/extended/interfaces/wolframalpha.py +1 -1
  78. symai/extended/metrics/__init__.py +11 -1
  79. symai/extended/metrics/similarity.py +11 -13
  80. symai/extended/os_command.py +17 -16
  81. symai/extended/packages/__init__.py +29 -3
  82. symai/extended/packages/symdev.py +19 -16
  83. symai/extended/packages/sympkg.py +12 -9
  84. symai/extended/packages/symrun.py +21 -19
  85. symai/extended/repo_cloner.py +11 -10
  86. symai/extended/seo_query_optimizer.py +1 -2
  87. symai/extended/solver.py +20 -23
  88. symai/extended/summarizer.py +4 -3
  89. symai/extended/taypan_interpreter.py +10 -12
  90. symai/extended/vectordb.py +99 -82
  91. symai/formatter/__init__.py +9 -1
  92. symai/formatter/formatter.py +12 -16
  93. symai/formatter/regex.py +62 -63
  94. symai/functional.py +176 -122
  95. symai/imports.py +136 -127
  96. symai/interfaces.py +56 -27
  97. symai/memory.py +14 -13
  98. symai/misc/console.py +49 -39
  99. symai/misc/loader.py +5 -3
  100. symai/models/__init__.py +17 -1
  101. symai/models/base.py +269 -181
  102. symai/models/errors.py +0 -1
  103. symai/ops/__init__.py +32 -22
  104. symai/ops/measures.py +11 -15
  105. symai/ops/primitives.py +348 -228
  106. symai/post_processors.py +32 -28
  107. symai/pre_processors.py +39 -41
  108. symai/processor.py +6 -4
  109. symai/prompts.py +59 -45
  110. symai/server/huggingface_server.py +23 -20
  111. symai/server/llama_cpp_server.py +7 -5
  112. symai/shell.py +3 -4
  113. symai/shellsv.py +499 -375
  114. symai/strategy.py +517 -287
  115. symai/symbol.py +111 -116
  116. symai/utils.py +42 -36
  117. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
  118. symbolicai-1.0.0.dist-info/RECORD +163 -0
  119. symbolicai-0.20.2.dist-info/RECORD +0 -162
  120. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
  121. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
  122. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  123. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/shellsv.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import glob
3
2
  import json
4
3
  import logging
5
4
  import os
@@ -10,8 +9,9 @@ import subprocess
10
9
  import sys
11
10
  import time
12
11
  import traceback
12
+ from collections.abc import Iterable
13
13
  from pathlib import Path
14
- from typing import Iterable, Tuple
14
+ from types import SimpleNamespace
15
15
 
16
16
  #@TODO: refactor to use rich instead of prompt_toolkit
17
17
  from prompt_toolkit import HTML, PromptSession, print_formatted_text
@@ -19,23 +19,27 @@ from prompt_toolkit.completion import Completer, Completion, WordCompleter
19
19
  from prompt_toolkit.history import History
20
20
  from prompt_toolkit.key_binding import KeyBindings
21
21
  from prompt_toolkit.keys import Keys
22
- from prompt_toolkit.lexers import PygmentsLexer
23
22
  from prompt_toolkit.patch_stdout import patch_stdout
24
23
  from prompt_toolkit.shortcuts import CompleteStyle, ProgressBar
25
24
  from prompt_toolkit.styles import Style
26
- from pygments.lexers.shell import BashLexer
27
25
 
28
26
  from .backend.settings import HOME_PATH, SYMSH_CONFIG
29
27
  from .components import FileReader, Function, Indexer
30
- from .extended import (ArxivPdfParser, Conversation, DocumentRetriever,
31
- FileMerger, RepositoryCloner,
32
- RetrievalAugmentedConversation)
28
+ from .extended import (
29
+ ArxivPdfParser,
30
+ Conversation,
31
+ DocumentRetriever,
32
+ FileMerger,
33
+ RepositoryCloner,
34
+ RetrievalAugmentedConversation,
35
+ )
33
36
  from .imports import Import
34
37
  from .interfaces import Interface
35
38
  from .menu.screen import show_intro_menu
36
39
  from .misc.console import ConsoleStyle
37
40
  from .misc.loader import Loader
38
41
  from .symbol import Symbol
42
+ from .utils import UserMessage
39
43
 
40
44
  logging.getLogger("prompt_toolkit").setLevel(logging.ERROR)
41
45
  logging.getLogger("asyncio").setLevel(logging.ERROR)
@@ -43,14 +47,14 @@ logging.getLogger("subprocess").setLevel(logging.ERROR)
43
47
 
44
48
  # load json config from home directory root
45
49
  home_path = HOME_PATH
46
- config_path = os.path.join(home_path, 'symsh.config.json')
50
+ config_path = home_path / 'symsh.config.json'
47
51
  # migrate config from old path
48
52
  if 'colors' not in SYMSH_CONFIG:
49
53
  __new_config__ = {"colors": SYMSH_CONFIG}
50
54
  # add command in config
51
55
  SYMSH_CONFIG = __new_config__
52
56
  # save config
53
- with open(config_path, 'w') as f:
57
+ with config_path.open('w') as f:
54
58
  json.dump(__new_config__, f, indent=4)
55
59
 
56
60
  # make sure map-nt-cmd is in config
@@ -58,15 +62,22 @@ if 'map-nt-cmd' not in SYMSH_CONFIG:
58
62
  # add command in config
59
63
  SYMSH_CONFIG['map-nt-cmd'] = True
60
64
  # save config
61
- with open(config_path, 'w') as f:
65
+ with config_path.open('w') as f:
62
66
  json.dump(SYMSH_CONFIG, f, indent=4)
63
67
 
64
- print = print_formatted_text
65
- FunctionType = Function
66
- ConversationType = Conversation
67
- RetrievalConversationType = RetrievalAugmentedConversation
68
- use_styles = False
69
- map_nt_cmd_enabled = SYMSH_CONFIG['map-nt-cmd']
68
+ print = print_formatted_text # noqa
69
+ map_nt_cmd_enabled = SYMSH_CONFIG['map-nt-cmd']
70
+
71
+ _shell_state = SimpleNamespace(
72
+ function_type=Function,
73
+ conversation_type=Conversation,
74
+ retrieval_conversation_type=RetrievalAugmentedConversation,
75
+ use_styles=False,
76
+ stateful_conversation=None,
77
+ previous_kwargs=None,
78
+ previous_prefix=None,
79
+ exec_prefix="default",
80
+ )
70
81
 
71
82
  SHELL_CONTEXT = """[Description]
72
83
  This shell program is the command interpreter on the Linux systems, MacOS and Windows PowerShell.
@@ -80,9 +91,6 @@ If additional instructions are provided the follow the user query to produce the
80
91
  A well related and helpful answer with suggested improvements is preferred over "I don't know" or "I don't understand" answers or stating the obvious.
81
92
  """
82
93
 
83
- stateful_conversation = None
84
- previous_kwargs = None
85
-
86
94
  def supports_ansi_escape():
87
95
  try:
88
96
  os.get_terminal_size(0)
@@ -91,14 +99,22 @@ def supports_ansi_escape():
91
99
  return False
92
100
 
93
101
  class PathCompleter(Completer):
94
- def get_completions(self, document, complete_event):
102
+ def get_completions(self, document, _complete_event):
95
103
  complete_word = document.get_word_before_cursor(WORD=True)
96
104
  sep = os.path.sep
97
105
  if complete_word.startswith(f'~{sep}'):
98
106
  complete_word = FileReader.expand_user_path(complete_word)
99
107
 
100
108
  # list all files and directories in current directory
101
- files = list(glob.glob(complete_word + '*'))
109
+ complete_path = Path(complete_word)
110
+ if complete_word.endswith(sep):
111
+ parent = complete_path
112
+ pattern = '*'
113
+ else:
114
+ baseline = Path()
115
+ parent = complete_path.parent if complete_path.parent != baseline else baseline
116
+ pattern = f"{complete_path.name}*" if complete_path.name else '*'
117
+ files = [str(path) for path in parent.glob(pattern)]
102
118
  if len(files) == 0:
103
119
  return None
104
120
 
@@ -106,6 +122,7 @@ class PathCompleter(Completer):
106
122
  files_ = []
107
123
 
108
124
  for file in files:
125
+ path_obj = Path(file)
109
126
  # split the command into words by space (ignore escaped spaces)
110
127
  command_words = document.text.split(' ')
111
128
  if len(command_words) > 1:
@@ -115,28 +132,27 @@ class PathCompleter(Completer):
115
132
  else:
116
133
  start_position = len(document.text)
117
134
  # if there is a space in the file name, then escape it
118
- if ' ' in file:
119
- file = file.replace(' ', '\\ ')
120
- if (document.text.startswith('cd') or document.text.startswith('mkdir')) and os.path.isfile(file):
135
+ display_name = file.replace(' ', '\\ ') if ' ' in file else file
136
+ if (document.text.startswith('cd') or document.text.startswith('mkdir')) and path_obj.is_file():
121
137
  continue
122
- if os.path.isdir(file):
123
- dirs_.append(file)
138
+ if path_obj.is_dir():
139
+ dirs_.append(display_name)
124
140
  else:
125
- files_.append(file)
141
+ files_.append(display_name)
126
142
 
127
143
  for d in dirs_:
128
144
  # if starts with home directory, then replace it with ~
129
- d = FileReader.expand_user_path(d)
130
- yield Completion(d, start_position=-start_position,
131
- style='class:path-completion',
132
- selected_style='class:path-completion-selected')
145
+ directory_completion = FileReader.expand_user_path(d)
146
+ yield Completion(directory_completion, start_position=-start_position,
147
+ style='class:path-completion',
148
+ selected_style='class:path-completion-selected')
133
149
 
134
150
  for f in files_:
135
151
  # if starts with home directory, then replace it with ~
136
- f = FileReader.expand_user_path(f)
137
- yield Completion(f, start_position=-start_position,
138
- style='class:file-completion',
139
- selected_style='class:file-completion-selected')
152
+ file_completion = FileReader.expand_user_path(f)
153
+ yield Completion(file_completion, start_position=-start_position,
154
+ style='class:file-completion',
155
+ selected_style='class:file-completion-selected')
140
156
 
141
157
  class HistoryCompleter(WordCompleter):
142
158
  def get_completions(self, document, complete_event):
@@ -182,25 +198,23 @@ class MergedCompleter(Completer):
182
198
 
183
199
  # Create custom keybindings
184
200
  bindings = KeyBindings()
185
- previous_prefix = None
186
- exec_prefix = 'default'
187
201
  # Get a copy of the current environment
188
202
  default_env = os.environ.copy()
189
203
 
190
204
  def get_exec_prefix():
205
+ exec_prefix = _shell_state.exec_prefix
191
206
  return sys.exec_prefix if exec_prefix == 'default' else exec_prefix
192
207
 
193
208
  def get_conda_env():
194
209
  # what conda env am I in (e.g., where is my Python process from)?
195
210
  ENVBIN = get_exec_prefix()
196
- env_name = os.path.basename(ENVBIN)
197
- return env_name
211
+ return Path(ENVBIN).name
198
212
 
199
213
  # bind to 'Ctrl' + 'Space'
200
214
  @bindings.add(Keys.ControlSpace)
201
215
  def _(event):
202
216
  current_user_input = event.current_buffer.document.text
203
- func = FunctionType(SHELL_CONTEXT)
217
+ func = _shell_state.function_type(SHELL_CONTEXT)
204
218
 
205
219
  bottom_toolbar = HTML(' <b>[f]</b> Print "f" <b>[x]</b> Abort.')
206
220
 
@@ -209,30 +223,29 @@ def _(event):
209
223
 
210
224
  cancel = [False]
211
225
  @kb.add('f')
212
- def _(event):
213
- print('You pressed `f`.')
226
+ def _(_event):
227
+ UserMessage('You pressed `f`.', style="alert")
214
228
 
215
229
  @kb.add('x')
216
- def _(event):
230
+ def _(_event):
217
231
  " Send Abort (control-c) signal. "
218
232
  cancel[0] = True
219
233
  os.kill(os.getpid(), signal.SIGINT)
220
234
 
221
235
  # Use `patch_stdout`, to make sure that prints go above the
222
236
  # application.
223
- with patch_stdout():
224
- with ProgressBar(key_bindings=kb, bottom_toolbar=bottom_toolbar) as pb:
225
- # TODO: hack to simulate progress bar of indeterminate length of an synchronous function
226
- for i in pb(range(100)):
227
- if i > 50 and i < 70:
228
- time.sleep(.01)
237
+ with patch_stdout(), ProgressBar(key_bindings=kb, bottom_toolbar=bottom_toolbar) as pb:
238
+ # TODO: hack to simulate progress bar of indeterminate length of an synchronous function
239
+ for i in pb(range(100)):
240
+ if i > 50 and i < 70:
241
+ time.sleep(.01)
229
242
 
230
- if i == 60:
231
- res = func(current_user_input) # hack to see progress bar
243
+ if i == 60:
244
+ res = func(current_user_input) # hack to see progress bar
232
245
 
233
- # Stop when the cancel flag has been set.
234
- if cancel[0]:
235
- break
246
+ # Stop when the cancel flag has been set.
247
+ if cancel[0]:
248
+ break
236
249
 
237
250
  with ConsoleStyle('code') as console:
238
251
  console.print(res)
@@ -255,14 +268,14 @@ class FileHistory(History):
255
268
  '''
256
269
 
257
270
  def __init__(self, filename: str) -> None:
258
- self.filename = filename
271
+ self.filename = Path(filename)
259
272
  super().__init__()
260
273
 
261
274
  def load_history_strings(self) -> Iterable[str]:
262
275
  lines: list[str] = []
263
276
 
264
- if os.path.exists(self.filename):
265
- with open(self.filename, "r") as f:
277
+ if self.filename.exists():
278
+ with self.filename.open() as f:
266
279
  lines = f.readlines()
267
280
  # Remove comments and empty lines.
268
281
  lines = [line for line in lines if line.strip() and not line.startswith("#")]
@@ -276,17 +289,17 @@ class FileHistory(History):
276
289
 
277
290
  def store_string(self, string: str) -> None:
278
291
  # Save to file.
279
- with open(self.filename, "ab") as f:
292
+ with self.filename.open("ab") as f:
280
293
 
281
294
  def write(t: str) -> None:
282
295
  f.write(t.encode("utf-8"))
283
296
 
284
297
  for line in string.split("\n"):
285
- write("%s\n" % line)
298
+ write(f"{line}\n")
286
299
 
287
300
  # Defining commands history
288
301
  def load_history(home_path=HOME_PATH, history_file='.bash_history'):
289
- history_file_path = os.path.join(home_path, history_file)
302
+ history_file_path = home_path / history_file
290
303
  history = FileHistory(history_file_path)
291
304
  return history, list(history.load_history_strings())
292
305
 
@@ -294,14 +307,14 @@ def load_history(home_path=HOME_PATH, history_file='.bash_history'):
294
307
  def get_git_branch():
295
308
  try:
296
309
  git_process = subprocess.Popen(['git', 'rev-parse', '--abbrev-ref', 'HEAD'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
297
- stdout, stderr = git_process.communicate()
310
+ stdout, _stderr = git_process.communicate()
298
311
  if git_process.returncode == 0:
299
312
  return stdout.strip().decode('utf-8')
300
313
  except FileNotFoundError:
301
314
  pass
302
315
  return None
303
316
 
304
- def disambiguate(cmds: str) -> Tuple[str, int]:
317
+ def disambiguate(cmds: str) -> tuple[str, int]:
305
318
  '''
306
319
  Ok, so, possible options for now:
307
320
  1. query | cmd
@@ -311,140 +324,199 @@ def disambiguate(cmds: str) -> Tuple[str, int]:
311
324
  4. query | cmd cmd ...
312
325
  5. query | file | cmd
313
326
  '''
314
- has_at_least_one_cmd = any([shutil.which(cmd) is not None for cmd in cmds.split(' ')])
327
+ has_at_least_one_cmd = any(shutil.which(cmd) is not None for cmd in cmds.split(' '))
315
328
  maybe_cmd = cmds.split(' ')[0].strip() # get first command
316
329
  maybe_files = FileReader.extract_files(cmds)
317
330
  # if cmd follows file(s) or file(s) follows cmd throw error as not supported
318
331
  if maybe_files is not None and has_at_least_one_cmd:
319
- raise ValueError('Cannot disambiguate commands that have both files and commands or multiple commands. Please provide correct order of commands. '
320
- 'Supported are: '
321
- 'query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
322
- 'and '
323
- 'query | cmd (e.g. "what flags can I use with rg?" | rg --help)')
332
+ msg = (
333
+ 'Cannot disambiguate commands that have both files and commands or multiple commands. Please provide '
334
+ 'correct order of commands. Supported are: query | file [file ...] (e.g. "what do these files have in '
335
+ 'common?" | file1 [file2 ...]) and query | cmd (e.g. "what flags can I use with rg?" | rg --help)'
336
+ )
337
+ UserMessage(msg, raise_with=ValueError)
324
338
  # now check order of commands and keep correct order
325
339
  if shutil.which(maybe_cmd) is not None:
326
- cmd_out = subprocess.run(cmds, capture_output=True, text=True, shell=True)
340
+ cmd_out = subprocess.run(cmds, check=False, capture_output=True, text=True, shell=True)
327
341
  if not cmd_out.stdout:
328
- raise ValueError(f'Command not found or failed. Error: {cmd_out.stderr}')
342
+ msg = f'Command not found or failed. Error: {cmd_out.stderr}'
343
+ UserMessage(msg, raise_with=ValueError)
329
344
  return cmd_out.stdout, 1
330
345
  if maybe_files is not None:
331
346
  return maybe_files, 2
347
+ return None
332
348
 
333
349
  # query language model
350
+ def _starts_with_prefix(query: str, prefix: str) -> bool:
351
+ return (
352
+ query.startswith(f'{prefix}"')
353
+ or query.startswith(f"{prefix}'")
354
+ or query.startswith(f'{prefix}`')
355
+ )
356
+
357
+
358
+ def _is_new_conversation_query(query: str) -> bool:
359
+ return _starts_with_prefix(query, '!')
360
+
361
+
362
+ def _is_followup_conversation_query(query: str) -> bool:
363
+ return _starts_with_prefix(query, '.')
364
+
365
+
366
+ def _is_stateful_query(query: str) -> bool:
367
+ return any(_starts_with_prefix(query, prefix) for prefix in ['.', '!'])
368
+
369
+
370
+ def _extract_query_kwargs(query: str, previous_kwargs, existing_kwargs):
371
+ if '--kwargs' not in query and '-kw' not in query:
372
+ return query, existing_kwargs, previous_kwargs
373
+
374
+ splitter = '--kwargs' if '--kwargs' in query else '-kw'
375
+ splits = query.split(splitter)
376
+ suffix = splits[-1]
377
+ if previous_kwargs is None and '=' not in suffix and ',' not in suffix:
378
+ msg = 'Kwargs format must be last in query.'
379
+ UserMessage(msg, raise_with=ValueError)
380
+ if previous_kwargs is not None and '=' not in suffix and ',' not in suffix:
381
+ cmd_kwargs = previous_kwargs
382
+ else:
383
+ query = splits[0].strip()
384
+ kwargs_str = suffix.strip()
385
+ cmd_kwargs = dict([kw.split('=') for kw in kwargs_str.split(',')])
386
+ cmd_kwargs = {k.strip(): Symbol(v.strip()).ast() for k, v in cmd_kwargs.items()}
387
+
388
+ previous_kwargs = cmd_kwargs
389
+ merged_kwargs = {**existing_kwargs, **cmd_kwargs}
390
+ return query, merged_kwargs, previous_kwargs
391
+
392
+
393
+ def _process_new_conversation(query, conversation_cls, symai_path, plugin, previous_kwargs, state):
394
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
395
+ conversation = conversation_cls(auto_print=False)
396
+ conversation_cls.save_conversation_state(conversation, symai_path)
397
+ state.stateful_conversation = conversation
398
+ if plugin is None:
399
+ return conversation, previous_kwargs, None, False
400
+ with Loader(desc="Inference ...", end=""):
401
+ cmd = query[1:].strip('\'"')
402
+ cmd = f"symrun {plugin} '{cmd}' --disable-pbar"
403
+ cmd_out = run_shell_command(cmd, auto_query_on_error=True)
404
+ conversation.store(cmd_out)
405
+ conversation_cls.save_conversation_state(conversation, symai_path)
406
+ state.stateful_conversation = conversation
407
+ state.previous_kwargs = previous_kwargs
408
+ return conversation, previous_kwargs, cmd_out, True
409
+
410
+
411
+ def _process_followup_conversation(query, conversation, conversation_cls, symai_path, plugin, previous_kwargs, state):
412
+ try:
413
+ conversation = conversation.load_conversation_state(symai_path)
414
+ state.stateful_conversation = conversation
415
+ except Exception:
416
+ with ConsoleStyle('error') as console:
417
+ console.print('No conversation state found. Please start a new conversation.')
418
+ return conversation, previous_kwargs, None, True
419
+ if plugin is None:
420
+ return conversation, previous_kwargs, None, False
421
+ with Loader(desc="Inference ...", end=""):
422
+ trimmed_query = query[1:].strip('\'"')
423
+ answer = conversation(trimmed_query).value
424
+ cmd = f"symrun {plugin} '{answer}' --disable-pbar"
425
+ cmd_out = run_shell_command(cmd, auto_query_on_error=True)
426
+ conversation.store(cmd_out)
427
+ conversation_cls.save_conversation_state(conversation, symai_path)
428
+ state.stateful_conversation = conversation
429
+ state.previous_kwargs = previous_kwargs
430
+ return conversation, previous_kwargs, cmd_out, True
431
+
432
+
433
+ def _handle_piped_query(query, conversation, state):
434
+ cmds = query.split('|')
435
+ if len(cmds) > 2:
436
+ msg = (
437
+ 'Cannot disambiguate commands that have more than 1 pipes. Please provide correct order of commands. '
438
+ 'Supported are: query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
439
+ 'and query | cmd (e.g. "what flags can I use with rg?" | rg --help)'
440
+ )
441
+ UserMessage(msg, raise_with=ValueError)
442
+ base_query = cmds[0]
443
+ payload, order = disambiguate(cmds[1].strip())
444
+ is_stateful = _is_stateful_query(base_query)
445
+ if is_stateful:
446
+ func = conversation
447
+ else:
448
+ func = (
449
+ state.function_type(payload)
450
+ if order == 1
451
+ else state.conversation_type(file_link=payload, auto_print=False)
452
+ )
453
+ if is_stateful:
454
+ if order == 1:
455
+ func.store(payload)
456
+ elif order == 2:
457
+ for file in payload:
458
+ func.store_file(file)
459
+ return func, base_query
460
+
461
+
462
+ def _select_function_for_query(query, conversation, state):
463
+ if '|' in query:
464
+ return _handle_piped_query(query, conversation, state)
465
+ if _is_stateful_query(query):
466
+ return conversation, query
467
+ return state.function_type(SHELL_CONTEXT), query
468
+
469
+
470
+ def _should_save_conversation(conversation, query):
471
+ if conversation is None:
472
+ return False
473
+ return _is_stateful_query(query)
474
+
475
+
334
476
  def query_language_model(query: str, res=None, *args, **kwargs):
335
- global stateful_conversation, previous_kwargs
477
+ state = _shell_state
478
+ conversation = state.stateful_conversation
479
+ previous_kwargs = state.previous_kwargs
480
+ conversation_cls = state.conversation_type
336
481
  home_path = HOME_PATH
337
- symai_path = os.path.join(home_path, '.conversation_state')
482
+ symai_path = home_path / '.conversation_state'
338
483
  plugin = SYMSH_CONFIG.get('plugin_prefix')
339
484
 
340
- # check and extract kwargs from query if any
341
- # format --kwargs key1=value1,key2=value2,key3=value3,...keyN=valueN
342
- if '--kwargs' in query or '-kw' in query:
343
- splitter = '--kwargs' if '--kwargs' in query else '-kw'
344
- # check if kwargs format is last in query otherwise raise error
345
- splits = query.split(f'{splitter}')
346
- if previous_kwargs is None and '=' not in splits[-1] and ',' not in splits[-1]:
347
- raise ValueError('Kwargs format must be last in query.')
348
- elif previous_kwargs is not None and '=' not in splits[-1] and ',' not in splits[-1]:
349
- # use previous kwargs
350
- cmd_kwargs = previous_kwargs
351
- else:
352
- # remove kwargs from query
353
- query = splits[0].strip()
354
- # extract kwargs
355
- kwargs_str = splits[-1].strip()
356
- cmd_kwargs = dict([kw.split('=') for kw in kwargs_str.split(',')])
357
- cmd_kwargs = {k.strip(): Symbol(v.strip()).ast() for k, v in cmd_kwargs.items()}
358
-
359
- previous_kwargs = cmd_kwargs
360
- # unpack cmd_kwargs to kwargs
361
- kwargs = {**kwargs, **cmd_kwargs}
362
-
363
- # Handle stateful conversations:
364
- # 1. If query starts with !" (new conversation), create new conversation state
365
- # 2. If query starts with ." (follow-up), either:
366
- # - Load existing conversation state if it exists
367
- # - Create new conversation state if none exists
368
- if (query.startswith('!"') or query.startswith("!'") or query.startswith('!`')):
369
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
370
- stateful_conversation = ConversationType(auto_print=False)
371
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
372
- # Special case: if query starts with !" and has a prefix, run the prefix command and store the output
373
- if plugin is not None:
374
- with Loader(desc="Inference ...", end=""):
375
- cmd = query[1:].strip('\'"')
376
- cmd = f"symrun {plugin} '{cmd}' --disable-pbar"
377
- cmd_out = run_shell_command(cmd, auto_query_on_error=True)
378
- stateful_conversation.store(cmd_out)
379
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
380
- return cmd_out
381
- elif query.startswith('."') or query.startswith(".'") or query.startswith('.`'):
382
- try:
383
- stateful_conversation = stateful_conversation.load_conversation_state(symai_path)
384
- except Exception:
385
- with ConsoleStyle('error') as console:
386
- console.print('No conversation state found. Please start a new conversation.')
387
- return
388
- if plugin is not None:
389
- with Loader(desc="Inference ...", end=""):
390
- query = query[1:].strip('\'"')
391
- answer = stateful_conversation(query).value
392
- cmd = f"symrun {plugin} '{answer}' --disable-pbar"
393
- cmd_out = run_shell_command(cmd, auto_query_on_error=True)
394
- stateful_conversation.store(cmd_out)
395
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
396
- return cmd_out
397
- cmd = None
398
- if '|' in query:
399
- cmds = query.split('|')
400
- if len(cmds) > 2:
401
- raise ValueError(('Cannot disambiguate commands that have more than 1 pipes. Please provide correct order of commands. '
402
- 'Supported are: '
403
- 'query | file [file ...] (e.g. "what do these files have in common?" | file1 [file2 ...]) '
404
- 'and '
405
- 'query | cmd (e.g. "what flags can I use with rg?" | rg --help)'))
406
- query = cmds[0]
407
- payload, order = disambiguate(cmds[1].strip())
408
- # check if we're in a stateful conversation
409
- is_stateful = query.startswith(('.', '!')) and any(query.startswith(f"{prefix}{quote}")
410
- for prefix in ['.', '!']
411
- for quote in ['"', "'", '`'])
412
-
413
- if is_stateful:
414
- func = stateful_conversation
415
- else:
416
- func = FunctionType(payload) if order == 1 else ConversationType(file_link=payload, auto_print=False)
417
-
418
- if is_stateful:
419
- if order == 1:
420
- func.store(payload)
421
- elif order == 2:
422
- for file in payload:
423
- func.store_file(file)
424
- else:
425
- if query.startswith('."') or query.startswith(".'") or query.startswith('.`') or\
426
- query.startswith('!"') or query.startswith("!'") or query.startswith('!`'):
427
- func = stateful_conversation
428
- else:
429
- func = FunctionType(SHELL_CONTEXT)
430
-
485
+ query, kwargs, previous_kwargs = _extract_query_kwargs(query, previous_kwargs, kwargs)
486
+
487
+ if _is_new_conversation_query(query):
488
+ conversation, previous_kwargs, result, should_return = _process_new_conversation(
489
+ query, conversation_cls, symai_path, plugin, previous_kwargs, state
490
+ )
491
+ if should_return:
492
+ return result
493
+ elif _is_followup_conversation_query(query):
494
+ conversation, previous_kwargs, result, should_return = _process_followup_conversation(
495
+ query, conversation, conversation_cls, symai_path, plugin, previous_kwargs, state
496
+ )
497
+ if should_return:
498
+ return result
499
+ func, query = _select_function_for_query(query, conversation, state)
431
500
  with Loader(desc="Inference ...", end=""):
501
+ query_to_execute = query
502
+ if res is not None:
503
+ query_to_execute = f"[Context]\n{res}\n\n[Query]\n{query}"
504
+ msg = func(query_to_execute, *args, **kwargs)
432
505
  if res is not None:
433
- query = f"[Context]\n{res}\n\n[Query]\n{query}"
434
- msg = func(query, *args, **kwargs)
506
+ query = query_to_execute
435
507
 
436
- if stateful_conversation is not None and (
437
- query.startswith('."') or query.startswith(".'") or query.startswith('.`') or
438
- query.startswith('!"') or query.startswith("!'") or query.startswith('!`')
439
- ):
440
- ConversationType.save_conversation_state(stateful_conversation, symai_path)
508
+ if _should_save_conversation(conversation, query):
509
+ conversation_cls.save_conversation_state(conversation, symai_path)
441
510
 
511
+ state.stateful_conversation = conversation
512
+ state.previous_kwargs = previous_kwargs
442
513
  return msg
443
514
 
444
- def retrieval_augmented_indexing(query: str, index_name = None, *args, **kwargs):
445
- global stateful_conversation
515
+ def retrieval_augmented_indexing(query: str, index_name = None, *_args, **_kwargs):
516
+ state = _shell_state
446
517
  sep = os.path.sep
447
518
  path = query
519
+ home_path = HOME_PATH
448
520
 
449
521
  # check if path contains overwrite flag
450
522
  overwrite = False
@@ -468,16 +540,17 @@ def retrieval_augmented_indexing(query: str, index_name = None, *args, **kwargs)
468
540
  # check if path contains git flag
469
541
  if path.startswith('git@'):
470
542
  overwrite = True
471
- repo_path = os.path.join(home_path, 'temp')
472
- cloner = RepositoryCloner(repo_path=repo_path)
473
- url = path[4:]
474
- if 'http' not in url:
475
- url = 'https://' + url
476
- url = url.replace('.com:', '.com/')
477
- # if ends with '.git' then remove it
478
- if url.endswith('.git'):
479
- url = url[:-4]
480
- path = cloner(url)
543
+ repo_path = home_path / 'temp'
544
+ with Loader(desc="Cloning repo ...", end=""):
545
+ cloner = RepositoryCloner(repo_path=str(repo_path))
546
+ url = path[4:]
547
+ if 'http' not in url:
548
+ url = 'https://' + url
549
+ url = url.replace('.com:', '.com/')
550
+ # if ends with '.git' then remove it
551
+ if url.endswith('.git'):
552
+ url = url[:-4]
553
+ path = cloner(url)
481
554
 
482
555
  # merge files
483
556
  merger = FileMerger()
@@ -491,36 +564,35 @@ def retrieval_augmented_indexing(query: str, index_name = None, *args, **kwargs)
491
564
 
492
565
  index_name = path.split(sep)[-1] if index_name is None else index_name
493
566
  index_name = Indexer.replace_special_chars(index_name)
494
- print(f'Indexing {index_name} ...')
567
+ UserMessage(f'Indexing {index_name} ...', style="extensity")
495
568
 
496
569
  # creates index if not exists
497
570
  DocumentRetriever(index_name=index_name, file=file, overwrite=overwrite)
498
571
 
499
- home_path = HOME_PATH
500
- symai_path = os.path.join(home_path, '.conversation_state')
501
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
502
- stateful_conversation = RetrievalConversationType(auto_print=False, index_name=index_name)
572
+ symai_path = home_path / '.conversation_state'
573
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
574
+ stateful_conversation = state.retrieval_conversation_type(auto_print=False, index_name=index_name)
575
+ state.stateful_conversation = stateful_conversation
503
576
  Conversation.save_conversation_state(stateful_conversation, symai_path)
504
577
  if use_index_name:
505
578
  message = 'New session '
506
579
  else:
507
580
  message = f'Repository {url} cloned and ' if query.startswith('git@') or query.startswith('git:') else f'Directory {path} '
508
- msg = f'{message}successfully indexed: {index_name}'
509
- return msg
581
+ return f'{message}successfully indexed: {index_name}'
510
582
 
511
- def search_engine(query: str, res=None, *args, **kwargs):
583
+ def search_engine(query: str, res=None, *_args, **_kwargs):
512
584
  search = Interface('serpapi')
513
585
  with Loader(desc="Searching ...", end=""):
514
586
  search_query = Symbol(query).extract('search engine optimized query')
515
587
  res = search(search_query)
516
588
  with Loader(desc="Inference ...", end=""):
517
- func = FunctionType(query)
589
+ func = _shell_state.function_type(query)
518
590
  msg = func(res, payload=res)
519
591
  # write a temp dump file with the query and results
520
592
  home_path = HOME_PATH
521
- symai_path = os.path.join(home_path, '.search_dump')
522
- os.makedirs(os.path.dirname(symai_path), exist_ok=True)
523
- with open(symai_path, 'w') as f:
593
+ symai_path = home_path / '.search_dump'
594
+ symai_path.parent.mkdir(parents=True, exist_ok=True)
595
+ with symai_path.open('w') as f:
524
596
  f.write(f'[SEARCH_QUERY]:\n{search_query}\n[RESULTS]\n{res}\n[MESSAGE]\n{msg}')
525
597
  return msg
526
598
 
@@ -528,12 +600,12 @@ def set_default_module(cmd: str):
528
600
  if cmd.startswith('set-plugin'):
529
601
  module = cmd.split('set-plugin')[-1].strip()
530
602
  SYMSH_CONFIG['plugin_prefix'] = module
531
- with open(config_path, 'w') as f:
603
+ with config_path.open('w') as f:
532
604
  json.dump(SYMSH_CONFIG, f, indent=4)
533
605
  msg = f"Default plugin set to '{module}'"
534
606
  elif cmd == 'unset-plugin':
535
607
  SYMSH_CONFIG['plugin_prefix'] = None
536
- with open(config_path, 'w') as f:
608
+ with config_path.open('w') as f:
537
609
  json.dump(SYMSH_CONFIG, f, indent=4)
538
610
  msg = "Default plugin unset"
539
611
  elif cmd == 'get-plugin':
@@ -543,35 +615,33 @@ def set_default_module(cmd: str):
543
615
  console.print(msg)
544
616
 
545
617
  def handle_error(cmd, res, message, auto_query_on_error):
546
- msg = Symbol(cmd) | f'\n{str(res)}'
618
+ msg = Symbol(cmd) | f'\n{res!s}'
547
619
  if 'command not found' in str(res) or 'not recognized as an internal or external command' in str(res):
548
620
  return res.stderr.decode('utf-8')
549
- else:
550
- stderr = res.stderr
551
- if stderr and auto_query_on_error:
552
- rsp = stderr.decode('utf-8')
553
- print(rsp)
554
- msg = msg | f"\n{rsp}"
555
- if 'usage:' in rsp:
556
- try:
557
- cmd = cmd.split('usage: ')[-1].split(' ')[0]
558
- # get man page result for command
559
- res = subprocess.run('man -P cat %s' % cmd,
560
- shell=True,
561
- stdout=subprocess.PIPE)
562
- stdout = res.stdout
563
- if stdout:
564
- rsp = stdout.decode('utf-8')[:500]
565
- msg = msg | f"\n{rsp}"
566
- except Exception:
567
- pass
568
-
569
- return query_language_model(msg)
570
- else:
571
- stdout = res.stdout
572
- if stdout:
573
- message = stderr.decode('utf-8')
574
- return message
621
+ stderr = res.stderr
622
+ if stderr and auto_query_on_error:
623
+ rsp = stderr.decode('utf-8')
624
+ UserMessage(rsp, style="alert")
625
+ msg = msg | f"\n{rsp}"
626
+ if 'usage:' in rsp:
627
+ try:
628
+ cmd = cmd.split('usage: ')[-1].split(' ')[0]
629
+ # get man page result for command
630
+ res = subprocess.run(f'man -P cat {cmd}',
631
+ check=False, shell=True,
632
+ stdout=subprocess.PIPE)
633
+ stdout = res.stdout
634
+ if stdout:
635
+ rsp = stdout.decode('utf-8')[:500]
636
+ msg = msg | f"\n{rsp}"
637
+ except Exception:
638
+ pass
639
+
640
+ return query_language_model(msg)
641
+ stdout = res.stdout
642
+ if stdout:
643
+ message = stderr.decode('utf-8')
644
+ return message
575
645
 
576
646
  # run shell command
577
647
  def run_shell_command(cmd: str, prev=None, auto_query_on_error: bool=False, stdout=None, stderr=None):
@@ -581,14 +651,14 @@ def run_shell_command(cmd: str, prev=None, auto_query_on_error: bool=False, stdo
581
651
  conda_env = get_exec_prefix()
582
652
  # copy default_env
583
653
  new_env = default_env.copy()
584
- if exec_prefix != 'default':
654
+ if _shell_state.exec_prefix != 'default':
585
655
  # remove current env from PATH
586
656
  new_env["PATH"] = new_env["PATH"].replace(sys.exec_prefix, conda_env)
587
657
  # Execute the command
588
658
  try:
589
659
  stdout = subprocess.PIPE if auto_query_on_error else stdout
590
660
  stderr = subprocess.PIPE if auto_query_on_error else stderr
591
- res = subprocess.run(cmd, shell=True, stdout=stdout, stderr=stderr, env=new_env)
661
+ res = subprocess.run(cmd, check=False, shell=True, stdout=stdout, stderr=stderr, env=new_env)
592
662
  if res and stdout and res.stdout:
593
663
  message = res.stdout.decode('utf-8')
594
664
  elif res and stderr and res.stderr:
@@ -602,8 +672,7 @@ def run_shell_command(cmd: str, prev=None, auto_query_on_error: bool=False, stdo
602
672
  if res.returncode == 0:
603
673
  return message
604
674
  # If command not found, then try to query language model
605
- else:
606
- return handle_error(cmd, res, message, auto_query_on_error)
675
+ return handle_error(cmd, res, message, auto_query_on_error)
607
676
 
608
677
  def is_llm_request(cmd: str):
609
678
  return cmd.startswith('"') or cmd.startswith('."') or cmd.startswith('!"') or cmd.startswith('?"') or\
@@ -655,185 +724,238 @@ def map_nt_cmd(cmd: str, map_nt_cmd_enabled: bool = True):
655
724
  original_cmd = cmd
656
725
  cmd = re.sub(linux_cmd, windows_cmd, cmd)
657
726
  if cmd != original_cmd:
658
- print(f'symsh >> command "{original_cmd}" mapped to "{cmd}"\n')
727
+ UserMessage(f'symsh >> command "{original_cmd}" mapped to "{cmd}"\n', style="extensity")
659
728
 
660
729
  return cmd
661
730
 
662
- def process_command(cmd: str, res=None, auto_query_on_error: bool=False):
663
- global exec_prefix, previous_prefix
664
731
 
665
- # map commands to windows if needed
666
- cmd = map_nt_cmd(cmd)
732
+ def _handle_plugin_commands(cmd: str):
667
733
  if cmd.startswith('set-plugin') or cmd == 'unset-plugin' or cmd == 'get-plugin':
668
734
  return set_default_module(cmd)
735
+ return None
669
736
 
670
- sep = os.path.sep
671
- # check for '&&' to also preserve pipes '|' in normal shell commands
672
- if '" && ' in cmd or "' && " in cmd or '` && ' in cmd:
673
- if is_llm_request(cmd):
674
- # Process each command (the ones involving the LLM) separately
675
- cmds = cmd.split(' && ')
676
- if not is_llm_request(cmds[0]):
677
- return ValueError('The first command must be a LLM request.')
678
- # Process the first command as an LLM request
679
- res = query_language_model(cmds[0], res=res)
680
- rest = ' && '.join(cmds[1:])
681
- if '$1' in cmds[1]:
682
- res = str(res).replace('\n', r'\\n')
683
- rest = rest.replace('$1', '"%s"' % res)
684
- res = None
685
- cmd = rest
686
- # If it's a normal shell command with pipes or &&, pass it whole
687
- res = run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
688
- return res
689
737
 
738
+ def _handle_chained_llm_commands(cmd: str, res, auto_query_on_error: bool):
739
+ if '" && ' not in cmd and "' && " not in cmd and '` && ' not in cmd:
740
+ return None
741
+ if not is_llm_request(cmd):
742
+ return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
743
+ cmds = cmd.split(' && ')
744
+ if not is_llm_request(cmds[0]):
745
+ return ValueError('The first command must be a LLM request.')
746
+ first_res = query_language_model(cmds[0], res=res)
747
+ rest = ' && '.join(cmds[1:])
748
+ if len(cmds) > 1 and '$1' in cmds[1]:
749
+ first_res_str = str(first_res).replace('\n', r'\\n')
750
+ rest = rest.replace('$1', f'"{first_res_str}"')
751
+ first_res = None
752
+ return run_shell_command(rest, prev=first_res, auto_query_on_error=auto_query_on_error)
753
+
754
+
755
+ def _handle_llm_or_search(cmd: str, res):
690
756
  if cmd.startswith('?"') or cmd.startswith("?'") or cmd.startswith('?`'):
691
- cmd = cmd[1:]
692
- return search_engine(cmd, res=res)
693
-
694
- elif is_llm_request(cmd) or '...' in cmd:
757
+ query = cmd[1:]
758
+ return search_engine(query, res=res)
759
+ if is_llm_request(cmd) or '...' in cmd:
695
760
  return query_language_model(cmd, res=res)
761
+ return None
762
+
763
+
764
+ def _handle_retrieval_commands(cmd: str):
765
+ if cmd.startswith('*'):
766
+ return retrieval_augmented_indexing(cmd[1:])
767
+ return None
768
+
769
+
770
+ def _handle_man_command(cmd: str):
771
+ if cmd.startswith('man symsh'):
772
+ pkg_path = Path(__file__).resolve().parent
773
+ symsh_path = pkg_path / 'symsh.md'
774
+ with symsh_path.open(encoding="utf8") as file_ptr:
775
+ return file_ptr.read()
776
+ return None
777
+
696
778
 
697
- elif cmd.startswith('*'):
698
- cmd = cmd[1:]
699
- return retrieval_augmented_indexing(cmd)
700
-
701
- elif cmd.startswith('man symsh'):
702
- # read symsh.md file and print it
703
- # get symsh path
704
- pkg_path = os.path.dirname(os.path.abspath(__file__))
705
- symsh_path = os.path.join(pkg_path, 'symsh.md')
706
- with open(symsh_path, 'r', encoding="utf8") as f:
707
- return f.read()
708
-
709
- elif cmd.startswith('conda activate'):
710
- # check conda execution prefix and verify if environment exists
711
- env = sys.exec_prefix
712
- path_ = sep.join(env.split(sep)[:-1])
713
- env_base = os.path.join(sep, path_)
779
+ def _handle_conda_commands(cmd: str, state, res, auto_query_on_error: bool):
780
+ if cmd.startswith('conda activate'):
781
+ env = Path(sys.exec_prefix)
782
+ env_base = env.parent
714
783
  req_env = cmd.split(' ')[2]
715
- # check if environment exists
716
- env_path = os.path.join(env_base, req_env)
717
- if not os.path.exists(env_path):
784
+ env_path = env_base / req_env
785
+ if not env_path.exists():
718
786
  return f'Environment {req_env} does not exist!'
719
- previous_prefix = exec_prefix
720
- exec_prefix = os.path.join(env_base, req_env)
721
- return exec_prefix
722
-
723
- elif cmd.startswith('conda deactivate'):
724
- if previous_prefix is not None:
725
- exec_prefix = previous_prefix
726
- if previous_prefix == 'default':
727
- previous_prefix = None
787
+ state.previous_prefix = state.exec_prefix
788
+ state.exec_prefix = str(env_path)
789
+ return state.exec_prefix
790
+ if cmd.startswith('conda deactivate'):
791
+ prev_prefix = state.previous_prefix
792
+ if prev_prefix is not None:
793
+ state.exec_prefix = prev_prefix
794
+ if prev_prefix == 'default':
795
+ state.previous_prefix = None
728
796
  return get_exec_prefix()
797
+ if cmd.startswith('conda'):
798
+ env = Path(get_exec_prefix())
799
+ try:
800
+ env_base = env.parents[1]
801
+ except IndexError:
802
+ env_base = env.parent
803
+ cmd_rewritten = cmd.replace('conda', str(env_base / "condabin" / "conda"))
804
+ return run_shell_command(cmd_rewritten, prev=res, auto_query_on_error=auto_query_on_error)
805
+ return None
729
806
 
730
- elif cmd.startswith('conda'):
731
- env = get_exec_prefix()
732
- env_base = os.path.join(sep, *env.split(sep)[:-2])
733
- cmd = cmd.replace('conda', os.path.join(env_base, "condabin", "conda"))
734
- return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
735
807
 
736
- elif cmd.startswith('cd'):
808
+ def _handle_directory_navigation(cmd: str):
809
+ sep = os.path.sep
810
+ if cmd.startswith('cd'):
737
811
  try:
738
- # replace ~ with home directory
739
- cmd = FileReader.expand_user_path(cmd)
740
- # Change directory
741
- path = ' '.join(cmd.split(' ')[1:])
812
+ cmd_expanded = FileReader.expand_user_path(cmd)
813
+ path = ' '.join(cmd_expanded.split(' ')[1:])
742
814
  if path.endswith(sep):
743
815
  path = path[:-1]
744
816
  return os.chdir(path)
745
- except FileNotFoundError as e:
746
- return e
747
- except PermissionError as e:
748
- return e
749
-
750
- elif os.path.isdir(cmd):
817
+ except FileNotFoundError as err:
818
+ return err
819
+ except PermissionError as err:
820
+ return err
821
+ cmd_path = FileReader.expand_user_path(cmd)
822
+ if Path(cmd).is_dir():
751
823
  try:
752
- # replace ~ with home directory
753
- cmd = FileReader.expand_user_path(cmd)
754
- # Change directory
755
- os.chdir(cmd)
756
- except FileNotFoundError as e:
757
- return e
758
- except PermissionError as e:
759
- return e
760
-
761
- elif cmd.startswith('ll'):
762
-
763
- if os.name == 'nt':
764
- cmd = cmd.replace('ll', 'dir')
765
- return run_shell_command(cmd, prev=res)
766
- else:
767
- cmd = cmd.replace('ll', 'ls -la')
768
- return run_shell_command(cmd, prev=res)
824
+ os.chdir(cmd_path)
825
+ except FileNotFoundError as err:
826
+ return err
827
+ except PermissionError as err:
828
+ return err
829
+ return None
769
830
 
770
- else:
771
- return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
831
+
832
+ def _handle_ll_alias(cmd: str, res):
833
+ if not cmd.startswith('ll'):
834
+ return None
835
+ if os.name == 'nt':
836
+ rewritten = cmd.replace('ll', 'dir')
837
+ return run_shell_command(rewritten, prev=res)
838
+ rewritten = cmd.replace('ll', 'ls -la')
839
+ return run_shell_command(rewritten, prev=res)
840
+
841
+
842
+ def process_command(cmd: str, res=None, auto_query_on_error: bool=False):
843
+ state = _shell_state
844
+
845
+ # map commands to windows if needed
846
+ cmd = map_nt_cmd(cmd)
847
+ plugin_result = _handle_plugin_commands(cmd)
848
+ if plugin_result is not None:
849
+ return plugin_result
850
+
851
+ chained_result = _handle_chained_llm_commands(cmd, res, auto_query_on_error)
852
+ if chained_result is not None:
853
+ return chained_result
854
+
855
+ llm_or_search = _handle_llm_or_search(cmd, res)
856
+ if llm_or_search is not None:
857
+ return llm_or_search
858
+
859
+ retrieval_result = _handle_retrieval_commands(cmd)
860
+ if retrieval_result is not None:
861
+ return retrieval_result
862
+
863
+ man_result = _handle_man_command(cmd)
864
+ if man_result is not None:
865
+ return man_result
866
+
867
+ conda_result = _handle_conda_commands(cmd, state, res, auto_query_on_error)
868
+ if conda_result is not None:
869
+ return conda_result
870
+
871
+ directory_result = _handle_directory_navigation(cmd)
872
+ if directory_result is not None:
873
+ return directory_result
874
+
875
+ ll_result = _handle_ll_alias(cmd, res, auto_query_on_error)
876
+ if ll_result is not None:
877
+ return ll_result
878
+
879
+ return run_shell_command(cmd, prev=res, auto_query_on_error=auto_query_on_error)
772
880
 
773
881
  def save_conversation():
774
882
  home_path = HOME_PATH
775
- symai_path = os.path.join(home_path, '.conversation_state')
776
- Conversation.save_conversation_state(stateful_conversation, symai_path)
883
+ symai_path = home_path / '.conversation_state'
884
+ Conversation.save_conversation_state(_shell_state.stateful_conversation, symai_path)
885
+
886
+
887
+ def _is_exit_command(cmd: str) -> bool:
888
+ return cmd in ['quit', 'exit', 'q']
889
+
890
+
891
+ def _format_working_directory():
892
+ sep = os.path.sep
893
+ cur_working_dir = Path.cwd()
894
+ cur_working_dir_str = str(cur_working_dir)
895
+ if cur_working_dir_str.startswith(sep):
896
+ cur_working_dir_str = FileReader.expand_user_path(cur_working_dir_str)
897
+ paths = cur_working_dir_str.split(sep)
898
+ prev_paths = sep.join(paths[:-1])
899
+ last_path = paths[-1]
900
+ if len(paths) > 1:
901
+ return f'{prev_paths}{sep}<b>{last_path}</b>'
902
+ return f'<b>{last_path}</b>'
903
+
904
+
905
+ def _build_prompt(git_branch, conda_env, cur_working_dir_str):
906
+ if git_branch:
907
+ return HTML(
908
+ f"<ansiblue>{cur_working_dir_str}</ansiblue><ansiwhite> on git:[</ansiwhite>"
909
+ f"<ansigreen>{git_branch}</ansigreen><ansiwhite>]</ansiwhite> <ansiwhite>conda:[</ansiwhite>"
910
+ f"<ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ></ansicyan> "
911
+ )
912
+ return HTML(
913
+ f"<ansiblue>{cur_working_dir_str}</ansiblue> <ansiwhite>conda:[</ansiwhite>"
914
+ f"<ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ></ansicyan> "
915
+ )
916
+
917
+
918
+ def _handle_exit(state):
919
+ if state.stateful_conversation is not None:
920
+ save_conversation()
921
+ if not state.use_styles:
922
+ UserMessage('Goodbye!', style="extensity")
923
+ else:
924
+ func = _shell_state.function_type('Give short goodbye')
925
+ UserMessage(func('bye'), style="extensity")
926
+ os._exit(0)
927
+
777
928
 
778
929
  # Function to listen for user input and execute commands
779
930
  def listen(session: PromptSession, word_comp: WordCompleter, auto_query_on_error: bool=False, verbose: bool=False):
931
+ state = _shell_state
780
932
  with patch_stdout():
781
933
  while True:
782
934
  try:
783
935
  git_branch = get_git_branch()
784
936
  conda_env = get_conda_env()
785
- # get directory from the shell
786
- cur_working_dir = os.getcwd()
787
- sep = os.path.sep
788
- if cur_working_dir.startswith(sep):
789
- cur_working_dir = FileReader.expand_user_path(cur_working_dir)
790
- paths = cur_working_dir.split(sep)
791
- prev_paths = sep.join(paths[:-1])
792
- last_path = paths[-1]
793
-
794
- # Format the prompt
795
- if len(paths) > 1:
796
- cur_working_dir = f'{prev_paths}{sep}<b>{last_path}</b>'
797
- else:
798
- cur_working_dir = f'<b>{last_path}</b>'
799
-
800
- if git_branch:
801
- prompt = HTML(f"<ansiblue>{cur_working_dir}</ansiblue><ansiwhite> on git:[</ansiwhite><ansigreen>{git_branch}</ansigreen><ansiwhite>]</ansiwhite> <ansiwhite>conda:[</ansiwhite><ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ❯</ansicyan> ")
802
- else:
803
- prompt = HTML(f"<ansiblue>{cur_working_dir}</ansiblue> <ansiwhite>conda:[</ansiwhite><ansimagenta>{conda_env}</ansimagenta><ansiwhite>]</ansiwhite> <ansicyan><b>symsh:</b> ❯</ansicyan> ")
804
-
805
- # Read user input
937
+ cur_working_dir_str = _format_working_directory()
938
+ prompt = _build_prompt(git_branch, conda_env, cur_working_dir_str)
806
939
  cmd = session.prompt(prompt)
807
940
  if cmd.strip() == '':
808
941
  continue
809
942
 
810
- if cmd == 'quit' or cmd == 'exit' or cmd == 'q':
811
- if stateful_conversation is not None:
812
- save_conversation()
813
- if not use_styles:
814
- print('Goodbye!')
815
- else:
816
- func = FunctionType('Give short goodbye')
817
- print(func('bye'))
818
- os._exit(0)
819
- else:
820
- msg = process_command(cmd, auto_query_on_error=auto_query_on_error)
821
- if msg is not None:
822
- with ConsoleStyle('code') as console:
823
- console.print(msg)
943
+ if _is_exit_command(cmd):
944
+ _handle_exit(state)
945
+ msg = process_command(cmd, auto_query_on_error=auto_query_on_error)
946
+ if msg is not None:
947
+ with ConsoleStyle('code') as console:
948
+ console.print(msg)
824
949
 
825
950
  # Append the command to the word completer list
826
951
  word_comp.words.append(cmd)
827
952
 
828
953
  except KeyboardInterrupt:
829
- print()
830
- pass
831
-
954
+ UserMessage('', style="alert")
832
955
  except Exception as e:
833
- print(e)
956
+ UserMessage(str(e), style="alert")
834
957
  if verbose:
835
958
  traceback.print_exc()
836
- pass
837
959
 
838
960
  def create_session(history, merged_completer):
839
961
  colors = SYMSH_CONFIG['colors']
@@ -842,14 +964,12 @@ def create_session(history, merged_completer):
842
964
  style = Style.from_dict(colors)
843
965
 
844
966
  # Session for the auto-completion
845
- session = PromptSession(history=history,
846
- completer=merged_completer,
847
- complete_style=CompleteStyle.MULTI_COLUMN,
848
- reserve_space_for_menu=5,
849
- style=style,
850
- key_bindings=bindings)
851
-
852
- return session
967
+ return PromptSession(history=history,
968
+ completer=merged_completer,
969
+ complete_style=CompleteStyle.MULTI_COLUMN,
970
+ reserve_space_for_menu=5,
971
+ style=style,
972
+ key_bindings=bindings)
853
973
 
854
974
  def create_completer():
855
975
  # Load history
@@ -864,20 +984,24 @@ def create_completer():
864
984
  return history, word_comp, merged_completer
865
985
 
866
986
  def run(auto_query_on_error=False, conversation_style=None, verbose=False):
867
- global FunctionType, ConversationType, RetrievalConversationType, use_styles
987
+ state = _shell_state
868
988
  if conversation_style is not None and conversation_style != '':
869
- print('Loading style:', conversation_style)
989
+ UserMessage(f'Loading style: {conversation_style}', style="extensity")
870
990
  styles_ = Import.load_module_class(conversation_style)
871
- FunctionType, ConversationType, RetrievalConversationType = styles_
872
- use_styles = True
991
+ (
992
+ state.function_type,
993
+ state.conversation_type,
994
+ state.retrieval_conversation_type,
995
+ ) = styles_
996
+ state.use_styles = True
873
997
 
874
998
  if SYMSH_CONFIG['show-splash-screen']:
875
999
  show_intro_menu()
876
1000
  # set show splash screen to false
877
1001
  SYMSH_CONFIG['show-splash-screen'] = False
878
1002
  # save config
879
- _config_path = HOME_PATH / 'symsh.config.json'
880
- with open(_config_path, 'w') as f:
1003
+ _config_path = HOME_PATH / 'symsh.config.json'
1004
+ with _config_path.open('w') as f:
881
1005
  json.dump(SYMSH_CONFIG, f, indent=4)
882
1006
  if 'plugin_prefix' not in SYMSH_CONFIG:
883
1007
  SYMSH_CONFIG['plugin_prefix'] = None