clarifai 10.0.1__py3-none-any.whl → 10.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +23 -43
- clarifai/client/base.py +46 -4
- clarifai/client/dataset.py +85 -33
- clarifai/client/input.py +35 -7
- clarifai/client/model.py +192 -11
- clarifai/client/module.py +8 -6
- clarifai/client/runner.py +3 -1
- clarifai/client/search.py +6 -3
- clarifai/client/user.py +14 -12
- clarifai/client/workflow.py +8 -5
- clarifai/datasets/upload/features.py +3 -0
- clarifai/datasets/upload/image.py +57 -26
- clarifai/datasets/upload/loaders/README.md +3 -4
- clarifai/datasets/upload/loaders/xview_detection.py +9 -5
- clarifai/datasets/upload/utils.py +23 -7
- clarifai/models/model_serving/README.md +113 -121
- clarifai/models/model_serving/__init__.py +2 -0
- clarifai/models/model_serving/cli/_utils.py +53 -0
- clarifai/models/model_serving/cli/base.py +14 -0
- clarifai/models/model_serving/cli/build.py +79 -0
- clarifai/models/model_serving/cli/clarifai_clis.py +33 -0
- clarifai/models/model_serving/cli/create.py +171 -0
- clarifai/models/model_serving/cli/example_cli.py +34 -0
- clarifai/models/model_serving/cli/login.py +26 -0
- clarifai/models/model_serving/cli/upload.py +182 -0
- clarifai/models/model_serving/constants.py +20 -0
- clarifai/models/model_serving/docs/cli.md +150 -0
- clarifai/models/model_serving/docs/concepts.md +229 -0
- clarifai/models/model_serving/docs/dependencies.md +1 -1
- clarifai/models/model_serving/docs/inference_parameters.md +112 -107
- clarifai/models/model_serving/docs/model_types.md +16 -17
- clarifai/models/model_serving/model_config/__init__.py +4 -2
- clarifai/models/model_serving/model_config/base.py +369 -0
- clarifai/models/model_serving/model_config/config.py +219 -224
- clarifai/models/model_serving/model_config/inference_parameter.py +5 -0
- clarifai/models/model_serving/model_config/model_types_config/multimodal-embedder.yaml +25 -24
- clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml +19 -18
- clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml +20 -18
- clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml +19 -18
- clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml +19 -18
- clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml +22 -18
- clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml +32 -28
- clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml +19 -18
- clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +19 -18
- clarifai/models/model_serving/{models → model_config}/output.py +8 -0
- clarifai/models/model_serving/model_config/triton/__init__.py +14 -0
- clarifai/models/model_serving/model_config/{serializer.py → triton/serializer.py} +3 -1
- clarifai/models/model_serving/model_config/triton/triton_config.py +182 -0
- clarifai/models/model_serving/{models/model_types.py → model_config/triton/wrappers.py} +4 -4
- clarifai/models/model_serving/{models → repo_build}/__init__.py +2 -0
- clarifai/models/model_serving/repo_build/build.py +198 -0
- clarifai/models/model_serving/repo_build/static_files/_requirements.txt +2 -0
- clarifai/models/model_serving/repo_build/static_files/base_test.py +169 -0
- clarifai/models/model_serving/repo_build/static_files/inference.py +26 -0
- clarifai/models/model_serving/repo_build/static_files/sample_clarifai_config.yaml +25 -0
- clarifai/models/model_serving/repo_build/static_files/test.py +40 -0
- clarifai/models/model_serving/{models/pb_model.py → repo_build/static_files/triton/model.py} +15 -14
- clarifai/models/model_serving/utils.py +21 -0
- clarifai/rag/rag.py +67 -23
- clarifai/rag/utils.py +21 -5
- clarifai/utils/evaluation/__init__.py +427 -0
- clarifai/utils/evaluation/helpers.py +522 -0
- clarifai/utils/logging.py +7 -0
- clarifai/utils/model_train.py +3 -1
- clarifai/versions.py +1 -1
- {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/METADATA +58 -10
- clarifai-10.1.1.dist-info/RECORD +115 -0
- clarifai-10.1.1.dist-info/entry_points.txt +2 -0
- clarifai/datasets/upload/loaders/coco_segmentation.py +0 -98
- clarifai/models/model_serving/cli/deploy_cli.py +0 -123
- clarifai/models/model_serving/cli/model_zip.py +0 -61
- clarifai/models/model_serving/cli/repository.py +0 -89
- clarifai/models/model_serving/docs/custom_config.md +0 -33
- clarifai/models/model_serving/docs/output.md +0 -28
- clarifai/models/model_serving/models/default_test.py +0 -281
- clarifai/models/model_serving/models/inference.py +0 -50
- clarifai/models/model_serving/models/test.py +0 -64
- clarifai/models/model_serving/pb_model_repository.py +0 -108
- clarifai-10.0.1.dist-info/RECORD +0 -103
- clarifai-10.0.1.dist-info/entry_points.txt +0 -4
- {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/LICENSE +0 -0
- {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/WHEEL +0 -0
- {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/top_level.txt +0 -0
clarifai/rag/rag.py
CHANGED
@@ -17,6 +17,8 @@ from clarifai.rag.utils import (convert_messages_to_str, format_assistant_messag
|
|
17
17
|
split_document)
|
18
18
|
from clarifai.utils.logging import get_logger
|
19
19
|
|
20
|
+
DEFAULT_RAG_PROMPT_TEMPLATE = "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
|
21
|
+
|
20
22
|
|
21
23
|
class RAG:
|
22
24
|
"""
|
@@ -24,7 +26,8 @@ class RAG:
|
|
24
26
|
|
25
27
|
Example:
|
26
28
|
>>> from clarifai.rag import RAG
|
27
|
-
>>> rag_agent = RAG()
|
29
|
+
>>> rag_agent = RAG(workflow_url=YOUR_WORKFLOW_URL)
|
30
|
+
>>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
|
28
31
|
"""
|
29
32
|
chat_state_id = None
|
30
33
|
|
@@ -49,43 +52,70 @@ class RAG:
|
|
49
52
|
@classmethod
|
50
53
|
def setup(cls,
|
51
54
|
user_id: str = None,
|
55
|
+
app_url: str = None,
|
52
56
|
llm_url: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
|
53
57
|
base_workflow: str = "Text",
|
54
58
|
workflow_yaml_filename: str = 'prompter_wf.yaml',
|
59
|
+
workflow_id: str = None,
|
55
60
|
base_url: str = "https://api.clarifai.com",
|
56
61
|
pat: str = None,
|
57
62
|
**kwargs):
|
58
63
|
"""Creates an app with `Text` as base workflow, create prompt model, create prompt workflow.
|
59
64
|
|
65
|
+
**kwargs: Additional keyword arguments to be passed to rag-promter model.
|
66
|
+
- min_score (float): The minimum score for search hits.
|
67
|
+
- max_results (float): The maximum number of search hits.
|
68
|
+
- prompt_template (str): The prompt template used. Must contain {data.hits} for the search hits and {data.text.raw} for the query string.
|
69
|
+
|
60
70
|
Example:
|
61
71
|
>>> from clarifai.rag import RAG
|
62
|
-
>>> rag_agent = RAG.setup()
|
72
|
+
>>> rag_agent = RAG.setup(user_id=YOUR_USER_ID)
|
73
|
+
>>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
|
74
|
+
|
75
|
+
Or if you already have an existing app with ingested data:
|
76
|
+
>>> rag_agent = RAG.setup(app_url=YOUR_APP_URL)
|
77
|
+
>>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
|
63
78
|
"""
|
64
|
-
|
79
|
+
now_ts = str(int(datetime.now().timestamp()))
|
80
|
+
if user_id and not app_url:
|
81
|
+
user = User(user_id=user_id, base_url=base_url, pat=pat)
|
82
|
+
## Create an App
|
83
|
+
app_id = f"rag_app_{now_ts}"
|
84
|
+
app = user.create_app(app_id=app_id, base_workflow=base_workflow)
|
85
|
+
|
86
|
+
if not user_id and app_url:
|
87
|
+
app = App(url=app_url, pat=pat)
|
88
|
+
uid = app_url.split(".com/")[1].split("/")[0]
|
89
|
+
user = User(user_id=uid, base_url=base_url, pat=pat)
|
90
|
+
|
91
|
+
if user_id and app_url:
|
92
|
+
raise UserError("Must provide one of user_id or app_url, not both.")
|
93
|
+
|
94
|
+
if not user_id and not app_url:
|
65
95
|
raise UserError(
|
66
|
-
"user_id must be provided.
|
67
|
-
|
68
|
-
|
96
|
+
"user_id or app_url must be provided. The user_id can be found at https://clarifai.com/settings."
|
97
|
+
)
|
98
|
+
|
99
|
+
llm = Model(url=llm_url, pat=pat)
|
69
100
|
|
101
|
+
min_score = kwargs.get("min_score", 0.95)
|
102
|
+
max_results = kwargs.get("max_results", 5)
|
103
|
+
prompt_template = kwargs.get("prompt_template", DEFAULT_RAG_PROMPT_TEMPLATE)
|
70
104
|
params = Struct()
|
71
105
|
params.update({
|
72
|
-
"
|
73
|
-
|
106
|
+
"min_score": min_score,
|
107
|
+
"max_results": max_results,
|
108
|
+
"prompt_template": prompt_template
|
74
109
|
})
|
75
110
|
prompter_model_params = {"params": params}
|
76
111
|
|
77
|
-
## Create an App
|
78
|
-
now_ts = str(int(datetime.now().timestamp()))
|
79
|
-
app_id = f"rag_app_{now_ts}"
|
80
|
-
app = user.create_app(app_id=app_id, base_workflow=base_workflow)
|
81
|
-
|
82
112
|
## Create rag-prompter model and version
|
83
|
-
|
84
|
-
|
113
|
+
model_id = f"prompter-{workflow_id}" if workflow_id is not None else f"rag-prompter-{now_ts}"
|
114
|
+
prompter_model = app.create_model(model_id=model_id, model_type_id="rag-prompter")
|
85
115
|
prompter_model = prompter_model.create_version(output_info=prompter_model_params)
|
86
116
|
|
87
117
|
## Generate a tmp yaml file for workflow creation
|
88
|
-
workflow_id = f"rag-wf-{now_ts}"
|
118
|
+
workflow_id = f"rag-wf-{now_ts}" if workflow_id is None else workflow_id
|
89
119
|
workflow_dict = {
|
90
120
|
"workflow": {
|
91
121
|
"id":
|
@@ -124,6 +154,8 @@ class RAG:
|
|
124
154
|
batch_size: int = 128,
|
125
155
|
chunk_size: int = 1024,
|
126
156
|
chunk_overlap: int = 200,
|
157
|
+
dataset_id: str = None,
|
158
|
+
metadata: dict = None,
|
127
159
|
**kwargs) -> None:
|
128
160
|
"""Uploads documents to the app.
|
129
161
|
- Read from a local directory or public url or local filename.
|
@@ -141,9 +173,10 @@ class RAG:
|
|
141
173
|
|
142
174
|
Example:
|
143
175
|
>>> from clarifai.rag import RAG
|
144
|
-
>>> rag_agent = RAG.setup()
|
176
|
+
>>> rag_agent = RAG.setup(user_id=YOUR_USER_ID)
|
145
177
|
>>> rag_agent.upload(folder_path = "~/work/docs")
|
146
178
|
>>> rag_agent.upload(file_path = "~/work/docs/manual.pdf")
|
179
|
+
>>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
|
147
180
|
"""
|
148
181
|
#set batch size
|
149
182
|
if batch_size > MAX_UPLOAD_BATCH_SIZE:
|
@@ -162,14 +195,15 @@ class RAG:
|
|
162
195
|
|
163
196
|
#splitting documents into chunks
|
164
197
|
text_chunks = []
|
165
|
-
|
198
|
+
metadata_list = []
|
166
199
|
|
167
200
|
#iterate through documents
|
168
201
|
for doc in documents:
|
202
|
+
doc_i = 0
|
169
203
|
cur_text_chunks = split_document(
|
170
204
|
text=doc.text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
|
171
205
|
text_chunks.extend(cur_text_chunks)
|
172
|
-
|
206
|
+
metadata_list.extend([doc.metadata for _ in range(len(cur_text_chunks))])
|
173
207
|
#if batch size is reached, upload the batch
|
174
208
|
if len(text_chunks) > batch_size:
|
175
209
|
for idx in range(0, len(text_chunks), batch_size):
|
@@ -178,18 +212,23 @@ class RAG:
|
|
178
212
|
batch_texts = text_chunks[0:batch_size]
|
179
213
|
batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
|
180
214
|
#metadata
|
181
|
-
batch_metadatas =
|
215
|
+
batch_metadatas = metadata_list[0:batch_size]
|
182
216
|
meta_list = []
|
183
217
|
for meta in batch_metadatas:
|
184
218
|
meta_struct = Struct()
|
185
219
|
meta_struct.update(meta)
|
220
|
+
meta_struct.update({"doc_chunk_no": doc_i})
|
221
|
+
if metadata and isinstance(metadata, dict):
|
222
|
+
meta_struct.update(metadata)
|
186
223
|
meta_list.append(meta_struct)
|
224
|
+
doc_i += 1
|
187
225
|
del batch_metadatas
|
188
226
|
#creating input proto
|
189
227
|
input_batch = [
|
190
228
|
self._app.inputs().get_text_input(
|
191
229
|
input_id=batch_ids[i],
|
192
230
|
raw_text=text,
|
231
|
+
dataset_id=dataset_id,
|
193
232
|
metadata=meta_list[i],
|
194
233
|
) for i, text in enumerate(batch_texts)
|
195
234
|
]
|
@@ -197,32 +236,37 @@ class RAG:
|
|
197
236
|
self._app.inputs().upload_inputs(inputs=input_batch)
|
198
237
|
#delete uploaded chunks
|
199
238
|
del text_chunks[0:batch_size]
|
200
|
-
del
|
239
|
+
del metadata_list[0:batch_size]
|
201
240
|
|
202
241
|
#uploading the remaining chunks
|
203
242
|
if len(text_chunks) > 0:
|
204
243
|
batch_size = len(text_chunks)
|
205
244
|
batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
|
206
245
|
#metadata
|
207
|
-
batch_metadatas =
|
246
|
+
batch_metadatas = metadata_list[0:batch_size]
|
208
247
|
meta_list = []
|
209
248
|
for meta in batch_metadatas:
|
210
249
|
meta_struct = Struct()
|
211
250
|
meta_struct.update(meta)
|
251
|
+
meta_struct.update({"doc_chunk_no": doc_i})
|
252
|
+
if metadata and isinstance(metadata, dict):
|
253
|
+
meta_struct.update(metadata)
|
212
254
|
meta_list.append(meta_struct)
|
255
|
+
doc_i += 1
|
213
256
|
del batch_metadatas
|
214
257
|
#creating input proto
|
215
258
|
input_batch = [
|
216
259
|
self._app.inputs().get_text_input(
|
217
260
|
input_id=batch_ids[i],
|
218
261
|
raw_text=text,
|
262
|
+
dataset_id=dataset_id,
|
219
263
|
metadata=meta_list[i],
|
220
264
|
) for i, text in enumerate(text_chunks)
|
221
265
|
]
|
222
266
|
#uploading input with metadata
|
223
267
|
self._app.inputs().upload_inputs(inputs=input_batch)
|
224
268
|
del text_chunks
|
225
|
-
del
|
269
|
+
del metadata_list
|
226
270
|
|
227
271
|
def chat(self, messages: List[dict], client_manage_state: bool = False) -> List[dict]:
|
228
272
|
"""Chat interface in OpenAI API format.
|
clarifai/rag/utils.py
CHANGED
@@ -3,9 +3,6 @@ from pathlib import Path
|
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import requests
|
6
|
-
from llama_index import Document, SimpleDirectoryReader, download_loader
|
7
|
-
from llama_index.node_parser.text import SentenceSplitter
|
8
|
-
from pypdf import PdfReader
|
9
6
|
|
10
7
|
|
11
8
|
## TODO: Make this token-aware.
|
@@ -35,8 +32,7 @@ def format_assistant_message(raw_text: str) -> dict:
|
|
35
32
|
return {"role": "assistant", "content": raw_text}
|
36
33
|
|
37
34
|
|
38
|
-
def load_documents(file_path: str = None, folder_path: str = None,
|
39
|
-
url: str = None) -> List[Document]:
|
35
|
+
def load_documents(file_path: str = None, folder_path: str = None, url: str = None) -> List[any]:
|
40
36
|
"""Loads documents from a local directory or public url or local filename.
|
41
37
|
|
42
38
|
Args:
|
@@ -44,6 +40,13 @@ def load_documents(file_path: str = None, folder_path: str = None,
|
|
44
40
|
folder_path (str): The path to the folder.
|
45
41
|
url (str): The url to the file.
|
46
42
|
"""
|
43
|
+
#check import packages
|
44
|
+
try:
|
45
|
+
from llama_index.core import Document, SimpleDirectoryReader
|
46
|
+
from llama_index.core.readers.download import download_loader
|
47
|
+
except ImportError:
|
48
|
+
raise ImportError("Could not import llama index package. "
|
49
|
+
"Please install it with `pip install llama-index-core==0.10.1`.")
|
47
50
|
#document loaders for filepath
|
48
51
|
if file_path:
|
49
52
|
if file_path.endswith(".pdf"):
|
@@ -76,6 +79,12 @@ def load_documents(file_path: str = None, folder_path: str = None,
|
|
76
79
|
documents = [Document(text=response.content)]
|
77
80
|
#for pdf files
|
78
81
|
except Exception:
|
82
|
+
#check import packages
|
83
|
+
try:
|
84
|
+
from pypdf import PdfReader
|
85
|
+
except ImportError:
|
86
|
+
raise ImportError("Could not import pypdf package. "
|
87
|
+
"Please install it with `pip install pypdf==3.17.4`.")
|
79
88
|
documents = []
|
80
89
|
pdf_file = PdfReader(io.BytesIO(response.content))
|
81
90
|
num_pages = len(pdf_file.pages)
|
@@ -97,6 +106,13 @@ def split_document(text: str, chunk_size: int, chunk_overlap: int, **kwargs) ->
|
|
97
106
|
chunk_overlap (int): The amount of overlap between each chunk.
|
98
107
|
**kwargs: Additional keyword arguments for the SentenceSplitter.
|
99
108
|
"""
|
109
|
+
#check import packages
|
110
|
+
try:
|
111
|
+
from llama_index.core.node_parser.text import SentenceSplitter
|
112
|
+
except ImportError:
|
113
|
+
raise ImportError("Could not import llama index package. "
|
114
|
+
"Please install it with `pip install llama-index-core==0.10.1`.")
|
115
|
+
#document
|
100
116
|
text_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
|
101
117
|
text_chunks = text_parser.split_text(text)
|
102
118
|
return text_chunks
|
@@ -0,0 +1,427 @@
|
|
1
|
+
import os
|
2
|
+
from enum import Enum
|
3
|
+
from typing import List, Tuple, Union
|
4
|
+
|
5
|
+
from clarifai.client.dataset import Dataset
|
6
|
+
from clarifai.client.model import Model
|
7
|
+
|
8
|
+
from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
|
9
|
+
make_handler_by_type)
|
10
|
+
|
11
|
+
try:
|
12
|
+
import seaborn as sns
|
13
|
+
except ImportError:
|
14
|
+
raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
|
15
|
+
|
16
|
+
try:
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
except ImportError:
|
19
|
+
raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
|
20
|
+
|
21
|
+
try:
|
22
|
+
import pandas as pd
|
23
|
+
except ImportError:
|
24
|
+
raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
|
25
|
+
|
26
|
+
try:
|
27
|
+
from loguru import logger
|
28
|
+
except ImportError:
|
29
|
+
from ..logging import get_logger
|
30
|
+
logger = get_logger(logger_level="INFO", name=__name__)
|
31
|
+
|
32
|
+
__all__ = ['EvalResultCompare']
|
33
|
+
|
34
|
+
|
35
|
+
class CompareMode(Enum):
|
36
|
+
MANY_MODELS_TO_ONE_DATA = 0
|
37
|
+
ONE_MODEL_TO_MANY_DATA = 1
|
38
|
+
|
39
|
+
|
40
|
+
class EvalResultCompare:
|
41
|
+
"""Compare evaluation result of models against datasets.
|
42
|
+
Note: The module will pick latest result on the datasets.
|
43
|
+
and models must be same model type
|
44
|
+
|
45
|
+
Args:
|
46
|
+
---
|
47
|
+
models (Union[List[Model], List[str]]): List of Model or urls of models.
|
48
|
+
datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
|
49
|
+
attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
|
50
|
+
auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self,
|
54
|
+
models: Union[List[Model], List[str]],
|
55
|
+
datasets: Union[Dataset, List[Dataset], str, List[str]],
|
56
|
+
attempt_evaluate: bool = False,
|
57
|
+
auth_kwargs: dict = {}):
|
58
|
+
assert isinstance(models, list), ValueError("Expected list")
|
59
|
+
|
60
|
+
if len(models) > 1:
|
61
|
+
self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
|
62
|
+
self.comparator = "Model"
|
63
|
+
assert isinstance(datasets, Dataset) or (
|
64
|
+
isinstance(datasets, list) and len(datasets) == 1
|
65
|
+
), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
|
66
|
+
else:
|
67
|
+
self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
|
68
|
+
self.comparator = "Dataset"
|
69
|
+
|
70
|
+
# validate models
|
71
|
+
if all(map(lambda x: isinstance(x, str), models)):
|
72
|
+
models = [Model(each, **auth_kwargs) for each in models]
|
73
|
+
elif not all(map(lambda x: isinstance(x, Model), models)):
|
74
|
+
raise ValueError(
|
75
|
+
f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
|
76
|
+
)
|
77
|
+
# validate datasets
|
78
|
+
if not isinstance(datasets, list):
|
79
|
+
datasets = [
|
80
|
+
datasets,
|
81
|
+
]
|
82
|
+
if all(map(lambda x: isinstance(x, str), datasets)):
|
83
|
+
datasets = [Dataset(each, **auth_kwargs) for each in datasets]
|
84
|
+
elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
|
85
|
+
raise ValueError(
|
86
|
+
f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
|
87
|
+
)
|
88
|
+
# Validate models vs datasets together
|
89
|
+
self._eval_handlers: List[_BaseEvalResultHandler] = []
|
90
|
+
self.model_type = None
|
91
|
+
logger.info("Initializing models...")
|
92
|
+
for model in models:
|
93
|
+
model.load_info()
|
94
|
+
model_type = model.model_info.model_type_id
|
95
|
+
if not self.model_type:
|
96
|
+
self.model_type = model_type
|
97
|
+
else:
|
98
|
+
assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
|
99
|
+
m = make_handler_by_type(model_type)(model=model)
|
100
|
+
logger.info(f"* {m.get_model_name(pretify=True)}")
|
101
|
+
m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
|
102
|
+
self._eval_handlers.append(m)
|
103
|
+
|
104
|
+
@property
|
105
|
+
def eval_handlers(self):
|
106
|
+
return self._eval_handlers
|
107
|
+
|
108
|
+
def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
|
109
|
+
""" Run methods of `eval_handlers[...].model`
|
110
|
+
|
111
|
+
Args:
|
112
|
+
func_name (str): method name, see `_BaseEvalResultHandler` child classes
|
113
|
+
kwargs: keyword arguments of the method
|
114
|
+
|
115
|
+
Return:
|
116
|
+
tuple:
|
117
|
+
- list of outputs
|
118
|
+
- list of comparator names
|
119
|
+
|
120
|
+
"""
|
121
|
+
outs = []
|
122
|
+
comparators = []
|
123
|
+
logger.info(f'Running `{func_name}`')
|
124
|
+
for _, each in enumerate(self.eval_handlers):
|
125
|
+
for ds_index, _ in enumerate(each.eval_data):
|
126
|
+
func = eval(f'each.{func_name}')
|
127
|
+
out = func(index=ds_index, **kwargs)
|
128
|
+
|
129
|
+
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
130
|
+
name = each.get_model_name(pretify=True)
|
131
|
+
else:
|
132
|
+
name = each.get_dataset_name_by_index(ds_index, pretify=True)
|
133
|
+
if out is None:
|
134
|
+
logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
|
135
|
+
continue
|
136
|
+
comparators.append(name)
|
137
|
+
outs.append(out)
|
138
|
+
|
139
|
+
# remove app_id if models a
|
140
|
+
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
141
|
+
apps = set([comp.split('/')[0] for comp in comparators])
|
142
|
+
if len(apps) == 1:
|
143
|
+
comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
|
144
|
+
|
145
|
+
if not outs:
|
146
|
+
logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
|
147
|
+
|
148
|
+
return outs, comparators
|
149
|
+
|
150
|
+
def detailed_summary(self,
|
151
|
+
confidence_threshold: float = .5,
|
152
|
+
iou_threshold: float = .5,
|
153
|
+
area: str = "all",
|
154
|
+
bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
|
155
|
+
"""
|
156
|
+
Retrieve and compute popular metrics of model.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
|
160
|
+
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
|
161
|
+
area (float): size of area, support {all, small, medium}, applicable for detection
|
162
|
+
|
163
|
+
Return:
|
164
|
+
None or tuple of dataframe: df summary per concept and total concepts
|
165
|
+
|
166
|
+
"""
|
167
|
+
df = []
|
168
|
+
total = []
|
169
|
+
# loop over all eval_handlers/dataset and call its method
|
170
|
+
outs, comparators = self._loop_eval_handlers(
|
171
|
+
'detailed_summary',
|
172
|
+
confidence_threshold=confidence_threshold,
|
173
|
+
iou_threshold=iou_threshold,
|
174
|
+
area=area,
|
175
|
+
bypass_const=bypass_const)
|
176
|
+
for indx, out in enumerate(outs):
|
177
|
+
_df, _total = out
|
178
|
+
_df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
|
179
|
+
_total['Concept'].replace(
|
180
|
+
to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
|
181
|
+
_total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
|
182
|
+
df.append(_df)
|
183
|
+
total.append(_total)
|
184
|
+
|
185
|
+
if df:
|
186
|
+
df = pd.concat(df, axis=0)
|
187
|
+
total = pd.concat(total, axis=0)
|
188
|
+
return df, total
|
189
|
+
else:
|
190
|
+
return None
|
191
|
+
|
192
|
+
def confusion_matrix(self, show=True, save_path: str = None,
|
193
|
+
cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
194
|
+
"""Return dataframe of confusion matrix
|
195
|
+
Args:
|
196
|
+
show (bool, optional): Show the chart. Defaults to True.
|
197
|
+
save_path (str): path to save rendered chart.
|
198
|
+
cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
|
199
|
+
Returns:
|
200
|
+
None or pd.Dataframe, If models don't have confusion matrix, return None
|
201
|
+
"""
|
202
|
+
outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
|
203
|
+
all_dfs = []
|
204
|
+
for _, (df, anchor) in enumerate(zip(outs, comparators)):
|
205
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
206
|
+
all_dfs.append(df)
|
207
|
+
|
208
|
+
if all_dfs:
|
209
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
210
|
+
if save_path or show:
|
211
|
+
|
212
|
+
def _facet_heatmap(data, **kws):
|
213
|
+
data = data.dropna(axis=1)
|
214
|
+
data = data.drop(self.comparator, axis=1)
|
215
|
+
concepts = data.columns
|
216
|
+
colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
|
217
|
+
data.columns = colnames
|
218
|
+
ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
|
219
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
|
220
|
+
ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
|
221
|
+
|
222
|
+
temp = all_dfs.copy()
|
223
|
+
temp.columns = ["_".join(pair) for pair in temp.columns]
|
224
|
+
with sns.plotting_context(font_scale=5.5):
|
225
|
+
g = sns.FacetGrid(
|
226
|
+
temp,
|
227
|
+
col=self.comparator,
|
228
|
+
col_wrap=3,
|
229
|
+
aspect=1,
|
230
|
+
height=3,
|
231
|
+
sharex=False,
|
232
|
+
sharey=False,
|
233
|
+
)
|
234
|
+
cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
|
235
|
+
g = g.map_dataframe(
|
236
|
+
_facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
|
237
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
238
|
+
if show:
|
239
|
+
plt.show()
|
240
|
+
if save_path:
|
241
|
+
g.savefig(save_path)
|
242
|
+
|
243
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
244
|
+
|
245
|
+
@staticmethod
|
246
|
+
def _set_default_kwargs(kwargs: dict, var_name: str, value):
|
247
|
+
if var_name not in kwargs:
|
248
|
+
kwargs.update({var_name: value})
|
249
|
+
return kwargs
|
250
|
+
|
251
|
+
@staticmethod
|
252
|
+
def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
|
253
|
+
hue_order = df["concept"].unique().tolist()
|
254
|
+
hue_order.remove(MACRO_AVG)
|
255
|
+
hue_order.insert(0, MACRO_AVG)
|
256
|
+
EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
|
257
|
+
|
258
|
+
sizes = {}
|
259
|
+
for each in hue_order:
|
260
|
+
s = 1.5
|
261
|
+
if each == MACRO_AVG:
|
262
|
+
s = 4.
|
263
|
+
sizes.update({each: s})
|
264
|
+
EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
|
265
|
+
EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
|
266
|
+
|
267
|
+
EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
|
268
|
+
EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
|
269
|
+
|
270
|
+
return kwargs
|
271
|
+
|
272
|
+
def roc_curve_plot(self,
|
273
|
+
show=True,
|
274
|
+
save_path: str = None,
|
275
|
+
roc_curve_kwargs: dict = {},
|
276
|
+
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
277
|
+
"""Return dataframe of ROC curve
|
278
|
+
Args:
|
279
|
+
show (bool, optional): Show the chart. Defaults to True.
|
280
|
+
save_path (str): path to save rendered chart.
|
281
|
+
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
|
282
|
+
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
|
283
|
+
Returns:
|
284
|
+
None or pd.Dataframe, If models don't have ROC curve, return None
|
285
|
+
"""
|
286
|
+
sns.set_palette("Paired")
|
287
|
+
outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
|
288
|
+
all_dfs = []
|
289
|
+
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
290
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
291
|
+
all_dfs.append(df)
|
292
|
+
|
293
|
+
if all_dfs:
|
294
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
295
|
+
if save_path or show:
|
296
|
+
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
297
|
+
g = sns.relplot(
|
298
|
+
data=all_dfs,
|
299
|
+
x="fpr",
|
300
|
+
y="tpr",
|
301
|
+
hue='concept',
|
302
|
+
kind="line",
|
303
|
+
col=self.comparator,
|
304
|
+
**relplot_kwargs)
|
305
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
306
|
+
if show:
|
307
|
+
plt.show()
|
308
|
+
if save_path:
|
309
|
+
g.savefig(save_path)
|
310
|
+
|
311
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
312
|
+
|
313
|
+
def pr_plot(self,
|
314
|
+
show=True,
|
315
|
+
save_path: str = None,
|
316
|
+
pr_curve_kwargs: dict = {},
|
317
|
+
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
318
|
+
"""Return dataframe of PR curve
|
319
|
+
Args:
|
320
|
+
show (bool, optional): Show the chart. Defaults to True.
|
321
|
+
save_path (str): path to save rendered chart.
|
322
|
+
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
|
323
|
+
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
|
324
|
+
Returns:
|
325
|
+
None or pd.Dataframe, If models don't have PR curve, return None
|
326
|
+
"""
|
327
|
+
sns.set_palette("Paired")
|
328
|
+
outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
|
329
|
+
all_dfs = []
|
330
|
+
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
331
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
332
|
+
all_dfs.append(df)
|
333
|
+
|
334
|
+
if all_dfs:
|
335
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
336
|
+
if save_path or show:
|
337
|
+
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
338
|
+
g = sns.relplot(
|
339
|
+
data=all_dfs,
|
340
|
+
x="recall",
|
341
|
+
y="precision",
|
342
|
+
hue='concept',
|
343
|
+
kind="line",
|
344
|
+
col=self.comparator,
|
345
|
+
**relplot_kwargs)
|
346
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
347
|
+
if show:
|
348
|
+
plt.show()
|
349
|
+
if save_path:
|
350
|
+
g.savefig(save_path)
|
351
|
+
|
352
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
353
|
+
|
354
|
+
def all(
|
355
|
+
self,
|
356
|
+
output_folder: str,
|
357
|
+
confidence_threshold: float = 0.5,
|
358
|
+
iou_threshold: float = 0.5,
|
359
|
+
overwrite: bool = False,
|
360
|
+
metric_kwargs: dict = {},
|
361
|
+
pr_plot_kwargs: dict = {},
|
362
|
+
roc_plot_kwargs: dict = {},
|
363
|
+
):
|
364
|
+
"""Run all comparison methods one by one:
|
365
|
+
- detailed_summary
|
366
|
+
- pr_curve (if applicable)
|
367
|
+
- pr_plot
|
368
|
+
- confusion_matrix (if applicable)
|
369
|
+
And save to output_folder
|
370
|
+
|
371
|
+
Args:
|
372
|
+
output_folder (str): path to output
|
373
|
+
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
|
374
|
+
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
|
375
|
+
overwrite (bool): overwrite result of output_folder.
|
376
|
+
metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
|
377
|
+
roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
|
378
|
+
pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
|
379
|
+
"""
|
380
|
+
eval_type = get_eval_type(self.model_type)
|
381
|
+
area = metric_kwargs.pop("area", "all")
|
382
|
+
bypass_const = metric_kwargs.pop("bypass_const", False)
|
383
|
+
|
384
|
+
fname = f"conf-{confidence_threshold}"
|
385
|
+
if eval_type == EvalType.DETECTION:
|
386
|
+
fname = f"{fname}_iou-{iou_threshold}_area-{area}"
|
387
|
+
|
388
|
+
def join_root(*args):
|
389
|
+
return os.path.join(output_folder, *args)
|
390
|
+
|
391
|
+
output_folder = join_root(fname)
|
392
|
+
if os.path.exists(output_folder) and not overwrite:
|
393
|
+
raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
|
394
|
+
|
395
|
+
os.makedirs(output_folder, exist_ok=True)
|
396
|
+
|
397
|
+
logger.info("Making summary tables...")
|
398
|
+
dfs = self.detailed_summary(
|
399
|
+
confidence_threshold=confidence_threshold,
|
400
|
+
iou_threshold=iou_threshold,
|
401
|
+
area=area,
|
402
|
+
bypass_const=bypass_const)
|
403
|
+
if dfs is not None:
|
404
|
+
concept_df, total_df = dfs
|
405
|
+
concept_df.to_csv(join_root("concepts_summary.csv"))
|
406
|
+
total_df.to_csv(join_root("total_summary.csv"))
|
407
|
+
|
408
|
+
curve_metric_kwargs = dict(
|
409
|
+
confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
|
410
|
+
curve_metric_kwargs.update(metric_kwargs)
|
411
|
+
|
412
|
+
self.roc_curve_plot(
|
413
|
+
show=False,
|
414
|
+
save_path=join_root("roc.jpg"),
|
415
|
+
roc_curve_kwargs=curve_metric_kwargs,
|
416
|
+
relplot_kwargs=roc_plot_kwargs)
|
417
|
+
|
418
|
+
self.pr_plot(
|
419
|
+
show=False,
|
420
|
+
save_path=join_root("pr.jpg"),
|
421
|
+
pr_curve_kwargs=curve_metric_kwargs,
|
422
|
+
relplot_kwargs=pr_plot_kwargs)
|
423
|
+
|
424
|
+
self.confusion_matrix(
|
425
|
+
show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
|
426
|
+
|
427
|
+
logger.info(f"Done. Your outputs are saved at {output_folder}")
|