coauthor 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of coauthor might be problematic. Click here for more details.

@@ -0,0 +1,73 @@
1
+ import yaml
2
+ import os
3
+
4
+
5
+ def read_config(file_path, logger=None):
6
+ if logger:
7
+ logger.info(f"Reading configuration from {file_path}")
8
+ with open(file_path, "r") as file:
9
+ return yaml.safe_load(file)
10
+
11
+
12
+ def get_config_path(config_filename=".coauthor.yml", search_dir=os.getcwd()):
13
+ # Search current directory and traverse upwards to all parent directories
14
+ traversed_paths = []
15
+ while True:
16
+ potential_path = os.path.join(search_dir, config_filename)
17
+ if os.path.exists(potential_path):
18
+ return potential_path, traversed_paths
19
+ # Append to traversed paths for logging purposes
20
+ traversed_paths.append(search_dir)
21
+ # Move up one directory
22
+ parent_dir = os.path.dirname(search_dir)
23
+ if parent_dir == search_dir: # If reached the root directory
24
+ break
25
+ search_dir = parent_dir
26
+
27
+ # Check in user's home directory
28
+ home_dir = os.path.expanduser("~")
29
+ home_path = os.path.join(home_dir, config_filename)
30
+ if os.path.exists(home_path):
31
+ traversed_paths.append(home_dir)
32
+ return home_path, traversed_paths
33
+
34
+ # If no config file is found, return None and the paths checked
35
+ return None, traversed_paths
36
+
37
+
38
+ def get_config(path=None, logger=None, config_filename=".coauthor.yml", search_dir=os.getcwd(), args=None):
39
+ config = {}
40
+ config_path = None
41
+ if args and "config_path" in args:
42
+ config_path = args.config_path
43
+ if not config_path:
44
+ if path:
45
+ # config_path = read_config(path, logger)
46
+ config_path = path
47
+ else:
48
+ config_path, searched_paths = get_config_path(config_filename, search_dir)
49
+ if not config_path:
50
+ if logger:
51
+ logger.warning(f"Configuration file not found. Searched directories: {', '.join(searched_paths)}")
52
+ if config_path:
53
+ config = read_config(config_path, logger)
54
+ if "watch_directory" not in config:
55
+ config["watch_directory"] = os.getcwd()
56
+ if "callback" not in config and "callbacks" not in config:
57
+ config["callback"] = "process_file_with_openai_agent"
58
+ if "agent" not in config or config["agent"] is None:
59
+ config["agent"] = {}
60
+ if not "api_key_var" in config["agent"]:
61
+ config["agent"]["api_key_var"] = "OPENAI_API_KEY"
62
+ if not "api_url_var" in config["agent"]:
63
+ config["agent"]["api_url_var"] = "OPENAI_API_URL"
64
+ if not "model" in config["agent"]:
65
+ config["agent"]["model"] = "openai/gpt-4o"
66
+ return config
67
+
68
+
69
+ def get_jinja_config(config):
70
+ if "jinja" in config:
71
+ return config["jinja"]
72
+ config_jinja = {"search_path": ".coauthor/templates"}
73
+ return config_jinja
coauthor/utils/git.py ADDED
@@ -0,0 +1,47 @@
1
+ import subprocess
2
+ import os
3
+
4
+
5
+ def get_git_diff(file_path):
6
+ """
7
+ Given the path of a file, this function checks if the file is part of a Git repository.
8
+ If it is, the function returns the diff of the file showing outstanding changes.
9
+ If the file is not in a Git repository, it returns None.
10
+
11
+ :param file_path: The path to the file.
12
+ :return: The diff of the file or None if not in a Git repository.
13
+ """
14
+
15
+ # Check if the file is part of a git repository
16
+ # try:
17
+ # `git rev-parse --is-inside-work-tree` returns `true` if inside repository
18
+ result = subprocess.run(
19
+ ["git", "rev-parse", "--is-inside-work-tree"],
20
+ cwd=os.path.dirname(file_path),
21
+ stdout=subprocess.PIPE,
22
+ stderr=subprocess.PIPE,
23
+ )
24
+
25
+ # if result.returncode != 0 or result.stdout.strip() != b"true":
26
+ # # Not a git repository
27
+ # return ""
28
+
29
+ # Get the diff of the file
30
+ diff_result = subprocess.run(
31
+ ["git", "diff", file_path], cwd=os.path.dirname(file_path), stdout=subprocess.PIPE, stderr=subprocess.PIPE
32
+ )
33
+
34
+ # if diff_result.returncode != 0:
35
+ # # Could not get a diff, possibly an error occurred
36
+ # return ""
37
+
38
+ return diff_result.stdout.decode("utf-8")
39
+
40
+ # except Exception as e:
41
+ # # Handle unexpected errors, possibly log this
42
+ # return ""
43
+
44
+
45
+ # Example usage:
46
+ # diff = get_git_diff('/path/to/your/file')
47
+ # print(diff)
@@ -0,0 +1,101 @@
1
+ import jinja2
2
+ import os
3
+ from coauthor.utils.config import get_jinja_config
4
+ from coauthor.utils.git import get_git_diff
5
+ from coauthor.utils.jinja_filters import select_task, get_task_attribute
6
+
7
+
8
+ def render_template_to_file(task, template, path, config, logger):
9
+ content = render_template(task, template, config, logger)
10
+ with open(path, "w", encoding="utf-8") as file:
11
+ file.write(content)
12
+
13
+
14
+ def search_path_directories(search_path):
15
+ # Return a list of directory search_path and all its subdirectories
16
+ directories = []
17
+ for root, dirs, files in os.walk(search_path):
18
+ directories.append(root)
19
+ return directories
20
+
21
+
22
+ def template_exists(task, template, config, logger):
23
+ config = get_jinja_config(config)
24
+ search_paths = search_path_directories(config["search_path"])
25
+ template_loader = jinja2.FileSystemLoader(searchpath=search_paths)
26
+ templates = template_loader.list_templates()
27
+ if template in templates:
28
+ return True
29
+ logger.debug(f"Template {template} not found!")
30
+ return False
31
+
32
+
33
+ def render_template(task, template_path, config, logger):
34
+ logger.debug(f"Render template {template_path} for task {task['id']}")
35
+ jinja_config = get_jinja_config(config)
36
+ search_paths = search_path_directories(jinja_config["search_path"])
37
+ template_loader = jinja2.FileSystemLoader(searchpath=search_paths)
38
+ templates = template_loader.list_templates()
39
+ logger.debug(f"templates: {templates}")
40
+ if "custom_delimiters" in jinja_config:
41
+ logger.debug("Creating Jinja environment using custom delimiters")
42
+ custom_delimiters = jinja_config["custom_delimiters"]
43
+ template_env = jinja2.Environment(
44
+ loader=template_loader,
45
+ block_start_string=custom_delimiters.get("block_start_string", "{%"),
46
+ block_end_string=custom_delimiters.get("block_end_string", "%}"),
47
+ variable_start_string=custom_delimiters.get("variable_start_string", "{{"),
48
+ variable_end_string=custom_delimiters.get("variable_end_string", "}}"),
49
+ comment_start_string=custom_delimiters.get("comment_start_string", "{#"),
50
+ comment_end_string=custom_delimiters.get("comment_end_string", "#}"),
51
+ )
52
+ else:
53
+ template_env = jinja2.Environment(loader=template_loader)
54
+ template_env.filters["include_file_content"] = include_file_content
55
+ template_env.filters["get_git_diff"] = get_git_diff
56
+ template_env.filters["file_exists"] = file_exists
57
+ template_env.filters["select_task"] = select_task
58
+ template_env.filters["get_task_attribute"] = get_task_attribute
59
+
60
+ logger.debug(f"Get Jinja template: {template_path}")
61
+ template = template_env.get_template(template_path)
62
+ context = {"config": config, "task": task, "workflow": config["current-workflow"]}
63
+ return template.render(context)
64
+
65
+
66
+ def render_content(task, template_string, config, logger):
67
+ logger.debug(f"Render content for task {task['id']}")
68
+ jinja_config = get_jinja_config(config)
69
+
70
+ if "custom_delimiters" in jinja_config:
71
+ logger.debug("Creating Jinja environment using custom delimiters")
72
+ custom_delimiters = jinja_config["custom_delimiters"]
73
+ template_env = jinja2.Environment(
74
+ block_start_string=custom_delimiters.get("block_start_string", "{%"),
75
+ block_end_string=custom_delimiters.get("block_end_string", "%}"),
76
+ variable_start_string=custom_delimiters.get("variable_start_string", "{{"),
77
+ variable_end_string=custom_delimiters.get("variable_end_string", "}}"),
78
+ comment_start_string=custom_delimiters.get("comment_start_string", "{#"),
79
+ comment_end_string=custom_delimiters.get("comment_end_string", "#}"),
80
+ )
81
+ else:
82
+ template_env = jinja2.Environment()
83
+
84
+ template_env.filters["include_file_content"] = include_file_content
85
+ template_env.filters["get_git_diff"] = get_git_diff
86
+ template_env.filters["file_exists"] = file_exists
87
+ template_env.filters["select_task"] = select_task
88
+ template_env.filters["get_task_attribute"] = get_task_attribute
89
+
90
+ template = template_env.from_string(template_string)
91
+ context = {"config": config}
92
+ return template.render(context)
93
+
94
+
95
+ def include_file_content(path):
96
+ with open(path, "r", encoding="utf-8") as file:
97
+ return file.read()
98
+
99
+
100
+ def file_exists(path):
101
+ return os.path.exists(path)
@@ -0,0 +1,53 @@
1
+ import logging
2
+ from coauthor.utils.logger import Logger
3
+ import os
4
+
5
+
6
+ class TaskNotFoundError(Exception):
7
+ """Exception raised when a task with the specified ID is not found."""
8
+
9
+ def __init__(self, task_id):
10
+ self.task_id = task_id
11
+ super().__init__(f"Task with ID {self.task_id} not found.")
12
+
13
+
14
+ class AttributeNotFoundError(Exception):
15
+ """Exception raised when a specified attribute is not found in a task."""
16
+
17
+ def __init__(self, attribute):
18
+ self.attribute = attribute
19
+ super().__init__(f"Attribute '{self.attribute}' not found in task.")
20
+
21
+
22
+ def select_task(config, task_id):
23
+ """
24
+ Select and return a task from a list of tasks based on its id.
25
+
26
+ :param tasks: a list of task dictionaries
27
+ :param id: the id of the task to find
28
+ :return: The task dictionary with the matching id
29
+ :raises TaskNotFoundError: If no task with the given id is found
30
+ """
31
+ workflow = config["current-workflow"]
32
+ for task in workflow["tasks"]:
33
+ if task.get("id") == task_id:
34
+ return task
35
+ raise TaskNotFoundError(task_id)
36
+
37
+
38
+ def get_task_attribute(config, id, attribute):
39
+ """
40
+ Get a specific attribute from a task identified by id.
41
+
42
+ :param tasks: a list of task dictionaries
43
+ :param id: the id of the task from which to get the attribute
44
+ :param attribute: the attribute to retrieve from the task
45
+ :return: The value of the specified attribute in the task
46
+ :raises TaskNotFoundError: If no task with the given id is found
47
+ :raises AttributeNotFoundError: If the attribute does not exist in the task
48
+ """
49
+ task = select_task(config, id)
50
+ if attribute in task:
51
+ return task[attribute]
52
+ else:
53
+ raise AttributeNotFoundError(attribute)
@@ -0,0 +1,43 @@
1
+ import logging
2
+ import os
3
+
4
+
5
+ class Logger:
6
+ def __init__(self, name, level=logging.INFO, log_file=None):
7
+ if not log_file:
8
+ log_file = os.path.join(os.getcwd(), "coauthor.log")
9
+ self.logger = logging.getLogger(name)
10
+ self.logger.setLevel(level)
11
+
12
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
13
+
14
+ self.log_file = log_file
15
+ file_handler = logging.FileHandler(log_file)
16
+ file_handler.setLevel(level)
17
+ file_handler.setFormatter(formatter)
18
+ self.logger.addHandler(file_handler)
19
+
20
+ console_handler = logging.StreamHandler()
21
+ console_handler.setLevel(level)
22
+ console_handler.setFormatter(formatter)
23
+ self.logger.addHandler(console_handler)
24
+
25
+ def clean_log_file(self):
26
+ if self.log_file and os.path.exists(self.log_file):
27
+ os.remove(self.log_file)
28
+ print(f"Log file '{self.log_file}' has been removed.")
29
+
30
+ def debug(self, msg):
31
+ self.logger.debug(msg)
32
+
33
+ def info(self, msg):
34
+ self.logger.info(msg)
35
+
36
+ def warning(self, msg):
37
+ self.logger.warning(msg)
38
+
39
+ def error(self, msg):
40
+ self.logger.error(msg)
41
+
42
+ # def exception(self, msg):
43
+ # self.logger.exception(msg)
@@ -0,0 +1,127 @@
1
+ import os
2
+ import re
3
+
4
+
5
+ def file_submit_to_ai(config, logger):
6
+ """
7
+ Return the content of the file if it requires AI processing based on path and content matching criteria.
8
+
9
+ Parameters:
10
+ - path (str): The path of the file to be processed.
11
+ - task (dict): Configuration dictionary which contains optional keys 'regex_path' and
12
+ 'regex_content' for path and content validation respectively.
13
+
14
+ Returns:
15
+ - str: The content of the file if both path and content match the specified regex patterns in
16
+ the task.
17
+ - None: If the content doesn't meet the regex requirements.
18
+
19
+ The function checks if the file path and content match the given regex patterns in the task.
20
+ """
21
+ path = config["current-task"]["path-modify-event"]
22
+ path_match = file_path_match(config, logger)
23
+ content_match = file_content_match(config, logger)
24
+ content = None
25
+ if path_match:
26
+ if content_match:
27
+ with open(path, "r", encoding="utf-8") as file1:
28
+ content = file1.read()
29
+ logger.debug(f"file_submit_to_ai: path_match: {path_match}, content_match: {content_match}")
30
+ return content
31
+
32
+
33
+ def regex_content_match(content, regex, logger):
34
+ """
35
+ Determine if the given content matches any provided regex pattern.
36
+
37
+ Parameters:
38
+ - content (str): The content to be validated against the regex.
39
+ - regex (str or list): A single regex pattern or list of regex patterns to match against the content.
40
+
41
+ Returns:
42
+ - bool: True if the content matches any of the regex patterns, otherwise False.
43
+
44
+ The function checks the content against each regex pattern for a match.
45
+ """
46
+ if isinstance(regex, list):
47
+ for pattern in regex:
48
+ if isinstance(pattern, list):
49
+ sub_patterns_all_match = True
50
+ for sub_pattern in pattern:
51
+ if not re.match(sub_pattern, content, re.IGNORECASE | re.DOTALL):
52
+ logger.debug(f"regex_content_match: sub_patterns_all_match: False")
53
+ sub_patterns_all_match = False
54
+ if sub_patterns_all_match:
55
+ logger.debug(f"regex_content_match: match True for pattern: {pattern}, content: {content}")
56
+ return True
57
+ else:
58
+ if re.match(pattern, content, re.IGNORECASE | re.DOTALL): # TypeError: unhashable type: 'list'
59
+ logger.debug(f"regex_content_match: match: True for pattern: {pattern}, content: {content}")
60
+ return True
61
+ elif isinstance(regex, str):
62
+ if re.match(regex, content, re.IGNORECASE | re.DOTALL):
63
+ logger.debug(f"regex_content_match: match: True for regex: {regex}, content: {content}")
64
+ return True
65
+ return False
66
+
67
+
68
+ def file_path_match(config, logger):
69
+ """
70
+ Check if the file path matches the regex pattern specified in the workflow.
71
+
72
+ Parameters:
73
+ - path (str): The file path to validate.
74
+ - workflow (dict): Configuration dictionary containing the 'path_patterns' for validation.
75
+
76
+ Returns:
77
+ - bool: True if the path matches the regex or if no 'path_patterns' is specified, otherwise False.
78
+ """
79
+ workflow = config["current-workflow"]
80
+ path = config["current-task"]["path-modify-event"]
81
+ if "path_patterns" in workflow:
82
+ return regex_content_match(path, workflow["path_patterns"], logger)
83
+ return True
84
+
85
+
86
+ def file_content_match(config, logger):
87
+ """
88
+ Check if the file content matches the regex pattern specified in the workflow configuration.
89
+
90
+ Parameters:
91
+ - path (str): The path of the file whose content is to be validated.
92
+ - workflow (dict): Configuration dictionary containing the 'content_patterns' for content validation.
93
+
94
+ Returns:
95
+ - bool: True if the content matches the regex, otherwise False.
96
+ """
97
+ workflow = config["current-workflow"]
98
+ path = config["current-task"]["path-modify-event"]
99
+ if not os.path.exists(path):
100
+ logger.warning(f"file_content_match: path {path} does not exist!")
101
+ return False
102
+
103
+ if "content_patterns" in workflow:
104
+ logger.debug(f"file_content_match: content_patterns: {workflow['content_patterns']}")
105
+ with open(path, "r", encoding="utf-8") as file:
106
+ content = file.read()
107
+ if regex_content_match(content, workflow["content_patterns"], logger):
108
+ return True
109
+ else:
110
+ logger.warning(f"file_content_match: workflow has no content_patterns! So content match is false! ")
111
+ logger.debug(f"workflow: {workflow}")
112
+ return False
113
+
114
+
115
+ def path_new_replace(path, search, replace):
116
+ """
117
+ Replace a portion of the path with a new string.
118
+
119
+ Parameters:
120
+ - path (str): The original path string.
121
+ - search (str): The substring to search for in the path.
122
+ - replace (str): The string to replace the search substring with.
123
+
124
+ Returns:
125
+ - str: The modified path with the substitutions made.
126
+ """
127
+ return path.replace(search, replace)
@@ -0,0 +1,35 @@
1
+ def get_all_watch_directories_from_workflows(config, logger):
2
+ return get_all_directories_from_workflows(config, logger, "watch")
3
+
4
+
5
+ def get_all_scan_directories_from_workflows(config, logger):
6
+ return get_all_directories_from_workflows(config, logger, "scan")
7
+
8
+
9
+ def get_all_directories_from_workflows(config, logger, watch_or_scan_key):
10
+ dirs = set() # Use a set to ensure that directories are unique
11
+ for workflow in config["workflows"]:
12
+ if "filesystem" in workflow[watch_or_scan_key]:
13
+ dirs.update(workflow[watch_or_scan_key]["filesystem"]["paths"])
14
+ return list(dirs) # Convert the set back to a list for the return value
15
+
16
+
17
+ def get_workflows_that_watch(config, logger):
18
+ wtw = []
19
+ workflows = config["workflows"]
20
+ for workflow in workflows:
21
+ if "watch" in workflow:
22
+ wtw.append(workflow)
23
+ logger.debug(f"get_workflows_that_watch: workflows_that_watch: {wtw}")
24
+ return wtw
25
+
26
+
27
+ def get_workflows_that_scan(config, logger):
28
+ workflows_that_scan = []
29
+ workflows = config["workflows"]
30
+
31
+ for workflow in workflows:
32
+ if "scan" in workflow:
33
+ workflows_that_scan.append(workflow)
34
+ logger.debug(f"get_workflows_that_scan: workflows_that_scan: {workflows_that_scan}")
35
+ return workflows_that_scan
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE