zrb 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. zrb/__main__.py +3 -0
  2. zrb/builtin/llm/llm_chat.py +85 -5
  3. zrb/builtin/llm/previous-session.js +13 -0
  4. zrb/builtin/llm/tool/api.py +29 -0
  5. zrb/builtin/llm/tool/cli.py +1 -1
  6. zrb/builtin/llm/tool/rag.py +108 -145
  7. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/client_method.py +6 -6
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +3 -1
  9. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +88 -44
  10. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +12 -0
  11. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +28 -22
  12. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +6 -6
  13. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +43 -29
  14. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_repository.py +8 -0
  15. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +46 -14
  16. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +158 -20
  17. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +29 -0
  18. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +36 -14
  19. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +14 -14
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +1 -1
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +34 -6
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +2 -6
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +41 -2
  24. zrb/builtin/todo.py +1 -0
  25. zrb/config.py +23 -4
  26. zrb/input/any_input.py +5 -0
  27. zrb/input/base_input.py +6 -0
  28. zrb/input/bool_input.py +2 -0
  29. zrb/input/float_input.py +2 -0
  30. zrb/input/int_input.py +2 -0
  31. zrb/input/option_input.py +2 -0
  32. zrb/input/password_input.py +2 -0
  33. zrb/input/text_input.py +2 -0
  34. zrb/runner/common_util.py +1 -1
  35. zrb/runner/web_route/error_page/show_error_page.py +2 -1
  36. zrb/runner/web_route/static/resources/session/current-session.js +4 -2
  37. zrb/runner/web_route/static/resources/session/event.js +8 -2
  38. zrb/runner/web_route/task_session_api_route.py +48 -3
  39. zrb/task/base_task.py +14 -13
  40. zrb/task/llm_task.py +214 -84
  41. zrb/util/llm/tool.py +3 -7
  42. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/METADATA +2 -1
  43. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/RECORD +45 -43
  44. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/WHEEL +0 -0
  45. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/entry_points.txt +0 -0
zrb/input/base_input.py CHANGED
@@ -16,6 +16,7 @@ class BaseInput(AnyInput):
16
16
  default_str: StrAttr = "",
17
17
  auto_render: bool = True,
18
18
  allow_empty: bool = False,
19
+ allow_positional_parsing: bool = True,
19
20
  ):
20
21
  self._name = name
21
22
  self._description = description
@@ -23,6 +24,7 @@ class BaseInput(AnyInput):
23
24
  self._default_str = default_str
24
25
  self._auto_render = auto_render
25
26
  self._allow_empty = allow_empty
27
+ self._allow_positional_parsing = allow_positional_parsing
26
28
 
27
29
  def __repr__(self):
28
30
  return f"<{self.__class__.__name__} name={self._name}>"
@@ -39,6 +41,10 @@ class BaseInput(AnyInput):
39
41
  def prompt_message(self) -> str:
40
42
  return self._prompt if self._prompt is not None else self.name
41
43
 
44
+ @property
45
+ def allow_positional_parsing(self) -> bool:
46
+ return self._allow_positional_parsing
47
+
42
48
  def to_html(self, ctx: AnySharedContext) -> str:
43
49
  name = self.name
44
50
  description = self.description
zrb/input/bool_input.py CHANGED
@@ -13,6 +13,7 @@ class BoolInput(BaseInput):
13
13
  default_str: StrAttr = "False",
14
14
  auto_render: bool = True,
15
15
  allow_empty: bool = False,
16
+ allow_positional_parsing: bool = True,
16
17
  ):
17
18
  super().__init__(
18
19
  name=name,
@@ -21,6 +22,7 @@ class BoolInput(BaseInput):
21
22
  default_str=default_str,
22
23
  auto_render=auto_render,
23
24
  allow_empty=allow_empty,
25
+ allow_positional_parsing=allow_positional_parsing,
24
26
  )
25
27
 
