zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (108) hide show
  1. zrb/__init__.py +2 -6
  2. zrb/attr/type.py +10 -7
  3. zrb/builtin/__init__.py +2 -0
  4. zrb/builtin/git.py +12 -1
  5. zrb/builtin/group.py +31 -15
  6. zrb/builtin/llm/attachment.py +40 -0
  7. zrb/builtin/llm/chat_completion.py +274 -0
  8. zrb/builtin/llm/chat_session.py +126 -167
  9. zrb/builtin/llm/chat_session_cmd.py +288 -0
  10. zrb/builtin/llm/chat_trigger.py +79 -0
  11. zrb/builtin/llm/history.py +4 -4
  12. zrb/builtin/llm/llm_ask.py +217 -135
  13. zrb/builtin/llm/tool/api.py +74 -70
  14. zrb/builtin/llm/tool/cli.py +35 -21
  15. zrb/builtin/llm/tool/code.py +55 -73
  16. zrb/builtin/llm/tool/file.py +278 -344
  17. zrb/builtin/llm/tool/note.py +84 -0
  18. zrb/builtin/llm/tool/rag.py +27 -34
  19. zrb/builtin/llm/tool/sub_agent.py +54 -41
  20. zrb/builtin/llm/tool/web.py +74 -98
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  23. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  24. zrb/builtin/searxng/config/settings.yml +5671 -0
  25. zrb/builtin/searxng/start.py +21 -0
  26. zrb/builtin/shell/autocomplete/bash.py +4 -3
  27. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  28. zrb/config/config.py +202 -27
  29. zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
  30. zrb/config/default_prompt/interactive_system_prompt.md +24 -30
  31. zrb/config/default_prompt/persona.md +1 -1
  32. zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
  33. zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
  34. zrb/config/default_prompt/summarization_prompt.md +57 -16
  35. zrb/config/default_prompt/system_prompt.md +36 -30
  36. zrb/config/llm_config.py +119 -23
  37. zrb/config/llm_context/config.py +127 -90
  38. zrb/config/llm_context/config_parser.py +1 -7
  39. zrb/config/llm_context/workflow.py +81 -0
  40. zrb/config/llm_rate_limitter.py +100 -47
  41. zrb/context/any_shared_context.py +7 -1
  42. zrb/context/context.py +8 -2
  43. zrb/context/shared_context.py +3 -7
  44. zrb/group/any_group.py +3 -3
  45. zrb/group/group.py +3 -3
  46. zrb/input/any_input.py +5 -1
  47. zrb/input/base_input.py +18 -6
  48. zrb/input/option_input.py +13 -1
  49. zrb/input/text_input.py +7 -24
  50. zrb/runner/cli.py +21 -20
  51. zrb/runner/common_util.py +24 -19
  52. zrb/runner/web_route/task_input_api_route.py +5 -5
  53. zrb/runner/web_util/user.py +7 -3
  54. zrb/session/any_session.py +12 -6
  55. zrb/session/session.py +39 -18
  56. zrb/task/any_task.py +24 -3
  57. zrb/task/base/context.py +17 -9
  58. zrb/task/base/execution.py +15 -8
  59. zrb/task/base/lifecycle.py +8 -4
  60. zrb/task/base/monitoring.py +12 -7
  61. zrb/task/base_task.py +69 -5
  62. zrb/task/base_trigger.py +12 -5
  63. zrb/task/llm/agent.py +128 -167
  64. zrb/task/llm/agent_runner.py +152 -0
  65. zrb/task/llm/config.py +39 -20
  66. zrb/task/llm/conversation_history.py +110 -29
  67. zrb/task/llm/conversation_history_model.py +4 -179
  68. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  69. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  70. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  71. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  72. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  73. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  74. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  75. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  76. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  77. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  78. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  79. zrb/task/llm/file_replacement.py +206 -0
  80. zrb/task/llm/file_tool_model.py +57 -0
  81. zrb/task/llm/history_processor.py +206 -0
  82. zrb/task/llm/history_summarization.py +2 -193
  83. zrb/task/llm/print_node.py +184 -64
  84. zrb/task/llm/prompt.py +175 -179
  85. zrb/task/llm/subagent_conversation_history.py +41 -0
  86. zrb/task/llm/tool_wrapper.py +226 -85
  87. zrb/task/llm/workflow.py +76 -0
  88. zrb/task/llm_task.py +109 -71
  89. zrb/task/make_task.py +2 -3
  90. zrb/task/rsync_task.py +25 -10
  91. zrb/task/scheduler.py +4 -4
  92. zrb/util/attr.py +54 -39
  93. zrb/util/cli/markdown.py +12 -0
  94. zrb/util/cli/text.py +30 -0
  95. zrb/util/file.py +12 -3
  96. zrb/util/git.py +2 -2
  97. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  98. zrb/util/string/conversion.py +1 -1
  99. zrb/util/truncate.py +23 -0
  100. zrb/util/yaml.py +204 -0
  101. zrb/xcom/xcom.py +10 -0
  102. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
  103. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
  104. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
  105. zrb/task/llm/default_workflow/coding.md +0 -24
  106. zrb/task/llm/default_workflow/copywriting.md +0 -17
  107. zrb/task/llm/default_workflow/researching.md +0 -18
  108. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
