camel-ai 0.1.6.2__py3-none-any.whl → 0.1.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/configs/mistral_config.py +13 -9
- camel/embeddings/mistral_embedding.py +5 -5
- camel/interpreters/docker_interpreter.py +1 -1
- camel/loaders/__init__.py +1 -2
- camel/loaders/base_io.py +118 -52
- camel/loaders/jina_url_reader.py +6 -6
- camel/loaders/unstructured_io.py +24 -286
- camel/models/__init__.py +2 -0
- camel/models/mistral_model.py +120 -26
- camel/models/model_factory.py +3 -3
- camel/models/openai_compatibility_model.py +105 -0
- camel/retrievers/auto_retriever.py +25 -35
- camel/retrievers/vector_retriever.py +20 -18
- camel/storages/object_storages/__init__.py +22 -0
- camel/storages/object_storages/amazon_s3.py +205 -0
- camel/storages/object_storages/azure_blob.py +166 -0
- camel/storages/object_storages/base.py +115 -0
- camel/storages/object_storages/google_cloud.py +152 -0
- camel/toolkits/retrieval_toolkit.py +5 -5
- camel/toolkits/search_toolkit.py +4 -4
- camel/types/enums.py +7 -0
- camel/utils/token_counting.py +7 -3
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.4.dist-info}/METADATA +9 -5
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.4.dist-info}/RECORD +26 -20
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from openai import OpenAI, Stream
|
|
18
|
+
|
|
19
|
+
from camel.messages import OpenAIMessage
|
|
20
|
+
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
21
|
+
from camel.utils import (
|
|
22
|
+
BaseTokenCounter,
|
|
23
|
+
OpenAITokenCounter,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAICompatibilityModel:
|
|
28
|
+
r"""Constructor for model backend supporting OpenAI compatibility."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_type: str,
|
|
33
|
+
model_config_dict: Dict[str, Any],
|
|
34
|
+
api_key: str,
|
|
35
|
+
url: str,
|
|
36
|
+
token_counter: Optional[BaseTokenCounter] = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
r"""Constructor for model backend.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_type (ModelType): Model for which a backend is created.
|
|
42
|
+
model_config_dict (Dict[str, Any]): A dictionary that will
|
|
43
|
+
be fed into openai.ChatCompletion.create().
|
|
44
|
+
api_key (str): The API key for authenticating with the
|
|
45
|
+
model service. (default: :obj:`None`)
|
|
46
|
+
url (str): The url to the model service. (default:
|
|
47
|
+
:obj:`None`)
|
|
48
|
+
token_counter (Optional[BaseTokenCounter]): Token counter to use
|
|
49
|
+
for the model. If not provided, `OpenAITokenCounter(ModelType.
|
|
50
|
+
GPT_3_5_TURBO)` will be used.
|
|
51
|
+
"""
|
|
52
|
+
self.model_type = model_type
|
|
53
|
+
self.model_config_dict = model_config_dict
|
|
54
|
+
self._token_counter = token_counter
|
|
55
|
+
self._client = OpenAI(
|
|
56
|
+
timeout=60,
|
|
57
|
+
max_retries=3,
|
|
58
|
+
api_key=api_key,
|
|
59
|
+
base_url=url,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def run(
|
|
63
|
+
self,
|
|
64
|
+
messages: List[OpenAIMessage],
|
|
65
|
+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
66
|
+
r"""Runs inference of OpenAI chat completion.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
70
|
+
in OpenAI API format.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
74
|
+
`ChatCompletion` in the non-stream mode, or
|
|
75
|
+
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
76
|
+
"""
|
|
77
|
+
response = self._client.chat.completions.create(
|
|
78
|
+
messages=messages,
|
|
79
|
+
model=self.model_type,
|
|
80
|
+
**self.model_config_dict,
|
|
81
|
+
)
|
|
82
|
+
return response
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def token_counter(self) -> BaseTokenCounter:
|
|
86
|
+
r"""Initialize the token counter for the model backend.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
OpenAITokenCounter: The token counter following the model's
|
|
90
|
+
tokenization style.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if not self._token_counter:
|
|
94
|
+
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
|
|
95
|
+
return self._token_counter
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def stream(self) -> bool:
|
|
99
|
+
r"""Returns whether the model is in stream mode, which sends partial
|
|
100
|
+
results each time.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
bool: Whether the model is in stream mode.
|
|
104
|
+
"""
|
|
105
|
+
return self.model_config_dict.get('stream', False)
|
|
@@ -97,36 +97,36 @@ class AutoRetriever:
|
|
|
97
97
|
f"Unsupported vector storage type: {self.storage_type}"
|
|
98
98
|
)
|
|
99
99
|
|
|
100
|
-
def _collection_name_generator(self,
|
|
100
|
+
def _collection_name_generator(self, content: str) -> str:
|
|
101
101
|
r"""Generates a valid collection name from a given file path or URL.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
|
|
105
|
-
generate the collection name.
|
|
104
|
+
contents (str): Local file path, remote URL or string content.
|
|
106
105
|
|
|
107
106
|
Returns:
|
|
108
107
|
str: A sanitized, valid collection name suitable for use.
|
|
109
108
|
"""
|
|
110
|
-
# Check
|
|
111
|
-
parsed_url = urlparse(
|
|
112
|
-
|
|
109
|
+
# Check if the content is URL
|
|
110
|
+
parsed_url = urlparse(content)
|
|
111
|
+
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
113
112
|
|
|
114
113
|
# Convert given path into a collection name, ensuring it only
|
|
115
114
|
# contains numbers, letters, and underscores
|
|
116
|
-
if
|
|
115
|
+
if is_url:
|
|
117
116
|
# For URLs, remove https://, replace /, and any characters not
|
|
118
117
|
# allowed by Milvus with _
|
|
119
118
|
collection_name = re.sub(
|
|
120
119
|
r'[^0-9a-zA-Z]+',
|
|
121
120
|
'_',
|
|
122
|
-
|
|
121
|
+
content.replace("https://", ""),
|
|
123
122
|
)
|
|
124
|
-
|
|
123
|
+
elif os.path.exists(content):
|
|
125
124
|
# For file paths, get the stem and replace spaces with _, also
|
|
126
125
|
# ensuring only allowed characters are present
|
|
127
|
-
collection_name = re.sub(
|
|
128
|
-
|
|
129
|
-
|
|
126
|
+
collection_name = re.sub(r'[^0-9a-zA-Z]+', '_', Path(content).stem)
|
|
127
|
+
else:
|
|
128
|
+
# the content is string input
|
|
129
|
+
collection_name = content[:10]
|
|
130
130
|
|
|
131
131
|
# Ensure the collection name does not start or end with underscore
|
|
132
132
|
collection_name = collection_name.strip("_")
|
|
@@ -193,7 +193,7 @@ class AutoRetriever:
|
|
|
193
193
|
def run_vector_retriever(
|
|
194
194
|
self,
|
|
195
195
|
query: str,
|
|
196
|
-
|
|
196
|
+
contents: Union[str, List[str]],
|
|
197
197
|
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
198
198
|
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
199
199
|
return_detailed_info: bool = False,
|
|
@@ -203,8 +203,8 @@ class AutoRetriever:
|
|
|
203
203
|
|
|
204
204
|
Args:
|
|
205
205
|
query (str): Query string for information retriever.
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
contents (Union[str, List[str]]): Local file paths, remote URLs or
|
|
207
|
+
string contents.
|
|
208
208
|
top_k (int, optional): The number of top results to return during
|
|
209
209
|
retrieve. Must be a positive integer. Defaults to
|
|
210
210
|
`DEFAULT_TOP_K_RESULTS`.
|
|
@@ -223,24 +223,18 @@ class AutoRetriever:
|
|
|
223
223
|
Raises:
|
|
224
224
|
ValueError: If there's an vector storage existing with content
|
|
225
225
|
name in the vector path but the payload is None. If
|
|
226
|
-
`
|
|
226
|
+
`contents` is empty.
|
|
227
227
|
RuntimeError: If any errors occur during the retrieve process.
|
|
228
228
|
"""
|
|
229
|
-
if not
|
|
230
|
-
raise ValueError("
|
|
229
|
+
if not contents:
|
|
230
|
+
raise ValueError("content cannot be empty.")
|
|
231
231
|
|
|
232
|
-
|
|
233
|
-
[content_input_paths]
|
|
234
|
-
if isinstance(content_input_paths, str)
|
|
235
|
-
else content_input_paths
|
|
236
|
-
)
|
|
232
|
+
contents = [contents] if isinstance(contents, str) else contents
|
|
237
233
|
|
|
238
234
|
all_retrieved_info = []
|
|
239
|
-
for
|
|
235
|
+
for content in contents:
|
|
240
236
|
# Generate a valid collection name
|
|
241
|
-
collection_name = self._collection_name_generator(
|
|
242
|
-
content_input_path
|
|
243
|
-
)
|
|
237
|
+
collection_name = self._collection_name_generator(content)
|
|
244
238
|
try:
|
|
245
239
|
vector_storage_instance = self._initialize_vector_storage(
|
|
246
240
|
collection_name
|
|
@@ -251,13 +245,11 @@ class AutoRetriever:
|
|
|
251
245
|
file_is_modified = False # initialize with a default value
|
|
252
246
|
if (
|
|
253
247
|
vector_storage_instance.status().vector_count != 0
|
|
254
|
-
and
|
|
248
|
+
and os.path.exists(content)
|
|
255
249
|
):
|
|
256
250
|
# Get original modified date from file
|
|
257
251
|
modified_date_from_file = (
|
|
258
|
-
self._get_file_modified_date_from_file(
|
|
259
|
-
content_input_path
|
|
260
|
-
)
|
|
252
|
+
self._get_file_modified_date_from_file(content)
|
|
261
253
|
)
|
|
262
254
|
# Get modified date from vector storage
|
|
263
255
|
modified_date_from_storage = (
|
|
@@ -280,18 +272,16 @@ class AutoRetriever:
|
|
|
280
272
|
# Process and store the content to the vector storage
|
|
281
273
|
vr = VectorRetriever(
|
|
282
274
|
storage=vector_storage_instance,
|
|
283
|
-
similarity_threshold=similarity_threshold,
|
|
284
275
|
embedding_model=self.embedding_model,
|
|
285
276
|
)
|
|
286
|
-
vr.process(
|
|
277
|
+
vr.process(content)
|
|
287
278
|
else:
|
|
288
279
|
vr = VectorRetriever(
|
|
289
280
|
storage=vector_storage_instance,
|
|
290
|
-
similarity_threshold=similarity_threshold,
|
|
291
281
|
embedding_model=self.embedding_model,
|
|
292
282
|
)
|
|
293
283
|
# Retrieve info by given query from the vector storage
|
|
294
|
-
retrieved_info = vr.query(query, top_k)
|
|
284
|
+
retrieved_info = vr.query(query, top_k, similarity_threshold)
|
|
295
285
|
all_retrieved_info.extend(retrieved_info)
|
|
296
286
|
except Exception as e:
|
|
297
287
|
raise RuntimeError(
|
|
@@ -11,7 +11,9 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import os
|
|
14
15
|
from typing import Any, Dict, List, Optional
|
|
16
|
+
from urllib.parse import urlparse
|
|
15
17
|
|
|
16
18
|
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
17
19
|
from camel.loaders import UnstructuredIO
|
|
@@ -38,24 +40,18 @@ class VectorRetriever(BaseRetriever):
|
|
|
38
40
|
embedding_model (BaseEmbedding): Embedding model used to generate
|
|
39
41
|
vector embeddings.
|
|
40
42
|
storage (BaseVectorStorage): Vector storage to query.
|
|
41
|
-
similarity_threshold (float, optional): The similarity threshold
|
|
42
|
-
for filtering results. Defaults to `DEFAULT_SIMILARITY_THRESHOLD`.
|
|
43
43
|
unstructured_modules (UnstructuredIO): A module for parsing files and
|
|
44
44
|
URLs and chunking content based on specified parameters.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
def __init__(
|
|
48
48
|
self,
|
|
49
|
-
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
50
49
|
embedding_model: Optional[BaseEmbedding] = None,
|
|
51
50
|
storage: Optional[BaseVectorStorage] = None,
|
|
52
51
|
) -> None:
|
|
53
52
|
r"""Initializes the retriever class with an optional embedding model.
|
|
54
53
|
|
|
55
54
|
Args:
|
|
56
|
-
similarity_threshold (float, optional): The similarity threshold
|
|
57
|
-
for filtering results. Defaults to
|
|
58
|
-
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
59
55
|
embedding_model (Optional[BaseEmbedding]): The embedding model
|
|
60
56
|
instance. Defaults to `OpenAIEmbedding` if not provided.
|
|
61
57
|
storage (BaseVectorStorage): Vector storage to query.
|
|
@@ -68,12 +64,11 @@ class VectorRetriever(BaseRetriever):
|
|
|
68
64
|
vector_dim=self.embedding_model.get_output_dim()
|
|
69
65
|
)
|
|
70
66
|
)
|
|
71
|
-
self.
|
|
72
|
-
self.unstructured_modules: UnstructuredIO = UnstructuredIO()
|
|
67
|
+
self.uio: UnstructuredIO = UnstructuredIO()
|
|
73
68
|
|
|
74
69
|
def process(
|
|
75
70
|
self,
|
|
76
|
-
|
|
71
|
+
content: str,
|
|
77
72
|
chunk_type: str = "chunk_by_title",
|
|
78
73
|
**kwargs: Any,
|
|
79
74
|
) -> None:
|
|
@@ -82,16 +77,19 @@ class VectorRetriever(BaseRetriever):
|
|
|
82
77
|
vector storage.
|
|
83
78
|
|
|
84
79
|
Args:
|
|
85
|
-
|
|
86
|
-
processed.
|
|
80
|
+
contents (str): Local file path, remote URL or string content.
|
|
87
81
|
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
88
82
|
"chunk_by_title".
|
|
89
83
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
90
84
|
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
)
|
|
94
|
-
|
|
85
|
+
# Check if the content is URL
|
|
86
|
+
parsed_url = urlparse(content)
|
|
87
|
+
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
88
|
+
if is_url or os.path.exists(content):
|
|
89
|
+
elements = self.uio.parse_file_or_url(content, **kwargs)
|
|
90
|
+
else:
|
|
91
|
+
elements = [self.uio.create_element_from_text(text=content)]
|
|
92
|
+
chunks = self.uio.chunk_elements(
|
|
95
93
|
chunk_type=chunk_type, elements=elements
|
|
96
94
|
)
|
|
97
95
|
# Iterate to process and store embeddings, set batch of 50
|
|
@@ -105,7 +103,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
105
103
|
# Prepare the payload for each vector record, includes the content
|
|
106
104
|
# path, chunk metadata, and chunk text
|
|
107
105
|
for vector, chunk in zip(batch_vectors, batch_chunks):
|
|
108
|
-
content_path_info = {"content path":
|
|
106
|
+
content_path_info = {"content path": content}
|
|
109
107
|
chunk_metadata = {"metadata": chunk.metadata.to_dict()}
|
|
110
108
|
chunk_text = {"text": str(chunk)}
|
|
111
109
|
combined_dict = {
|
|
@@ -124,12 +122,16 @@ class VectorRetriever(BaseRetriever):
|
|
|
124
122
|
self,
|
|
125
123
|
query: str,
|
|
126
124
|
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
125
|
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
127
126
|
) -> List[Dict[str, Any]]:
|
|
128
127
|
r"""Executes a query in vector storage and compiles the retrieved
|
|
129
128
|
results into a dictionary.
|
|
130
129
|
|
|
131
130
|
Args:
|
|
132
131
|
query (str): Query string for information retriever.
|
|
132
|
+
similarity_threshold (float, optional): The similarity threshold
|
|
133
|
+
for filtering results. Defaults to
|
|
134
|
+
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
133
135
|
top_k (int, optional): The number of top results to return during
|
|
134
136
|
retriever. Must be a positive integer. Defaults to 1.
|
|
135
137
|
|
|
@@ -161,7 +163,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
161
163
|
formatted_results = []
|
|
162
164
|
for result in query_results:
|
|
163
165
|
if (
|
|
164
|
-
result.similarity >=
|
|
166
|
+
result.similarity >= similarity_threshold
|
|
165
167
|
and result.record.payload is not None
|
|
166
168
|
):
|
|
167
169
|
result_dict = {
|
|
@@ -182,7 +184,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
182
184
|
'text': (
|
|
183
185
|
f"No suitable information retrieved "
|
|
184
186
|
f"from {content_path} with similarity_threshold"
|
|
185
|
-
f" = {
|
|
187
|
+
f" = {similarity_threshold}."
|
|
186
188
|
)
|
|
187
189
|
}
|
|
188
190
|
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from .amazon_s3 import AmazonS3Storage
|
|
15
|
+
from .azure_blob import AzureBlobStorage
|
|
16
|
+
from .google_cloud import GoogleCloudStorage
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"AmazonS3Storage",
|
|
20
|
+
"AzureBlobStorage",
|
|
21
|
+
"GoogleCloudStorage",
|
|
22
|
+
]
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from pathlib import Path, PurePath
|
|
17
|
+
from typing import Optional, Tuple
|
|
18
|
+
from warnings import warn
|
|
19
|
+
|
|
20
|
+
from camel.loaders import File
|
|
21
|
+
from camel.storages.object_storages.base import BaseObjectStorage
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AmazonS3Storage(BaseObjectStorage):
|
|
25
|
+
r"""A class to connect with AWS S3 object storage to put and get objects
|
|
26
|
+
from one S3 bucket. The class will first try to use the credentials passed
|
|
27
|
+
as arguments, if not provided, it will look for the environment variables
|
|
28
|
+
`AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. If none of these are
|
|
29
|
+
provided, it will try to use the local credentials (will be created if
|
|
30
|
+
logged in with AWS CLI).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
bucket_name (str): The name of the S3 bucket.
|
|
34
|
+
create_if_not_exists (bool, optional): Whether to create the bucket if
|
|
35
|
+
it does not exist. Defaults to True.
|
|
36
|
+
access_key_id (Optional[str], optional): The AWS access key ID.
|
|
37
|
+
Defaults to None.
|
|
38
|
+
secret_access_key (Optional[str], optional): The AWS secret access key.
|
|
39
|
+
Defaults to None.
|
|
40
|
+
anonymous (bool, optional): Whether to use anonymous access. Defaults
|
|
41
|
+
to False.
|
|
42
|
+
|
|
43
|
+
References:
|
|
44
|
+
https://aws.amazon.com/pm/serv-s3/
|
|
45
|
+
|
|
46
|
+
https://aws.amazon.com/cli/
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
bucket_name: str,
|
|
52
|
+
create_if_not_exists: bool = True,
|
|
53
|
+
access_key_id: Optional[str] = None,
|
|
54
|
+
secret_access_key: Optional[str] = None,
|
|
55
|
+
anonymous: bool = False,
|
|
56
|
+
) -> None:
|
|
57
|
+
self._bucket_name = bucket_name
|
|
58
|
+
self._create_if_not_exists = create_if_not_exists
|
|
59
|
+
|
|
60
|
+
aws_key_id = access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
|
61
|
+
aws_secret_key = secret_access_key or os.getenv(
|
|
62
|
+
"AWS_SECRET_ACCESS_KEY"
|
|
63
|
+
)
|
|
64
|
+
if not all([aws_key_id, aws_secret_key]) and not anonymous:
|
|
65
|
+
warn(
|
|
66
|
+
"AWS access key not configured. Local credentials will be "
|
|
67
|
+
"used."
|
|
68
|
+
)
|
|
69
|
+
# Make all the empty values None
|
|
70
|
+
aws_key_id = None
|
|
71
|
+
aws_secret_key = None
|
|
72
|
+
|
|
73
|
+
import boto3
|
|
74
|
+
from botocore import UNSIGNED
|
|
75
|
+
from botocore.config import Config
|
|
76
|
+
|
|
77
|
+
if not anonymous:
|
|
78
|
+
self._client = boto3.client(
|
|
79
|
+
"s3",
|
|
80
|
+
aws_access_key_id=aws_key_id,
|
|
81
|
+
aws_secret_access_key=aws_secret_key,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
self._client = boto3.client(
|
|
85
|
+
"s3", config=Config(signature_version=UNSIGNED)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
self._prepare_and_check()
|
|
89
|
+
|
|
90
|
+
def _prepare_and_check(self) -> None:
|
|
91
|
+
r"""Check privileges and existence of the bucket."""
|
|
92
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
self._client.head_bucket(Bucket=self._bucket_name)
|
|
96
|
+
except ClientError as e:
|
|
97
|
+
error_code = e.response['Error']['Code']
|
|
98
|
+
if error_code == '403':
|
|
99
|
+
raise PermissionError(
|
|
100
|
+
f"Failed to access bucket {self._bucket_name}: "
|
|
101
|
+
f"No permission."
|
|
102
|
+
)
|
|
103
|
+
elif error_code == '404':
|
|
104
|
+
if self._create_if_not_exists:
|
|
105
|
+
self._client.create_bucket(Bucket=self._bucket_name)
|
|
106
|
+
warn(
|
|
107
|
+
f"Bucket {self._bucket_name} not found. Automatically "
|
|
108
|
+
f"created."
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
raise FileNotFoundError(
|
|
112
|
+
f"Failed to access bucket {self._bucket_name}: Not "
|
|
113
|
+
f"found."
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
raise e
|
|
117
|
+
except NoCredentialsError as e:
|
|
118
|
+
raise PermissionError("No AWS credentials found.") from e
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def canonicalize_path(file_path: PurePath) -> Tuple[str, str]:
|
|
122
|
+
r"""Canonicalize file path for Amazon S3.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
file_path (PurePath): The path to be canonicalized.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple[str, str]: The canonicalized file key and file name.
|
|
129
|
+
"""
|
|
130
|
+
return file_path.as_posix(), file_path.name
|
|
131
|
+
|
|
132
|
+
def _put_file(self, file_key: str, file: File) -> None:
|
|
133
|
+
r"""Put a file to the Amazon S3 bucket.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
file_key (str): The path to the object in the bucket.
|
|
137
|
+
file (File): The file to be uploaded.
|
|
138
|
+
"""
|
|
139
|
+
self._client.put_object(
|
|
140
|
+
Bucket=self._bucket_name, Key=file_key, Body=file.raw_bytes
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _get_file(self, file_key: str, filename: str) -> File:
|
|
144
|
+
r"""Get a file from the Amazon S3 bucket.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
file_key (str): The path to the object in the bucket.
|
|
148
|
+
filename (str): The name of the file.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
File: The object from the S3 bucket.
|
|
152
|
+
"""
|
|
153
|
+
response = self._client.get_object(
|
|
154
|
+
Bucket=self._bucket_name, Key=file_key
|
|
155
|
+
)
|
|
156
|
+
raw_bytes = response["Body"].read()
|
|
157
|
+
return File.create_file_from_raw_bytes(raw_bytes, filename)
|
|
158
|
+
|
|
159
|
+
def _upload_file(
|
|
160
|
+
self, local_file_path: Path, remote_file_key: str
|
|
161
|
+
) -> None:
|
|
162
|
+
r"""Upload a local file to the Amazon S3 bucket.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
local_file_path (Path): The path to the local file to be uploaded.
|
|
166
|
+
remote_file_key (str): The path to the object in the bucket.
|
|
167
|
+
"""
|
|
168
|
+
self._client.upload_file(
|
|
169
|
+
Bucket=self._bucket_name,
|
|
170
|
+
Key=remote_file_key,
|
|
171
|
+
Filename=local_file_path,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _download_file(
|
|
175
|
+
self,
|
|
176
|
+
local_file_path: Path,
|
|
177
|
+
remote_file_key: str,
|
|
178
|
+
) -> None:
|
|
179
|
+
r"""Download a file from the Amazon S3 bucket to the local system.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
local_file_path (Path): The path to the local file to be saved.
|
|
183
|
+
remote_file_key (str): The key of the object in the bucket.
|
|
184
|
+
"""
|
|
185
|
+
self._client.download_file(
|
|
186
|
+
Bucket=self._bucket_name,
|
|
187
|
+
Key=remote_file_key,
|
|
188
|
+
Filename=local_file_path,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _object_exists(self, file_key: str) -> bool:
|
|
192
|
+
r"""
|
|
193
|
+
Check if the object exists in the Amazon S3 bucket.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
file_key: The key of the object in the bucket.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
bool: Whether the object exists in the bucket.
|
|
200
|
+
"""
|
|
201
|
+
try:
|
|
202
|
+
self._client.head_object(Bucket=self._bucket_name, Key=file_key)
|
|
203
|
+
return True
|
|
204
|
+
except self._client.exceptions.ClientError:
|
|
205
|
+
return False
|