camel-ai 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +114 -23
- camel/configs/__init__.py +6 -4
- camel/configs/base_config.py +21 -0
- camel/configs/gemini_config.py +17 -9
- camel/configs/qwen_config.py +91 -0
- camel/configs/samba_config.py +1 -38
- camel/configs/yi_config.py +58 -0
- camel/generators.py +93 -0
- camel/interpreters/docker_interpreter.py +5 -0
- camel/interpreters/ipython_interpreter.py +2 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/apify_reader.py +223 -0
- camel/memories/agent_memories.py +24 -1
- camel/messages/base.py +38 -0
- camel/models/__init__.py +4 -0
- camel/models/model_factory.py +6 -0
- camel/models/qwen_model.py +139 -0
- camel/models/samba_model.py +1 -1
- camel/models/yi_model.py +138 -0
- camel/prompts/image_craft.py +8 -0
- camel/prompts/video_description_prompt.py +8 -0
- camel/retrievers/vector_retriever.py +5 -1
- camel/societies/role_playing.py +29 -18
- camel/societies/workforce/base.py +7 -1
- camel/societies/workforce/task_channel.py +10 -0
- camel/societies/workforce/utils.py +6 -0
- camel/societies/workforce/worker.py +2 -0
- camel/storages/vectordb_storages/qdrant.py +147 -24
- camel/tasks/task.py +15 -0
- camel/terminators/base.py +4 -0
- camel/terminators/response_terminator.py +1 -0
- camel/terminators/token_limit_terminator.py +1 -0
- camel/toolkits/__init__.py +4 -1
- camel/toolkits/base.py +9 -0
- camel/toolkits/data_commons_toolkit.py +360 -0
- camel/toolkits/function_tool.py +174 -7
- camel/toolkits/github_toolkit.py +175 -176
- camel/toolkits/google_scholar_toolkit.py +36 -7
- camel/toolkits/notion_toolkit.py +279 -0
- camel/toolkits/search_toolkit.py +164 -36
- camel/types/enums.py +88 -0
- camel/types/unified_model_type.py +10 -0
- camel/utils/commons.py +2 -1
- camel/utils/constants.py +2 -0
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/METADATA +129 -79
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/RECORD +49 -42
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/LICENSE +0 -0
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/WHEEL +0 -0
camel/toolkits/github_toolkit.py
CHANGED
|
@@ -12,88 +12,15 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
|
|
15
|
+
import logging
|
|
15
16
|
import os
|
|
16
|
-
from
|
|
17
|
-
from typing import List, Optional
|
|
18
|
-
|
|
19
|
-
from pydantic import BaseModel
|
|
17
|
+
from typing import Dict, List, Literal, Optional, Union
|
|
20
18
|
|
|
21
19
|
from camel.toolkits import FunctionTool
|
|
22
20
|
from camel.toolkits.base import BaseToolkit
|
|
23
21
|
from camel.utils import dependencies_required
|
|
24
22
|
|
|
25
|
-
|
|
26
|
-
class GithubIssue(BaseModel):
|
|
27
|
-
r"""Represents a GitHub issue.
|
|
28
|
-
|
|
29
|
-
Attributes:
|
|
30
|
-
title (str): The title of the issue.
|
|
31
|
-
body (str): The body/content of the issue.
|
|
32
|
-
number (int): The issue number.
|
|
33
|
-
file_path (str): The path of the file associated with the issue.
|
|
34
|
-
file_content (str): The content of the file associated with the issue.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
title: str
|
|
38
|
-
body: str
|
|
39
|
-
number: int
|
|
40
|
-
file_path: str
|
|
41
|
-
file_content: str
|
|
42
|
-
|
|
43
|
-
def __str__(self) -> str:
|
|
44
|
-
r"""Returns a string representation of the issue.
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
str: A string containing the title, body, number, file path, and
|
|
48
|
-
file content of the issue.
|
|
49
|
-
"""
|
|
50
|
-
return (
|
|
51
|
-
f"Title: {self.title}\n"
|
|
52
|
-
f"Body: {self.body}\n"
|
|
53
|
-
f"Number: {self.number}\n"
|
|
54
|
-
f"File Path: {self.file_path}\n"
|
|
55
|
-
f"File Content: {self.file_content}"
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class GithubPullRequestDiff(BaseModel):
|
|
60
|
-
r"""Represents a single diff of a pull request on Github.
|
|
61
|
-
|
|
62
|
-
Attributes:
|
|
63
|
-
filename (str): The name of the file that was changed.
|
|
64
|
-
patch (str): The diff patch for the file.
|
|
65
|
-
"""
|
|
66
|
-
|
|
67
|
-
filename: str
|
|
68
|
-
patch: str
|
|
69
|
-
|
|
70
|
-
def __str__(self) -> str:
|
|
71
|
-
r"""Returns a string representation of this diff."""
|
|
72
|
-
return f"Filename: {self.filename}\nPatch: {self.patch}"
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class GithubPullRequest(BaseModel):
|
|
76
|
-
r"""Represents a pull request on Github.
|
|
77
|
-
|
|
78
|
-
Attributes:
|
|
79
|
-
title (str): The title of the GitHub pull request.
|
|
80
|
-
body (str): The body/content of the GitHub pull request.
|
|
81
|
-
diffs (List[GithubPullRequestDiff]): A list of diffs for the pull
|
|
82
|
-
request.
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
title: str
|
|
86
|
-
body: str
|
|
87
|
-
diffs: List[GithubPullRequestDiff]
|
|
88
|
-
|
|
89
|
-
def __str__(self) -> str:
|
|
90
|
-
r"""Returns a string representation of the pull request."""
|
|
91
|
-
diff_summaries = '\n'.join(str(diff) for diff in self.diffs)
|
|
92
|
-
return (
|
|
93
|
-
f"Title: {self.title}\n"
|
|
94
|
-
f"Body: {self.body}\n"
|
|
95
|
-
f"Diffs: {diff_summaries}\n"
|
|
96
|
-
)
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
97
24
|
|
|
98
25
|
|
|
99
26
|
class GithubToolkit(BaseToolkit):
|
|
@@ -122,29 +49,14 @@ class GithubToolkit(BaseToolkit):
|
|
|
122
49
|
with GitHub. If not provided, it will be obtained using the
|
|
123
50
|
`get_github_access_token` method.
|
|
124
51
|
"""
|
|
52
|
+
from github import Auth, Github
|
|
53
|
+
|
|
125
54
|
if access_token is None:
|
|
126
55
|
access_token = self.get_github_access_token()
|
|
127
56
|
|
|
128
|
-
from github import Auth, Github
|
|
129
|
-
|
|
130
57
|
self.github = Github(auth=Auth.Token(access_token))
|
|
131
58
|
self.repo = self.github.get_repo(repo_name)
|
|
132
59
|
|
|
133
|
-
def get_tools(self) -> List[FunctionTool]:
|
|
134
|
-
r"""Returns a list of FunctionTool objects representing the
|
|
135
|
-
functions in the toolkit.
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
List[FunctionTool]: A list of FunctionTool objects
|
|
139
|
-
representing the functions in the toolkit.
|
|
140
|
-
"""
|
|
141
|
-
return [
|
|
142
|
-
FunctionTool(self.retrieve_issue_list),
|
|
143
|
-
FunctionTool(self.retrieve_issue),
|
|
144
|
-
FunctionTool(self.create_pull_request),
|
|
145
|
-
FunctionTool(self.retrieve_pull_requests),
|
|
146
|
-
]
|
|
147
|
-
|
|
148
60
|
def get_github_access_token(self) -> str:
|
|
149
61
|
r"""Retrieve the GitHub access token from environment variables.
|
|
150
62
|
|
|
@@ -165,87 +77,6 @@ class GithubToolkit(BaseToolkit):
|
|
|
165
77
|
)
|
|
166
78
|
return GITHUB_ACCESS_TOKEN
|
|
167
79
|
|
|
168
|
-
def retrieve_issue_list(self) -> List[GithubIssue]:
|
|
169
|
-
r"""Retrieve a list of open issues from the repository.
|
|
170
|
-
|
|
171
|
-
Returns:
|
|
172
|
-
A list of GithubIssue objects representing the open issues.
|
|
173
|
-
"""
|
|
174
|
-
issues = self.repo.get_issues(state='open')
|
|
175
|
-
return [
|
|
176
|
-
GithubIssue(
|
|
177
|
-
title=issue.title,
|
|
178
|
-
body=issue.body,
|
|
179
|
-
number=issue.number,
|
|
180
|
-
file_path=issue.labels[
|
|
181
|
-
0
|
|
182
|
-
].name, # we require file path to be the first label in the PR
|
|
183
|
-
file_content=self.retrieve_file_content(issue.labels[0].name),
|
|
184
|
-
)
|
|
185
|
-
for issue in issues
|
|
186
|
-
if not issue.pull_request
|
|
187
|
-
]
|
|
188
|
-
|
|
189
|
-
def retrieve_issue(self, issue_number: int) -> Optional[str]:
|
|
190
|
-
r"""Retrieves an issue from a GitHub repository.
|
|
191
|
-
|
|
192
|
-
This function retrieves an issue from a specified repository using the
|
|
193
|
-
issue number.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
issue_number (int): The number of the issue to retrieve.
|
|
197
|
-
|
|
198
|
-
Returns:
|
|
199
|
-
str: A formatted report of the retrieved issue.
|
|
200
|
-
"""
|
|
201
|
-
issues = self.retrieve_issue_list()
|
|
202
|
-
for issue in issues:
|
|
203
|
-
if issue.number == issue_number:
|
|
204
|
-
return str(issue)
|
|
205
|
-
return None
|
|
206
|
-
|
|
207
|
-
def retrieve_pull_requests(
|
|
208
|
-
self, days: int, state: str, max_prs: int
|
|
209
|
-
) -> List[str]:
|
|
210
|
-
r"""Retrieves a summary of merged pull requests from the repository.
|
|
211
|
-
The summary will be provided for the last specified number of days.
|
|
212
|
-
|
|
213
|
-
Args:
|
|
214
|
-
days (int): The number of days to retrieve merged pull requests
|
|
215
|
-
for.
|
|
216
|
-
state (str): A specific state of PRs to retrieve. Can be open or
|
|
217
|
-
closed.
|
|
218
|
-
max_prs (int): The maximum number of PRs to retrieve.
|
|
219
|
-
|
|
220
|
-
Returns:
|
|
221
|
-
List[str]: A list of merged pull request summaries.
|
|
222
|
-
"""
|
|
223
|
-
pull_requests = self.repo.get_pulls(state=state)
|
|
224
|
-
merged_prs = []
|
|
225
|
-
earliest_date: datetime = datetime.utcnow() - timedelta(days=days)
|
|
226
|
-
|
|
227
|
-
for pr in pull_requests[:max_prs]:
|
|
228
|
-
if (
|
|
229
|
-
pr.merged
|
|
230
|
-
and pr.merged_at is not None
|
|
231
|
-
and pr.merged_at.timestamp() > earliest_date.timestamp()
|
|
232
|
-
):
|
|
233
|
-
pr_details = GithubPullRequest(
|
|
234
|
-
title=pr.title, body=pr.body, diffs=[]
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
# Get files changed in the PR
|
|
238
|
-
files = pr.get_files()
|
|
239
|
-
|
|
240
|
-
for file in files:
|
|
241
|
-
diff = GithubPullRequestDiff(
|
|
242
|
-
filename=file.filename, patch=file.patch
|
|
243
|
-
)
|
|
244
|
-
pr_details.diffs.append(diff)
|
|
245
|
-
|
|
246
|
-
merged_prs.append(str(pr_details))
|
|
247
|
-
return merged_prs
|
|
248
|
-
|
|
249
80
|
def create_pull_request(
|
|
250
81
|
self,
|
|
251
82
|
file_path: str,
|
|
@@ -280,6 +111,7 @@ class GithubToolkit(BaseToolkit):
|
|
|
280
111
|
)
|
|
281
112
|
|
|
282
113
|
file = self.repo.get_contents(file_path)
|
|
114
|
+
|
|
283
115
|
from github.ContentFile import ContentFile
|
|
284
116
|
|
|
285
117
|
if isinstance(file, ContentFile):
|
|
@@ -300,6 +132,155 @@ class GithubToolkit(BaseToolkit):
|
|
|
300
132
|
else:
|
|
301
133
|
raise ValueError("PRs with multiple files aren't supported yet.")
|
|
302
134
|
|
|
135
|
+
def get_issue_list(
|
|
136
|
+
self, state: Literal["open", "closed", "all"] = "all"
|
|
137
|
+
) -> List[Dict[str, object]]:
|
|
138
|
+
r"""Retrieves all issues from the GitHub repository.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
state (Literal["open", "closed", "all"]): The state of pull
|
|
142
|
+
requests to retrieve. (default::obj: `all`)
|
|
143
|
+
Options are:
|
|
144
|
+
- "open": Retrieve only open pull requests.
|
|
145
|
+
- "closed": Retrieve only closed pull requests.
|
|
146
|
+
- "all": Retrieve all pull requests, regardless of state.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List[Dict[str, object]]: A list of dictionaries where each
|
|
150
|
+
dictionary contains the issue number and title.
|
|
151
|
+
"""
|
|
152
|
+
issues_info = []
|
|
153
|
+
issues = self.repo.get_issues(state=state)
|
|
154
|
+
|
|
155
|
+
for issue in issues:
|
|
156
|
+
issues_info.append({"number": issue.number, "title": issue.title})
|
|
157
|
+
|
|
158
|
+
return issues_info
|
|
159
|
+
|
|
160
|
+
def get_issue_content(self, issue_number: int) -> str:
|
|
161
|
+
r"""Retrieves the content of a specific issue by its number.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
issue_number (int): The number of the issue to retrieve.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
str: issues content details.
|
|
168
|
+
"""
|
|
169
|
+
try:
|
|
170
|
+
issue = self.repo.get_issue(number=issue_number)
|
|
171
|
+
return issue.body
|
|
172
|
+
except Exception as e:
|
|
173
|
+
return f"can't get Issue number {issue_number}: {e!s}"
|
|
174
|
+
|
|
175
|
+
def get_pull_request_list(
|
|
176
|
+
self, state: Literal["open", "closed", "all"] = "all"
|
|
177
|
+
) -> List[Dict[str, object]]:
|
|
178
|
+
r"""Retrieves all pull requests from the GitHub repository.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
state (Literal["open", "closed", "all"]): The state of pull
|
|
182
|
+
requests to retrieve. (default::obj: `all`)
|
|
183
|
+
Options are:
|
|
184
|
+
- "open": Retrieve only open pull requests.
|
|
185
|
+
- "closed": Retrieve only closed pull requests.
|
|
186
|
+
- "all": Retrieve all pull requests, regardless of state.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
list: A list of dictionaries where each dictionary contains the
|
|
190
|
+
pull request number and title.
|
|
191
|
+
"""
|
|
192
|
+
pull_requests_info = []
|
|
193
|
+
pull_requests = self.repo.get_pulls(state=state)
|
|
194
|
+
|
|
195
|
+
for pr in pull_requests:
|
|
196
|
+
pull_requests_info.append({"number": pr.number, "title": pr.title})
|
|
197
|
+
|
|
198
|
+
return pull_requests_info
|
|
199
|
+
|
|
200
|
+
def get_pull_request_code(self, pr_number: int) -> List[Dict[str, str]]:
|
|
201
|
+
r"""Retrieves the code changes of a specific pull request.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
pr_number (int): The number of the pull request to retrieve.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
List[Dict[str, str]]: A list of dictionaries where each dictionary
|
|
208
|
+
contains the file name and the corresponding code changes
|
|
209
|
+
(patch).
|
|
210
|
+
"""
|
|
211
|
+
# Retrieve the specific pull request
|
|
212
|
+
pr = self.repo.get_pull(number=pr_number)
|
|
213
|
+
|
|
214
|
+
# Collect the file changes from the pull request
|
|
215
|
+
files_changed = []
|
|
216
|
+
# Returns the files and their changes in the pull request
|
|
217
|
+
files = pr.get_files()
|
|
218
|
+
for file in files:
|
|
219
|
+
files_changed.append(
|
|
220
|
+
{
|
|
221
|
+
"filename": file.filename,
|
|
222
|
+
"patch": file.patch, # The code diff or changes
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return files_changed
|
|
227
|
+
|
|
228
|
+
def get_pull_request_comments(
|
|
229
|
+
self, pr_number: int
|
|
230
|
+
) -> List[Dict[str, str]]:
|
|
231
|
+
r"""Retrieves the comments from a specific pull request.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
pr_number (int): The number of the pull request to retrieve.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List[Dict[str, str]]: A list of dictionaries where each dictionary
|
|
238
|
+
contains the user ID and the comment body.
|
|
239
|
+
"""
|
|
240
|
+
# Retrieve the specific pull request
|
|
241
|
+
pr = self.repo.get_pull(number=pr_number)
|
|
242
|
+
|
|
243
|
+
# Collect the comments from the pull request
|
|
244
|
+
comments = []
|
|
245
|
+
# Returns all the comments in the pull request
|
|
246
|
+
for comment in pr.get_comments():
|
|
247
|
+
comments.append({"user": comment.user.login, "body": comment.body})
|
|
248
|
+
|
|
249
|
+
return comments
|
|
250
|
+
|
|
251
|
+
def get_all_file_paths(self, path: str = "") -> List[str]:
|
|
252
|
+
r"""Recursively retrieves all file paths in the GitHub repository.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
path (str): The repository path to start the traversal from.
|
|
256
|
+
empty string means starts from the root directory.
|
|
257
|
+
(default::obj: `""`)
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List[str]: A list of file paths within the specified directory
|
|
261
|
+
structure.
|
|
262
|
+
"""
|
|
263
|
+
from github.ContentFile import ContentFile
|
|
264
|
+
|
|
265
|
+
files: List[str] = []
|
|
266
|
+
|
|
267
|
+
# Retrieves all contents of the current directory
|
|
268
|
+
contents: Union[List[ContentFile], ContentFile] = (
|
|
269
|
+
self.repo.get_contents(path)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if isinstance(contents, ContentFile):
|
|
273
|
+
files.append(contents.path)
|
|
274
|
+
else:
|
|
275
|
+
for content in contents:
|
|
276
|
+
if content.type == "dir":
|
|
277
|
+
# If it's a directory, recursively retrieve its file paths
|
|
278
|
+
files.extend(self.get_all_file_paths(content.path))
|
|
279
|
+
else:
|
|
280
|
+
# If it's a file, add its path to the list
|
|
281
|
+
files.append(content.path)
|
|
282
|
+
return files
|
|
283
|
+
|
|
303
284
|
def retrieve_file_content(self, file_path: str) -> str:
|
|
304
285
|
r"""Retrieves the content of a file from the GitHub repository.
|
|
305
286
|
|
|
@@ -309,11 +290,29 @@ class GithubToolkit(BaseToolkit):
|
|
|
309
290
|
Returns:
|
|
310
291
|
str: The decoded content of the file.
|
|
311
292
|
"""
|
|
312
|
-
file_content = self.repo.get_contents(file_path)
|
|
313
|
-
|
|
314
293
|
from github.ContentFile import ContentFile
|
|
315
294
|
|
|
295
|
+
file_content = self.repo.get_contents(file_path)
|
|
316
296
|
if isinstance(file_content, ContentFile):
|
|
317
297
|
return file_content.decoded_content.decode()
|
|
318
298
|
else:
|
|
319
299
|
raise ValueError("PRs with multiple files aren't supported yet.")
|
|
300
|
+
|
|
301
|
+
def get_tools(self) -> List[FunctionTool]:
|
|
302
|
+
r"""Returns a list of FunctionTool objects representing the functions
|
|
303
|
+
in the toolkit.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
List[FunctionTool]: A list of FunctionTool objects representing
|
|
307
|
+
the functions in the toolkit.
|
|
308
|
+
"""
|
|
309
|
+
return [
|
|
310
|
+
FunctionTool(self.create_pull_request),
|
|
311
|
+
FunctionTool(self.get_issue_list),
|
|
312
|
+
FunctionTool(self.get_issue_content),
|
|
313
|
+
FunctionTool(self.get_pull_request_list),
|
|
314
|
+
FunctionTool(self.get_pull_request_code),
|
|
315
|
+
FunctionTool(self.get_pull_request_comments),
|
|
316
|
+
FunctionTool(self.get_all_file_paths),
|
|
317
|
+
FunctionTool(self.retrieve_file_content),
|
|
318
|
+
]
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
import re
|
|
15
|
-
from typing import List, Optional
|
|
15
|
+
from typing import Any, Dict, List, Optional
|
|
16
16
|
|
|
17
17
|
from camel.toolkits import FunctionTool
|
|
18
18
|
from camel.toolkits.base import BaseToolkit
|
|
@@ -28,6 +28,8 @@ class GoogleScholarToolkit(BaseToolkit):
|
|
|
28
28
|
is_author_name (bool): Flag to indicate if the identifier is a name.
|
|
29
29
|
(default: :obj:`False`)
|
|
30
30
|
scholarly (module): The scholarly module for querying Google Scholar.
|
|
31
|
+
author (Optional[Dict[str, Any]]): Cached author details, allowing
|
|
32
|
+
manual assignment if desired.
|
|
31
33
|
"""
|
|
32
34
|
|
|
33
35
|
def __init__(
|
|
@@ -46,6 +48,35 @@ class GoogleScholarToolkit(BaseToolkit):
|
|
|
46
48
|
self.scholarly = scholarly
|
|
47
49
|
self.author_identifier = author_identifier
|
|
48
50
|
self.is_author_name = is_author_name
|
|
51
|
+
self._author: Optional[Dict[str, Any]] = None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def author(self) -> Dict[str, Any]:
|
|
55
|
+
r"""Getter for the author attribute, fetching details if not cached.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Dict[str, Any]: A dictionary containing author details. If no data
|
|
59
|
+
is available, returns an empty dictionary.
|
|
60
|
+
"""
|
|
61
|
+
if self._author is None:
|
|
62
|
+
self.get_author_detailed_info()
|
|
63
|
+
return self._author or {}
|
|
64
|
+
|
|
65
|
+
@author.setter
|
|
66
|
+
def author(self, value: Optional[Dict[str, Any]]) -> None:
|
|
67
|
+
r"""Sets or overrides the cached author information.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
value (Optional[Dict[str, Any]]): A dictionary containing author
|
|
71
|
+
details to cache or `None` to clear the cached data.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If `value` is not a dictionary or `None`.
|
|
75
|
+
"""
|
|
76
|
+
if value is None or isinstance(value, dict):
|
|
77
|
+
self._author = value
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError("Author must be a dictionary or None.")
|
|
49
80
|
|
|
50
81
|
def _extract_author_id(self) -> Optional[str]:
|
|
51
82
|
r"""Extracts the author ID from a Google Scholar URL if provided.
|
|
@@ -73,8 +104,8 @@ class GoogleScholarToolkit(BaseToolkit):
|
|
|
73
104
|
author_id = self._extract_author_id()
|
|
74
105
|
first_author_result = self.scholarly.search_author_id(id=author_id)
|
|
75
106
|
|
|
76
|
-
|
|
77
|
-
return
|
|
107
|
+
self._author = self.scholarly.fill(first_author_result)
|
|
108
|
+
return self._author # type: ignore[return-value]
|
|
78
109
|
|
|
79
110
|
def get_author_publications(
|
|
80
111
|
self,
|
|
@@ -84,9 +115,8 @@ class GoogleScholarToolkit(BaseToolkit):
|
|
|
84
115
|
Returns:
|
|
85
116
|
List[str]: A list of publication titles authored by the author.
|
|
86
117
|
"""
|
|
87
|
-
author = self.get_author_detailed_info()
|
|
88
118
|
publication_titles = [
|
|
89
|
-
pub['bib']['title'] for pub in author['publications']
|
|
119
|
+
pub['bib']['title'] for pub in self.author['publications']
|
|
90
120
|
]
|
|
91
121
|
return publication_titles
|
|
92
122
|
|
|
@@ -105,8 +135,7 @@ class GoogleScholarToolkit(BaseToolkit):
|
|
|
105
135
|
Optional[dict]: A dictionary containing detailed information about
|
|
106
136
|
the publication if found; otherwise, `None`.
|
|
107
137
|
"""
|
|
108
|
-
|
|
109
|
-
publications = author['publications']
|
|
138
|
+
publications = self.author['publications']
|
|
110
139
|
for publication in publications:
|
|
111
140
|
if publication['bib']['title'] == publication_title:
|
|
112
141
|
return self.scholarly.fill(publication)
|