@@ -2,12 +2,124 @@ import os
2
2
 
3
3
  from zrb.config.config import CFG
4
4
  from zrb.config.llm_context.config_parser import markdown_to_dict
5
- from zrb.util.llm.prompt import demote_markdown_headers
5
+ from zrb.config.llm_context.workflow import LLMWorkflow
6
+ from zrb.util.markdown import demote_markdown_headers
6
7
 
7
8
 
8
9
  class LLMContextConfig:
9
10
  """High-level API for interacting with cascaded configurations."""
10
11
 
12
+ def write_note(
13
+ self,
14
+ content: str,
15
+ context_path: str | None = None,
16
+ cwd: str | None = None,
17
+ ):
18
+ """Writes content to a note block in the user's home configuration file."""
19
+ if cwd is None:
20
+ cwd = os.getcwd()
21
+ if context_path is None:
22
+ context_path = cwd
23
+ config_file = self._get_home_config_file()
24
+ sections = {}
25
+ if os.path.exists(config_file):
26
+ sections = self._parse_config(config_file)
27
+ abs_context_path = os.path.abspath(os.path.join(cwd, context_path))
28
+ found_key = None
29
+ for key in sections.keys():
30
+ if not key.startswith("Note:"):
31
+ continue
32
+ context_path_str = key[len("Note:") :].strip()
33
+ abs_key_path = self._normalize_context_path(
34
+ context_path_str,
35
+ os.path.dirname(config_file),
36
+ )
37
+ if abs_key_path == abs_context_path:
38
+ found_key = key
39
+ break
40
+ if found_key:
41
+ sections[found_key] = content
42
+ else:
43
+ config_dir = os.path.dirname(config_file)
44
+ formatted_path = self._format_context_path_for_writing(
45
+ abs_context_path,
46
+ config_dir,
47
+ )
48
+ new_key = f"Note: {formatted_path}"
49
+ sections[new_key] = content
50
+ # Serialize back to markdown
51
+ new_file_content = ""
52
+ for key, value in sections.items():
53
+ new_file_content += f"# {key}\n{demote_markdown_headers(value)}\n\n"
54
+ with open(config_file, "w") as f:
55
+ f.write(new_file_content)
56
+
57
+ def get_notes(self, cwd: str | None = None) -> dict[str, str]:
58
+ """Gathers all notes for a given path."""
59
+ if cwd is None:
60
+ cwd = os.getcwd()
61
+ config_file = self._get_home_config_file()
62
+ if not os.path.exists(config_file):
63
+ return {}
64
+ config_dir = os.path.dirname(config_file)
65
+ sections = self._parse_config(config_file)
66
+ notes: dict[str, str] = {}
67
+ for key, value in sections.items():
68
+ if key.lower().startswith("note:"):
69
+ context_path_str = key[len("note:") :].strip()
70
+ abs_context_path = self._normalize_context_path(
71
+ context_path_str,
72
+ config_dir,
73
+ )
74
+ # A context is relevant if its path is an ancestor of cwd
75
+ if os.path.commonpath([cwd, abs_context_path]) == abs_context_path:
76
+ notes[abs_context_path] = value
77
+ return notes
78
+
79
+ def get_workflows(self, cwd: str | None = None) -> dict[str, LLMWorkflow]:
80
+ """Gathers all relevant workflows for a given path."""
81
+ if cwd is None:
82
+ cwd = os.getcwd()
83
+ all_sections = self._get_all_sections(cwd)
84
+ workflows: dict[str, LLMWorkflow] = {}
85
+ # Iterate from closest to farthest
86
+ for config_dir, sections in all_sections:
87
+ for key, value in sections.items():
88
+ if key.lower().startswith("workflow:"):
89
+ workflow_name = key[len("workflow:") :].strip().lower()
90
+ # First one found wins
91
+ if workflow_name not in workflows:
92
+ workflows[workflow_name] = LLMWorkflow(
93
+ name=workflow_name,
94
+ content=value,
95
+ path=config_dir,
96
+ )
97
+ return workflows
98
+
99
+ def _format_context_path_for_writing(
100
+ self,
101
+ path_to_write: str,
102
+ relative_to_dir: str,
103
+ ) -> str:
104
+ """Formats a path for writing into a context file key."""
105
+ home_dir = os.path.expanduser("~")
106
+ abs_path_to_write = os.path.abspath(
107
+ os.path.join(relative_to_dir, path_to_write)
108
+ )
109
+ abs_relative_to_dir = os.path.abspath(relative_to_dir)
110
+ # Rule 1: Inside relative_to_dir
111
+ if abs_path_to_write.startswith(abs_relative_to_dir):
112
+ if abs_path_to_write == abs_relative_to_dir:
113
+ return "."
114
+ return os.path.relpath(abs_path_to_write, abs_relative_to_dir)
115
+ # Rule 2: Inside Home
116
+ if abs_path_to_write.startswith(home_dir):
117
+ if abs_path_to_write == home_dir:
118
+ return "~"
119
+ return os.path.join("~", os.path.relpath(abs_path_to_write, home_dir))
120
+ # Rule 3: Absolute
121
+ return abs_path_to_write
122
+
11
123
  def _find_config_files(self, cwd: str) -> list[str]:
