camel-ai 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +362 -237
- camel/benchmarks/__init__.py +11 -1
- camel/benchmarks/apibank.py +560 -0
- camel/benchmarks/apibench.py +496 -0
- camel/benchmarks/gaia.py +2 -2
- camel/benchmarks/nexus.py +518 -0
- camel/datagen/__init__.py +21 -0
- camel/datagen/cotdatagen.py +448 -0
- camel/datagen/self_instruct/__init__.py +36 -0
- camel/datagen/self_instruct/filter/__init__.py +34 -0
- camel/datagen/self_instruct/filter/filter_function.py +216 -0
- camel/datagen/self_instruct/filter/filter_registry.py +56 -0
- camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
- camel/datagen/self_instruct/self_instruct.py +393 -0
- camel/datagen/self_instruct/templates.py +384 -0
- camel/datahubs/huggingface.py +12 -2
- camel/datahubs/models.py +4 -2
- camel/embeddings/mistral_embedding.py +5 -1
- camel/embeddings/openai_compatible_embedding.py +6 -1
- camel/embeddings/openai_embedding.py +5 -1
- camel/interpreters/e2b_interpreter.py +5 -1
- camel/loaders/apify_reader.py +5 -1
- camel/loaders/chunkr_reader.py +5 -1
- camel/loaders/firecrawl_reader.py +0 -30
- camel/logger.py +11 -5
- camel/messages/conversion/sharegpt/hermes/hermes_function_formatter.py +4 -1
- camel/models/anthropic_model.py +5 -1
- camel/models/azure_openai_model.py +1 -2
- camel/models/cohere_model.py +5 -1
- camel/models/deepseek_model.py +5 -1
- camel/models/gemini_model.py +5 -1
- camel/models/groq_model.py +5 -1
- camel/models/mistral_model.py +5 -1
- camel/models/nemotron_model.py +5 -1
- camel/models/nvidia_model.py +5 -1
- camel/models/openai_model.py +5 -1
- camel/models/qwen_model.py +5 -1
- camel/models/reka_model.py +5 -1
- camel/models/reward/nemotron_model.py +5 -1
- camel/models/samba_model.py +5 -1
- camel/models/togetherai_model.py +5 -1
- camel/models/yi_model.py +5 -1
- camel/models/zhipuai_model.py +5 -1
- camel/retrievers/auto_retriever.py +8 -0
- camel/retrievers/vector_retriever.py +6 -3
- camel/schemas/openai_converter.py +5 -1
- camel/societies/role_playing.py +4 -4
- camel/societies/workforce/workforce.py +2 -2
- camel/storages/graph_storages/nebula_graph.py +119 -27
- camel/storages/graph_storages/neo4j_graph.py +138 -0
- camel/toolkits/__init__.py +4 -0
- camel/toolkits/arxiv_toolkit.py +20 -3
- camel/toolkits/dappier_toolkit.py +196 -0
- camel/toolkits/function_tool.py +61 -61
- camel/toolkits/meshy_toolkit.py +5 -1
- camel/toolkits/notion_toolkit.py +1 -1
- camel/toolkits/openbb_toolkit.py +869 -0
- camel/toolkits/search_toolkit.py +91 -5
- camel/toolkits/stripe_toolkit.py +5 -1
- camel/toolkits/twitter_toolkit.py +24 -16
- camel/types/enums.py +7 -1
- camel/types/unified_model_type.py +5 -0
- camel/utils/__init__.py +4 -0
- camel/utils/commons.py +142 -20
- {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/METADATA +17 -5
- {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/RECORD +69 -55
- {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/LICENSE +0 -0
- {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/WHEEL +0 -0
camel/models/cohere_model.py
CHANGED
|
@@ -43,6 +43,11 @@ except (ImportError, AttributeError):
|
|
|
43
43
|
class CohereModel(BaseModelBackend):
|
|
44
44
|
r"""Cohere API in a unified BaseModelBackend interface."""
|
|
45
45
|
|
|
46
|
+
@api_keys_required(
|
|
47
|
+
[
|
|
48
|
+
("api_key", 'COHERE_API_KEY'),
|
|
49
|
+
]
|
|
50
|
+
)
|
|
46
51
|
def __init__(
|
|
47
52
|
self,
|
|
48
53
|
model_type: Union[ModelType, str],
|
|
@@ -210,7 +215,6 @@ class CohereModel(BaseModelBackend):
|
|
|
210
215
|
)
|
|
211
216
|
return self._token_counter
|
|
212
217
|
|
|
213
|
-
@api_keys_required("COHERE_API_KEY")
|
|
214
218
|
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
|
|
215
219
|
r"""Runs inference of Cohere chat completion.
|
|
216
220
|
|
camel/models/deepseek_model.py
CHANGED
|
@@ -50,6 +50,11 @@ class DeepSeekModel(BaseModelBackend):
|
|
|
50
50
|
https://api-docs.deepseek.com/
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
|
+
@api_keys_required(
|
|
54
|
+
[
|
|
55
|
+
("api_key", "DEEPSEEK_API_KEY"),
|
|
56
|
+
]
|
|
57
|
+
)
|
|
53
58
|
def __init__(
|
|
54
59
|
self,
|
|
55
60
|
model_type: Union[ModelType, str],
|
|
@@ -90,7 +95,6 @@ class DeepSeekModel(BaseModelBackend):
|
|
|
90
95
|
)
|
|
91
96
|
return self._token_counter
|
|
92
97
|
|
|
93
|
-
@api_keys_required("DEEPSEEK_API_KEY")
|
|
94
98
|
def run(
|
|
95
99
|
self,
|
|
96
100
|
messages: List[OpenAIMessage],
|
camel/models/gemini_model.py
CHANGED
|
@@ -52,6 +52,11 @@ class GeminiModel(BaseModelBackend):
|
|
|
52
52
|
(default: :obj:`None`)
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
+
@api_keys_required(
|
|
56
|
+
[
|
|
57
|
+
("api_key", 'GEMINI_API_KEY'),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
55
60
|
def __init__(
|
|
56
61
|
self,
|
|
57
62
|
model_type: Union[ModelType, str],
|
|
@@ -77,7 +82,6 @@ class GeminiModel(BaseModelBackend):
|
|
|
77
82
|
base_url=self._url,
|
|
78
83
|
)
|
|
79
84
|
|
|
80
|
-
@api_keys_required("GEMINI_API_KEY")
|
|
81
85
|
def run(
|
|
82
86
|
self,
|
|
83
87
|
messages: List[OpenAIMessage],
|
camel/models/groq_model.py
CHANGED
|
@@ -51,6 +51,11 @@ class GroqModel(BaseModelBackend):
|
|
|
51
51
|
(default: :obj:`None`)
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
|
+
@api_keys_required(
|
|
55
|
+
[
|
|
56
|
+
("api_key", "GROQ_API_KEY"),
|
|
57
|
+
]
|
|
58
|
+
)
|
|
54
59
|
def __init__(
|
|
55
60
|
self,
|
|
56
61
|
model_type: Union[ModelType, str],
|
|
@@ -89,7 +94,6 @@ class GroqModel(BaseModelBackend):
|
|
|
89
94
|
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
|
|
90
95
|
return self._token_counter
|
|
91
96
|
|
|
92
|
-
@api_keys_required("GROQ_API_KEY")
|
|
93
97
|
def run(
|
|
94
98
|
self,
|
|
95
99
|
messages: List[OpenAIMessage],
|
camel/models/mistral_model.py
CHANGED
|
@@ -59,6 +59,11 @@ class MistralModel(BaseModelBackend):
|
|
|
59
59
|
be used. (default: :obj:`None`)
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
|
+
@api_keys_required(
|
|
63
|
+
[
|
|
64
|
+
("api_key", "MISTRAL_API_KEY"),
|
|
65
|
+
]
|
|
66
|
+
)
|
|
62
67
|
@dependencies_required('mistralai')
|
|
63
68
|
def __init__(
|
|
64
69
|
self,
|
|
@@ -200,7 +205,6 @@ class MistralModel(BaseModelBackend):
|
|
|
200
205
|
)
|
|
201
206
|
return self._token_counter
|
|
202
207
|
|
|
203
|
-
@api_keys_required("MISTRAL_API_KEY")
|
|
204
208
|
def run(
|
|
205
209
|
self,
|
|
206
210
|
messages: List[OpenAIMessage],
|
camel/models/nemotron_model.py
CHANGED
|
@@ -40,6 +40,11 @@ class NemotronModel(BaseModelBackend):
|
|
|
40
40
|
Nemotron model doesn't support additional model config like OpenAI.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
+
@api_keys_required(
|
|
44
|
+
[
|
|
45
|
+
("api_key", "NVIDIA_API_KEY"),
|
|
46
|
+
]
|
|
47
|
+
)
|
|
43
48
|
def __init__(
|
|
44
49
|
self,
|
|
45
50
|
model_type: Union[ModelType, str],
|
|
@@ -58,7 +63,6 @@ class NemotronModel(BaseModelBackend):
|
|
|
58
63
|
api_key=self._api_key,
|
|
59
64
|
)
|
|
60
65
|
|
|
61
|
-
@api_keys_required("NVIDIA_API_KEY")
|
|
62
66
|
def run(
|
|
63
67
|
self,
|
|
64
68
|
messages: List[OpenAIMessage],
|
camel/models/nvidia_model.py
CHANGED
|
@@ -48,6 +48,11 @@ class NvidiaModel(BaseModelBackend):
|
|
|
48
48
|
(default: :obj:`None`)
|
|
49
49
|
"""
|
|
50
50
|
|
|
51
|
+
@api_keys_required(
|
|
52
|
+
[
|
|
53
|
+
("api_key", "NVIDIA_API_KEY"),
|
|
54
|
+
]
|
|
55
|
+
)
|
|
51
56
|
def __init__(
|
|
52
57
|
self,
|
|
53
58
|
model_type: Union[ModelType, str],
|
|
@@ -72,7 +77,6 @@ class NvidiaModel(BaseModelBackend):
|
|
|
72
77
|
base_url=self._url,
|
|
73
78
|
)
|
|
74
79
|
|
|
75
|
-
@api_keys_required("NVIDIA_API_KEY")
|
|
76
80
|
def run(
|
|
77
81
|
self,
|
|
78
82
|
messages: List[OpenAIMessage],
|
camel/models/openai_model.py
CHANGED
|
@@ -52,6 +52,11 @@ class OpenAIModel(BaseModelBackend):
|
|
|
52
52
|
be used. (default: :obj:`None`)
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
+
@api_keys_required(
|
|
56
|
+
[
|
|
57
|
+
("api_key", "OPENAI_API_KEY"),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
55
60
|
def __init__(
|
|
56
61
|
self,
|
|
57
62
|
model_type: Union[ModelType, str],
|
|
@@ -86,7 +91,6 @@ class OpenAIModel(BaseModelBackend):
|
|
|
86
91
|
self._token_counter = OpenAITokenCounter(self.model_type)
|
|
87
92
|
return self._token_counter
|
|
88
93
|
|
|
89
|
-
@api_keys_required("OPENAI_API_KEY")
|
|
90
94
|
def run(
|
|
91
95
|
self,
|
|
92
96
|
messages: List[OpenAIMessage],
|
camel/models/qwen_model.py
CHANGED
|
@@ -52,6 +52,11 @@ class QwenModel(BaseModelBackend):
|
|
|
52
52
|
(default: :obj:`None`)
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
+
@api_keys_required(
|
|
56
|
+
[
|
|
57
|
+
("api_key", "QWEN_API_KEY"),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
55
60
|
def __init__(
|
|
56
61
|
self,
|
|
57
62
|
model_type: Union[ModelType, str],
|
|
@@ -77,7 +82,6 @@ class QwenModel(BaseModelBackend):
|
|
|
77
82
|
base_url=self._url,
|
|
78
83
|
)
|
|
79
84
|
|
|
80
|
-
@api_keys_required("QWEN_API_KEY")
|
|
81
85
|
def run(
|
|
82
86
|
self,
|
|
83
87
|
messages: List[OpenAIMessage],
|
camel/models/reka_model.py
CHANGED
|
@@ -56,6 +56,11 @@ class RekaModel(BaseModelBackend):
|
|
|
56
56
|
be used. (default: :obj:`None`)
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
+
@api_keys_required(
|
|
60
|
+
[
|
|
61
|
+
("api_key", "REKA_API_KEY"),
|
|
62
|
+
]
|
|
63
|
+
)
|
|
59
64
|
@dependencies_required('reka')
|
|
60
65
|
def __init__(
|
|
61
66
|
self,
|
|
@@ -168,7 +173,6 @@ class RekaModel(BaseModelBackend):
|
|
|
168
173
|
)
|
|
169
174
|
return self._token_counter
|
|
170
175
|
|
|
171
|
-
@api_keys_required("REKA_API_KEY")
|
|
172
176
|
def run(
|
|
173
177
|
self,
|
|
174
178
|
messages: List[OpenAIMessage],
|
|
@@ -53,7 +53,11 @@ class NemotronRewardModel(BaseRewardModel):
|
|
|
53
53
|
api_key=self.api_key,
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
-
@api_keys_required(
|
|
56
|
+
@api_keys_required(
|
|
57
|
+
[
|
|
58
|
+
(None, "NVIDIA_API_KEY"),
|
|
59
|
+
]
|
|
60
|
+
)
|
|
57
61
|
def evaluate(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
|
|
58
62
|
r"""Evaluate the messages using the Nemotron model.
|
|
59
63
|
|
camel/models/samba_model.py
CHANGED
|
@@ -74,6 +74,11 @@ class SambaModel(BaseModelBackend):
|
|
|
74
74
|
ModelType.GPT_4O_MINI)` will be used.
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
+
@api_keys_required(
|
|
78
|
+
[
|
|
79
|
+
("api_key", 'SAMBA_API_KEY'),
|
|
80
|
+
]
|
|
81
|
+
)
|
|
77
82
|
def __init__(
|
|
78
83
|
self,
|
|
79
84
|
model_type: Union[ModelType, str],
|
|
@@ -143,7 +148,6 @@ class SambaModel(BaseModelBackend):
|
|
|
143
148
|
" SambaNova service"
|
|
144
149
|
)
|
|
145
150
|
|
|
146
|
-
@api_keys_required("SAMBA_API_KEY")
|
|
147
151
|
def run( # type: ignore[misc]
|
|
148
152
|
self, messages: List[OpenAIMessage]
|
|
149
153
|
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
camel/models/togetherai_model.py
CHANGED
|
@@ -53,6 +53,11 @@ class TogetherAIModel(BaseModelBackend):
|
|
|
53
53
|
ModelType.GPT_4O_MINI)` will be used.
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
|
+
@api_keys_required(
|
|
57
|
+
[
|
|
58
|
+
("api_key", 'TOGETHER_API_KEY'),
|
|
59
|
+
]
|
|
60
|
+
)
|
|
56
61
|
def __init__(
|
|
57
62
|
self,
|
|
58
63
|
model_type: Union[ModelType, str],
|
|
@@ -78,7 +83,6 @@ class TogetherAIModel(BaseModelBackend):
|
|
|
78
83
|
base_url=self._url,
|
|
79
84
|
)
|
|
80
85
|
|
|
81
|
-
@api_keys_required("TOGETHER_API_KEY")
|
|
82
86
|
def run(
|
|
83
87
|
self,
|
|
84
88
|
messages: List[OpenAIMessage],
|
camel/models/yi_model.py
CHANGED
|
@@ -52,6 +52,11 @@ class YiModel(BaseModelBackend):
|
|
|
52
52
|
(default: :obj:`None`)
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
+
@api_keys_required(
|
|
56
|
+
[
|
|
57
|
+
("api_key", 'YI_API_KEY'),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
55
60
|
def __init__(
|
|
56
61
|
self,
|
|
57
62
|
model_type: Union[ModelType, str],
|
|
@@ -76,7 +81,6 @@ class YiModel(BaseModelBackend):
|
|
|
76
81
|
base_url=self._url,
|
|
77
82
|
)
|
|
78
83
|
|
|
79
|
-
@api_keys_required("YI_API_KEY")
|
|
80
84
|
def run(
|
|
81
85
|
self,
|
|
82
86
|
messages: List[OpenAIMessage],
|
camel/models/zhipuai_model.py
CHANGED
|
@@ -52,6 +52,11 @@ class ZhipuAIModel(BaseModelBackend):
|
|
|
52
52
|
(default: :obj:`None`)
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
+
@api_keys_required(
|
|
56
|
+
[
|
|
57
|
+
("api_key", 'ZHIPUAI_API_KEY'),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
55
60
|
def __init__(
|
|
56
61
|
self,
|
|
57
62
|
model_type: Union[ModelType, str],
|
|
@@ -76,7 +81,6 @@ class ZhipuAIModel(BaseModelBackend):
|
|
|
76
81
|
base_url=self._url,
|
|
77
82
|
)
|
|
78
83
|
|
|
79
|
-
@api_keys_required("ZHIPUAI_API_KEY")
|
|
80
84
|
def run(
|
|
81
85
|
self,
|
|
82
86
|
messages: List[OpenAIMessage],
|
|
@@ -121,6 +121,14 @@ class AutoRetriever:
|
|
|
121
121
|
|
|
122
122
|
collection_name = re.sub(r'[^a-zA-Z0-9]', '', content)[:20]
|
|
123
123
|
|
|
124
|
+
# Ensure the first character is either an underscore or a letter for
|
|
125
|
+
# Milvus
|
|
126
|
+
if (
|
|
127
|
+
self.storage_type == StorageType.MILVUS
|
|
128
|
+
and not collection_name[0].isalpha()
|
|
129
|
+
):
|
|
130
|
+
collection_name = f"_{collection_name}"
|
|
131
|
+
|
|
124
132
|
return collection_name
|
|
125
133
|
|
|
126
134
|
def run_vector_retriever(
|
|
@@ -161,13 +161,16 @@ class VectorRetriever(BaseRetriever):
|
|
|
161
161
|
# content path, chunk metadata, and chunk text
|
|
162
162
|
for vector, chunk in zip(batch_vectors, batch_chunks):
|
|
163
163
|
if isinstance(content, str):
|
|
164
|
-
content_path_info = {"content path": content}
|
|
164
|
+
content_path_info = {"content path": content[:100]}
|
|
165
165
|
elif isinstance(content, IOBase):
|
|
166
166
|
content_path_info = {"content path": "From file bytes"}
|
|
167
167
|
elif isinstance(content, Element):
|
|
168
168
|
content_path_info = {
|
|
169
|
-
"content path": content.metadata.file_directory
|
|
170
|
-
|
|
169
|
+
"content path": content.metadata.file_directory[
|
|
170
|
+
:100
|
|
171
|
+
]
|
|
172
|
+
if content.metadata.file_directory
|
|
173
|
+
else ""
|
|
171
174
|
}
|
|
172
175
|
|
|
173
176
|
chunk_metadata = {"metadata": chunk.metadata.to_dict()}
|
|
@@ -53,6 +53,11 @@ class OpenAISchemaConverter(BaseConverter):
|
|
|
53
53
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
|
+
@api_keys_required(
|
|
57
|
+
[
|
|
58
|
+
("api_key", "OPENAI_API_KEY"),
|
|
59
|
+
]
|
|
60
|
+
)
|
|
56
61
|
def __init__(
|
|
57
62
|
self,
|
|
58
63
|
model_type: ModelType = ModelType.GPT_4O_MINI,
|
|
@@ -69,7 +74,6 @@ class OpenAISchemaConverter(BaseConverter):
|
|
|
69
74
|
)._client
|
|
70
75
|
super().__init__()
|
|
71
76
|
|
|
72
|
-
@api_keys_required("OPENAI_API_KEY")
|
|
73
77
|
def convert( # type: ignore[override]
|
|
74
78
|
self,
|
|
75
79
|
content: str,
|
camel/societies/role_playing.py
CHANGED
|
@@ -509,8 +509,8 @@ class RolePlaying:
|
|
|
509
509
|
# step and once in role play), and the model generates only one
|
|
510
510
|
# response when multi-response support is enabled.
|
|
511
511
|
if (
|
|
512
|
-
'n' in self.user_agent.model_config_dict.keys()
|
|
513
|
-
and self.user_agent.model_config_dict['n'] > 1
|
|
512
|
+
'n' in self.user_agent.model_backend.model_config_dict.keys()
|
|
513
|
+
and self.user_agent.model_backend.model_config_dict['n'] > 1
|
|
514
514
|
):
|
|
515
515
|
self.user_agent.record_message(user_msg)
|
|
516
516
|
|
|
@@ -532,8 +532,8 @@ class RolePlaying:
|
|
|
532
532
|
# step and once in role play), and the model generates only one
|
|
533
533
|
# response when multi-response support is enabled.
|
|
534
534
|
if (
|
|
535
|
-
'n' in self.assistant_agent.model_config_dict.keys()
|
|
536
|
-
and self.assistant_agent.model_config_dict['n'] > 1
|
|
535
|
+
'n' in self.assistant_agent.model_backend.model_config_dict.keys()
|
|
536
|
+
and self.assistant_agent.model_backend.model_config_dict['n'] > 1
|
|
537
537
|
):
|
|
538
538
|
self.assistant_agent.record_message(assistant_msg)
|
|
539
539
|
|
|
@@ -251,7 +251,7 @@ class Workforce(BaseNode):
|
|
|
251
251
|
additional_info = "A Workforce node"
|
|
252
252
|
elif isinstance(child, SingleAgentWorker):
|
|
253
253
|
additional_info = "tools: " + (
|
|
254
|
-
", ".join(child.worker.
|
|
254
|
+
", ".join(child.worker.tool_dict.keys())
|
|
255
255
|
)
|
|
256
256
|
elif isinstance(child, RolePlayingWorker):
|
|
257
257
|
additional_info = "A Role playing node"
|
|
@@ -369,7 +369,7 @@ class Workforce(BaseNode):
|
|
|
369
369
|
model_config_dict=model_config_dict,
|
|
370
370
|
)
|
|
371
371
|
|
|
372
|
-
return ChatAgent(worker_sys_msg, model=model, tools=function_list)
|
|
372
|
+
return ChatAgent(worker_sys_msg, model=model, tools=function_list) # type: ignore[arg-type]
|
|
373
373
|
|
|
374
374
|
async def _get_returned_task(self) -> Task:
|
|
375
375
|
r"""Get the task that's published by this node and just get returned
|
|
@@ -12,8 +12,19 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
+
import logging
|
|
16
|
+
import re
|
|
15
17
|
import time
|
|
16
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
from camel.storages.graph_storages.base import BaseGraphStorage
|
|
21
|
+
from camel.storages.graph_storages.graph_element import (
|
|
22
|
+
GraphElement,
|
|
23
|
+
)
|
|
24
|
+
from camel.utils.commons import dependencies_required
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
17
28
|
|
|
18
29
|
if TYPE_CHECKING:
|
|
19
30
|
from nebula3.data.ResultSet import ( # type: ignore[import-untyped]
|
|
@@ -24,11 +35,6 @@ if TYPE_CHECKING:
|
|
|
24
35
|
Session,
|
|
25
36
|
)
|
|
26
37
|
|
|
27
|
-
from camel.storages.graph_storages.base import BaseGraphStorage
|
|
28
|
-
from camel.storages.graph_storages.graph_element import (
|
|
29
|
-
GraphElement,
|
|
30
|
-
)
|
|
31
|
-
from camel.utils.commons import dependencies_required
|
|
32
38
|
|
|
33
39
|
MAX_RETRIES = 5
|
|
34
40
|
RETRY_DELAY = 3
|
|
@@ -178,55 +184,81 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
178
184
|
"""
|
|
179
185
|
nodes = self._extract_nodes(graph_elements)
|
|
180
186
|
for node in nodes:
|
|
181
|
-
|
|
187
|
+
try:
|
|
188
|
+
self.add_node(node['id'], node['type'])
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.warning(f"Failed to add node {node}. Error: {e}")
|
|
191
|
+
continue
|
|
182
192
|
|
|
183
193
|
relationships = self._extract_relationships(graph_elements)
|
|
184
194
|
for rel in relationships:
|
|
185
|
-
|
|
195
|
+
try:
|
|
196
|
+
self.add_triplet(
|
|
197
|
+
rel['subj']['id'], rel['obj']['id'], rel['type']
|
|
198
|
+
)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning(f"Failed to add relationship {rel}. Error: {e}")
|
|
201
|
+
continue
|
|
186
202
|
|
|
187
203
|
def ensure_edge_type_exists(
|
|
188
204
|
self,
|
|
189
205
|
edge_type: str,
|
|
206
|
+
time_label: Optional[str] = None,
|
|
190
207
|
) -> None:
|
|
191
208
|
r"""Ensures that a specified edge type exists in the NebulaGraph
|
|
192
209
|
database. If the edge type already exists, this method does nothing.
|
|
193
210
|
|
|
194
211
|
Args:
|
|
195
212
|
edge_type (str): The name of the edge type to be created.
|
|
213
|
+
time_label (str, optional): A specific timestamp to set as the
|
|
214
|
+
default value for the time label property. If not
|
|
215
|
+
provided, no timestamp will be added. (default: :obj:`None`)
|
|
196
216
|
|
|
197
217
|
Raises:
|
|
198
218
|
Exception: If the edge type creation fails after multiple retry
|
|
199
219
|
attempts, an exception is raised with the error message.
|
|
200
220
|
"""
|
|
201
|
-
create_edge_stmt = f
|
|
221
|
+
create_edge_stmt = f"CREATE EDGE IF NOT EXISTS {edge_type} ()"
|
|
222
|
+
if time_label is not None:
|
|
223
|
+
time_label = self._validate_time_label(time_label)
|
|
224
|
+
create_edge_stmt = f"""CREATE EDGE IF NOT EXISTS {edge_type}
|
|
225
|
+
(time_label DATETIME DEFAULT {time_label})"""
|
|
202
226
|
|
|
203
227
|
for attempt in range(MAX_RETRIES):
|
|
204
228
|
res = self.query(create_edge_stmt)
|
|
205
229
|
if res.is_succeeded():
|
|
206
|
-
return #
|
|
230
|
+
return # Edge type creation succeeded
|
|
207
231
|
|
|
208
232
|
if attempt < MAX_RETRIES - 1:
|
|
209
233
|
time.sleep(RETRY_DELAY)
|
|
210
234
|
else:
|
|
211
235
|
# Final attempt failed, raise an exception
|
|
212
236
|
raise Exception(
|
|
213
|
-
f"Failed to create
|
|
237
|
+
f"Failed to create edge type `{edge_type}` after "
|
|
214
238
|
f"{MAX_RETRIES} attempts: {res.error_msg()}"
|
|
215
239
|
)
|
|
216
240
|
|
|
217
|
-
def ensure_tag_exists(
|
|
241
|
+
def ensure_tag_exists(
|
|
242
|
+
self, tag_name: str, time_label: Optional[str] = None
|
|
243
|
+
) -> None:
|
|
218
244
|
r"""Ensures a tag is created in the NebulaGraph database. If the tag
|
|
219
245
|
already exists, it does nothing.
|
|
220
246
|
|
|
221
247
|
Args:
|
|
222
248
|
tag_name (str): The name of the tag to be created.
|
|
249
|
+
time_label (str, optional): A specific timestamp to set as the
|
|
250
|
+
default value for the time label property. If not provided,
|
|
251
|
+
no timestamp will be added. (default: :obj:`None`)
|
|
223
252
|
|
|
224
253
|
Raises:
|
|
225
254
|
Exception: If the tag creation fails after retries, an exception
|
|
226
255
|
is raised with the error message.
|
|
227
256
|
"""
|
|
228
|
-
|
|
229
|
-
|
|
257
|
+
create_tag_stmt = f"CREATE TAG IF NOT EXISTS {tag_name} ()"
|
|
258
|
+
if time_label is not None:
|
|
259
|
+
time_label = self._validate_time_label(time_label)
|
|
260
|
+
create_tag_stmt = f"""CREATE TAG IF NOT EXISTS {tag_name}
|
|
261
|
+
(time_label DATETIME DEFAULT {time_label})"""
|
|
230
262
|
|
|
231
263
|
for attempt in range(MAX_RETRIES):
|
|
232
264
|
res = self.query(create_tag_stmt)
|
|
@@ -246,24 +278,39 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
246
278
|
self,
|
|
247
279
|
node_id: str,
|
|
248
280
|
tag_name: str,
|
|
281
|
+
time_label: Optional[str] = None,
|
|
249
282
|
) -> None:
|
|
250
283
|
r"""Add a node with the specified tag and properties.
|
|
251
284
|
|
|
252
285
|
Args:
|
|
253
286
|
node_id (str): The ID of the node.
|
|
254
287
|
tag_name (str): The tag name of the node.
|
|
288
|
+
time_label (str, optional): A specific timestamp to set for
|
|
289
|
+
the node's time label property. If not provided, no timestamp
|
|
290
|
+
will be added. (default: :obj:`None`)
|
|
255
291
|
"""
|
|
256
|
-
|
|
292
|
+
node_id = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', node_id)
|
|
293
|
+
tag_name = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', tag_name)
|
|
257
294
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
295
|
+
self.ensure_tag_exists(tag_name, time_label)
|
|
296
|
+
|
|
297
|
+
# Insert node with or without time_label property
|
|
298
|
+
if time_label is not None:
|
|
299
|
+
time_label = self._validate_time_label(time_label)
|
|
300
|
+
insert_stmt = (
|
|
301
|
+
f'INSERT VERTEX IF NOT EXISTS {tag_name}(time_label) VALUES '
|
|
302
|
+
f'"{node_id}":("{time_label}")'
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
insert_stmt = (
|
|
306
|
+
f'INSERT VERTEX IF NOT EXISTS {tag_name}() VALUES '
|
|
307
|
+
f'"{node_id}":()'
|
|
308
|
+
)
|
|
262
309
|
|
|
263
310
|
for attempt in range(MAX_RETRIES):
|
|
264
311
|
res = self.query(insert_stmt)
|
|
265
312
|
if res.is_succeeded():
|
|
266
|
-
return #
|
|
313
|
+
return # Node creation succeeded, exit the method
|
|
267
314
|
|
|
268
315
|
if attempt < MAX_RETRIES - 1:
|
|
269
316
|
time.sleep(RETRY_DELAY)
|
|
@@ -329,7 +376,7 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
329
376
|
@property
|
|
330
377
|
def get_structured_schema(self) -> Dict[str, Any]:
|
|
331
378
|
r"""Generates a structured schema consisting of node and relationship
|
|
332
|
-
properties, relationships, and metadata.
|
|
379
|
+
properties, relationships, and metadata, including timestamps.
|
|
333
380
|
|
|
334
381
|
Returns:
|
|
335
382
|
Dict[str, Any]: A dictionary representing the structured schema.
|
|
@@ -400,6 +447,7 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
400
447
|
subj: str,
|
|
401
448
|
obj: str,
|
|
402
449
|
rel: str,
|
|
450
|
+
time_label: Optional[str] = None,
|
|
403
451
|
) -> None:
|
|
404
452
|
r"""Adds a relationship (triplet) between two entities in the Nebula
|
|
405
453
|
Graph database.
|
|
@@ -408,24 +456,44 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
408
456
|
subj (str): The identifier for the subject entity.
|
|
409
457
|
obj (str): The identifier for the object entity.
|
|
410
458
|
rel (str): The relationship between the subject and object.
|
|
459
|
+
time_label (str, optional): A specific timestamp to set for the
|
|
460
|
+
time label property of the relationship. If not provided,
|
|
461
|
+
no timestamp will be added. (default: :obj:`None`)
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
ValueError: If the time_label format is invalid.
|
|
465
|
+
Exception: If creating the relationship fails.
|
|
411
466
|
"""
|
|
467
|
+
subj = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', subj)
|
|
468
|
+
obj = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', obj)
|
|
469
|
+
rel = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', rel)
|
|
470
|
+
|
|
412
471
|
self.ensure_tag_exists(subj)
|
|
413
472
|
self.ensure_tag_exists(obj)
|
|
414
|
-
self.ensure_edge_type_exists(rel)
|
|
473
|
+
self.ensure_edge_type_exists(rel, time_label)
|
|
415
474
|
self.add_node(node_id=subj, tag_name=subj)
|
|
416
475
|
self.add_node(node_id=obj, tag_name=obj)
|
|
417
476
|
|
|
418
|
-
# Avoid
|
|
477
|
+
# Avoid latency
|
|
419
478
|
time.sleep(1)
|
|
420
479
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
480
|
+
# Create edge with or without time_label property
|
|
481
|
+
if time_label is not None:
|
|
482
|
+
time_label = self._validate_time_label(time_label)
|
|
483
|
+
insert_stmt = (
|
|
484
|
+
f'INSERT EDGE IF NOT EXISTS {rel}(time_label) VALUES '
|
|
485
|
+
f'"{subj}"->"{obj}":("{time_label}")'
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
insert_stmt = (
|
|
489
|
+
f'INSERT EDGE IF NOT EXISTS {rel}() VALUES '
|
|
490
|
+
f'"{subj}"->"{obj}":()'
|
|
491
|
+
)
|
|
424
492
|
|
|
425
493
|
res = self.query(insert_stmt)
|
|
426
494
|
if not res.is_succeeded():
|
|
427
495
|
raise Exception(
|
|
428
|
-
f'create relationship `
|
|
496
|
+
f'create relationship `{subj}` -> `{obj}`'
|
|
429
497
|
+ f'failed: {res.error_msg()}'
|
|
430
498
|
)
|
|
431
499
|
|
|
@@ -545,3 +613,27 @@ class NebulaGraph(BaseGraphStorage):
|
|
|
545
613
|
)
|
|
546
614
|
|
|
547
615
|
return rel_schema_props, rel_structure_props
|
|
616
|
+
|
|
617
|
+
def _validate_time_label(self, time_label: str) -> str:
|
|
618
|
+
r"""Validates the format of a time label string.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
time_label (str): The time label string to validate.
|
|
622
|
+
Should be in format 'YYYY-MM-DDThh:mm:ss'.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
str: The validated time label.
|
|
626
|
+
|
|
627
|
+
Raises:
|
|
628
|
+
ValueError: If the time label format is invalid.
|
|
629
|
+
"""
|
|
630
|
+
try:
|
|
631
|
+
# Check if the format matches YYYY-MM-DDThh:mm:ss
|
|
632
|
+
pattern = r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$'
|
|
633
|
+
if not re.match(pattern, time_label):
|
|
634
|
+
raise ValueError(
|
|
635
|
+
"Time label must be in format 'YYYY-MM-DDThh:mm:ss'"
|
|
636
|
+
)
|
|
637
|
+
return time_label
|
|
638
|
+
except Exception as e:
|
|
639
|
+
raise ValueError(f"Invalid time label format: {e!s}")
|