waldiez 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/__init__.py +2 -0
- waldiez/__main__.py +2 -0
- waldiez/_version.py +3 -1
- waldiez/cli.py +13 -3
- waldiez/cli_extras.py +4 -3
- waldiez/conflict_checker.py +4 -3
- waldiez/exporter.py +28 -105
- waldiez/exporting/__init__.py +8 -9
- waldiez/exporting/agent/__init__.py +7 -0
- waldiez/exporting/agent/agent_exporter.py +279 -0
- waldiez/exporting/agent/utils/__init__.py +23 -0
- waldiez/exporting/agent/utils/agent_class_name.py +34 -0
- waldiez/exporting/agent/utils/agent_imports.py +50 -0
- waldiez/exporting/{agents → agent/utils}/code_execution.py +9 -11
- waldiez/exporting/{agents → agent/utils}/group_manager.py +47 -35
- waldiez/exporting/{agents → agent/utils}/rag_user/__init__.py +2 -0
- waldiez/exporting/{agents → agent/utils}/rag_user/chroma_utils.py +22 -17
- waldiez/exporting/{agents → agent/utils}/rag_user/mongo_utils.py +14 -10
- waldiez/exporting/{agents → agent/utils}/rag_user/pgvector_utils.py +12 -8
- waldiez/exporting/{agents → agent/utils}/rag_user/qdrant_utils.py +11 -8
- waldiez/exporting/{agents → agent/utils}/rag_user/rag_user.py +78 -55
- waldiez/exporting/{agents → agent/utils}/rag_user/vector_db.py +10 -8
- waldiez/exporting/agent/utils/swarm_agent.py +463 -0
- waldiez/exporting/{agents → agent/utils}/teachability.py +10 -6
- waldiez/exporting/{agents → agent/utils}/termination_message.py +7 -8
- waldiez/exporting/base/__init__.py +25 -0
- waldiez/exporting/base/agent_position.py +75 -0
- waldiez/exporting/base/base_exporter.py +118 -0
- waldiez/exporting/base/export_position.py +48 -0
- waldiez/exporting/base/import_position.py +23 -0
- waldiez/exporting/base/mixin.py +134 -0
- waldiez/exporting/base/utils/__init__.py +18 -0
- waldiez/exporting/{utils → base/utils}/comments.py +12 -55
- waldiez/exporting/{utils → base/utils}/naming.py +14 -4
- waldiez/exporting/base/utils/path_check.py +68 -0
- waldiez/exporting/{utils/object_string.py → base/utils/to_string.py} +21 -20
- waldiez/exporting/chats/__init__.py +5 -12
- waldiez/exporting/chats/chats_exporter.py +240 -0
- waldiez/exporting/chats/utils/__init__.py +15 -0
- waldiez/exporting/chats/utils/common.py +81 -0
- waldiez/exporting/chats/{nested.py → utils/nested.py} +125 -86
- waldiez/exporting/chats/utils/sequential.py +244 -0
- waldiez/exporting/chats/utils/single_chat.py +313 -0
- waldiez/exporting/chats/utils/swarm.py +207 -0
- waldiez/exporting/flow/__init__.py +5 -3
- waldiez/exporting/flow/flow_exporter.py +503 -0
- waldiez/exporting/flow/utils/__init__.py +47 -0
- waldiez/exporting/flow/utils/agent_utils.py +204 -0
- waldiez/exporting/flow/utils/chat_utils.py +71 -0
- waldiez/exporting/flow/utils/def_main.py +62 -0
- waldiez/exporting/flow/utils/flow_content.py +112 -0
- waldiez/exporting/flow/utils/flow_names.py +115 -0
- waldiez/exporting/flow/utils/importing_utils.py +179 -0
- waldiez/exporting/{utils → flow/utils}/logging_utils.py +34 -31
- waldiez/exporting/models/__init__.py +7 -242
- waldiez/exporting/models/models_exporter.py +192 -0
- waldiez/exporting/models/utils.py +166 -0
- waldiez/exporting/skills/__init__.py +7 -161
- waldiez/exporting/skills/skills_exporter.py +169 -0
- waldiez/exporting/skills/utils.py +281 -0
- waldiez/models/__init__.py +25 -7
- waldiez/models/agents/__init__.py +70 -0
- waldiez/models/agents/agent/__init__.py +11 -1
- waldiez/models/agents/agent/agent.py +9 -4
- waldiez/models/agents/agent/agent_data.py +3 -1
- waldiez/models/agents/agent/code_execution.py +2 -0
- waldiez/models/agents/agent/linked_skill.py +2 -0
- waldiez/models/agents/agent/nested_chat.py +2 -0
- waldiez/models/agents/agent/teachability.py +2 -0
- waldiez/models/agents/agent/termination_message.py +49 -13
- waldiez/models/agents/agents.py +15 -3
- waldiez/models/agents/assistant/__init__.py +2 -0
- waldiez/models/agents/assistant/assistant.py +2 -0
- waldiez/models/agents/assistant/assistant_data.py +2 -0
- waldiez/models/agents/group_manager/__init__.py +9 -1
- waldiez/models/agents/group_manager/group_manager.py +2 -0
- waldiez/models/agents/group_manager/group_manager_data.py +2 -0
- waldiez/models/agents/group_manager/speakers.py +49 -13
- waldiez/models/agents/rag_user/__init__.py +21 -4
- waldiez/models/agents/rag_user/rag_user.py +3 -1
- waldiez/models/agents/rag_user/rag_user_data.py +2 -0
- waldiez/models/agents/rag_user/retrieve_config.py +268 -17
- waldiez/models/agents/rag_user/vector_db_config.py +5 -3
- waldiez/models/agents/swarm_agent/__init__.py +49 -0
- waldiez/models/agents/swarm_agent/after_work.py +178 -0
- waldiez/models/agents/swarm_agent/on_condition.py +103 -0
- waldiez/models/agents/swarm_agent/on_condition_available.py +140 -0
- waldiez/models/agents/swarm_agent/on_condition_target.py +40 -0
- waldiez/models/agents/swarm_agent/swarm_agent.py +107 -0
- waldiez/models/agents/swarm_agent/swarm_agent_data.py +125 -0
- waldiez/models/agents/swarm_agent/update_system_message.py +144 -0
- waldiez/models/agents/user_proxy/__init__.py +2 -0
- waldiez/models/agents/user_proxy/user_proxy.py +2 -0
- waldiez/models/agents/user_proxy/user_proxy_data.py +2 -0
- waldiez/models/chat/__init__.py +21 -3
- waldiez/models/chat/chat.py +241 -7
- waldiez/models/chat/chat_data.py +192 -48
- waldiez/models/chat/chat_message.py +153 -144
- waldiez/models/chat/chat_nested.py +33 -53
- waldiez/models/chat/chat_summary.py +2 -0
- waldiez/models/common/__init__.py +6 -6
- waldiez/models/common/base.py +4 -1
- waldiez/models/common/method_utils.py +163 -83
- waldiez/models/flow/__init__.py +2 -0
- waldiez/models/flow/flow.py +176 -40
- waldiez/models/flow/flow_data.py +63 -2
- waldiez/models/flow/utils.py +172 -0
- waldiez/models/model/__init__.py +2 -0
- waldiez/models/model/model.py +25 -6
- waldiez/models/model/model_data.py +3 -1
- waldiez/models/skill/__init__.py +4 -1
- waldiez/models/skill/skill.py +30 -2
- waldiez/models/skill/skill_data.py +2 -0
- waldiez/models/waldiez.py +28 -4
- waldiez/runner.py +142 -228
- waldiez/running/__init__.py +33 -0
- waldiez/running/environment.py +83 -0
- waldiez/running/gen_seq_diagram.py +185 -0
- waldiez/running/running.py +300 -0
- {waldiez-0.2.2.dist-info → waldiez-0.3.0.dist-info}/METADATA +32 -26
- waldiez-0.3.0.dist-info/RECORD +125 -0
- waldiez-0.3.0.dist-info/licenses/LICENSE +201 -0
- waldiez/exporting/agents/__init__.py +0 -5
- waldiez/exporting/agents/agent.py +0 -236
- waldiez/exporting/agents/agent_skills.py +0 -67
- waldiez/exporting/agents/llm_config.py +0 -53
- waldiez/exporting/chats/chats.py +0 -46
- waldiez/exporting/chats/helpers.py +0 -420
- waldiez/exporting/flow/def_main.py +0 -32
- waldiez/exporting/flow/flow.py +0 -189
- waldiez/exporting/utils/__init__.py +0 -36
- waldiez/exporting/utils/importing.py +0 -265
- waldiez/exporting/utils/method_utils.py +0 -35
- waldiez/exporting/utils/path_check.py +0 -51
- waldiez-0.2.2.dist-info/RECORD +0 -92
- waldiez-0.2.2.dist-info/licenses/LICENSE +0 -21
- {waldiez-0.2.2.dist-info → waldiez-0.3.0.dist-info}/WHEEL +0 -0
- {waldiez-0.2.2.dist-info → waldiez-0.3.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,19 +1,27 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Nested chat model."""
|
|
2
4
|
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
|
-
from pydantic import
|
|
6
|
-
ConfigDict,
|
|
7
|
-
Field,
|
|
8
|
-
ValidationInfo,
|
|
9
|
-
field_validator,
|
|
10
|
-
model_validator,
|
|
11
|
-
)
|
|
12
|
-
from pydantic.alias_generators import to_camel
|
|
7
|
+
from pydantic import Field, field_validator, model_validator
|
|
13
8
|
from typing_extensions import Annotated, Self
|
|
14
9
|
|
|
15
|
-
from ..common import WaldiezBase
|
|
16
|
-
from .chat_message import WaldiezChatMessage
|
|
10
|
+
from ..common import WaldiezBase
|
|
11
|
+
from .chat_message import WaldiezChatMessage
|
|
12
|
+
|
|
13
|
+
NESTED_CHAT_MESSAGE = "nested_chat_message"
|
|
14
|
+
NESTED_CHAT_REPLY = "nested_chat_reply"
|
|
15
|
+
NESTED_CHAT_ARGS = ["recipient", "messages", "sender", "config"]
|
|
16
|
+
NESTED_CHAT_TYPES = (
|
|
17
|
+
[
|
|
18
|
+
"ConversableAgent",
|
|
19
|
+
"List[Dict[str, Any]]",
|
|
20
|
+
"ConversableAgent",
|
|
21
|
+
"Dict[str, Any]",
|
|
22
|
+
],
|
|
23
|
+
"Union[Dict[str, Any], str]",
|
|
24
|
+
)
|
|
17
25
|
|
|
18
26
|
|
|
19
27
|
class WaldiezChatNested(WaldiezBase):
|
|
@@ -27,13 +35,6 @@ class WaldiezChatNested(WaldiezBase):
|
|
|
27
35
|
The reply in a nested chat (recipient -> sender).
|
|
28
36
|
"""
|
|
29
37
|
|
|
30
|
-
model_config = ConfigDict(
|
|
31
|
-
extra="forbid",
|
|
32
|
-
alias_generator=to_camel,
|
|
33
|
-
populate_by_name=True,
|
|
34
|
-
frozen=False,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
38
|
message: Annotated[
|
|
38
39
|
Optional[WaldiezChatMessage],
|
|
39
40
|
Field(
|
|
@@ -66,17 +67,13 @@ class WaldiezChatNested(WaldiezBase):
|
|
|
66
67
|
|
|
67
68
|
@field_validator("message", "reply", mode="before")
|
|
68
69
|
@classmethod
|
|
69
|
-
def validate_message(
|
|
70
|
-
cls, value: Any, info: ValidationInfo
|
|
71
|
-
) -> WaldiezChatMessage:
|
|
70
|
+
def validate_message(cls, value: Any) -> WaldiezChatMessage:
|
|
72
71
|
"""Validate the message.
|
|
73
72
|
|
|
74
73
|
Parameters
|
|
75
74
|
----------
|
|
76
75
|
value : Any
|
|
77
76
|
The value.
|
|
78
|
-
info : ValidationInfo
|
|
79
|
-
The validation info.
|
|
80
77
|
|
|
81
78
|
Returns
|
|
82
79
|
-------
|
|
@@ -88,11 +85,6 @@ class WaldiezChatNested(WaldiezBase):
|
|
|
88
85
|
ValueError
|
|
89
86
|
If the validation fails.
|
|
90
87
|
"""
|
|
91
|
-
function_name: WaldiezMethodName = (
|
|
92
|
-
"nested_chat_message"
|
|
93
|
-
if info.field_name == "message"
|
|
94
|
-
else "nested_chat_reply"
|
|
95
|
-
)
|
|
96
88
|
if not value:
|
|
97
89
|
return WaldiezChatMessage(
|
|
98
90
|
type="none", use_carryover=False, content=None, context={}
|
|
@@ -102,17 +94,9 @@ class WaldiezChatNested(WaldiezBase):
|
|
|
102
94
|
type="string", use_carryover=False, content=value, context={}
|
|
103
95
|
)
|
|
104
96
|
if isinstance(value, dict):
|
|
105
|
-
return
|
|
97
|
+
return WaldiezChatMessage.model_validate(value)
|
|
106
98
|
if isinstance(value, WaldiezChatMessage):
|
|
107
|
-
return
|
|
108
|
-
{
|
|
109
|
-
"type": value.type,
|
|
110
|
-
"use_carryover": False,
|
|
111
|
-
"content": value.content,
|
|
112
|
-
"context": value.context,
|
|
113
|
-
},
|
|
114
|
-
function_name=function_name,
|
|
115
|
-
)
|
|
99
|
+
return value
|
|
116
100
|
raise ValueError(f"Invalid message type: {type(value)}")
|
|
117
101
|
|
|
118
102
|
@model_validator(mode="after")
|
|
@@ -134,27 +118,23 @@ class WaldiezChatNested(WaldiezBase):
|
|
|
134
118
|
self._message_content = ""
|
|
135
119
|
elif self.message.type == "string":
|
|
136
120
|
self._message_content = self.message.content
|
|
121
|
+
elif self.message.type == "method":
|
|
122
|
+
self._message_content = self.message.validate_method(
|
|
123
|
+
function_name=NESTED_CHAT_MESSAGE,
|
|
124
|
+
function_args=NESTED_CHAT_ARGS,
|
|
125
|
+
)
|
|
137
126
|
else:
|
|
138
|
-
self._message_content =
|
|
139
|
-
value={
|
|
140
|
-
"type": "method",
|
|
141
|
-
"content": self.message.content,
|
|
142
|
-
},
|
|
143
|
-
function_name="nested_chat_message",
|
|
144
|
-
skip_definition=True,
|
|
145
|
-
).content
|
|
127
|
+
self._message_content = self.message.content_body
|
|
146
128
|
if self.reply is not None:
|
|
147
129
|
if self.reply.type == "none":
|
|
148
130
|
self._reply_content = ""
|
|
149
131
|
elif self.reply.type == "string":
|
|
150
132
|
self._reply_content = self.reply.content
|
|
133
|
+
elif self.reply.type == "method":
|
|
134
|
+
self._reply_content = self.reply.validate_method(
|
|
135
|
+
function_name=NESTED_CHAT_REPLY,
|
|
136
|
+
function_args=NESTED_CHAT_ARGS,
|
|
137
|
+
)
|
|
151
138
|
else:
|
|
152
|
-
self._reply_content =
|
|
153
|
-
value={
|
|
154
|
-
"type": "method",
|
|
155
|
-
"content": self.reply.content,
|
|
156
|
-
},
|
|
157
|
-
function_name="nested_chat_reply",
|
|
158
|
-
skip_definition=True,
|
|
159
|
-
).content
|
|
139
|
+
self._reply_content = self.reply.content_body
|
|
160
140
|
return self
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Common utils for all models."""
|
|
2
4
|
|
|
3
5
|
from datetime import datetime, timezone
|
|
4
6
|
|
|
5
7
|
from .base import WaldiezBase
|
|
6
8
|
from .method_utils import (
|
|
7
|
-
METHOD_ARGS,
|
|
8
|
-
METHOD_TYPE_HINTS,
|
|
9
|
-
WaldiezMethodName,
|
|
10
9
|
check_function,
|
|
10
|
+
generate_function,
|
|
11
|
+
get_function,
|
|
11
12
|
parse_code_string,
|
|
12
13
|
)
|
|
13
14
|
|
|
@@ -29,10 +30,9 @@ def now() -> str:
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
31
32
|
"WaldiezBase",
|
|
32
|
-
"METHOD_ARGS",
|
|
33
|
-
"METHOD_TYPE_HINTS",
|
|
34
|
-
"WaldiezMethodName",
|
|
35
33
|
"now",
|
|
36
34
|
"check_function",
|
|
35
|
+
"get_function",
|
|
36
|
+
"generate_function",
|
|
37
37
|
"parse_code_string",
|
|
38
38
|
]
|
waldiez/models/common/base.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Base class to inherit from."""
|
|
2
4
|
|
|
3
5
|
from typing import Any, Dict
|
|
@@ -19,7 +21,8 @@ class WaldiezBase(BaseModel):
|
|
|
19
21
|
alias_generator=to_camel,
|
|
20
22
|
# allow passing either `skill_id` or `skillId`
|
|
21
23
|
populate_by_name=True,
|
|
22
|
-
|
|
24
|
+
# allow setting any attribute after initialization
|
|
25
|
+
frozen=False,
|
|
23
26
|
)
|
|
24
27
|
|
|
25
28
|
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
|
@@ -1,48 +1,13 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Function related utilities."""
|
|
2
4
|
|
|
3
|
-
# flake8: noqa E501
|
|
4
5
|
import ast
|
|
5
|
-
from typing import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
"nested_chat_message", # Agents' NestedChat
|
|
11
|
-
"nested_chat_reply", # Agents' NestedChat
|
|
12
|
-
"custom_speaker_selection", # GroupChat
|
|
13
|
-
"custom_embedding_function", # RAG
|
|
14
|
-
"custom_token_count_function", # RAG
|
|
15
|
-
"custom_text_split_function", # RAG
|
|
16
|
-
]
|
|
17
|
-
|
|
18
|
-
METHOD_ARGS: Dict[WaldiezMethodName, List[str]] = {
|
|
19
|
-
"callable_message": ["sender", "recipient", "context"],
|
|
20
|
-
"is_termination_message": ["message"],
|
|
21
|
-
"nested_chat_message": ["recipient", "messages", "sender", "config"],
|
|
22
|
-
"nested_chat_reply": ["recipient", "messages", "sender", "config"],
|
|
23
|
-
"custom_speaker_selection": ["last_speaker", "groupchat"],
|
|
24
|
-
"custom_embedding_function": [],
|
|
25
|
-
"custom_token_count_function": ["text", "model"],
|
|
26
|
-
"custom_text_split_function": [
|
|
27
|
-
"text",
|
|
28
|
-
"max_tokens",
|
|
29
|
-
"chunk_mode",
|
|
30
|
-
"must_break_at_empty_line",
|
|
31
|
-
"overlap",
|
|
32
|
-
],
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
# pylint: disable=line-too-long
|
|
36
|
-
METHOD_TYPE_HINTS: Dict[WaldiezMethodName, str] = {
|
|
37
|
-
"callable_message": "# type: (ConversableAgent, ConversableAgent, dict) -> Union[dict, str]",
|
|
38
|
-
"is_termination_message": "# type: (dict) -> bool",
|
|
39
|
-
"nested_chat_message": "# type: (ConversableAgent, list[dict], ConversableAgent, dict) -> Union[dict, str]",
|
|
40
|
-
"nested_chat_reply": "# type: (ConversableAgent, list[dict], ConversableAgent, dict) -> Union[dict, str]",
|
|
41
|
-
"custom_speaker_selection": "# type: (ConversableAgent, GroupChat) -> Union[Agent, str, None]",
|
|
42
|
-
"custom_embedding_function": "# type: () -> Callable[..., Any]",
|
|
43
|
-
"custom_token_count_function": "# type: (str, str) -> int",
|
|
44
|
-
"custom_text_split_function": "# type: (str, int, str, bool, int) -> List[str]",
|
|
45
|
-
}
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import parso
|
|
9
|
+
import parso.python
|
|
10
|
+
import parso.tree
|
|
46
11
|
|
|
47
12
|
|
|
48
13
|
def parse_code_string(
|
|
@@ -65,28 +30,27 @@ def parse_code_string(
|
|
|
65
30
|
try:
|
|
66
31
|
tree = ast.parse(code_string)
|
|
67
32
|
except SyntaxError as e:
|
|
68
|
-
return f"SyntaxError: {e}, in \n{code_string}", None
|
|
33
|
+
return f"SyntaxError: {e}, in " + "\n" + f"{code_string}", None
|
|
69
34
|
except BaseException as e: # pragma: no cover
|
|
70
|
-
return f"Invalid code: {e}, in \n{code_string}", None
|
|
35
|
+
return f"Invalid code: {e}, in " + "\n" + f"{code_string}", None
|
|
71
36
|
return None, tree
|
|
72
37
|
|
|
73
38
|
|
|
74
39
|
def check_function(
|
|
75
40
|
code_string: str,
|
|
76
|
-
function_name:
|
|
77
|
-
|
|
41
|
+
function_name: str,
|
|
42
|
+
function_args: List[str],
|
|
78
43
|
) -> Tuple[bool, str]:
|
|
79
44
|
"""Check the function.
|
|
80
45
|
|
|
81
46
|
Parameters
|
|
82
47
|
----------
|
|
83
48
|
code_string : str
|
|
84
|
-
The code string.
|
|
85
|
-
function_name :
|
|
86
|
-
The expected
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
49
|
+
The code string to check.
|
|
50
|
+
function_name : str
|
|
51
|
+
The expected method name.
|
|
52
|
+
function_args : List[str]
|
|
53
|
+
The expected method arguments.
|
|
90
54
|
Returns
|
|
91
55
|
-------
|
|
92
56
|
Tuple[bool, str]
|
|
@@ -96,24 +60,19 @@ def check_function(
|
|
|
96
60
|
error, tree = parse_code_string(code_string)
|
|
97
61
|
if error is not None or tree is None:
|
|
98
62
|
return False, error or "Invalid code"
|
|
99
|
-
|
|
100
|
-
return False, f"Invalid function name: {function_name}"
|
|
101
|
-
expected_method_args = METHOD_ARGS[function_name]
|
|
102
|
-
return _get_function_body(
|
|
63
|
+
return _validate_function_body(
|
|
103
64
|
tree,
|
|
104
65
|
code_string,
|
|
105
66
|
function_name,
|
|
106
|
-
|
|
107
|
-
skip_type_hints=skip_type_hints,
|
|
67
|
+
function_args,
|
|
108
68
|
)
|
|
109
69
|
|
|
110
70
|
|
|
111
|
-
def
|
|
71
|
+
def _validate_function_body(
|
|
112
72
|
tree: ast.Module,
|
|
113
73
|
code_string: str,
|
|
114
|
-
function_name:
|
|
115
|
-
|
|
116
|
-
skip_type_hints: bool = False,
|
|
74
|
+
function_name: str,
|
|
75
|
+
function_args: List[str],
|
|
117
76
|
) -> Tuple[bool, str]:
|
|
118
77
|
"""Get the function body.
|
|
119
78
|
|
|
@@ -121,13 +80,12 @@ def _get_function_body(
|
|
|
121
80
|
----------
|
|
122
81
|
tree : ast.Module
|
|
123
82
|
The ast module.
|
|
124
|
-
|
|
125
|
-
The
|
|
126
|
-
function_name :
|
|
127
|
-
The expected
|
|
128
|
-
|
|
83
|
+
function_body : str
|
|
84
|
+
The function body.
|
|
85
|
+
function_name : str
|
|
86
|
+
The expected method name.
|
|
87
|
+
function_args : List[str]
|
|
129
88
|
The expected method arguments.
|
|
130
|
-
|
|
131
89
|
Returns
|
|
132
90
|
-------
|
|
133
91
|
Tuple[bool, str]
|
|
@@ -138,29 +96,151 @@ def _get_function_body(
|
|
|
138
96
|
if isinstance(node, ast.FunctionDef):
|
|
139
97
|
if node.name != function_name:
|
|
140
98
|
continue
|
|
141
|
-
if len(node.args.args) != len(
|
|
99
|
+
if len(node.args.args) != len(function_args):
|
|
142
100
|
return (
|
|
143
101
|
False,
|
|
144
|
-
|
|
102
|
+
(
|
|
103
|
+
f"Invalid number of arguments, in function {node.name},"
|
|
104
|
+
f" expected: {len(function_args)},"
|
|
105
|
+
f" got: {len(node.args.args)} :("
|
|
106
|
+
),
|
|
145
107
|
)
|
|
146
|
-
for arg, expected_arg in zip(node.args.args,
|
|
108
|
+
for arg, expected_arg in zip(node.args.args, function_args):
|
|
147
109
|
if arg.arg != expected_arg:
|
|
148
110
|
return (
|
|
149
111
|
False,
|
|
150
|
-
|
|
112
|
+
(
|
|
113
|
+
f"Invalid argument name: {arg.arg}"
|
|
114
|
+
f" in function {node.name}"
|
|
115
|
+
),
|
|
151
116
|
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
function_body = "\n".join(function_body_lines[1:])
|
|
156
|
-
if not skip_type_hints:
|
|
157
|
-
# add type hints after the function definition
|
|
158
|
-
function_body = (
|
|
159
|
-
f" {METHOD_TYPE_HINTS[function_name]}\n{function_body}"
|
|
160
|
-
)
|
|
117
|
+
if not node.body:
|
|
118
|
+
return False, "No body found in the function"
|
|
119
|
+
function_body = _get_function_body(code_string, node)
|
|
161
120
|
return True, function_body
|
|
162
121
|
error_msg = (
|
|
163
|
-
f"No
|
|
164
|
-
f" and arguments `{
|
|
122
|
+
f"No method with name `{function_name}`"
|
|
123
|
+
f" and arguments `{function_args}` found"
|
|
165
124
|
)
|
|
166
125
|
return False, error_msg
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_function(
|
|
129
|
+
code_string: str,
|
|
130
|
+
function_name: str,
|
|
131
|
+
) -> str:
|
|
132
|
+
"""Get the function signature and body.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
code_string : str
|
|
137
|
+
The code string.
|
|
138
|
+
function_name : str
|
|
139
|
+
The function name.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
str
|
|
144
|
+
The function signature and body.
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
tree = parso.parse(code_string) # type: ignore
|
|
148
|
+
except BaseException: # pylint: disable=broad-except
|
|
149
|
+
return ""
|
|
150
|
+
for node in tree.iter_funcdefs():
|
|
151
|
+
if node.name.value == function_name:
|
|
152
|
+
return node.get_code()
|
|
153
|
+
return ""
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _get_function_body(
|
|
157
|
+
code_string: str,
|
|
158
|
+
node: ast.FunctionDef,
|
|
159
|
+
) -> str:
|
|
160
|
+
"""Get the function body, including docstring and comments inside.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
code_string : str
|
|
165
|
+
The code string.
|
|
166
|
+
node : ast.FunctionDef
|
|
167
|
+
The function node.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
str
|
|
172
|
+
The function body.
|
|
173
|
+
|
|
174
|
+
Raises
|
|
175
|
+
------
|
|
176
|
+
ValueError
|
|
177
|
+
If no body found in the function.
|
|
178
|
+
"""
|
|
179
|
+
lines = code_string.splitlines()
|
|
180
|
+
signature_start_line = node.lineno - 1
|
|
181
|
+
body_start_line = node.body[0].lineno - 1
|
|
182
|
+
signature_end_line = signature_start_line
|
|
183
|
+
for i in range(signature_start_line, body_start_line):
|
|
184
|
+
if ")" in lines[i]:
|
|
185
|
+
signature_end_line = i
|
|
186
|
+
break
|
|
187
|
+
function_body_lines = lines[signature_end_line + 1 :]
|
|
188
|
+
last_line = function_body_lines[-1]
|
|
189
|
+
if not last_line.strip() and len(function_body_lines) > 1:
|
|
190
|
+
function_body_lines = function_body_lines[:-1]
|
|
191
|
+
function_body = "\n".join(function_body_lines)
|
|
192
|
+
while function_body.startswith("\n"):
|
|
193
|
+
function_body = function_body[1:]
|
|
194
|
+
while function_body.endswith("\n"):
|
|
195
|
+
function_body = function_body[:-1]
|
|
196
|
+
return function_body
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def generate_function(
|
|
200
|
+
function_name: str,
|
|
201
|
+
function_args: List[str],
|
|
202
|
+
function_types: Tuple[List[str], str],
|
|
203
|
+
function_body: str,
|
|
204
|
+
types_as_comments: bool = False,
|
|
205
|
+
) -> str:
|
|
206
|
+
"""Generate a function.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
function_name : str
|
|
211
|
+
The function name.
|
|
212
|
+
function_args : List[str]
|
|
213
|
+
The function arguments.
|
|
214
|
+
function_types : Tuple[List[str], str]
|
|
215
|
+
The function types.
|
|
216
|
+
function_body : str
|
|
217
|
+
The function body.
|
|
218
|
+
types_as_comments : bool, optional
|
|
219
|
+
Include the type hints as comments (or in the function signature)
|
|
220
|
+
(default is False).
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
str
|
|
224
|
+
The generated function.
|
|
225
|
+
"""
|
|
226
|
+
function_string = f"def {function_name}("
|
|
227
|
+
if not function_args:
|
|
228
|
+
function_string += ")"
|
|
229
|
+
else:
|
|
230
|
+
function_string += "\n"
|
|
231
|
+
for arg, arg_type in zip(function_args, function_types[0]):
|
|
232
|
+
if types_as_comments:
|
|
233
|
+
function_string += f" {arg}, # type: {arg_type}" + "\n"
|
|
234
|
+
else:
|
|
235
|
+
function_string += f" {arg}: {arg_type}," + "\n"
|
|
236
|
+
function_string += ")"
|
|
237
|
+
if types_as_comments:
|
|
238
|
+
function_string += ":\n"
|
|
239
|
+
function_string += " # type: (...) -> " + function_types[1]
|
|
240
|
+
else:
|
|
241
|
+
function_string += " -> " + function_types[1] + ":"
|
|
242
|
+
function_string += "\n" if not function_body.startswith("\n") else ""
|
|
243
|
+
function_string += f"{function_body}"
|
|
244
|
+
if not function_string.endswith("\n"):
|
|
245
|
+
function_string += "\n"
|
|
246
|
+
return function_string
|