12
124
  configs = []
13
125
  current_dir = cwd
@@ -21,6 +133,10 @@ class LLMContextConfig:
21
133
  current_dir = os.path.dirname(current_dir)
22
134
  return configs
23
135
 
136
+ def _get_home_config_file(self) -> str:
137
+ home_dir = os.path.expanduser("~")
138
+ return os.path.join(home_dir, CFG.LLM_CONTEXT_FILE)
139
+
24
140
  def _parse_config(self, file_path: str) -> dict[str, str]:
25
141
  with open(file_path, "r") as f:
26
142
  content = f.read()
@@ -35,95 +151,16 @@ class LLMContextConfig:
35
151
  all_sections.append((config_dir, sections))
36
152
  return all_sections
37
153
 
38
- def get_contexts(self, cwd: str | None = None) -> dict[str, str]:
39
- """Gathers all relevant contexts for a given path."""
40
- if cwd is None:
41
- cwd = os.getcwd()
42
- all_sections = self._get_all_sections(cwd)
43
- contexts: dict[str, str] = {}
44
- for config_dir, sections in reversed(all_sections):
45
- for key, value in sections.items():
46
- if key.startswith("Context:"):
47
- context_path = key[len("Context:") :].strip()
48
- if context_path == ".":
49
- context_path = config_dir
50
- elif not os.path.isabs(context_path):
51
- context_path = os.path.abspath(
52
- os.path.join(config_dir, context_path)
53
- )
54
- if os.path.isabs(context_path) or cwd.startswith(context_path):
55
- contexts[context_path] = value
56
- return contexts
57
-
58
- def get_workflows(self, cwd: str | None = None) -> dict[str, str]:
59
- """Gathers all relevant workflows for a given path."""
60
- if cwd is None:
61
- cwd = os.getcwd()
62
- all_sections = self._get_all_sections(cwd)
63
- workflows: dict[str, str] = {}
64
- for _, sections in reversed(all_sections):
65
- for key, value in sections.items():
66
- if key.startswith("Workflow:"):
67
- workflow_name = key[len("Workflow:") :].strip()
68
- workflow_name = key.replace("Workflow:", "").lower().strip()
69
- workflows[workflow_name] = value
70
- return workflows
71
-
72
- def write_context(
73
- self, content: str, context_path: str | None = None, cwd: str | None = None
74
- ):
75
- """Writes content to a context block in the nearest configuration file."""
76
- if cwd is None:
77
- cwd = os.getcwd()
78
- if context_path is None:
79
- context_path = cwd
80
-
81
- config_files = self._find_config_files(cwd)
82
- if config_files:
83
- config_file = config_files[0] # Closest config file
84
- else:
85
- config_file = os.path.join(cwd, CFG.LLM_CONTEXT_FILE)
86
-
87
- sections = {}
88
- if os.path.exists(config_file):
89
- sections = self._parse_config(config_file)
90
-
91
- # Determine the section key
92
- section_key_path = context_path
93
- if not os.path.isabs(context_path):
94
- config_dir = os.path.dirname(config_file)
95
- section_key_path = os.path.abspath(os.path.join(config_dir, context_path))
96
-
97
- # Find existing key
98
- found_key = ""
99
- for key in sections.keys():
100
- if not key.startswith("Context:"):
101
- continue
102
- key_path = key.replace("Context:", "").strip()
103
- if key_path == ".":
104
- key_path = os.path.dirname(config_file)
105
- elif not os.path.isabs(key_path):
106
- key_path = os.path.abspath(
107
- os.path.join(os.path.dirname(config_file), key_path)
108
- )
109
- if key_path == section_key_path:
110
- found_key = key
111
- break
112
-
113
- if found_key != "":
114
- sections[found_key] = content
115
- else:
116
- # Add new entry
117
- new_key = f"Context: {context_path}"
118
- sections[new_key] = content
119
-
120
- # Serialize back to markdown
121
- new_file_content = ""
122
- for key, value in sections.items():
123
- new_file_content += f"# {key}\n{demote_markdown_headers(value)}\n\n"
124
-
125
- with open(config_file, "w") as f:
126
- f.write(new_file_content)
154
+ def _normalize_context_path(
155
+ self,
156
+ path_str: str,
157
+ relative_to_dir: str,
158
+ ) -> str:
159
+ """Normalizes a context path string to an absolute path."""
160
+ expanded_path = os.path.expanduser(path_str)
161
+ if os.path.isabs(expanded_path):
162
+ return os.path.abspath(expanded_path)
163
+ return os.path.abspath(os.path.join(relative_to_dir, expanded_path))
127
164
 