26
28
  def to_html(self, ctx: AnySharedContext) -> str:
zrb/input/float_input.py CHANGED
@@ -12,6 +12,7 @@ class FloatInput(BaseInput):
12
12
  default_str: StrAttr = "0.0",
13
13
  auto_render: bool = True,
14
14
  allow_empty: bool = False,
15
+ allow_positional_parsing: bool = True,
15
16
  ):
16
17
  super().__init__(
17
18
  name=name,
@@ -20,6 +21,7 @@ class FloatInput(BaseInput):
20
21
  default_str=default_str,
21
22
  auto_render=auto_render,
22
23
  allow_empty=allow_empty,
24
+ allow_positional_parsing=allow_positional_parsing,
23
25
  )
24
26
 
25
27
  def to_html(self, ctx: AnySharedContext) -> str:
zrb/input/int_input.py CHANGED
@@ -12,6 +12,7 @@ class IntInput(BaseInput):
12
12
  default_str: StrAttr = "0",
13
13
  auto_render: bool = True,
14
14
  allow_empty: bool = False,
15
+ allow_positional_parsing: bool = True,
15
16
  ):
16
17
  super().__init__(
17
18
  name=name,
@@ -20,6 +21,7 @@ class IntInput(BaseInput):
20
21
  default_str=default_str,
21
22
  auto_render=auto_render,
22
23
  allow_empty=allow_empty,
24
+ allow_positional_parsing=allow_positional_parsing,
23
25
  )
24
26
 
25
27
  def to_html(self, ctx: AnySharedContext) -> str:
zrb/input/option_input.py CHANGED
@@ -14,6 +14,7 @@ class OptionInput(BaseInput):
14
14
  default_str: StrAttr = "",
15
15
  auto_render: bool = True,
16
16
  allow_empty: bool = False,
17
+ allow_positional_parsing: bool = True,
17
18
  ):
18
19
  super().__init__(
19
20
  name=name,
@@ -22,6 +23,7 @@ class OptionInput(BaseInput):
22
23
  default_str=default_str,
23
24
  auto_render=auto_render,
24
25
  allow_empty=allow_empty,
26
+ allow_positional_parsing=allow_positional_parsing,
25
27
  )
26
28
  self._options = options
27
29
 
@@ -14,6 +14,7 @@ class PasswordInput(BaseInput):
14
14
  default_str: str | Callable[[AnySharedContext], str] = "",
15
15
  auto_render: bool = True,
16
16
  allow_empty: bool = False,
17
+ allow_positional_parsing: bool = True,
17
18
  ):
18
19
  super().__init__(
19
20
  name=name,
@@ -22,6 +23,7 @@ class PasswordInput(BaseInput):
22
23
  default_str=default_str,
23
24
  auto_render=auto_render,
24
25
  allow_empty=allow_empty,
26
+ allow_positional_parsing=allow_positional_parsing,
25
27
  )
26
28
  self._is_secret = True
27
29
 
zrb/input/text_input.py CHANGED
@@ -18,6 +18,7 @@ class TextInput(BaseInput):
18
18
  default_str: str | Callable[[AnySharedContext], str] = "",
19
19
  auto_render: bool = True,
20
20
  allow_empty: bool = False,
21
+ allow_positional_parsing: bool = True,
21
22
  editor: str = DEFAULT_EDITOR,
22
23
  extension: str = ".txt",
23
24
  comment_start: str | None = None,
@@ -30,6 +31,7 @@ class TextInput(BaseInput):
30
31
  default_str=default_str,
31
32
  auto_render=auto_render,
32
33
  allow_empty=allow_empty,
34
+ allow_positional_parsing=allow_positional_parsing,
33
35
  )
34
36
  self._editor = editor
35
37
  self._extension = extension
