clarifai 10.1.0__py3-none-any.whl → 10.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +23 -43
- clarifai/client/base.py +44 -4
- clarifai/client/dataset.py +8 -12
- clarifai/client/input.py +29 -1
- clarifai/client/model.py +191 -10
- clarifai/client/module.py +7 -5
- clarifai/client/runner.py +3 -1
- clarifai/client/search.py +6 -3
- clarifai/client/user.py +14 -12
- clarifai/client/workflow.py +7 -4
- clarifai/datasets/upload/loaders/README.md +3 -4
- clarifai/datasets/upload/loaders/xview_detection.py +5 -5
- clarifai/rag/rag.py +25 -11
- clarifai/rag/utils.py +21 -6
- clarifai/utils/evaluation/__init__.py +427 -0
- clarifai/utils/evaluation/helpers.py +522 -0
- clarifai/utils/model_train.py +3 -1
- clarifai/versions.py +1 -1
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/METADATA +32 -7
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/RECORD +24 -23
- clarifai/datasets/upload/loaders/coco_segmentation.py +0 -98
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/LICENSE +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/WHEEL +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/entry_points.txt +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/top_level.txt +0 -0
clarifai/client/module.py
CHANGED
@@ -18,6 +18,7 @@ class Module(Lister, BaseClient):
|
|
18
18
|
module_version: Dict = {'id': ""},
|
19
19
|
base_url: str = "https://api.clarifai.com",
|
20
20
|
pat: str = None,
|
21
|
+
token: str = None,
|
21
22
|
**kwargs):
|
22
23
|
"""Initializes a Module object.
|
23
24
|
|
@@ -26,7 +27,8 @@ class Module(Lister, BaseClient):
|
|
26
27
|
module_id (str): The Module ID to interact with.
|
27
28
|
module_version (dict): The Module Version to interact with.
|
28
29
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
29
|
-
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
|
31
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
|
30
32
|
**kwargs: Additional keyword arguments to be passed to the Module.
|
31
33
|
"""
|
32
34
|
if url and module_id:
|
@@ -41,7 +43,8 @@ class Module(Lister, BaseClient):
|
|
41
43
|
self.kwargs = {**kwargs, 'id': module_id, 'module_version': module_version}
|
42
44
|
self.module_info = resources_pb2.Module(**self.kwargs)
|
43
45
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
44
|
-
BaseClient.__init__(
|
46
|
+
BaseClient.__init__(
|
47
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
45
48
|
Lister.__init__(self)
|
46
49
|
|
47
50
|
def list_versions(self, page_no: int = None,
|
@@ -78,10 +81,9 @@ class Module(Lister, BaseClient):
|
|
78
81
|
for module_version_info in all_module_versions_info:
|
79
82
|
module_version_info['id'] = module_version_info['module_version_id']
|
80
83
|
del module_version_info['module_version_id']
|
81
|
-
yield Module(
|
84
|
+
yield Module.from_auth_helper(
|
85
|
+
self.auth_helper,
|
82
86
|
module_id=self.id,
|
83
|
-
base_url=self.base,
|
84
|
-
pat=self.pat,
|
85
87
|
**dict(self.kwargs, module_version=module_version_info))
|
86
88
|
|
87
89
|
def __getattr__(self, name):
|
clarifai/client/runner.py
CHANGED
@@ -39,6 +39,7 @@ class Runner(BaseClient):
|
|
39
39
|
check_runner_exists: bool = True,
|
40
40
|
base_url: str = "https://api.clarifai.com",
|
41
41
|
pat: str = None,
|
42
|
+
token: str = None,
|
42
43
|
num_parallel_polls: int = 4,
|
43
44
|
**kwargs) -> None:
|
44
45
|
"""
|
@@ -47,6 +48,7 @@ class Runner(BaseClient):
|
|
47
48
|
user_id (str): Clarifai User ID
|
48
49
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
49
50
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
51
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
50
52
|
num_parallel_polls (int): the max number of threads for parallel run loops to be fetching work from
|
51
53
|
"""
|
52
54
|
user_id = user_id or os.environ.get("CLARIFAI_USER_ID", "")
|
@@ -60,7 +62,7 @@ class Runner(BaseClient):
|
|
60
62
|
self.kwargs = {**kwargs, 'id': runner_id, 'user_id': user_id}
|
61
63
|
self.runner_info = resources_pb2.Runner(**self.kwargs)
|
62
64
|
self.num_parallel_polls = min(10, num_parallel_polls)
|
63
|
-
BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat)
|
65
|
+
BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat, token=token)
|
64
66
|
|
65
67
|
# Check that the runner exists.
|
66
68
|
if check_runner_exists:
|
clarifai/client/search.py
CHANGED
@@ -23,7 +23,8 @@ class Search(Lister, BaseClient):
|
|
23
23
|
top_k: int = DEFAULT_TOP_K,
|
24
24
|
metric: str = DEFAULT_SEARCH_METRIC,
|
25
25
|
base_url: str = "https://api.clarifai.com",
|
26
|
-
pat: str = None
|
26
|
+
pat: str = None,
|
27
|
+
token: str = None):
|
27
28
|
"""Initialize the Search object.
|
28
29
|
|
29
30
|
Args:
|
@@ -33,6 +34,7 @@ class Search(Lister, BaseClient):
|
|
33
34
|
metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
|
34
35
|
base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
|
35
36
|
pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
37
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
36
38
|
|
37
39
|
Raises:
|
38
40
|
UserError: If the metric is not 'cosine' or 'euclidean'.
|
@@ -46,9 +48,10 @@ class Search(Lister, BaseClient):
|
|
46
48
|
self.data_proto = resources_pb2.Data()
|
47
49
|
self.top_k = top_k
|
48
50
|
|
49
|
-
self.inputs = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat)
|
51
|
+
self.inputs = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat, token=token)
|
50
52
|
self.rank_filter_schema = get_schema()
|
51
|
-
BaseClient.__init__(
|
53
|
+
BaseClient.__init__(
|
54
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
52
55
|
Lister.__init__(self, page_size=1000)
|
53
56
|
|
54
57
|
def _get_annot_proto(self, **kwargs):
|
clarifai/client/user.py
CHANGED
@@ -19,6 +19,7 @@ class User(Lister, BaseClient):
|
|
19
19
|
user_id: str = None,
|
20
20
|
base_url: str = "https://api.clarifai.com",
|
21
21
|
pat: str = None,
|
22
|
+
token: str = None,
|
22
23
|
**kwargs):
|
23
24
|
"""Initializes an User object.
|
24
25
|
|
@@ -26,12 +27,13 @@ class User(Lister, BaseClient):
|
|
26
27
|
user_id (str): The user ID for the user to interact with.
|
27
28
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
28
29
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
29
31
|
**kwargs: Additional keyword arguments to be passed to the User.
|
30
32
|
"""
|
31
33
|
self.kwargs = {**kwargs, 'id': user_id}
|
32
34
|
self.user_info = resources_pb2.User(**self.kwargs)
|
33
35
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
34
|
-
BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat)
|
36
|
+
BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat, token=token)
|
35
37
|
Lister.__init__(self)
|
36
38
|
|
37
39
|
def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
@@ -62,7 +64,9 @@ class User(Lister, BaseClient):
|
|
62
64
|
per_page=per_page,
|
63
65
|
page_no=page_no)
|
64
66
|
for app_info in all_apps_info:
|
65
|
-
yield App(
|
67
|
+
yield App.from_auth_helper(
|
68
|
+
self.auth_helper,
|
69
|
+
**app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
|
66
70
|
|
67
71
|
def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
68
72
|
per_page: int = None) -> Generator[Runner, None, None]:
|
@@ -94,7 +98,8 @@ class User(Lister, BaseClient):
|
|
94
98
|
page_no=page_no)
|
95
99
|
|
96
100
|
for runner_info in all_runners_info:
|
97
|
-
yield Runner(
|
101
|
+
yield Runner.from_auth_helper(
|
102
|
+
auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
98
103
|
|
99
104
|
def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
|
100
105
|
"""Creates an app for the user.
|
@@ -120,8 +125,7 @@ class User(Lister, BaseClient):
|
|
120
125
|
if response.status.code != status_code_pb2.SUCCESS:
|
121
126
|
raise Exception(response.status)
|
122
127
|
self.logger.info("\nApp created\n%s", response.status)
|
123
|
-
|
124
|
-
return App(app_id=app_id, **kwargs)
|
128
|
+
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
|
125
129
|
|
126
130
|
def create_runner(self, runner_id: str, labels: List[str], description: str) -> Runner:
|
127
131
|
"""Create a runner
|
@@ -151,14 +155,13 @@ class User(Lister, BaseClient):
|
|
151
155
|
raise Exception(response.status)
|
152
156
|
self.logger.info("\nRunner created\n%s", response.status)
|
153
157
|
|
154
|
-
return Runner(
|
158
|
+
return Runner.from_auth_helper(
|
159
|
+
auth=self.auth_helper,
|
155
160
|
runner_id=runner_id,
|
156
161
|
user_id=self.id,
|
157
162
|
labels=labels,
|
158
163
|
description=description,
|
159
|
-
check_runner_exists=False
|
160
|
-
base_url=self.base,
|
161
|
-
pat=self.pat)
|
164
|
+
check_runner_exists=False)
|
162
165
|
|
163
166
|
def app(self, app_id: str, **kwargs) -> App:
|
164
167
|
"""Returns an App object for the specified app ID.
|
@@ -181,8 +184,7 @@ class User(Lister, BaseClient):
|
|
181
184
|
raise Exception(response.status)
|
182
185
|
|
183
186
|
kwargs['user_id'] = self.id
|
184
|
-
|
185
|
-
return App(app_id=app_id, **kwargs)
|
187
|
+
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
|
186
188
|
|
187
189
|
def runner(self, runner_id: str) -> Runner:
|
188
190
|
"""Returns a Runner object if exists.
|
@@ -210,7 +212,7 @@ class User(Lister, BaseClient):
|
|
210
212
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
211
213
|
list(dict_response.keys())[1])
|
212
214
|
|
213
|
-
return Runner(
|
215
|
+
return Runner.from_auth_helper(self.auth_helper, check_runner_exists=False, **kwargs)
|
214
216
|
|
215
217
|
def delete_app(self, app_id: str) -> None:
|
216
218
|
"""Deletes an app for the user.
|
clarifai/client/workflow.py
CHANGED
@@ -27,6 +27,7 @@ class Workflow(Lister, BaseClient):
|
|
27
27
|
output_config: Dict = {'min_value': 0},
|
28
28
|
base_url: str = "https://api.clarifai.com",
|
29
29
|
pat: str = None,
|
30
|
+
token: str = None,
|
30
31
|
**kwargs):
|
31
32
|
"""Initializes a Workflow object.
|
32
33
|
|
@@ -40,6 +41,8 @@ class Workflow(Lister, BaseClient):
|
|
40
41
|
select_concepts (list[Concept]): The concepts to select.
|
41
42
|
sample_ms (int): The number of milliseconds to sample.
|
42
43
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
44
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
43
46
|
**kwargs: Additional keyword arguments to be passed to the Workflow.
|
44
47
|
"""
|
45
48
|
if url and workflow_id:
|
@@ -55,7 +58,8 @@ class Workflow(Lister, BaseClient):
|
|
55
58
|
self.output_config = output_config
|
56
59
|
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
|
57
60
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
58
|
-
BaseClient.__init__(
|
61
|
+
BaseClient.__init__(
|
62
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
59
63
|
Lister.__init__(self)
|
60
64
|
|
61
65
|
def predict(self, inputs: List[Input], workflow_state_id: str = None):
|
@@ -206,10 +210,9 @@ class Workflow(Lister, BaseClient):
|
|
206
210
|
for workflow_version_info in all_workflow_versions_info:
|
207
211
|
workflow_version_info['id'] = workflow_version_info['workflow_version_id']
|
208
212
|
del workflow_version_info['workflow_version_id']
|
209
|
-
yield Workflow(
|
213
|
+
yield Workflow.from_auth_helper(
|
214
|
+
auth=self.auth_helper,
|
210
215
|
workflow_id=self.id,
|
211
|
-
base_url=self.base,
|
212
|
-
pat=self.pat,
|
213
216
|
**dict(self.kwargs, version=workflow_version_info))
|
214
217
|
|
215
218
|
def export(self, out_path: str):
|
@@ -8,15 +8,15 @@ If a dataset module exists in the zoo, uploading the specific dataset can be eas
|
|
8
8
|
|
9
9
|
```python
|
10
10
|
from clarifai.client.app import App
|
11
|
-
from clarifai.datasets.upload.loaders.
|
11
|
+
from clarifai.datasets.upload.loaders.coco_detection import COCODetectionDataLoader
|
12
12
|
|
13
13
|
app = App(app_id="", user_id="")
|
14
14
|
# Create a dataset in Clarifai App
|
15
15
|
dataset = app.create_dataset(dataset_id="")
|
16
16
|
# instantiate dataloader object
|
17
|
-
|
17
|
+
coco_det_dataloader = COCODetectionDataLoader(images_dir="", label_filepath="")
|
18
18
|
# execute data upload to Clarifai app dataset
|
19
|
-
dataset.upload_dataset(dataloader=
|
19
|
+
dataset.upload_dataset(dataloader=coco_det_dataloader)
|
20
20
|
```
|
21
21
|
|
22
22
|
## Dataset Loaders
|
@@ -24,7 +24,6 @@ dataset.upload_dataset(dataloader=coco_seg_dataloader)
|
|
24
24
|
| dataset name | task | module name (.py)
|
25
25
|
| --- | --- | ---
|
26
26
|
| [COCO 2017](https://cocodataset.org/#download) | Detection | `coco_detection` |
|
27
|
-
| | Segmentation | `coco_segmentation` |
|
28
27
|
| | Captions | `coco_captions` |
|
29
28
|
|[xVIEW](http://xviewdataset.org/) | Detection | `xview_detection` |
|
30
29
|
| [ImageNet](https://www.image-net.org/) | Classification | `imagenet_classification` |
|
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
6
6
|
from multiprocessing import cpu_count
|
7
7
|
from typing import DefaultDict, Dict, List
|
8
8
|
|
9
|
-
import
|
9
|
+
from PIL import Image
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
12
|
from clarifai.datasets.upload.base import ClarifaiDataLoader
|
@@ -54,9 +54,8 @@ class xviewDetectionDataLoader(ClarifaiDataLoader):
|
|
54
54
|
def compress_tiff(self, img_path: str) -> None:
|
55
55
|
"""Compress tiff image"""
|
56
56
|
img_comp_path = os.path.join(self.img_comp_dir, os.path.basename(img_path))
|
57
|
-
img_arr =
|
58
|
-
|
59
|
-
img_comp_path, img_arr, params=(cv2.IMWRITE_TIFF_COMPRESSION, 8)) # 8: Adobe Deflate
|
57
|
+
img_arr = Image.open(img_path)
|
58
|
+
img_arr.save(img_comp_path, 'TIFF', compression='tiff_deflate')
|
60
59
|
|
61
60
|
def preprocess(self):
|
62
61
|
"""Compress the tiff images to comply with clarifai grpc image encoding limit(<20MB) Uses ADOBE_DEFLATE compression algorithm"""
|
@@ -133,7 +132,8 @@ class xviewDetectionDataLoader(ClarifaiDataLoader):
|
|
133
132
|
_id = os.path.splitext(os.path.basename(self.image_paths[index]))[0]
|
134
133
|
image_path = self.image_paths[index]
|
135
134
|
|
136
|
-
|
135
|
+
image = Image.open(image_path)
|
136
|
+
image_width, image_height = image.size
|
137
137
|
annots = []
|
138
138
|
class_names = []
|
139
139
|
for bbox, concept in zip(self.all_data[_id]['bboxes'], self.all_data[_id]['concepts']):
|
clarifai/rag/rag.py
CHANGED
@@ -76,16 +76,17 @@ class RAG:
|
|
76
76
|
>>> rag_agent = RAG.setup(app_url=YOUR_APP_URL)
|
77
77
|
>>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
|
78
78
|
"""
|
79
|
-
|
79
|
+
now_ts = str(int(datetime.now().timestamp()))
|
80
80
|
if user_id and not app_url:
|
81
81
|
user = User(user_id=user_id, base_url=base_url, pat=pat)
|
82
82
|
## Create an App
|
83
|
-
now_ts = str(int(datetime.now().timestamp()))
|
84
83
|
app_id = f"rag_app_{now_ts}"
|
85
84
|
app = user.create_app(app_id=app_id, base_workflow=base_workflow)
|
86
85
|
|
87
86
|
if not user_id and app_url:
|
88
87
|
app = App(url=app_url, pat=pat)
|
88
|
+
uid = app_url.split(".com/")[1].split("/")[0]
|
89
|
+
user = User(user_id=uid, base_url=base_url, pat=pat)
|
89
90
|
|
90
91
|
if user_id and app_url:
|
91
92
|
raise UserError("Must provide one of user_id or app_url, not both.")
|
@@ -95,7 +96,7 @@ class RAG:
|
|
95
96
|
"user_id or app_url must be provided. The user_id can be found at https://clarifai.com/settings."
|
96
97
|
)
|
97
98
|
|
98
|
-
llm = Model(llm_url)
|
99
|
+
llm = Model(url=llm_url, pat=pat)
|
99
100
|
|
100
101
|
min_score = kwargs.get("min_score", 0.95)
|
101
102
|
max_results = kwargs.get("max_results", 5)
|
@@ -109,8 +110,8 @@ class RAG:
|
|
109
110
|
prompter_model_params = {"params": params}
|
110
111
|
|
111
112
|
## Create rag-prompter model and version
|
112
|
-
|
113
|
-
|
113
|
+
model_id = f"prompter-{workflow_id}" if workflow_id is not None else f"rag-prompter-{now_ts}"
|
114
|
+
prompter_model = app.create_model(model_id=model_id, model_type_id="rag-prompter")
|
114
115
|
prompter_model = prompter_model.create_version(output_info=prompter_model_params)
|
115
116
|
|
116
117
|
## Generate a tmp yaml file for workflow creation
|
@@ -153,6 +154,8 @@ class RAG:
|
|
153
154
|
batch_size: int = 128,
|
154
155
|
chunk_size: int = 1024,
|
155
156
|
chunk_overlap: int = 200,
|
157
|
+
dataset_id: str = None,
|
158
|
+
metadata: dict = None,
|
156
159
|
**kwargs) -> None:
|
157
160
|
"""Uploads documents to the app.
|
158
161
|
- Read from a local directory or public url or local filename.
|
@@ -192,14 +195,15 @@ class RAG:
|
|
192
195
|
|
193
196
|
#splitting documents into chunks
|
194
197
|
text_chunks = []
|
195
|
-
|
198
|
+
metadata_list = []
|
196
199
|
|
197
200
|
#iterate through documents
|
198
201
|
for doc in documents:
|
202
|
+
doc_i = 0
|
199
203
|
cur_text_chunks = split_document(
|
200
204
|
text=doc.text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
|
201
205
|
text_chunks.extend(cur_text_chunks)
|
202
|
-
|
206
|
+
metadata_list.extend([doc.metadata for _ in range(len(cur_text_chunks))])
|
203
207
|
#if batch size is reached, upload the batch
|
204
208
|
if len(text_chunks) > batch_size:
|
205
209
|
for idx in range(0, len(text_chunks), batch_size):
|
@@ -208,18 +212,23 @@ class RAG:
|
|
208
212
|
batch_texts = text_chunks[0:batch_size]
|
209
213
|
batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
|
210
214
|
#metadata
|
211
|
-
batch_metadatas =
|
215
|
+
batch_metadatas = metadata_list[0:batch_size]
|
212
216
|
meta_list = []
|
213
217
|
for meta in batch_metadatas:
|
214
218
|
meta_struct = Struct()
|
215
219
|
meta_struct.update(meta)
|
220
|
+
meta_struct.update({"doc_chunk_no": doc_i})
|
221
|
+
if metadata and isinstance(metadata, dict):
|
222
|
+
meta_struct.update(metadata)
|
216
223
|
meta_list.append(meta_struct)
|
224
|
+
doc_i += 1
|
217
225
|
del batch_metadatas
|
218
226
|
#creating input proto
|
219
227
|
input_batch = [
|
220
228
|
self._app.inputs().get_text_input(
|
221
229
|
input_id=batch_ids[i],
|
222
230
|
raw_text=text,
|
231
|
+
dataset_id=dataset_id,
|
223
232
|
metadata=meta_list[i],
|
224
233
|
) for i, text in enumerate(batch_texts)
|
225
234
|
]
|
@@ -227,32 +236,37 @@ class RAG:
|
|
227
236
|
self._app.inputs().upload_inputs(inputs=input_batch)
|
228
237
|
#delete uploaded chunks
|
229
238
|
del text_chunks[0:batch_size]
|
230
|
-
del
|
239
|
+
del metadata_list[0:batch_size]
|
231
240
|
|
232
241
|
#uploading the remaining chunks
|
233
242
|
if len(text_chunks) > 0:
|
234
243
|
batch_size = len(text_chunks)
|
235
244
|
batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
|
236
245
|
#metadata
|
237
|
-
batch_metadatas =
|
246
|
+
batch_metadatas = metadata_list[0:batch_size]
|
238
247
|
meta_list = []
|
239
248
|
for meta in batch_metadatas:
|
240
249
|
meta_struct = Struct()
|
241
250
|
meta_struct.update(meta)
|
251
|
+
meta_struct.update({"doc_chunk_no": doc_i})
|
252
|
+
if metadata and isinstance(metadata, dict):
|
253
|
+
meta_struct.update(metadata)
|
242
254
|
meta_list.append(meta_struct)
|
255
|
+
doc_i += 1
|
243
256
|
del batch_metadatas
|
244
257
|
#creating input proto
|
245
258
|
input_batch = [
|
246
259
|
self._app.inputs().get_text_input(
|
247
260
|
input_id=batch_ids[i],
|
248
261
|
raw_text=text,
|
262
|
+
dataset_id=dataset_id,
|
249
263
|
metadata=meta_list[i],
|
250
264
|
) for i, text in enumerate(text_chunks)
|
251
265
|
]
|
252
266
|
#uploading input with metadata
|
253
267
|
self._app.inputs().upload_inputs(inputs=input_batch)
|
254
268
|
del text_chunks
|
255
|
-
del
|
269
|
+
del metadata_list
|
256
270
|
|
257
271
|
def chat(self, messages: List[dict], client_manage_state: bool = False) -> List[dict]:
|
258
272
|
"""Chat interface in OpenAI API format.
|
clarifai/rag/utils.py
CHANGED
@@ -3,10 +3,6 @@ from pathlib import Path
|
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import requests
|
6
|
-
from llama_index.core import Document, SimpleDirectoryReader
|
7
|
-
from llama_index.core.node_parser.text import SentenceSplitter
|
8
|
-
from llama_index.core.readers.download import download_loader
|
9
|
-
from pypdf import PdfReader
|
10
6
|
|
11
7
|
|
12
8
|
## TODO: Make this token-aware.
|
@@ -36,8 +32,7 @@ def format_assistant_message(raw_text: str) -> dict:
|
|
36
32
|
return {"role": "assistant", "content": raw_text}
|
37
33
|
|
38
34
|
|
39
|
-
def load_documents(file_path: str = None, folder_path: str = None,
|
40
|
-
url: str = None) -> List[Document]:
|
35
|
+
def load_documents(file_path: str = None, folder_path: str = None, url: str = None) -> List[any]:
|
41
36
|
"""Loads documents from a local directory or public url or local filename.
|
42
37
|
|
43
38
|
Args:
|
@@ -45,6 +40,13 @@ def load_documents(file_path: str = None, folder_path: str = None,
|
|
45
40
|
folder_path (str): The path to the folder.
|
46
41
|
url (str): The url to the file.
|
47
42
|
"""
|
43
|
+
#check import packages
|
44
|
+
try:
|
45
|
+
from llama_index.core import Document, SimpleDirectoryReader
|
46
|
+
from llama_index.core.readers.download import download_loader
|
47
|
+
except ImportError:
|
48
|
+
raise ImportError("Could not import llama index package. "
|
49
|
+
"Please install it with `pip install llama-index-core==0.10.1`.")
|
48
50
|
#document loaders for filepath
|
49
51
|
if file_path:
|
50
52
|
if file_path.endswith(".pdf"):
|
@@ -77,6 +79,12 @@ def load_documents(file_path: str = None, folder_path: str = None,
|
|
77
79
|
documents = [Document(text=response.content)]
|
78
80
|
#for pdf files
|
79
81
|
except Exception:
|
82
|
+
#check import packages
|
83
|
+
try:
|
84
|
+
from pypdf import PdfReader
|
85
|
+
except ImportError:
|
86
|
+
raise ImportError("Could not import pypdf package. "
|
87
|
+
"Please install it with `pip install pypdf==3.17.4`.")
|
80
88
|
documents = []
|
81
89
|
pdf_file = PdfReader(io.BytesIO(response.content))
|
82
90
|
num_pages = len(pdf_file.pages)
|
@@ -98,6 +106,13 @@ def split_document(text: str, chunk_size: int, chunk_overlap: int, **kwargs) ->
|
|
98
106
|
chunk_overlap (int): The amount of overlap between each chunk.
|
99
107
|
**kwargs: Additional keyword arguments for the SentenceSplitter.
|
100
108
|
"""
|
109
|
+
#check import packages
|
110
|
+
try:
|
111
|
+
from llama_index.core.node_parser.text import SentenceSplitter
|
112
|
+
except ImportError:
|
113
|
+
raise ImportError("Could not import llama index package. "
|
114
|
+
"Please install it with `pip install llama-index-core==0.10.1`.")
|
115
|
+
#document
|
101
116
|
text_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
|
102
117
|
text_chunks = text_parser.split_text(text)
|
103
118
|
return text_chunks
|