clarifai 10.0.1__py3-none-any.whl → 10.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. clarifai/client/app.py +23 -43
  2. clarifai/client/base.py +46 -4
  3. clarifai/client/dataset.py +85 -33
  4. clarifai/client/input.py +35 -7
  5. clarifai/client/model.py +192 -11
  6. clarifai/client/module.py +8 -6
  7. clarifai/client/runner.py +3 -1
  8. clarifai/client/search.py +6 -3
  9. clarifai/client/user.py +14 -12
  10. clarifai/client/workflow.py +8 -5
  11. clarifai/datasets/upload/features.py +3 -0
  12. clarifai/datasets/upload/image.py +57 -26
  13. clarifai/datasets/upload/loaders/README.md +3 -4
  14. clarifai/datasets/upload/loaders/xview_detection.py +9 -5
  15. clarifai/datasets/upload/utils.py +23 -7
  16. clarifai/models/model_serving/README.md +113 -121
  17. clarifai/models/model_serving/__init__.py +2 -0
  18. clarifai/models/model_serving/cli/_utils.py +53 -0
  19. clarifai/models/model_serving/cli/base.py +14 -0
  20. clarifai/models/model_serving/cli/build.py +79 -0
  21. clarifai/models/model_serving/cli/clarifai_clis.py +33 -0
  22. clarifai/models/model_serving/cli/create.py +171 -0
  23. clarifai/models/model_serving/cli/example_cli.py +34 -0
  24. clarifai/models/model_serving/cli/login.py +26 -0
  25. clarifai/models/model_serving/cli/upload.py +182 -0
  26. clarifai/models/model_serving/constants.py +20 -0
  27. clarifai/models/model_serving/docs/cli.md +150 -0
  28. clarifai/models/model_serving/docs/concepts.md +229 -0
  29. clarifai/models/model_serving/docs/dependencies.md +1 -1
  30. clarifai/models/model_serving/docs/inference_parameters.md +112 -107
  31. clarifai/models/model_serving/docs/model_types.md +16 -17
  32. clarifai/models/model_serving/model_config/__init__.py +4 -2
  33. clarifai/models/model_serving/model_config/base.py +369 -0
  34. clarifai/models/model_serving/model_config/config.py +219 -224
  35. clarifai/models/model_serving/model_config/inference_parameter.py +5 -0
  36. clarifai/models/model_serving/model_config/model_types_config/multimodal-embedder.yaml +25 -24
  37. clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml +19 -18
  38. clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml +20 -18
  39. clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml +19 -18
  40. clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml +19 -18
  41. clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml +22 -18
  42. clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml +32 -28
  43. clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml +19 -18
  44. clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +19 -18
  45. clarifai/models/model_serving/{models → model_config}/output.py +8 -0
  46. clarifai/models/model_serving/model_config/triton/__init__.py +14 -0
  47. clarifai/models/model_serving/model_config/{serializer.py → triton/serializer.py} +3 -1
  48. clarifai/models/model_serving/model_config/triton/triton_config.py +182 -0
  49. clarifai/models/model_serving/{models/model_types.py → model_config/triton/wrappers.py} +4 -4
  50. clarifai/models/model_serving/{models → repo_build}/__init__.py +2 -0
  51. clarifai/models/model_serving/repo_build/build.py +198 -0
  52. clarifai/models/model_serving/repo_build/static_files/_requirements.txt +2 -0
  53. clarifai/models/model_serving/repo_build/static_files/base_test.py +169 -0
  54. clarifai/models/model_serving/repo_build/static_files/inference.py +26 -0
  55. clarifai/models/model_serving/repo_build/static_files/sample_clarifai_config.yaml +25 -0
  56. clarifai/models/model_serving/repo_build/static_files/test.py +40 -0
  57. clarifai/models/model_serving/{models/pb_model.py → repo_build/static_files/triton/model.py} +15 -14
  58. clarifai/models/model_serving/utils.py +21 -0
  59. clarifai/rag/rag.py +67 -23
  60. clarifai/rag/utils.py +21 -5
  61. clarifai/utils/evaluation/__init__.py +427 -0
  62. clarifai/utils/evaluation/helpers.py +522 -0
  63. clarifai/utils/logging.py +7 -0
  64. clarifai/utils/model_train.py +3 -1
  65. clarifai/versions.py +1 -1
  66. {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/METADATA +58 -10
  67. clarifai-10.1.1.dist-info/RECORD +115 -0
  68. clarifai-10.1.1.dist-info/entry_points.txt +2 -0
  69. clarifai/datasets/upload/loaders/coco_segmentation.py +0 -98
  70. clarifai/models/model_serving/cli/deploy_cli.py +0 -123
  71. clarifai/models/model_serving/cli/model_zip.py +0 -61
  72. clarifai/models/model_serving/cli/repository.py +0 -89
  73. clarifai/models/model_serving/docs/custom_config.md +0 -33
  74. clarifai/models/model_serving/docs/output.md +0 -28
  75. clarifai/models/model_serving/models/default_test.py +0 -281
  76. clarifai/models/model_serving/models/inference.py +0 -50
  77. clarifai/models/model_serving/models/test.py +0 -64
  78. clarifai/models/model_serving/pb_model_repository.py +0 -108
  79. clarifai-10.0.1.dist-info/RECORD +0 -103
  80. clarifai-10.0.1.dist-info/entry_points.txt +0 -4
  81. {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/LICENSE +0 -0
  82. {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/WHEEL +0 -0
  83. {clarifai-10.0.1.dist-info → clarifai-10.1.1.dist-info}/top_level.txt +0 -0
clarifai/rag/rag.py CHANGED
@@ -17,6 +17,8 @@ from clarifai.rag.utils import (convert_messages_to_str, format_assistant_messag
17
17
  split_document)
18
18
  from clarifai.utils.logging import get_logger
19
19
 
20
+ DEFAULT_RAG_PROMPT_TEMPLATE = "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
21
+
20
22
 
21
23
  class RAG:
22
24
  """
@@ -24,7 +26,8 @@ class RAG:
24
26
 
25
27
  Example:
26
28
  >>> from clarifai.rag import RAG
27
- >>> rag_agent = RAG()
29
+ >>> rag_agent = RAG(workflow_url=YOUR_WORKFLOW_URL)
30
+ >>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
28
31
  """
29
32
  chat_state_id = None
30
33
 
@@ -49,43 +52,70 @@ class RAG:
49
52
  @classmethod
50
53
  def setup(cls,
51
54
  user_id: str = None,
55
+ app_url: str = None,
52
56
  llm_url: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
53
57
  base_workflow: str = "Text",
54
58
  workflow_yaml_filename: str = 'prompter_wf.yaml',
59
+ workflow_id: str = None,
55
60
  base_url: str = "https://api.clarifai.com",
56
61
  pat: str = None,
57
62
  **kwargs):
58
63
  """Creates an app with `Text` as base workflow, create prompt model, create prompt workflow.
59
64
 
65
+ **kwargs: Additional keyword arguments to be passed to rag-promter model.
66
+ - min_score (float): The minimum score for search hits.
67
+ - max_results (float): The maximum number of search hits.
68
+ - prompt_template (str): The prompt template used. Must contain {data.hits} for the search hits and {data.text.raw} for the query string.
69
+
60
70
  Example:
61
71
  >>> from clarifai.rag import RAG
62
- >>> rag_agent = RAG.setup()
72
+ >>> rag_agent = RAG.setup(user_id=YOUR_USER_ID)
73
+ >>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
74
+
75
+ Or if you already have an existing app with ingested data:
76
+ >>> rag_agent = RAG.setup(app_url=YOUR_APP_URL)
77
+ >>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
63
78
  """
64
- if not user_id:
79
+ now_ts = str(int(datetime.now().timestamp()))
80
+ if user_id and not app_url:
81
+ user = User(user_id=user_id, base_url=base_url, pat=pat)
82
+ ## Create an App
83
+ app_id = f"rag_app_{now_ts}"
84
+ app = user.create_app(app_id=app_id, base_workflow=base_workflow)
85
+
86
+ if not user_id and app_url:
87
+ app = App(url=app_url, pat=pat)
88
+ uid = app_url.split(".com/")[1].split("/")[0]
89
+ user = User(user_id=uid, base_url=base_url, pat=pat)
90
+
91
+ if user_id and app_url:
92
+ raise UserError("Must provide one of user_id or app_url, not both.")
93
+
94
+ if not user_id and not app_url:
65
95
  raise UserError(
66
- "user_id must be provided. It can be found at https://clarifai.com/settings.")
67
- user = User(user_id=user_id, base_url=base_url, pat=pat)
68
- llm = Model(llm_url)
96
+ "user_id or app_url must be provided. The user_id can be found at https://clarifai.com/settings."
97
+ )
98
+
99
+ llm = Model(url=llm_url, pat=pat)
69
100
 
101
+ min_score = kwargs.get("min_score", 0.95)
102
+ max_results = kwargs.get("max_results", 5)
103
+ prompt_template = kwargs.get("prompt_template", DEFAULT_RAG_PROMPT_TEMPLATE)
70
104
  params = Struct()
71
105
  params.update({
72
- "prompt_template":
73
- "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
106
+ "min_score": min_score,
107
+ "max_results": max_results,
108
+ "prompt_template": prompt_template
74
109
  })
75
110
  prompter_model_params = {"params": params}
76
111
 
77
- ## Create an App
78
- now_ts = str(int(datetime.now().timestamp()))
79
- app_id = f"rag_app_{now_ts}"
80
- app = user.create_app(app_id=app_id, base_workflow=base_workflow)
81
-
82
112
  ## Create rag-prompter model and version
83
- prompter_model = app.create_model(
84
- model_id=f"rag_prompter_{now_ts}", model_type_id="rag-prompter")
113
+ model_id = f"prompter-{workflow_id}" if workflow_id is not None else f"rag-prompter-{now_ts}"
114
+ prompter_model = app.create_model(model_id=model_id, model_type_id="rag-prompter")
85
115
  prompter_model = prompter_model.create_version(output_info=prompter_model_params)
86
116
 
87
117
  ## Generate a tmp yaml file for workflow creation
88
- workflow_id = f"rag-wf-{now_ts}"
118
+ workflow_id = f"rag-wf-{now_ts}" if workflow_id is None else workflow_id
89
119
  workflow_dict = {
90
120
  "workflow": {
91
121
  "id":
@@ -124,6 +154,8 @@ class RAG:
124
154
  batch_size: int = 128,
125
155
  chunk_size: int = 1024,
126
156
  chunk_overlap: int = 200,
157
+ dataset_id: str = None,
158
+ metadata: dict = None,
127
159
  **kwargs) -> None:
128
160
  """Uploads documents to the app.
129
161
  - Read from a local directory or public url or local filename.
@@ -141,9 +173,10 @@ class RAG:
141
173
 
142
174
  Example:
143
175
  >>> from clarifai.rag import RAG
144
- >>> rag_agent = RAG.setup()
176
+ >>> rag_agent = RAG.setup(user_id=YOUR_USER_ID)
145
177
  >>> rag_agent.upload(folder_path = "~/work/docs")
146
178
  >>> rag_agent.upload(file_path = "~/work/docs/manual.pdf")
179
+ >>> rag_agent.chat(messages=[{"role":"human", "content":"What is Clarifai"}])
147
180
  """
148
181
  #set batch size
149
182
  if batch_size > MAX_UPLOAD_BATCH_SIZE:
@@ -162,14 +195,15 @@ class RAG:
162
195
 
163
196
  #splitting documents into chunks
164
197
  text_chunks = []
165
- metadata = []
198
+ metadata_list = []
166
199
 
167
200
  #iterate through documents
168
201
  for doc in documents:
202
+ doc_i = 0
169
203
  cur_text_chunks = split_document(
170
204
  text=doc.text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
171
205
  text_chunks.extend(cur_text_chunks)
172
- metadata.extend([doc.metadata for _ in range(len(cur_text_chunks))])
206
+ metadata_list.extend([doc.metadata for _ in range(len(cur_text_chunks))])
173
207
  #if batch size is reached, upload the batch
174
208
  if len(text_chunks) > batch_size:
175
209
  for idx in range(0, len(text_chunks), batch_size):
@@ -178,18 +212,23 @@ class RAG:
178
212
  batch_texts = text_chunks[0:batch_size]
179
213
  batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
180
214
  #metadata
181
- batch_metadatas = metadata[0:batch_size]
215
+ batch_metadatas = metadata_list[0:batch_size]
182
216
  meta_list = []
183
217
  for meta in batch_metadatas:
184
218
  meta_struct = Struct()
185
219
  meta_struct.update(meta)
220
+ meta_struct.update({"doc_chunk_no": doc_i})
221
+ if metadata and isinstance(metadata, dict):
222
+ meta_struct.update(metadata)
186
223
  meta_list.append(meta_struct)
224
+ doc_i += 1
187
225
  del batch_metadatas
188
226
  #creating input proto
189
227
  input_batch = [
190
228
  self._app.inputs().get_text_input(
191
229
  input_id=batch_ids[i],
192
230
  raw_text=text,
231
+ dataset_id=dataset_id,
193
232
  metadata=meta_list[i],
194
233
  ) for i, text in enumerate(batch_texts)
195
234
  ]
@@ -197,32 +236,37 @@ class RAG:
197
236
  self._app.inputs().upload_inputs(inputs=input_batch)
198
237
  #delete uploaded chunks
199
238
  del text_chunks[0:batch_size]
200
- del metadata[0:batch_size]
239
+ del metadata_list[0:batch_size]
201
240
 
202
241
  #uploading the remaining chunks
203
242
  if len(text_chunks) > 0:
204
243
  batch_size = len(text_chunks)
205
244
  batch_ids = [uuid.uuid4().hex for _ in range(batch_size)]
206
245
  #metadata
207
- batch_metadatas = metadata[0:batch_size]
246
+ batch_metadatas = metadata_list[0:batch_size]
208
247
  meta_list = []
209
248
  for meta in batch_metadatas:
210
249
  meta_struct = Struct()
211
250
  meta_struct.update(meta)
251
+ meta_struct.update({"doc_chunk_no": doc_i})
252
+ if metadata and isinstance(metadata, dict):
253
+ meta_struct.update(metadata)
212
254
  meta_list.append(meta_struct)
255
+ doc_i += 1
213
256
  del batch_metadatas
214
257
  #creating input proto
215
258
  input_batch = [
216
259
  self._app.inputs().get_text_input(
217
260
  input_id=batch_ids[i],
218
261
  raw_text=text,
262
+ dataset_id=dataset_id,
219
263
  metadata=meta_list[i],
220
264
  ) for i, text in enumerate(text_chunks)
221
265
  ]
222
266
  #uploading input with metadata
223
267
  self._app.inputs().upload_inputs(inputs=input_batch)
224
268
  del text_chunks
225
- del metadata
269
+ del metadata_list
226
270
 
227
271
  def chat(self, messages: List[dict], client_manage_state: bool = False) -> List[dict]:
228
272
  """Chat interface in OpenAI API format.
clarifai/rag/utils.py CHANGED
@@ -3,9 +3,6 @@ from pathlib import Path
3
3
  from typing import List
4
4
 
5
5
  import requests
6
- from llama_index import Document, SimpleDirectoryReader, download_loader
7
- from llama_index.node_parser.text import SentenceSplitter
8
- from pypdf import PdfReader
9
6
 
10
7
 
11
8
  ## TODO: Make this token-aware.
@@ -35,8 +32,7 @@ def format_assistant_message(raw_text: str) -> dict:
35
32
  return {"role": "assistant", "content": raw_text}
36
33
 
37
34
 
38
- def load_documents(file_path: str = None, folder_path: str = None,
39
- url: str = None) -> List[Document]:
35
+ def load_documents(file_path: str = None, folder_path: str = None, url: str = None) -> List[any]:
40
36
  """Loads documents from a local directory or public url or local filename.
41
37
 
42
38
  Args:
@@ -44,6 +40,13 @@ def load_documents(file_path: str = None, folder_path: str = None,
44
40
  folder_path (str): The path to the folder.
45
41
  url (str): The url to the file.
46
42
  """
43
+ #check import packages
44
+ try:
45
+ from llama_index.core import Document, SimpleDirectoryReader
46
+ from llama_index.core.readers.download import download_loader
47
+ except ImportError:
48
+ raise ImportError("Could not import llama index package. "
49
+ "Please install it with `pip install llama-index-core==0.10.1`.")
47
50
  #document loaders for filepath
48
51
  if file_path:
49
52
  if file_path.endswith(".pdf"):
@@ -76,6 +79,12 @@ def load_documents(file_path: str = None, folder_path: str = None,
76
79
  documents = [Document(text=response.content)]
77
80
  #for pdf files
78
81
  except Exception:
82
+ #check import packages
83
+ try:
84
+ from pypdf import PdfReader
85
+ except ImportError:
86
+ raise ImportError("Could not import pypdf package. "
87
+ "Please install it with `pip install pypdf==3.17.4`.")
79
88
  documents = []
80
89
  pdf_file = PdfReader(io.BytesIO(response.content))
81
90
  num_pages = len(pdf_file.pages)
@@ -97,6 +106,13 @@ def split_document(text: str, chunk_size: int, chunk_overlap: int, **kwargs) ->
97
106
  chunk_overlap (int): The amount of overlap between each chunk.
98
107
  **kwargs: Additional keyword arguments for the SentenceSplitter.
99
108
  """
109
+ #check import packages
110
+ try:
111
+ from llama_index.core.node_parser.text import SentenceSplitter
112
+ except ImportError:
113
+ raise ImportError("Could not import llama index package. "
114
+ "Please install it with `pip install llama-index-core==0.10.1`.")
115
+ #document
100
116
  text_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
101
117
  text_chunks = text_parser.split_text(text)
102
118
  return text_chunks
@@ -0,0 +1,427 @@
1
+ import os
2
+ from enum import Enum
3
+ from typing import List, Tuple, Union
4
+
5
+ from clarifai.client.dataset import Dataset
6
+ from clarifai.client.model import Model
7
+
8
+ from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
9
+ make_handler_by_type)
10
+
11
+ try:
12
+ import seaborn as sns
13
+ except ImportError:
14
+ raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
15
+
16
+ try:
17
+ import matplotlib.pyplot as plt
18
+ except ImportError:
19
+ raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
20
+
21
+ try:
22
+ import pandas as pd
23
+ except ImportError:
24
+ raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
25
+
26
+ try:
27
+ from loguru import logger
28
+ except ImportError:
29
+ from ..logging import get_logger
30
+ logger = get_logger(logger_level="INFO", name=__name__)
31
+
32
+ __all__ = ['EvalResultCompare']
33
+
34
+
35
+ class CompareMode(Enum):
36
+ MANY_MODELS_TO_ONE_DATA = 0
37
+ ONE_MODEL_TO_MANY_DATA = 1
38
+
39
+
40
+ class EvalResultCompare:
41
+ """Compare evaluation result of models against datasets.
42
+ Note: The module will pick latest result on the datasets.
43
+ and models must be same model type
44
+
45
+ Args:
46
+ ---
47
+ models (Union[List[Model], List[str]]): List of Model or urls of models.
48
+ datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
49
+ attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
50
+ auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
51
+ """
52
+
53
+ def __init__(self,
54
+ models: Union[List[Model], List[str]],
55
+ datasets: Union[Dataset, List[Dataset], str, List[str]],
56
+ attempt_evaluate: bool = False,
57
+ auth_kwargs: dict = {}):
58
+ assert isinstance(models, list), ValueError("Expected list")
59
+
60
+ if len(models) > 1:
61
+ self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
62
+ self.comparator = "Model"
63
+ assert isinstance(datasets, Dataset) or (
64
+ isinstance(datasets, list) and len(datasets) == 1
65
+ ), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
66
+ else:
67
+ self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
68
+ self.comparator = "Dataset"
69
+
70
+ # validate models
71
+ if all(map(lambda x: isinstance(x, str), models)):
72
+ models = [Model(each, **auth_kwargs) for each in models]
73
+ elif not all(map(lambda x: isinstance(x, Model), models)):
74
+ raise ValueError(
75
+ f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
76
+ )
77
+ # validate datasets
78
+ if not isinstance(datasets, list):
79
+ datasets = [
80
+ datasets,
81
+ ]
82
+ if all(map(lambda x: isinstance(x, str), datasets)):
83
+ datasets = [Dataset(each, **auth_kwargs) for each in datasets]
84
+ elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
85
+ raise ValueError(
86
+ f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
87
+ )
88
+ # Validate models vs datasets together
89
+ self._eval_handlers: List[_BaseEvalResultHandler] = []
90
+ self.model_type = None
91
+ logger.info("Initializing models...")
92
+ for model in models:
93
+ model.load_info()
94
+ model_type = model.model_info.model_type_id
95
+ if not self.model_type:
96
+ self.model_type = model_type
97
+ else:
98
+ assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
99
+ m = make_handler_by_type(model_type)(model=model)
100
+ logger.info(f"* {m.get_model_name(pretify=True)}")
101
+ m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
102
+ self._eval_handlers.append(m)
103
+
104
+ @property
105
+ def eval_handlers(self):
106
+ return self._eval_handlers
107
+
108
+ def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
109
+ """ Run methods of `eval_handlers[...].model`
110
+
111
+ Args:
112
+ func_name (str): method name, see `_BaseEvalResultHandler` child classes
113
+ kwargs: keyword arguments of the method
114
+
115
+ Return:
116
+ tuple:
117
+ - list of outputs
118
+ - list of comparator names
119
+
120
+ """
121
+ outs = []
122
+ comparators = []
123
+ logger.info(f'Running `{func_name}`')
124
+ for _, each in enumerate(self.eval_handlers):
125
+ for ds_index, _ in enumerate(each.eval_data):
126
+ func = eval(f'each.{func_name}')
127
+ out = func(index=ds_index, **kwargs)
128
+
129
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
130
+ name = each.get_model_name(pretify=True)
131
+ else:
132
+ name = each.get_dataset_name_by_index(ds_index, pretify=True)
133
+ if out is None:
134
+ logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
135
+ continue
136
+ comparators.append(name)
137
+ outs.append(out)
138
+
139
+ # remove app_id if models a
140
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
141
+ apps = set([comp.split('/')[0] for comp in comparators])
142
+ if len(apps) == 1:
143
+ comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
144
+
145
+ if not outs:
146
+ logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
147
+
148
+ return outs, comparators
149
+
150
+ def detailed_summary(self,
151
+ confidence_threshold: float = .5,
152
+ iou_threshold: float = .5,
153
+ area: str = "all",
154
+ bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
155
+ """
156
+ Retrieve and compute popular metrics of model.
157
+
158
+ Args:
159
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
160
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
161
+ area (float): size of area, support {all, small, medium}, applicable for detection
162
+
163
+ Return:
164
+ None or tuple of dataframe: df summary per concept and total concepts
165
+
166
+ """
167
+ df = []
168
+ total = []
169
+ # loop over all eval_handlers/dataset and call its method
170
+ outs, comparators = self._loop_eval_handlers(
171
+ 'detailed_summary',
172
+ confidence_threshold=confidence_threshold,
173
+ iou_threshold=iou_threshold,
174
+ area=area,
175
+ bypass_const=bypass_const)
176
+ for indx, out in enumerate(outs):
177
+ _df, _total = out
178
+ _df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
179
+ _total['Concept'].replace(
180
+ to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
181
+ _total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
182
+ df.append(_df)
183
+ total.append(_total)
184
+
185
+ if df:
186
+ df = pd.concat(df, axis=0)
187
+ total = pd.concat(total, axis=0)
188
+ return df, total
189
+ else:
190
+ return None
191
+
192
+ def confusion_matrix(self, show=True, save_path: str = None,
193
+ cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
194
+ """Return dataframe of confusion matrix
195
+ Args:
196
+ show (bool, optional): Show the chart. Defaults to True.
197
+ save_path (str): path to save rendered chart.
198
+ cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
199
+ Returns:
200
+ None or pd.Dataframe, If models don't have confusion matrix, return None
201
+ """
202
+ outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
203
+ all_dfs = []
204
+ for _, (df, anchor) in enumerate(zip(outs, comparators)):
205
+ df[self.comparator] = [anchor for _ in range(len(df))]
206
+ all_dfs.append(df)
207
+
208
+ if all_dfs:
209
+ all_dfs = pd.concat(all_dfs, axis=0)
210
+ if save_path or show:
211
+
212
+ def _facet_heatmap(data, **kws):
213
+ data = data.dropna(axis=1)
214
+ data = data.drop(self.comparator, axis=1)
215
+ concepts = data.columns
216
+ colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
217
+ data.columns = colnames
218
+ ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
219
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
220
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
221
+
222
+ temp = all_dfs.copy()
223
+ temp.columns = ["_".join(pair) for pair in temp.columns]
224
+ with sns.plotting_context(font_scale=5.5):
225
+ g = sns.FacetGrid(
226
+ temp,
227
+ col=self.comparator,
228
+ col_wrap=3,
229
+ aspect=1,
230
+ height=3,
231
+ sharex=False,
232
+ sharey=False,
233
+ )
234
+ cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
235
+ g = g.map_dataframe(
236
+ _facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
237
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
238
+ if show:
239
+ plt.show()
240
+ if save_path:
241
+ g.savefig(save_path)
242
+
243
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
244
+
245
+ @staticmethod
246
+ def _set_default_kwargs(kwargs: dict, var_name: str, value):
247
+ if var_name not in kwargs:
248
+ kwargs.update({var_name: value})
249
+ return kwargs
250
+
251
+ @staticmethod
252
+ def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
253
+ hue_order = df["concept"].unique().tolist()
254
+ hue_order.remove(MACRO_AVG)
255
+ hue_order.insert(0, MACRO_AVG)
256
+ EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
257
+
258
+ sizes = {}
259
+ for each in hue_order:
260
+ s = 1.5
261
+ if each == MACRO_AVG:
262
+ s = 4.
263
+ sizes.update({each: s})
264
+ EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
265
+ EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
266
+
267
+ EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
268
+ EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
269
+
270
+ return kwargs
271
+
272
+ def roc_curve_plot(self,
273
+ show=True,
274
+ save_path: str = None,
275
+ roc_curve_kwargs: dict = {},
276
+ relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
277
+ """Return dataframe of ROC curve
278
+ Args:
279
+ show (bool, optional): Show the chart. Defaults to True.
280
+ save_path (str): path to save rendered chart.
281
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
282
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
283
+ Returns:
284
+ None or pd.Dataframe, If models don't have ROC curve, return None
285
+ """
286
+ sns.set_palette("Paired")
287
+ outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
288
+ all_dfs = []
289
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
290
+ df[self.comparator] = [anchor for _ in range(len(df))]
291
+ all_dfs.append(df)
292
+
293
+ if all_dfs:
294
+ all_dfs = pd.concat(all_dfs, axis=0)
295
+ if save_path or show:
296
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
297
+ g = sns.relplot(
298
+ data=all_dfs,
299
+ x="fpr",
300
+ y="tpr",
301
+ hue='concept',
302
+ kind="line",
303
+ col=self.comparator,
304
+ **relplot_kwargs)
305
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
306
+ if show:
307
+ plt.show()
308
+ if save_path:
309
+ g.savefig(save_path)
310
+
311
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
312
+
313
+ def pr_plot(self,
314
+ show=True,
315
+ save_path: str = None,
316
+ pr_curve_kwargs: dict = {},
317
+ relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
318
+ """Return dataframe of PR curve
319
+ Args:
320
+ show (bool, optional): Show the chart. Defaults to True.
321
+ save_path (str): path to save rendered chart.
322
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
323
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
324
+ Returns:
325
+ None or pd.Dataframe, If models don't have PR curve, return None
326
+ """
327
+ sns.set_palette("Paired")
328
+ outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
329
+ all_dfs = []
330
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
331
+ df[self.comparator] = [anchor for _ in range(len(df))]
332
+ all_dfs.append(df)
333
+
334
+ if all_dfs:
335
+ all_dfs = pd.concat(all_dfs, axis=0)
336
+ if save_path or show:
337
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
338
+ g = sns.relplot(
339
+ data=all_dfs,
340
+ x="recall",
341
+ y="precision",
342
+ hue='concept',
343
+ kind="line",
344
+ col=self.comparator,
345
+ **relplot_kwargs)
346
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
347
+ if show:
348
+ plt.show()
349
+ if save_path:
350
+ g.savefig(save_path)
351
+
352
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
353
+
354
+ def all(
355
+ self,
356
+ output_folder: str,
357
+ confidence_threshold: float = 0.5,
358
+ iou_threshold: float = 0.5,
359
+ overwrite: bool = False,
360
+ metric_kwargs: dict = {},
361
+ pr_plot_kwargs: dict = {},
362
+ roc_plot_kwargs: dict = {},
363
+ ):
364
+ """Run all comparison methods one by one:
365
+ - detailed_summary
366
+ - pr_curve (if applicable)
367
+ - pr_plot
368
+ - confusion_matrix (if applicable)
369
+ And save to output_folder
370
+
371
+ Args:
372
+ output_folder (str): path to output
373
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
374
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
375
+ overwrite (bool): overwrite result of output_folder.
376
+ metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
377
+ roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
378
+ pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
379
+ """
380
+ eval_type = get_eval_type(self.model_type)
381
+ area = metric_kwargs.pop("area", "all")
382
+ bypass_const = metric_kwargs.pop("bypass_const", False)
383
+
384
+ fname = f"conf-{confidence_threshold}"
385
+ if eval_type == EvalType.DETECTION:
386
+ fname = f"{fname}_iou-{iou_threshold}_area-{area}"
387
+
388
+ def join_root(*args):
389
+ return os.path.join(output_folder, *args)
390
+
391
+ output_folder = join_root(fname)
392
+ if os.path.exists(output_folder) and not overwrite:
393
+ raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
394
+
395
+ os.makedirs(output_folder, exist_ok=True)
396
+
397
+ logger.info("Making summary tables...")
398
+ dfs = self.detailed_summary(
399
+ confidence_threshold=confidence_threshold,
400
+ iou_threshold=iou_threshold,
401
+ area=area,
402
+ bypass_const=bypass_const)
403
+ if dfs is not None:
404
+ concept_df, total_df = dfs
405
+ concept_df.to_csv(join_root("concepts_summary.csv"))
406
+ total_df.to_csv(join_root("total_summary.csv"))
407
+
408
+ curve_metric_kwargs = dict(
409
+ confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
410
+ curve_metric_kwargs.update(metric_kwargs)
411
+
412
+ self.roc_curve_plot(
413
+ show=False,
414
+ save_path=join_root("roc.jpg"),
415
+ roc_curve_kwargs=curve_metric_kwargs,
416
+ relplot_kwargs=roc_plot_kwargs)
417
+
418
+ self.pr_plot(
419
+ show=False,
420
+ save_path=join_root("pr.jpg"),
421
+ pr_curve_kwargs=curve_metric_kwargs,
422
+ relplot_kwargs=pr_plot_kwargs)
423
+
424
+ self.confusion_matrix(
425
+ show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
426
+
427
+ logger.info(f"Done. Your outputs are saved at {output_folder}")