zrb/runner/common_util.py CHANGED
@@ -15,7 +15,7 @@ def get_run_kwargs(
15
15
  if task_input.name in str_kwargs:
16
16
  # Update shared context for next input default value
17
17
  task_input.update_shared_context(shared_ctx, str_kwargs[task_input.name])
18
- elif arg_index < len(args):
18
+ elif arg_index < len(args) and task_input.allow_positional_parsing:
19
19
  run_kwargs[task_input.name] = args[arg_index]
20
20
  # Update shared context for next input default value
21
21
  task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
@@ -23,5 +23,6 @@ def show_error_page(user: User, root_group: AnyGroup, status_code: int, message:
23
23
  "error_status_code": status_code,
24
24
  "error_message": message,
25
25
  },
26
- )
26
+ ),
27
+ status_code=status_code,
27
28
  )
@@ -13,14 +13,16 @@ const CURRENT_SESSION = {
13
13
  for (const inputName in dataInputs) {
14
14
  const inputValue = dataInputs[inputName];
15
15
  const input = submitTaskForm.querySelector(`[name="${inputName}"]`);
16
- input.value = inputValue;
16
+ if (input) {
17
+ input.value = inputValue;
18
+ }
17
19
  }
18
20
  resultLineCount = data.final_result.split("\n").length;
19
21
  resultTextarea.rows = resultLineCount <= 5 ? resultLineCount : 5;
20
22
  // update text areas
21
23
  resultTextarea.value = data.final_result;
22
24
  logTextarea.value = data.log.join("\n");
23
- logTextarea.scrollTop = logTextarea.scrollHeight;
25
+ // logTextarea.scrollTop = logTextarea.scrollHeight;
24
26
  // visualize history
25
27
  this.showCurrentSession(data.task_status, data.finished);
26
28
  if (data.finished) {
@@ -20,9 +20,9 @@ window.addEventListener("load", async function () {
20
20
 
21
21
 
22
22
  const submitTaskForm = document.getElementById("submit-task-form");
23
- submitTaskForm.addEventListener("input", async function(event) {
23
+ submitTaskForm.addEventListener("change", async function(event) {
24
24
  const currentInput = event.target;
25
- const inputs = Array.from(submitTaskForm.querySelectorAll("input[name]"));
25
+ const inputs = Array.from(submitTaskForm.querySelectorAll("input[name], textarea[name], select[name]"));
26
26
  const inputMap = {};
27
27
  const fixedInputNames = [];
28
28
  for (const input of inputs) {
@@ -53,6 +53,12 @@ submitTaskForm.addEventListener("input", async function(event) {
53
53
  return;
54
54
  }
55
55
  const input = submitTaskForm.querySelector(`[name="${key}"]`);
56
+ if (input === currentInput) {
57
+ return;
58
+ }
59
+ if (value === "") {
60
+ return;
61
+ }
56
62
  input.value = value;
57
63
  });
58
64
  } else {
@@ -94,9 +94,54 @@ def serve_task_session_api(
94
94
  if min_start_query is None
95
95
  else datetime.strptime(min_start_query, "%Y-%m-%d %H:%M:%S")
96
96
  )
97
- return session_state_logger.list(
98
- task_path, min_start_time, max_start_time, page, limit
97
+ return sanitize_session_state_log_list(
98
+ task,
99
+ session_state_logger.list(
100
+ task_path, min_start_time, max_start_time, page, limit
101
+ ),
99
102
  )
100
103
  else:
101
- return session_state_logger.read(residual_args[0])
104
+ return sanitize_session_state_log(
105
+ task, session_state_logger.read(residual_args[0])
106
+ )
102
107
  return JSONResponse(content={"detail": "Not found"}, status_code=404)
108
+
109
+
110
+ def sanitize_session_state_log_list(
111
+ task: AnyTask, session_state_log_list: SessionStateLogList
112
+ ) -> SessionStateLogList:
113
+ return SessionStateLogList(
114
+ total=session_state_log_list.total,
115
+ data=[
116
+ sanitize_session_state_log(task, data)
117
+ for data in session_state_log_list.data
118
+ ],
119
+ )
120
+
121
+
122
+ def sanitize_session_state_log(
123
+ task: AnyTask, session_state_log: SessionStateLog
124
+ ) -> SessionStateLog:
125
+ """
126
+ In session, we create snake_case aliases of inputs.
127
+ The purpose was to increase ergonomics, so that user can use `input.system_prompt`
128
+ instead of `input["system-prompt"]`
129
+ However, when we serve the session through HTTP API,
130
+ we only want to show the original input names.
131
+ """
132
+ enhanced_inputs = session_state_log.input
133
+ real_inputs = {}
134
+ for real_input in task.inputs:
135
+ real_input_name = real_input.name
136
+ real_inputs[real_input_name] = enhanced_inputs[real_input_name]
137
+ return SessionStateLog(
138
+ name=session_state_log.name,
139
+ start_time=session_state_log.start_time,
140
+ main_task_name=session_state_log.main_task_name,
141
+ path=session_state_log.path,
142
+ input=real_inputs,
143
+ final_result=session_state_log.final_result,
144
+ finished=session_state_log.finished,
145
+ log=session_state_log.log,
146
+ task_status=session_state_log.task_status,
147
+ )
zrb/task/base_task.py CHANGED
@@ -242,28 +242,28 @@ class BaseTask(AnyTask):
242
242
  def run(
243
243
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
244
244
  ) -> Any:
245
- loop = asyncio.new_event_loop()
246
- try:
247
- return loop.run_until_complete(self._run_and_cleanup(session, str_kwargs))
248
- finally:
249
- loop.close()
245
+ return asyncio.run(self._run_and_cleanup(session, str_kwargs))
250
246
 
251
247
  async def _run_and_cleanup(
252
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
248
+ self,
249
+ session: AnySession | None = None,
250
+ str_kwargs: dict[str, str] = {},
253
251
  ) -> Any:
252
+ current_task = asyncio.create_task(self.async_run(session, str_kwargs))
254
253
  try:
255
- result = await self.async_run(session, str_kwargs)
254
+ result = await current_task
256
255
  finally:
257
- if not session.is_terminated:
256
+ if session and not session.is_terminated:
258
257
  session.terminate()
259
258
  # Cancel all running tasks except the current one
260
- current_task = asyncio.current_task()
261
259
  pending = [task for task in asyncio.all_tasks() if task is not current_task]
262
260
  for task in pending:
263
261
  task.cancel()
264
- # Wait for all tasks to complete with a timeout
265
262
  if pending:
266
- await asyncio.wait(pending, timeout=5)
263
+ try:
264
+ await asyncio.wait(pending, timeout=5)
265
+ except asyncio.CancelledError:
266
+ pass
267
267
  return result
268
268
 
269
269
  async def async_run(
@@ -274,7 +274,8 @@ class BaseTask(AnyTask):
274
274
  # Update session
275
275
  self.__fill_shared_context_inputs(session.shared_ctx, str_kwargs)
276
276
  self.__fill_shared_context_envs(session.shared_ctx)
277
- return await run_async(self.exec_root_tasks(session))
277
+ result = await run_async(self.exec_root_tasks(session))
278
+ return result
278
279
 
279
280
  def __fill_shared_context_inputs(
280
281
  self, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
@@ -352,7 +353,7 @@ class BaseTask(AnyTask):
352
353
  session.get_task_status(self).mark_as_skipped()
353
354
  return
354
355
  # Wait for task to be ready
355
- await run_async(self.__exec_action_until_ready(session))
356
+ return await run_async(self.__exec_action_until_ready(session))
356
357
 
357
358
  def __get_execute_condition(self, session: Session) -> bool:
358
359
  ctx = self.get_ctx(session)
zrb/task/llm_task.py CHANGED
@@ -17,6 +17,7 @@ from zrb.util.attr import get_str_attr
17
17
  from zrb.util.cli.style import stylize_faint
18
18
  from zrb.util.file import read_file, write_file
19
19
  from zrb.util.llm.tool import callable_to_tool_schema
20
+ from zrb.util.run import run_async
20
21
 
21
22
  ListOfDict = list[dict[str, Any]]
22
23
 
@@ -24,14 +25,18 @@ ListOfDict = list[dict[str, Any]]
24
25
  class AdditionalTool(BaseModel):
25
26
  fn: Callable
26
27
  name: str | None
27
- description: str | None
28
28
 
29
29
 
30
30
  def scratchpad(thought: str) -> str:
31
- """Use this tool to note your thought and planning"""
31
+ """Write your thought, analysis, reasoning, and evaluation here."""
32
32
  return thought
33
33
 
34
34
 
35
+ def end_conversation(final_answer: str) -> str:
36
+ """End conversation with a final answer containing all necessary information"""
37
+ return final_answer
38
+
39
+
35
40
  class LLMTask(BaseTask):
36
41
  def __init__(
37
42
  self,
@@ -47,11 +52,17 @@ class LLMTask(BaseTask):
47
52
  system_prompt: StrAttr | None = LLM_SYSTEM_PROMPT,
48
53
  render_system_prompt: bool = True,
49
54
  message: StrAttr | None = None,
50
- tools: (
51
- dict[str, Callable] | Callable[[AnySharedContext], dict[str, Callable]]
52
- ) = {},
53
- history: ListOfDict | Callable[[AnySharedContext], ListOfDict] = [],
54
- history_file: StrAttr | None = None,
55
+ tools: list[Callable] | Callable[[AnySharedContext], list[Callable]] = [],
56
+ conversation_history: (
57
+ ListOfDict | Callable[[AnySharedContext], ListOfDict]
58
+ ) = [],
59
+ conversation_history_reader: (
60
+ Callable[[AnySharedContext], ListOfDict] | None
61
+ ) = None,
62
+ conversation_history_writer: (
63
+ Callable[[AnySharedContext, ListOfDict], None] | None
64
+ ) = None,
65
+ conversation_history_file: StrAttr | None = None,
55
66
  render_history_file: bool = True,
56
67
  model_kwargs: (
57
68
  dict[str, Any] | Callable[[AnySharedContext], dict[str, Any]]
@@ -97,84 +108,165 @@ class LLMTask(BaseTask):
97
108
  self._render_system_prompt = render_system_prompt
98
109
  self._message = message
99
110
  self._tools = tools
100
- self._history = history
101
- self._history_file = history_file
111
+ self._conversation_history = conversation_history
112
+ self._conversation_history_reader = conversation_history_reader
113
+ self._conversation_history_writer = conversation_history_writer
114
+ self._conversation_history_file = conversation_history_file
102
115
  self._render_history_file = render_history_file
103
- self._additional_tools: list[AdditionalTool] = []
104
116
 
105
- def add_tool(
106
- self, tool: Callable, name: str | None = None, description: str | None = None
107
- ):
108
- self._additional_tools.append(
109
- AdditionalTool(fn=tool, name=name, description=description)
110
- )
117
+ def add_tool(self, tool: Callable):
118
+ self._tools.append(tool)
111
119
 
112
120
  async def _exec_action(self, ctx: AnyContext) -> Any:
113
- from litellm import acompletion, supports_function_calling
121
+ import litellm
122
+ from litellm.utils import supports_function_calling
114
123
 
124
+ user_message = {"role": "user", "content": self._get_message(ctx)}
125
+ ctx.print(stylize_faint(f"{user_message}"))
115
126
  model = self._get_model(ctx)
116
127
  try:
117
- allow_function_call = supports_function_calling(model=model)
128
+ is_function_call_supported = supports_function_calling(model=model)
118
129
  except Exception:
119
- allow_function_call = False
120
- model_kwargs = self._get_model_kwargs(ctx)
130
+ is_function_call_supported = False
131
+ litellm.add_function_to_prompt = True
132
+ if not is_function_call_supported:
133
+ ctx.log_warning(f"Model {model} doesn't support function call")
134
+ available_tools = self._get_available_tools(
135
+ ctx, include_end_conversation=not is_function_call_supported
136
+ )
137
+ model_kwargs = self._get_model_kwargs(ctx, available_tools)
121
138
  ctx.log_debug("MODEL KWARGS", model_kwargs)
122
139
  system_prompt = self._get_system_prompt(ctx)
123
140
  ctx.log_debug("SYSTEM PROMPT", system_prompt)
124
- history = self._get_history(ctx)
141
+ history = await self._read_conversation_history(ctx)
125
142
  ctx.log_debug("HISTORY PROMPT", history)
126
- user_message = {"role": "user", "content": self._get_message(ctx)}
127
- ctx.print(stylize_faint(f"{user_message}"))
128
- messages = history + [user_message]
129
- available_tools = self._get_tools(ctx)
130
- available_tools["scratchpad"] = scratchpad
131
- if allow_function_call:
132
- tool_schema = [
133
- callable_to_tool_schema(tool, name)
134
- for name, tool in available_tools.items()
135
- ]
136
- for additional_tool in self._additional_tools:
137
- fn = additional_tool.fn
138
- tool_name = additional_tool.name or fn.__name__
139
- tool_description = additional_tool.description
140
- available_tools[tool_name] = additional_tool.fn
141
- tool_schema.append(
142
- callable_to_tool_schema(
143
- fn, name=tool_name, description=tool_description
144
- )
145
- )
146
- model_kwargs["tools"] = tool_schema
147
- ctx.log_debug("TOOL SCHEMA", tool_schema)
148
- history_file = self._get_history_file(ctx)
143
+ conversations = history + [user_message]
149
144
  while True:
150
- response = await acompletion(
151
- model=model,
152
- messages=[{"role": "system", "content": system_prompt}] + messages,
153
- **model_kwargs,
145
+ llm_response = await self._get_llm_response(
146
+ model, system_prompt, conversations, model_kwargs
147
+ )
148
+ llm_response_dict = llm_response.to_dict()
149
+ ctx.print(stylize_faint(f"{llm_response_dict}"))
150
+ conversations.append(llm_response_dict)
151
+ ctx.log_debug("RESPONSE MESSAGE", llm_response)
152
+ if is_function_call_supported:
153
+ if not llm_response.tool_calls:
154
+ # No tool call, end conversation
155
+ await self._write_conversation_history(ctx, conversations)
156
+ return llm_response.content
157
+ await self._handle_tool_calls(
158
+ ctx, available_tools, conversations, llm_response
159
+ )
160
+ if not is_function_call_supported:
161
+ try:
162
+ json_payload = json.loads(llm_response.content)
163
+ function_name = _get_fallback_function_name(json_payload)
164
+ function_kwargs = _get_fallback_function_kwargs(json_payload)
165
+ tool_execution_message = (
166
+ await self._create_fallback_tool_exec_message(
167
+ available_tools, function_name, function_kwargs
168
+ )
169
+ )
170
+ ctx.print(stylize_faint(f"{tool_execution_message}"))
171
+ conversations.append(tool_execution_message)
172
+ if function_name == "end_conversation":
173
+ await self._write_conversation_history(ctx, conversations)
174
+ return function_kwargs.get("final_answer", "")
175
+ except Exception as e:
176
+ ctx.log_error(e)
177
+ tool_execution_message = self._create_exec_scratchpad_message(
178
+ f"{e}"
179
+ )
180
+ conversations.append(tool_execution_message)
181
+
182
+ async def _handle_tool_calls(
183
+ self,
184
+ ctx: AnyContext,
185
+ available_tools: dict[str, Callable],
186
+ conversations: list[dict[str, Any]],
187
+ llm_response: Any,
188
+ ):
189
+ # noqa Reference: https://docs.litellm.ai/docs/completion/function_call#full-code---parallel-function-calling-with-gpt-35-turbo-1106
190
+ for tool_call in llm_response.tool_calls:
191
+ tool_execution_message = await self._create_tool_exec_message(
192
+ available_tools, tool_call
154
193
  )
155
- response_message = response.choices[0].message
156
- ctx.print(stylize_faint(f"{response_message.to_dict()}"))
157
- messages.append(response_message.to_dict())
158
- tool_calls = response_message.tool_calls
159
- if tool_calls:
160
- # noqa Reference: https://docs.litellm.ai/docs/completion/function_call#full-code---parallel-function-calling-with-gpt-35-turbo-1106
161
- for tool_call in tool_calls:
162
- function_name = tool_call.function.name
163
- function_to_call = available_tools[function_name]
164
- function_kwargs = json.loads(tool_call.function.arguments)
165
- function_response = function_to_call(**function_kwargs)
166
- tool_call_message = {
167
- "tool_call_id": tool_call.id,
168
- "role": "tool",
169
- "name": function_name,
170
- "content": function_response,
171
- }
172
- ctx.print(stylize_faint(f"{tool_call_message}"))
173
- messages.append(tool_call_message)
174
- continue
175
- if history_file != "":
176
- write_file(history_file, json.dumps(messages, indent=2))
177
- return response_message.content
194
+ ctx.print(stylize_faint(f"{tool_execution_message}"))
195
+ conversations.append(tool_execution_message)
196
+
197
+ async def _write_conversation_history(
198
+ self, ctx: AnyContext, conversations: list[Any]
199
+ ):
200
+ if self._conversation_history_writer is not None:
201
+ await run_async(self._conversation_history_writer(ctx, conversations))
202
+ history_file = self._get_history_file(ctx)
203
+ if history_file != "":
204
+ write_file(history_file, json.dumps(conversations, indent=2))
205
+
206
+ async def _get_llm_response(
207
+ self,
208
+ model: str,
209
+ system_prompt: str,
210
+ conversations: list[Any],
211
+ model_kwargs: dict[str, Any],
212
+ ) -> Any:
213
+ from litellm import acompletion
214
+
215
+ llm_response = await acompletion(
216
+ model=model,
217
+ messages=[{"role": "system", "content": system_prompt}] + conversations,
218
+ **model_kwargs,
219
+ )
220
+ return llm_response.choices[0].message
221
+
222
+ async def _create_tool_exec_message(
223
+ self, available_tools: dict[str, Callable], tool_call: Any
224
+ ) -> dict[str, Any]:
225
+ function_name = tool_call.function.name
226
+ function_kwargs = json.loads(tool_call.function.arguments)
227
+ return {
228
+ "tool_call_id": tool_call.id,
229
+ "role": "tool",
230
+ "name": function_name,
231
+ "content": await self._get_exec_tool_result(
232
+ available_tools, function_name, function_kwargs
233
+ ),
234
+ }
235
+
236
+ async def _create_fallback_tool_exec_message(
237
+ self,
238
+ available_tools: dict[str, Callable],
239
+ function_name: str,
240
+ function_kwargs: dict[str, Any],
241
+ ) -> dict[str, Any]:
242
+ result = await self._get_exec_tool_result(
243
+ available_tools, function_name, function_kwargs
244
+ )
245
+ return self._create_exec_scratchpad_message(
246
+ f"Result of {function_name} call: {result}"
247
+ )
248
+
249
+ def _create_exec_scratchpad_message(self, message: str) -> dict[str, Any]:
250
+ return {
251
+ "role": "assistant",
252
+ "content": json.dumps(
253
+ {"name": "scratchpad", "arguments": {"thought": message}}
254
+ ),
255
+ }
256
+
257
+ async def _get_exec_tool_result(
258
+ self,
259
+ available_tools: dict[str, Callable],
260
+ function_name: str,
261
+ function_kwargs: dict[str, Any],
262
+ ) -> str:
263
+ if function_name not in available_tools:
264
+ return f"[ERROR] Invalid tool: {function_name}"
265
+ function_to_call = available_tools[function_name]
266
+ try:
267
+ return await run_async(function_to_call(**function_kwargs))
268
+ except Exception as e:
269
+ return f"[ERROR] {e}"
178
270
 
179
271
  def _get_model(self, ctx: AnyContext) -> str:
180
272
  return get_str_attr(
@@ -192,29 +284,67 @@ class LLMTask(BaseTask):
192
284
  def _get_message(self, ctx: AnyContext) -> str:
193
285
  return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
194
286
 
195
- def _get_model_kwargs(self, ctx: AnyContext) -> dict[str, Callable]:
287
+ def _get_model_kwargs(
288
+ self, ctx: AnyContext, available_tools: dict[str, Callable]
289
+ ) -> dict[str, Any]:
290
+ model_kwargs = {}
196
291
  if callable(self._model_kwargs):
197
- return self._model_kwargs(ctx)
198
- return {**self._model_kwargs}
292
+ model_kwargs = self._model_kwargs(ctx)
293
+ else:
294
+ model_kwargs = self._model_kwargs
295
+ model_kwargs["tools"] = [
296
+ callable_to_tool_schema(tool) for tool in available_tools.values()
297
+ ]
298
+ return model_kwargs
199
299
 
200
- def _get_tools(self, ctx: AnyContext) -> dict[str, Callable]:
201
- if callable(self._tools):
202
- return self._tools(ctx)
203
- return self._tools
300
+ def _get_available_tools(
301
+ self, ctx: AnyContext, include_end_conversation: bool
302
+ ) -> dict[str, Callable]:
303
+ tools = {"scratchpad": scratchpad}
304
+ if include_end_conversation:
305
+ tools["end_conversation"] = end_conversation
306
+ tool_list = self._tools(ctx) if callable(self._tools) else self._tools
307
+ for tool in tool_list:
308
+ tools[tool.__name__] = tool
309
+ return tools
204
310
 
205
- def _get_history(self, ctx: AnyContext) -> ListOfDict:
206
- if callable(self._history):
207
- return self._history(ctx)
311
+ async def _read_conversation_history(self, ctx: AnyContext) -> ListOfDict:
312
+ if self._conversation_history_reader is not None:
313
+ return await run_async(self._conversation_history_reader(ctx))
314
+ if callable(self._conversation_history):
315
+ return self._conversation_history(ctx)
208
316
  history_file = self._get_history_file(ctx)
209
317
  if (
210
- len(self._history) == 0
318
+ len(self._conversation_history) == 0
211
319
  and history_file != ""
212
320
  and os.path.isfile(history_file)
213
321
  ):
214
322
  return json.loads(read_file(history_file))
215
- return self._history
323
+ return self._conversation_history
216
324
 
217
325
  def _get_history_file(self, ctx: AnyContext) -> str:
218
326
  return get_str_attr(
219
- ctx, self._history_file, "", auto_render=self._render_history_file
327
+ ctx,
328
+ self._conversation_history_file,
329
+ "",
330
+ auto_render=self._render_history_file,
220
331
  )
332
+
333
+
334
+ def _get_fallback_function_name(json_payload: dict[str, Any]) -> str:
335
+ for key in ("name",):
336
+ if key in json_payload:
337
+ return json_payload[key]
338
+ raise ValueError("Function name not provided")
339
+
340
+
341
+ def _get_fallback_function_kwargs(json_payload: dict[str, Any]) -> str:
342
+ for key in (
343
+ "arguments",
344
+ "args",
345
+ "parameters",
346
+ "params",
347
+ ):
348
+ if key in json_payload:
349
+ return json_payload[key]
350
+ raise ValueError("Function arguments not provided")