zrb 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,10 +6,11 @@ from zrb.builtin.group import llm_group
6
6
  from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
7
7
  from zrb.builtin.llm.tool.cli import run_shell_command
8
8
  from zrb.builtin.llm.tool.file import (
9
+ apply_diff,
9
10
  list_files,
10
- read_all_files,
11
- read_text_file,
12
- write_text_file,
11
+ read_from_file,
12
+ search_files,
13
+ write_to_file,
13
14
  )
14
15
  from zrb.builtin.llm.tool.web import (
15
16
  create_search_internet_tool,
@@ -161,10 +162,11 @@ llm_chat: LLMTask = llm_group.add_task(
161
162
 
162
163
 
163
164
  if LLM_ALLOW_ACCESS_LOCAL_FILE:
164
- llm_chat.add_tool(read_all_files)
165
165
  llm_chat.add_tool(list_files)
166
- llm_chat.add_tool(read_text_file)
167
- llm_chat.add_tool(write_text_file)
166
+ llm_chat.add_tool(read_from_file)
167
+ llm_chat.add_tool(write_to_file)
168
+ llm_chat.add_tool(search_files)
169
+ llm_chat.add_tool(apply_diff)
168
170
 
169
171
  if LLM_ALLOW_ACCESS_SHELL:
170
172
  llm_chat.add_tool(run_shell_command)
@@ -5,7 +5,7 @@ from typing import Annotated, Literal
5
5
  def get_current_location() -> (
6
6
  Annotated[str, "JSON string representing latitude and longitude"]
7
7
  ): # noqa
8
- """Get the user's current location."""
8
+ """Get the user's current location. This function take no argument."""
9
9
  import requests
10
10
 
11
11
  return json.dumps(requests.get("http://ip-api.com/json?fields=lat,lon").json())
@@ -1,119 +1,125 @@
1
1
  import fnmatch
2
2
  import os
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple, Union
3
5
 
4
- from zrb.util.file import read_file, write_file
5
-
6
- _INCLUDED_PATTERNS: list[str] = [
7
- "*.py", # Python
8
- "*.go", # Go
9
- "*.rs", # Rust
10
- "*.js", # JavaScript
11
- "*.ts", # TypeScript
12
- "*.java", # Java
13
- "*.c", # C
14
- "*.cpp", # C++
15
- "*.cc", # Alternative C++ extension
16
- "*.cxx", # Alternative C++ extension
17
- "*.rb", # Ruby
18
- "*.swift", # Swift
19
- "*.kt", # Kotlin
20
- "*.php", # PHP
21
- "*.pl", # Perl / Prolog
22
- "*.pm", # Perl module
23
- "*.sh", # Shell
24
- "*.bat", # Batch
25
- "*.ps1", # PowerShell
26
- "*.R", # R (capital)
27
- "*.r", # R (lowercase)
28
- "*.scala", # Scala
29
- "*.hs", # Haskell
30
- "*.cs", # C#
31
- "*.fs", # F#
32
- "*.ex", # Elixir
33
- "*.exs", # Elixir script
34
- "*.erl", # Erlang
35
- "*.hrl", # Erlang header
36
- "*.dart", # Dart
37
- "*.m", # Objective-C / Matlab (note: conflicts may arise)
38
- "*.mm", # Objective-C++
39
- "*.lua", # Lua
40
- "*.jl", # Julia
41
- "*.groovy", # Groovy
42
- "*.clj", # Clojure
43
- "*.cljs", # ClojureScript
44
- "*.cljc", # Clojure common
45
- "*.vb", # Visual Basic
46
- "*.f90", # Fortran
47
- "*.f95", # Fortran
48
- "*.adb", # Ada
49
- "*.ads", # Ada specification
50
- "*.pas", # Pascal
51
- "*.pp", # Pascal
52
- "*.ml", # OCaml
53
- "*.mli", # OCaml interface
54
- "*.nim", # Nim
55
- "*.rkt", # Racket
56
- "*.d", # D
57
- "*.lisp", # Common Lisp
58
- "*.lsp", # Lisp variant
59
- "*.cl", # Common Lisp
60
- "*.scm", # Scheme
61
- "*.st", # Smalltalk
62
- "*.vhd", # VHDL
63
- "*.vhdl", # VHDL
64
- "*.v", # Verilog
65
- "*.asm", # Assembly
66
- "*.s", # Assembly (alternative)
67
- "*.sql", # SQL (if desired)
68
- ]
6
+ from zrb.util.file import read_file as _read_file
7
+ from zrb.util.file import write_file as _write_file
69
8
 
70
- # Extended list of directories and patterns to exclude.
71
- _EXCLUDED_PATTERNS: list[str] = [
72
- "venv", # Python virtual environments
9
+ # Common directories and files to exclude from file operations
10
+ _DEFAULT_EXCLUDES = [
11
+ # Version control
12
+ ".git",
13
+ ".svn",
14
+ ".hg",
15
+ # Dependencies and packages
16
+ "node_modules",
17
+ "venv",
73
18
  ".venv",
74
- "node_modules", # Node.js dependencies
75
- ".git", # Git repositories
76
- "__pycache__", # Python cache directories
77
- "build", # Build directories
78
- "dist", # Distribution directories
79
- "target", # Build output directories (Java, Rust, etc.)
80
- "bin", # Binary directories
81
- "obj", # Object files directories
82
- ".idea", # JetBrains IDEs
83
- ".vscode", # VS Code settings
84
- ".eggs", # Python eggs
19
+ "env",
20
+ ".env",
21
+ # Build and cache
22
+ "__pycache__",
23
+ "*.pyc",
24
+ "build",
25
+ "dist",
26
+ "target",
27
+ # IDE and editor files
28
+ ".idea",
29
+ ".vscode",
30
+ "*.swp",
31
+ "*.swo",
32
+ # OS-specific
33
+ ".DS_Store",
34
+ "Thumbs.db",
35
+ # Temporary and backup files
36
+ "*.tmp",
37
+ "*.bak",
38
+ "*.log",
85
39
  ]
86
40
 
41
+ # Maximum number of lines to read before truncating
42
+ _MAX_LINES_BEFORE_TRUNCATION = 1000
43
+
44
+ # Number of context lines to show around method definitions when truncating
45
+ _CONTEXT_LINES = 5
46
+
87
47
 
88
48
  def list_files(
89
- directory: str = ".",
90
- included_patterns: list[str] = _INCLUDED_PATTERNS,
91
- excluded_patterns: list[str] = _EXCLUDED_PATTERNS,
49
+ path: str = ".",
50
+ recursive: bool = True,
51
+ file_pattern: Optional[str] = None,
52
+ excluded_patterns: list[str] = _DEFAULT_EXCLUDES,
92
53
  ) -> list[str]:
93
- """List all files in a directory that match any of the included glob patterns
94
- and do not reside in any directory matching an excluded pattern.
95
- Patterns are evaluated using glob-style matching.
54
+ """
55
+ List files in a directory that match specified patterns.
56
+
57
+ Args:
58
+ path: The path of the directory to list contents for
59
+ (relative to the current working directory)
60
+ recursive: Whether to list files recursively.
61
+ Use True for recursive listing, False for top-level only.
62
+ file_pattern: Optional glob pattern to filter files.
63
+ None by default (all files will be included).
64
+ excluded_patterns: List of glob patterns to exclude. By default, contains sane values
65
+ to exclude common directories and files like version control, build artifacts,
66
+ and temporary files.
67
+
68
+ Returns:
69
+ A list of file paths matching the criteria
96
70
  """
97
71
  all_files: list[str] = []
98
- for root, dirs, files in os.walk(directory):
99
- for filename in files:
100
- if any(fnmatch.fnmatch(filename, pat) for pat in included_patterns):
72
+
73
+ if recursive:
74
+ for root, dirs, files in os.walk(path):
75
+ # Filter out excluded directories to avoid descending into them
76
+ dirs[:] = [
77
+ d
78
+ for d in dirs
79
+ if not _should_exclude(os.path.join(root, d), excluded_patterns)
80
+ ]
81
+
82
+ for filename in files:
101
83
  full_path = os.path.join(root, filename)
102
- if _should_exclude(full_path, excluded_patterns):
103
- continue
104
- all_files.append(full_path)
105
- return all_files
84
+ # If file_pattern is None, include all files, otherwise match the pattern
85
+ if file_pattern is None or fnmatch.fnmatch(filename, file_pattern):
86
+ if not _should_exclude(full_path, excluded_patterns):
87
+ all_files.append(full_path)
88
+ else:
89
+ # Non-recursive listing (top-level only)
90
+ try:
91
+ for item in os.listdir(path):
92
+ full_path = os.path.join(path, item)
93
+ if os.path.isfile(full_path):
94
+ # If file_pattern is None, include all files, otherwise match the pattern
95
+ if file_pattern is None or fnmatch.fnmatch(item, file_pattern):
96
+ if not _should_exclude(full_path, excluded_patterns):
97
+ all_files.append(full_path)
98
+ except (FileNotFoundError, PermissionError) as e:
99
+ print(f"Error listing files in {path}: {e}")
100
+
101
+ return sorted(all_files)
106
102
 
107
103
 
108
- def _should_exclude(full_path: str, excluded_patterns: list[str]) -> bool:
104
+ def _should_exclude(
105
+ full_path: str, excluded_patterns: list[str] = _DEFAULT_EXCLUDES
106
+ ) -> bool:
109
107
  """
110
108
  Return True if the file at full_path should be excluded based on
111
109
  the list of excluded_patterns. Patterns that include a path separator
112
110
  are applied to the full normalized path; otherwise they are matched
113
111
  against each individual component of the path.
112
+
113
+ Args:
114
+ full_path: The full path to check
115
+ excluded_patterns: List of patterns to exclude
116
+
117
+ Returns:
118
+ True if the path should be excluded, False otherwise
114
119
  """
115
120
  norm_path = os.path.normpath(full_path)
116
121
  path_parts = norm_path.split(os.sep)
122
+
117
123
  for pat in excluded_patterns:
118
124
  # If the pattern seems intended for full path matching (contains a separator)
119
125
  if os.sep in pat or "/" in pat:
@@ -123,30 +129,382 @@ def _should_exclude(full_path: str, excluded_patterns: list[str]) -> bool:
123
129
  # Otherwise check each part of the path
124
130
  if any(fnmatch.fnmatch(part, pat) for part in path_parts):
125
131
  return True
132
+ # Also check the filename against the pattern
133
+ if os.path.isfile(full_path) and fnmatch.fnmatch(
134
+ os.path.basename(full_path), pat
135
+ ):
136
+ return True
137
+
126
138
  return False
127
139
 
128
140
 
129
- def read_text_file(file: str) -> str:
130
- """Read a text file and return a string containing the file content."""
131
- return read_file(os.path.abspath(file))
141
+ def read_from_file(
142
+ path: str,
143
+ start_line: Optional[int] = None,
144
+ end_line: Optional[int] = None,
145
+ auto_truncate: bool = False,
146
+ ) -> str:
147
+ """
148
+ Read the contents of a file at the specified path.
149
+
150
+ Args:
151
+ path: The path of the file to read (relative to the current working directory)
152
+ start_line: The starting line number to read from (1-based).
153
+ If not provided, starts from the beginning.
154
+ end_line: The ending line number to read to (1-based, inclusive).
155
+ If not provided, reads to the end.
156
+ auto_truncate: Whether to automatically truncate large files when start_line
157
+ and end_line are not specified. If true and the file exceeds a certain
158
+ line threshold, it will return a subset of lines with information about
159
+ the total line count and method definitions. Default is False for backward
160
+ compatibility, but setting to True is recommended for large files.
132
161
 
162
+ Returns:
163
+ A string containing the file content, with line numbers prefixed to each line.
164
+ For truncated files, includes summary information.
165
+ """
166
+ try:
167
+ abs_path = os.path.abspath(path)
133
168
 
134
- def write_text_file(file: str, content: str):
135
- """Write content to a text file"""
136
- return write_file(os.path.abspath(file), content)
169
+ # Read the entire file content
170
+ content = _read_file(abs_path)
171
+ lines = content.splitlines()
172
+ total_lines = len(lines)
137
173
 
174
+ # Determine if we should truncate
175
+ should_truncate = (
176
+ auto_truncate
177
+ and start_line is None
178
+ and end_line is None
179
+ and total_lines > _MAX_LINES_BEFORE_TRUNCATION
180
+ )
138
181
 
139
- def read_all_files(
140
- directory: str = ".",
141
- included_patterns: list[str] = _INCLUDED_PATTERNS,
142
- excluded_patterns: list[str] = _EXCLUDED_PATTERNS,
143
- ) -> list[str]:
144
- """Read all files in a directory that match any of the included glob patterns
145
- and do not match any of the excluded glob patterns.
146
- Patterns are evaluated using glob-style matching.
147
- """
148
- files = list_files(directory, included_patterns, excluded_patterns)
149
- for index, file in enumerate(files):
150
- content = read_text_file(file)
151
- files[index] = f"# {file}\n```\n{content}\n```"
152
- return files
182
+ # Adjust line indices (convert from 1-based to 0-based)
183
+ start_idx = (start_line - 1) if start_line is not None else 0
184
+ end_idx = end_line if end_line is not None else total_lines
185
+
186
+ # Validate indices
187
+ if start_idx < 0:
188
+ start_idx = 0
189
+ if end_idx > total_lines:
190
+ end_idx = total_lines
191
+
192
+ if should_truncate:
193
+ # Find method definitions and their line ranges
194
+ method_info = _find_method_definitions(lines)
195
+
196
+ # Create a truncated view with method definitions
197
+ result_lines = []
198
+
199
+ # Add file info header
200
+ result_lines.append(f"File: {path} (truncated, {total_lines} lines total)")
201
+ result_lines.append("")
202
+
203
+ # Add beginning of file (first 100 lines)
204
+ first_chunk = min(100, total_lines // 3)
205
+ for i in range(first_chunk):
206
+ result_lines.append(f"{i+1} | {lines[i]}")
207
+
208
+ result_lines.append("...")
209
+ omitted_msg = (
210
+ f"[{first_chunk+1} - {total_lines-100}] Lines omitted for brevity"
211
+ )
212
+ result_lines.append(omitted_msg)
213
+ result_lines.append("...")
214
+
215
+ # Add end of file (last 100 lines)
216
+ for i in range(max(first_chunk, total_lines - 100), total_lines):
217
+ result_lines.append(f"{i+1} | {lines[i]}")
218
+
219
+ # Add method definitions summary
220
+ if method_info:
221
+ result_lines.append("")
222
+ result_lines.append("Method definitions found:")
223
+ for method in method_info:
224
+ method_line = (
225
+ f"- {method['name']} "
226
+ f"(lines {method['start_line']}-{method['end_line']})"
227
+ )
228
+ result_lines.append(method_line)
229
+
230
+ return "\n".join(result_lines)
231
+ else:
232
+ # Return the requested range with line numbers
233
+ result_lines = []
234
+ for i in range(start_idx, end_idx):
235
+ result_lines.append(f"{i+1} | {lines[i]}")
236
+
237
+ return "\n".join(result_lines)
238
+
239
+ except Exception as e:
240
+ return f"Error reading file {path}: {str(e)}"
241
+
242
+
243
+ def _find_method_definitions(lines: List[str]) -> List[Dict[str, Union[str, int]]]:
244
+ """
245
+ Find method definitions in the given lines of code.
246
+
247
+ Args:
248
+ lines: List of code lines to analyze
249
+
250
+ Returns:
251
+ List of dictionaries containing method name, start line, and end line
252
+ """
253
+ method_info = []
254
+
255
+ # Simple regex patterns for common method/function definitions
256
+ patterns = [
257
+ # Python
258
+ r"^\s*def\s+([a-zA-Z0-9_]+)\s*\(",
259
+ # JavaScript/TypeScript
260
+ r"^\s*(function\s+([a-zA-Z0-9_]+)|([a-zA-Z0-9_]+)\s*=\s*function|"
261
+ r"\s*([a-zA-Z0-9_]+)\s*\([^)]*\)\s*{)",
262
+ # Java/C#/C++
263
+ r"^\s*(?:public|private|protected|static|final|abstract|synchronized)?"
264
+ r"\s+(?:[a-zA-Z0-9_<>[\]]+\s+)+([a-zA-Z0-9_]+)\s*\(",
265
+ ]
266
+
267
+ current_method = None
268
+
269
+ for i, line in enumerate(lines):
270
+ # Check if this line starts a method definition
271
+ for pattern in patterns:
272
+ match = re.search(pattern, line)
273
+ if match:
274
+ # If we were tracking a method, close it
275
+ if current_method:
276
+ current_method["end_line"] = i
277
+ method_info.append(current_method)
278
+
279
+ # Start tracking a new method
280
+ method_name = next(
281
+ group for group in match.groups() if group is not None
282
+ )
283
+ current_method = {
284
+ "name": method_name,
285
+ "start_line": i + 1, # 1-based line numbering
286
+ "end_line": None,
287
+ }
288
+ break
289
+
290
+ # Check for method end (simplistic approach)
291
+ if current_method and line.strip() == "}":
292
+ current_method["end_line"] = i + 1
293
+ method_info.append(current_method)
294
+ current_method = None
295
+
296
+ # Close any open method at the end of the file
297
+ if current_method:
298
+ current_method["end_line"] = len(lines)
299
+ method_info.append(current_method)
300
+
301
+ return method_info
302
+
303
+
304
+ def write_to_file(path: str, content: str) -> bool:
305
+ """
306
+ Write content to a file at the specified path.
307
+
308
+ Args:
309
+ path: The path of the file to write to (relative to the current working directory)
310
+ content: The content to write to the file
311
+
312
+ Returns:
313
+ True if successful, False otherwise
314
+ """
315
+ try:
316
+ # Ensure directory exists
317
+ directory = os.path.dirname(os.path.abspath(path))
318
+ if directory and not os.path.exists(directory):
319
+ os.makedirs(directory, exist_ok=True)
320
+
321
+ # Write the content
322
+ _write_file(os.path.abspath(path), content)
323
+ return True
324
+ except Exception as e:
325
+ print(f"Error writing to file {path}: {str(e)}")
326
+ return False
327
+
328
+
329
+ def search_files(
330
+ path: str, regex: str, file_pattern: Optional[str] = None, context_lines: int = 2
331
+ ) -> str:
332
+ """
333
+ Search for a regex pattern across files in a specified directory.
334
+
335
+ Args:
336
+ path: The path of the directory to search in
337
+ (relative to the current working directory)
338
+ regex: The regular expression pattern to search for
339
+ file_pattern: Optional glob pattern to filter files.
340
+ Default is None, which includes all files. Only specify this if you need to
341
+ filter to specific file types (but in most cases, leaving as None is better).
342
+ context_lines: Number of context lines to show before and after each match.
343
+ Default is 2, which provides good context without overwhelming output.
344
+
345
+ Returns:
346
+ A string containing the search results with context
347
+ """
348
+ try:
349
+ # Compile the regex pattern
350
+ pattern = re.compile(regex)
351
+
352
+ # Get the list of files to search
353
+ files = list_files(path, recursive=True, file_pattern=file_pattern)
354
+
355
+ results = []
356
+ match_count = 0
357
+
358
+ for file_path in files:
359
+ try:
360
+ with open(file_path, "r", encoding="utf-8", errors="replace") as f:
361
+ lines = f.readlines()
362
+
363
+ file_matches = []
364
+
365
+ for i, line in enumerate(lines):
366
+ if pattern.search(line):
367
+ # Determine context range
368
+ start = max(0, i - context_lines)
369
+ end = min(len(lines), i + context_lines + 1)
370
+
371
+ # Add file header if this is the first match in the file
372
+ if not file_matches:
373
+ file_matches.append(
374
+ f"\n{'-' * 80}\n{file_path}\n{'-' * 80}"
375
+ )
376
+
377
+ # Add separator if this isn't the first match and isn't contiguous
378
+ # with previous
379
+ if (
380
+ file_matches
381
+ and file_matches[-1] != f"Line {start+1}-{end}:"
382
+ ):
383
+ file_matches.append(f"\nLine {start+1}-{end}:")
384
+
385
+ # Add context lines
386
+ for j in range(start, end):
387
+ prefix = ">" if j == i else " "
388
+ file_matches.append(f"{prefix} {j+1}: {lines[j].rstrip()}")
389
+
390
+ match_count += 1
391
+
392
+ if file_matches:
393
+ results.extend(file_matches)
394
+
395
+ except Exception as e:
396
+ results.append(f"Error reading {file_path}: {str(e)}")
397
+
398
+ if not results:
399
+ return f"No matches found for pattern '{regex}' in {path}"
400
+
401
+ # Count unique files by counting headers
402
+ file_count = len([r for r in results if r.startswith("-" * 80)])
403
+ summary = f"Found {match_count} matches in {file_count} files:\n"
404
+ return summary + "\n".join(results)
405
+
406
+ except Exception as e:
407
+ return f"Error searching files: {str(e)}"
408
+
409
+
410
+ def apply_diff(path: str, diff: str, start_line: int, end_line: int) -> bool:
411
+ """
412
+ Replace existing code using a search and replace block.
413
+
414
+ Args:
415
+ path: The path of the file to modify (relative to the current working directory)
416
+ diff: The search/replace block defining the changes
417
+ start_line: The line number where the search block starts (1-based)
418
+ end_line: The line number where the search block ends (1-based)
419
+
420
+ Returns:
421
+ True if successful, False otherwise
422
+
423
+ The diff format should be:
424
+ ```
425
+ <<<<<<< SEARCH
426
+ [exact content to find including whitespace]
427
+ =======
428
+ [new content to replace with]
429
+ >>>>>>> REPLACE
430
+ ```
431
+ """
432
+ try:
433
+ # Read the file
434
+ abs_path = os.path.abspath(path)
435
+ content = _read_file(abs_path)
436
+ lines = content.splitlines()
437
+
438
+ # Validate line numbers
439
+ if start_line < 1 or end_line > len(lines) or start_line > end_line:
440
+ print(
441
+ f"Invalid line range: {start_line}-{end_line} (file has {len(lines)} lines)"
442
+ )
443
+ return False
444
+
445
+ # Parse the diff
446
+ search_content, replace_content = _parse_diff(diff)
447
+ if search_content is None or replace_content is None:
448
+ print("Invalid diff format")
449
+ return False
450
+
451
+ # Extract the content to be replaced
452
+ original_content = "\n".join(lines[start_line - 1 : end_line])
453
+
454
+ # Verify the search content matches
455
+ if original_content != search_content:
456
+ print("Search content does not match the specified lines in the file")
457
+ return False
458
+
459
+ # Replace the content
460
+ new_lines = (
461
+ lines[: start_line - 1] + replace_content.splitlines() + lines[end_line:]
462
+ )
463
+ new_content = "\n".join(new_lines)
464
+
465
+ # Write the modified content back to the file
466
+ _write_file(abs_path, new_content)
467
+ return True
468
+
469
+ except Exception as e:
470
+ print(f"Error applying diff to {path}: {str(e)}")
471
+ return False
472
+
473
+
474
+ def _parse_diff(diff: str) -> Tuple[Optional[str], Optional[str]]:
475
+ """
476
+ Parse a diff string to extract search and replace content.
477
+
478
+ Args:
479
+ diff: The diff string to parse
480
+
481
+ Returns:
482
+ A tuple of (search_content, replace_content), or (None, None) if parsing fails
483
+ """
484
+ try:
485
+ # Split the diff into sections
486
+ search_marker = "<<<<<<< SEARCH"
487
+ separator = "======="
488
+ replace_marker = ">>>>>>> REPLACE"
489
+
490
+ if (
491
+ search_marker not in diff
492
+ or separator not in diff
493
+ or replace_marker not in diff
494
+ ):
495
+ return None, None
496
+
497
+ # Extract search content
498
+ search_start = diff.index(search_marker) + len(search_marker)
499
+ search_end = diff.index(separator)
500
+ search_content = diff[search_start:search_end].strip()
501
+
502
+ # Extract replace content
503
+ replace_start = diff.index(separator) + len(separator)
504
+ replace_end = diff.index(replace_marker)
505
+ replace_content = diff[replace_start:replace_end].strip()
506
+
507
+ return search_content, replace_content
508
+
509
+ except Exception:
510
+ return None, None