128
165
 
129
166
  llm_context_config = LLMContextConfig()
@@ -1,6 +1,6 @@
1
1
  import re
2
2
 
3
- from zrb.util.llm.prompt import promote_markdown_headers
3
+ from zrb.util.markdown import promote_markdown_headers
4
4
 
5
5
 
6
6
  def markdown_to_dict(markdown: str) -> dict[str, str]:
@@ -8,21 +8,17 @@ def markdown_to_dict(markdown: str) -> dict[str, str]:
8
8
  current_title = ""
9
9
  current_content: list[str] = []
10
10
  fence_stack: list[str] = []
11
-
12
11
  fence_pattern = re.compile(r"^([`~]{3,})(.*)$")
13
12
  h1_pattern = re.compile(r"^# (.+)$")
14
-
15
13
  for line in markdown.splitlines():
16
14
  # Detect code fence open/close
17
15
  fence_match = fence_pattern.match(line.strip())
18
-
19
16
  if fence_match:
20
17
  fence = fence_match.group(1)
21
18
  if fence_stack and fence_stack[-1] == fence:
22
19
  fence_stack.pop() # close current fence
23
20
  else:
24
21
  fence_stack.append(fence) # open new fence
25
-
26
22
  # Only parse H1 when not inside a code fence
27
23
  if not fence_stack:
28
24
  h1_match = h1_pattern.match(line)
@@ -34,9 +30,7 @@ def markdown_to_dict(markdown: str) -> dict[str, str]:
34
30
  current_title = h1_match.group(1).strip()
35
31
  current_content = []
36
32
  continue
37
-
38
33
  current_content.append(line)
39
-
40
34
  # Save final section
41
35
  if current_title:
42
36
  sections[current_title] = "\n".join(current_content).strip()
@@ -0,0 +1,81 @@
1
+ class LLMWorkflow:
2
+ def __init__(
3
+ self, name: str, path: str, content: str, description: str | None = None
4
+ ):
5
+ self._name = name
6
+ self._path = path
7
+
8
+ # Extract YAML metadata and clean content
9
+ (
10
+ extracted_description,
11
+ cleaned_content,
12
+ ) = self._extract_yaml_metadata_and_clean_content(content)
13
+ self._content = cleaned_content
14
+
15
+ # Use provided description or extracted one
16
+ self._description = (
17
+ description if description is not None else extracted_description
18
+ )
19
+
20
+ def _extract_yaml_metadata_and_clean_content(
21
+ self, content: str
22
+ ) -> tuple[str | None, str]:
23
+ """Extract YAML metadata and clean content.
24
+
25
+ Looks for YAML metadata between --- lines, extracts the 'description' field,
26
+ and returns the content without the YAML metadata.
27
+ """
28
+ import re
29
+
30
+ import yaml
31
+
32
+ # Pattern to match YAML metadata between --- delimiters
33
+ yaml_pattern = r"^---\s*\n(.*?)\n---\s*\n"
34
+ match = re.search(yaml_pattern, content, re.DOTALL | re.MULTILINE)
35
+
36
+ if match:
37
+ yaml_content = match.group(1)
38
+ try:
39
+ metadata = yaml.safe_load(yaml_content)
40
+ description = (
41
+ metadata.get("description") if isinstance(metadata, dict) else None
42
+ )
43
+ # Remove the YAML metadata from content
44
+ cleaned_content = re.sub(
45
+ yaml_pattern, "", content, count=1, flags=re.DOTALL | re.MULTILINE
46
+ )
47
+ return description, cleaned_content.strip()
48
+ except yaml.YAMLError:
49
+ # If YAML parsing fails, return original content
50
+ pass
51
+
52
+ # No YAML metadata found, return original content
53
+ return None, content
54
+
55
+ @property
56
+ def name(self) -> str:
57
+ return self._name
58
+
59
+ @property
60
+ def path(self) -> str:
61
+ return self._path
62
+
63
+ @property
64
+ def content(self) -> str:
65
+ return self._content
66
+
67
+ @property
68
+ def description(self) -> str:
69
+ if self._description is not None:
70
+ return self._description
71
+ if len(self._content) > 1000:
72
+ non_empty_lines = [
73
+ line for line in self._content.split("\n") if line.strip() != ""
74
+ ]
75
+ first_non_empty_line = (
76
+ non_empty_lines[0] if len(non_empty_lines) > 0 else ""
77
+ )
78
+ if len(first_non_empty_line) > 200:
79
+ return first_non_empty_line[:200] + "... (more)"
80
+ return first_non_empty_line
81
+ return self._content
@@ -1,30 +1,13 @@
1
1
  import asyncio
2
+ import json
2
3
  import time
3
4
  from collections import deque
4
- from typing import Callable
5
-
6
- import tiktoken
5
+ from typing import Any, Callable
7
6
 
8
7
  from zrb.config.config import CFG
9
8
 
10
9
 
11
- def _estimate_token(text: str) -> int:
12
- """
13
- Estimates the number of tokens in a given text.
14
- Tries to use the 'gpt-4o' model's tokenizer for an accurate count.
15
- If the tokenizer is unavailable (e.g., due to network issues),
16
- it falls back to a heuristic of 4 characters per token.
17
- """
18
- try:
19
- # Primary method: Use tiktoken for an accurate count
20
- enc = tiktoken.encoding_for_model("gpt-4o")
21
- return len(enc.encode(text))
22
- except Exception:
23
- # Fallback method: Heuristic (4 characters per token)
24
- return len(text) // 4
25
-
26
-
27
- class LLMRateLimiter:
10
+ class LLMRateLimitter:
28
11
  """
29
12
  Helper class to enforce LLM API rate limits and throttling.
30
13
  Tracks requests and tokens in a rolling 60-second window.
@@ -35,14 +18,18 @@ class LLMRateLimiter:
35
18
  max_requests_per_minute: int | None = None,
36
19
  max_tokens_per_minute: int | None = None,
37
20
  max_tokens_per_request: int | None = None,
21
+ max_tokens_per_tool_call_result: int | None = None,
38
22
  throttle_sleep: float | None = None,
39
- token_counter_fn: Callable[[str], int] | None = None,
23
+ use_tiktoken: bool | None = None,
24
+ tiktoken_encoding_name: str | None = None,
40
25
  ):
41
26
  self._max_requests_per_minute = max_requests_per_minute
42
27
  self._max_tokens_per_minute = max_tokens_per_minute
43
28
  self._max_tokens_per_request = max_tokens_per_request
29
+ self._max_tokens_per_tool_call_result = max_tokens_per_tool_call_result
44
30
  self._throttle_sleep = throttle_sleep
45
- self._token_counter_fn = token_counter_fn
31
+ self._use_tiktoken = use_tiktoken
32
+ self._tiktoken_encoding_name = tiktoken_encoding_name
46
33
  self.request_times = deque()
47
34
  self.token_times = deque()
48
35
 
@@ -64,6 +51,12 @@ class LLMRateLimiter:
64
51
  return self._max_tokens_per_request
65
52
  return CFG.LLM_MAX_TOKENS_PER_REQUEST
66
53
 
54
+ @property
55
+ def max_tokens_per_tool_call_result(self) -> int:
56
+ if self._max_tokens_per_tool_call_result is not None:
57
+ return self._max_tokens_per_tool_call_result
58
+ return CFG.LLM_MAX_TOKENS_PER_TOOL_CALL_RESULT
59
+
67
60
  @property
68
61
  def throttle_sleep(self) -> float:
69
62
  if self._throttle_sleep is not None:
@@ -71,10 +64,16 @@ class LLMRateLimiter:
71
64
  return CFG.LLM_THROTTLE_SLEEP
72
65
 
73
66
  @property
74
- def count_token(self) -> Callable[[str], int]:
75
- if self._token_counter_fn is not None:
76
- return self._token_counter_fn
77
- return _estimate_token
67
+ def use_tiktoken(self) -> bool:
68
+ if self._use_tiktoken is not None:
69
+ return self._use_tiktoken
70
+ return CFG.USE_TIKTOKEN
71
+
72
+ @property
73
+ def tiktoken_encoding_name(self) -> str:
74
+ if self._tiktoken_encoding_name is not None:
75
+ return self._tiktoken_encoding_name
76
+ return CFG.TIKTOKEN_ENCODING_NAME
78
77
 
79
78
  def set_max_requests_per_minute(self, value: int):
80
79
  self._max_requests_per_minute = value
@@ -85,29 +84,56 @@ class LLMRateLimiter:
85
84
  def set_max_tokens_per_request(self, value: int):
86
85
  self._max_tokens_per_request = value
87
86
 
87
+ def set_max_tokens_per_tool_call_result(self, value: int):
88
+ self._max_tokens_per_tool_call_result = value
89
+
88
90
  def set_throttle_sleep(self, value: float):
89
91
  self._throttle_sleep = value
90
92
 
91
- def set_token_counter_fn(self, fn: Callable[[str], int]):
92
- self._token_counter_fn = fn
93
-
94
- def clip_prompt(self, prompt: str, limit: int) -> str:
95
- token_count = self.count_token(prompt)
96
- if token_count <= limit:
97
- return prompt
98
- while token_count > limit:
99
- prompt_parts = prompt.split(" ")
100
- last_part_index = len(prompt_parts) - 2
101
- clipped_prompt = " ".join(prompt_parts[:last_part_index])
102
- clipped_prompt += "(Content clipped...)"
103
- token_count = self.count_token(clipped_prompt)
104
- if token_count < limit:
105
- return clipped_prompt
106
- return prompt[:limit]
107
-
108
- async def throttle(self, prompt: str):
93
+ def count_token(self, prompt: Any) -> int:
94
+ str_prompt = self._prompt_to_str(prompt)
95
+ if not self.use_tiktoken:
96
+ return self._fallback_count_token(str_prompt)
97
+ try:
98
+ import tiktoken
99
+
100
+ enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
101
+ return len(enc.encode(str_prompt))
102
+ except Exception:
103
+ return self._fallback_count_token(str_prompt)
104
+
105
+ def _fallback_count_token(self, str_prompt: str) -> int:
106
+ return len(str_prompt) // 4
107
+
108
+ def clip_prompt(self, prompt: Any, limit: int) -> str:
109
+ str_prompt = self._prompt_to_str(prompt)
110
+ if not self.use_tiktoken:
111
+ return self._fallback_clip_prompt(str_prompt, limit)
112
+ try:
113
+ import tiktoken
114
+
115
+ enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
116
+ tokens = enc.encode(str_prompt)
117
+ if len(tokens) <= limit:
118
+ return str_prompt
119
+ truncated = tokens[: limit - 3]
120
+ clipped_text = enc.decode(truncated)
121
+ return clipped_text + "..."
122
+ except Exception:
123
+ return self._fallback_clip_prompt(str_prompt, limit)
124
+
125
+ def _fallback_clip_prompt(self, str_prompt: str, limit: int) -> str:
126
+ char_limit = limit * 4 if limit * 4 <= 10 else limit * 4 - 10
127
+ return str_prompt[:char_limit] + "..."
128
+
129
+ async def throttle(
130
+ self,
131
+ prompt: Any,
132
+ throttle_notif_callback: Callable[[str], Any] | None = None,
133
+ ):
109
134
  now = time.time()
110
- tokens = self.count_token(prompt)
135
+ str_prompt = self._prompt_to_str(prompt)
136
+ tokens = self.count_token(str_prompt)
111
137
  # Clean up old entries
112
138
  while self.request_times and now - self.request_times[0] > 60:
113
139
  self.request_times.popleft()
@@ -116,13 +142,34 @@ class LLMRateLimiter:
116
142
  # Check per-request token limit
117
143
  if tokens > self.max_tokens_per_request:
118
144
  raise ValueError(
119
- f"Request exceeds max_tokens_per_request ({self.max_tokens_per_request})."
145
+ (
146
+ "Request exceeds max_tokens_per_request "
147
+ f"({tokens} > {self.max_tokens_per_request})."
148
+ )
149
+ )
150
+ if tokens > self.max_tokens_per_minute:
151
+ raise ValueError(
152
+ (
153
+ "Request exceeds max_tokens_per_minute "
154
+ f"({tokens} > {self.max_tokens_per_minute})."
155
+ )
120
156
  )
121
157
  # Wait if over per-minute request or token limit
122
158
  while (
123
159
  len(self.request_times) >= self.max_requests_per_minute
124
160
  or sum(t for _, t in self.token_times) + tokens > self.max_tokens_per_minute
125
161
  ):
162
+ if throttle_notif_callback is not None:
163
+ if len(self.request_times) >= self.max_requests_per_minute:
164
+ rpm = len(self.request_times)
165
+ throttle_notif_callback(
166
+ f"Max request per minute exceeded: {rpm} of {self.max_requests_per_minute}"
167
+ )
168
+ else:
169
+ tpm = sum(t for _, t in self.token_times) + tokens
170
+ throttle_notif_callback(
171
+ f"Max token per minute exceeded: {tpm} of {self.max_tokens_per_minute}"
172
+ )
126
173
  await asyncio.sleep(self.throttle_sleep)
127
174
  now = time.time()
128
175
  while self.request_times and now - self.request_times[0] > 60:
@@ -133,5 +180,11 @@ class LLMRateLimiter:
133
180
  self.request_times.append(now)
134
181
  self.token_times.append((now, tokens))
135
182
 
183
+ def _prompt_to_str(self, prompt: Any) -> str:
184
+ try:
185
+ return json.dumps(prompt)
186
+ except Exception:
187
+ return f"{prompt}"
188
+
136
189
 
137
- llm_rate_limitter = LLMRateLimiter()
190
+ llm_rate_limitter = LLMRateLimitter()
@@ -29,26 +29,32 @@ class AnySharedContext(ABC):
29
29
  pass
30
30
 
31
31
  @property
32
+ @abstractmethod
32
33
  def input(self) -> DotDict:
33
34
  pass
34
35
 
35
36
  @property
37
+ @abstractmethod
36
38
  def env(self) -> DotDict:
37
39
  pass
38
40
 
39
41
  @property
42
+ @abstractmethod
40
43
  def args(self) -> list[Any]:
41
44
  pass
42
45
 
43
46
  @property
44
- def xcom(self) -> DotDict[str, Xcom]:
47
+ @abstractmethod
48
+ def xcom(self) -> DotDict:
45
49
  pass
46
50
 
47
51
  @property
52
+ @abstractmethod
48
53
  def shared_log(self) -> list[str]:
49
54
  pass
50
55
 
51
56
  @property
57
+ @abstractmethod
52
58
  def session(self) -> any_session.AnySession | None:
53
59
  pass
54
60
 
zrb/context/context.py CHANGED
@@ -63,7 +63,7 @@ class Context(AnyContext):
63
63
 
64
64
  @property
65
65
  def session(self) -> AnySession | None:
66
- return self._shared_ctx._session
66
+ return self._shared_ctx.session
67
67
 
68
68
  def update_task_env(self, task_env: dict[str, str]):
69
69
  self._env.update(task_env)
@@ -119,7 +119,13 @@ class Context(AnyContext):
119
119
  return
120
120
  color = self._color
121
121
  icon = self._icon
122
- max_name_length = max(len(name) + len(icon) for name in self.session.task_names)
122
+ # Handle case where session is None (e.g., in tests)
123
+ if self.session is None:
124
+ max_name_length = len(self._task_name) + len(icon)
125
+ else:
126
+ max_name_length = max(
127
+ len(name) + len(icon) for name in self.session.task_names
128
+ )
123
129
  styled_task_name = f"{icon} {self._task_name}"
124
130
  padded_styled_task_name = styled_task_name.rjust(max_name_length + 1)
125
131
  if self._attempt == 0:
@@ -40,11 +40,7 @@ class SharedContext(AnySharedContext):
40
40
 
41
41
  def __repr__(self):
42
42
  class_name = self.__class__.__name__
43
- input = self._input
44
- args = self._args
45
- env = self._env
46
- xcom = self._xcom
47
- return f"<{class_name} input={input} args={args} xcom={xcom} env={env}>"
43
+ return f"<{class_name}>"
48
44
 
49
45
  @property
50
46
  def is_web_mode(self) -> bool:
@@ -70,7 +66,7 @@ class SharedContext(AnySharedContext):
70
66
  return self._args
71
67
 
72
68
  @property
73
- def xcom(self) -> DotDict[str, Xcom]:
69
+ def xcom(self) -> DotDict:
74
70
  return self._xcom
75
71
 
76
72
  @property
@@ -85,7 +81,7 @@ class SharedContext(AnySharedContext):
85
81
  self._log.append(message)
86
82
  session = self.session
87
83
  if session is not None:
88
- session_parent: AnySession = session.parent
84
+ session_parent: AnySession | None = session.parent
89
85
  if session_parent is not None:
90
86
  session_parent.shared_ctx.append_to_shared_log(message)
91
87
 
zrb/group/any_group.py CHANGED
@@ -35,11 +35,11 @@ class AnyGroup(ABC):
35
35
  pass
36
36
 
37
37
  @abstractmethod
38
- def add_group(self, group: "AnyGroup | str") -> "AnyGroup":
38
+ def add_group(self, group: "AnyGroup", alias: str | None = None) -> "AnyGroup":
39
39
  pass
40
40
 
41
41
  @abstractmethod
42
- def add_task(self, task: AnyTask, alias: str | None = None) -> AnyTask:
42
+ def add_task(self, task: "AnyTask", alias: str | None = None) -> "AnyTask":
43
43
  pass
44
44
 
45
45
  @abstractmethod
@@ -55,5 +55,5 @@ class AnyGroup(ABC):
55
55
  pass
56
56
 
57
57
  @abstractmethod
58
- def get_group_by_alias(self, name: str) -> "AnyGroup | None":
58
+ def get_group_by_alias(self, alias: str) -> "AnyGroup | None":
59
59
  pass