rapidata 2.21.4__py3-none-any.whl → 2.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +5 -0
- rapidata/api_client/__init__.py +8 -4
- rapidata/api_client/api/__init__.py +1 -0
- rapidata/api_client/api/evaluation_workflow_api.py +372 -0
- rapidata/api_client/api/identity_api.py +268 -0
- rapidata/api_client/api/rapid_api.py +353 -1987
- rapidata/api_client/api/simple_workflow_api.py +6 -6
- rapidata/api_client/models/__init__.py +7 -4
- rapidata/api_client/models/add_campaign_model.py +25 -1
- rapidata/api_client/models/add_validation_rapid_model_truth.py +24 -10
- rapidata/api_client/models/compare_result.py +2 -0
- rapidata/api_client/models/create_order_model.py +43 -2
- rapidata/api_client/models/evaluation_workflow_model1.py +115 -0
- rapidata/api_client/models/filter.py +2 -2
- rapidata/api_client/models/get_validation_rapids_result.py +11 -4
- rapidata/api_client/models/get_validation_rapids_result_truth.py +24 -10
- rapidata/api_client/models/get_workflow_by_id_result_workflow.py +23 -9
- rapidata/api_client/models/get_workflow_results_result.py +118 -0
- rapidata/api_client/models/get_workflow_results_result_paged_result.py +105 -0
- rapidata/api_client/models/google_one_tap_login_model.py +87 -0
- rapidata/api_client/models/labeling_selection.py +22 -3
- rapidata/api_client/models/logic_operator.py +1 -0
- rapidata/api_client/models/rapid_response.py +3 -1
- rapidata/api_client/models/retrieval_mode.py +38 -0
- rapidata/api_client/models/root_filter.py +2 -2
- rapidata/api_client/models/skip_truth.py +94 -0
- rapidata/api_client/models/sticky_state.py +38 -0
- rapidata/api_client/models/update_validation_rapid_model.py +11 -4
- rapidata/api_client/models/update_validation_rapid_model_truth.py +24 -10
- rapidata/api_client/rest.py +1 -0
- rapidata/api_client_README.md +10 -11
- rapidata/rapidata_client/__init__.py +7 -0
- rapidata/rapidata_client/api/rapidata_exception.py +5 -3
- rapidata/rapidata_client/assets/__init__.py +1 -0
- rapidata/rapidata_client/assets/_media_asset.py +16 -10
- rapidata/rapidata_client/assets/_multi_asset.py +6 -0
- rapidata/rapidata_client/assets/_sessions.py +35 -0
- rapidata/rapidata_client/assets/_text_asset.py +6 -0
- rapidata/rapidata_client/demographic/demographic_manager.py +2 -35
- rapidata/rapidata_client/logging/__init__.py +2 -0
- rapidata/rapidata_client/logging/logger.py +47 -0
- rapidata/rapidata_client/logging/output_manager.py +16 -0
- rapidata/rapidata_client/order/_rapidata_dataset.py +11 -15
- rapidata/rapidata_client/order/_rapidata_order_builder.py +15 -2
- rapidata/rapidata_client/order/rapidata_order.py +23 -14
- rapidata/rapidata_client/order/rapidata_order_manager.py +4 -2
- rapidata/rapidata_client/order/rapidata_results.py +2 -1
- rapidata/rapidata_client/rapidata_client.py +6 -1
- rapidata/rapidata_client/selection/__init__.py +1 -0
- rapidata/rapidata_client/selection/labeling_selection.py +8 -2
- rapidata/rapidata_client/selection/retrieval_modes.py +9 -0
- rapidata/rapidata_client/settings/alert_on_fast_response.py +2 -1
- rapidata/rapidata_client/settings/free_text_minimum_characters.py +2 -1
- rapidata/rapidata_client/validation/rapidata_validation_set.py +4 -34
- rapidata/rapidata_client/validation/rapids/rapids.py +6 -7
- rapidata/rapidata_client/validation/validation_set_manager.py +39 -36
- rapidata/service/credential_manager.py +22 -30
- rapidata/service/openapi_service.py +11 -0
- {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/METADATA +2 -1
- {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/RECORD +62 -49
- {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/WHEEL +1 -1
- {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/LICENSE +0 -0
|
@@ -29,6 +29,8 @@ from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset, B
|
|
|
29
29
|
|
|
30
30
|
from typing import Optional, cast, Sequence
|
|
31
31
|
|
|
32
|
+
from rapidata.rapidata_client.logging import logger, managed_print
|
|
33
|
+
|
|
32
34
|
|
|
33
35
|
class RapidataOrderBuilder:
|
|
34
36
|
"""Builder object for creating Rapidata orders.
|
|
@@ -73,7 +75,7 @@ class RapidataOrderBuilder:
|
|
|
73
75
|
raise ValueError("You must provide a workflow to create an order.")
|
|
74
76
|
|
|
75
77
|
if self.__referee is None:
|
|
76
|
-
|
|
78
|
+
managed_print("No referee provided, using default NaiveReferee.")
|
|
77
79
|
self.__referee = NaiveReferee()
|
|
78
80
|
|
|
79
81
|
return CreateOrderModel(
|
|
@@ -113,6 +115,7 @@ class RapidataOrderBuilder:
|
|
|
113
115
|
RapidataOrder: The created RapidataOrder instance.
|
|
114
116
|
"""
|
|
115
117
|
order_model = self._to_model()
|
|
118
|
+
logger.debug(f"Creating order with model: {order_model}")
|
|
116
119
|
if isinstance(
|
|
117
120
|
self.__workflow, CompareWorkflow
|
|
118
121
|
): # Temporary fix; will be handled by backend in the future
|
|
@@ -125,12 +128,17 @@ class RapidataOrderBuilder:
|
|
|
125
128
|
)
|
|
126
129
|
|
|
127
130
|
self.order_id = str(result.order_id)
|
|
131
|
+
logger.debug(f"Order created with ID: {self.order_id}")
|
|
128
132
|
|
|
129
133
|
self.__dataset = (
|
|
130
134
|
RapidataDataset(result.dataset_id, self.__openapi_service)
|
|
131
135
|
if result.dataset_id
|
|
132
136
|
else None
|
|
133
137
|
)
|
|
138
|
+
if self.__dataset:
|
|
139
|
+
logger.debug(f"Dataset created with ID: {self.__dataset.dataset_id}")
|
|
140
|
+
else:
|
|
141
|
+
logger.warning("No dataset created for this order.")
|
|
134
142
|
|
|
135
143
|
order = RapidataOrder(
|
|
136
144
|
order_id=self.order_id,
|
|
@@ -138,6 +146,9 @@ class RapidataOrderBuilder:
|
|
|
138
146
|
name=self._name,
|
|
139
147
|
)
|
|
140
148
|
|
|
149
|
+
logger.debug(f"Order created: {order}")
|
|
150
|
+
logger.debug("Adding media to the order.")
|
|
151
|
+
|
|
141
152
|
if all(isinstance(item, MediaAsset) for item in self.__assets) and self.__dataset:
|
|
142
153
|
assets = cast(list[MediaAsset], self.__assets)
|
|
143
154
|
self.__dataset._add_media_from_paths(assets, self.__metadata, max_upload_workers)
|
|
@@ -183,6 +194,8 @@ class RapidataOrderBuilder:
|
|
|
183
194
|
"Media paths must all be of the same type: MediaAsset, TextAsset, or MultiAsset."
|
|
184
195
|
)
|
|
185
196
|
|
|
197
|
+
logger.debug("Media added to the order.")
|
|
198
|
+
logger.debug("Setting order to preview")
|
|
186
199
|
self.__openapi_service.order_api.order_order_id_preview_post(self.order_id)
|
|
187
200
|
|
|
188
201
|
return order
|
|
@@ -291,7 +304,7 @@ class RapidataOrderBuilder:
|
|
|
291
304
|
raise TypeError("Filters must be of type Filter.")
|
|
292
305
|
|
|
293
306
|
if len(self.__user_filters) > 0:
|
|
294
|
-
|
|
307
|
+
managed_print("Overwriting existing user filters.")
|
|
295
308
|
|
|
296
309
|
self.__user_filters = filters
|
|
297
310
|
return self
|
|
@@ -17,6 +17,7 @@ from rapidata.api_client.models.preliminary_download_model import PreliminaryDow
|
|
|
17
17
|
from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
|
|
18
18
|
from rapidata.rapidata_client.order.rapidata_results import RapidataResults
|
|
19
19
|
from rapidata.service.openapi_service import OpenAPIService
|
|
20
|
+
from rapidata.rapidata_client.logging import logger, managed_print
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class RapidataOrder:
|
|
@@ -47,23 +48,29 @@ class RapidataOrder:
|
|
|
47
48
|
self._max_retries = 10
|
|
48
49
|
self._retry_delay = 2
|
|
49
50
|
self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}"
|
|
51
|
+
logger.debug("RapidataOrder initialized")
|
|
50
52
|
|
|
51
|
-
def run(self
|
|
53
|
+
def run(self) -> "RapidataOrder":
|
|
52
54
|
"""Runs the order to start collecting responses."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
logger.info(f"Starting order '{self}'")
|
|
56
|
+
self.__openapi_service.order_api.order_order_id_submit_post(self.order_id)
|
|
57
|
+
logger.debug(f"Order '{self}' has been started.")
|
|
58
|
+
managed_print(f"Order '{self.name}' is now viewable under: {self.order_details_page}")
|
|
56
59
|
return self
|
|
57
60
|
|
|
58
61
|
def pause(self) -> None:
|
|
59
62
|
"""Pauses the order."""
|
|
63
|
+
logger.info(f"Pausing order '{self}'")
|
|
60
64
|
self.__openapi_service.order_api.order_pause_post(self.order_id)
|
|
61
|
-
|
|
65
|
+
logger.debug(f"Order '{self}' has been paused.")
|
|
66
|
+
managed_print(f"Order '{self}' has been paused.")
|
|
62
67
|
|
|
63
68
|
def unpause(self) -> None:
|
|
64
69
|
"""Unpauses/resumes the order."""
|
|
70
|
+
logger.info(f"Unpausing order '{self}'")
|
|
65
71
|
self.__openapi_service.order_api.order_resume_post(self.order_id)
|
|
66
|
-
|
|
72
|
+
logger.debug(f"Order '{self}' has been unpaused.")
|
|
73
|
+
managed_print(f"Order '{self}' has been unpaused.")
|
|
67
74
|
|
|
68
75
|
def get_status(self) -> str:
|
|
69
76
|
"""
|
|
@@ -95,12 +102,12 @@ class RapidataOrder:
|
|
|
95
102
|
raise Exception("Order has not been started yet. Please start it first.")
|
|
96
103
|
|
|
97
104
|
while self.get_status() == OrderState.SUBMITTED:
|
|
98
|
-
|
|
105
|
+
managed_print(f"Order '{self}' is submitted and being reviewed. Standby...", end="\r")
|
|
99
106
|
sleep(1)
|
|
100
107
|
|
|
101
108
|
if self.get_status() == OrderState.MANUALREVIEW:
|
|
102
109
|
raise Exception(
|
|
103
|
-
f"Order '{self
|
|
110
|
+
f"Order '{self}' is in manual review. It might take some time to start. "
|
|
104
111
|
"To speed up the process, contact support (info@rapidata.ai).\n"
|
|
105
112
|
"Once started, run this method again to display the progress bar."
|
|
106
113
|
)
|
|
@@ -145,12 +152,12 @@ class RapidataOrder:
|
|
|
145
152
|
Note that preliminary results are not final and may not contain all the datapoints & responses. Only the onese that are already available.
|
|
146
153
|
This will throw an exception if there are no responses available yet.
|
|
147
154
|
"""
|
|
148
|
-
|
|
155
|
+
logger.info(f"Getting results for order '{self}'...")
|
|
149
156
|
if preliminary_results and self.get_status() not in [OrderState.COMPLETED]:
|
|
150
157
|
return self.__get_preliminary_results()
|
|
151
158
|
|
|
152
159
|
elif preliminary_results and self.get_status() in [OrderState.COMPLETED]:
|
|
153
|
-
|
|
160
|
+
managed_print("Order is already completed. Returning final results.")
|
|
154
161
|
|
|
155
162
|
while self.get_status() not in [OrderState.COMPLETED, OrderState.PAUSED, OrderState.MANUALREVIEW, OrderState.FAILED]:
|
|
156
163
|
sleep(5)
|
|
@@ -167,10 +174,11 @@ class RapidataOrder:
|
|
|
167
174
|
Raises:
|
|
168
175
|
Exception: If the order is not in processing state.
|
|
169
176
|
"""
|
|
177
|
+
logger.info("Opening order details page in browser...")
|
|
170
178
|
could_open_browser = webbrowser.open(self.order_details_page)
|
|
171
179
|
if not could_open_browser:
|
|
172
180
|
encoded_url = urllib.parse.quote(self.order_details_page, safe="%/:=&?~#+!$,;'@()*[]")
|
|
173
|
-
|
|
181
|
+
managed_print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
|
|
174
182
|
|
|
175
183
|
def preview(self) -> None:
|
|
176
184
|
"""
|
|
@@ -178,13 +186,14 @@ class RapidataOrder:
|
|
|
178
186
|
|
|
179
187
|
Raises:
|
|
180
188
|
Exception: If the order is not in processing state.
|
|
181
|
-
"""
|
|
189
|
+
"""
|
|
190
|
+
logger.info("Opening order preview in browser...")
|
|
182
191
|
campaign_id = self.__get_campaign_id()
|
|
183
192
|
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
|
|
184
193
|
could_open_browser = webbrowser.open(auth_url)
|
|
185
194
|
if not could_open_browser:
|
|
186
195
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
187
|
-
|
|
196
|
+
managed_print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
|
|
188
197
|
|
|
189
198
|
def __get_pipeline_id(self) -> str:
|
|
190
199
|
"""Internal method to fetch and cache the pipeline ID."""
|
|
@@ -249,7 +258,7 @@ class RapidataOrder:
|
|
|
249
258
|
raise Exception(f"Failed to get preliminary results: {str(e)}") from e
|
|
250
259
|
|
|
251
260
|
def __str__(self) -> str:
|
|
252
|
-
return f"name
|
|
261
|
+
return f"RapidataOrder(name='{self.name}', order_id='{self.order_id}')"
|
|
253
262
|
|
|
254
263
|
def __repr__(self) -> str:
|
|
255
264
|
return f"RapidataOrder(name='{self.name}', order_id='{self.order_id}')"
|
|
@@ -34,6 +34,7 @@ from rapidata.api_client.models.page_info import PageInfo
|
|
|
34
34
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
35
35
|
from rapidata.api_client.models.filter import Filter
|
|
36
36
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
37
|
+
from rapidata.rapidata_client.logging import logger
|
|
37
38
|
|
|
38
39
|
from tqdm import tqdm
|
|
39
40
|
|
|
@@ -53,6 +54,7 @@ class RapidataOrderManager:
|
|
|
53
54
|
self.settings = RapidataSettings
|
|
54
55
|
self.selections = RapidataSelections
|
|
55
56
|
self.__priority = 50
|
|
57
|
+
logger.debug("RapidataOrderManager initialized")
|
|
56
58
|
|
|
57
59
|
def __get_selections(self, validation_set_id: str | None, labeling_amount=3) -> Sequence[RapidataSelection]:
|
|
58
60
|
if validation_set_id:
|
|
@@ -85,7 +87,7 @@ class RapidataOrderManager:
|
|
|
85
87
|
raise ValueError("You can only use contexts or sentences, not both")
|
|
86
88
|
|
|
87
89
|
if contexts and data_type == RapidataDataTypes.TEXT:
|
|
88
|
-
|
|
90
|
+
logger.warning("Warning: Contexts are not supported for text data type. Ignoring contexts.")
|
|
89
91
|
|
|
90
92
|
if not confidence_threshold:
|
|
91
93
|
referee = NaiveReferee(responses=responses_per_datapoint)
|
|
@@ -98,7 +100,7 @@ class RapidataOrderManager:
|
|
|
98
100
|
order_builder = RapidataOrderBuilder(name=name, openapi_service=self._openapi_service)
|
|
99
101
|
|
|
100
102
|
if selections and validation_set_id:
|
|
101
|
-
|
|
103
|
+
logger.warning("Warning: Both selections and validation_set_id provided. Ignoring validation_set_id.")
|
|
102
104
|
|
|
103
105
|
if selections is None:
|
|
104
106
|
selections = self.__get_selections(validation_set_id, labeling_amount=default_labeling_amount)
|
|
@@ -2,6 +2,7 @@ import pandas as pd
|
|
|
2
2
|
from typing import Any
|
|
3
3
|
from pandas.core.indexes.base import Index
|
|
4
4
|
import json
|
|
5
|
+
from rapidata.rapidata_client.logging import managed_print
|
|
5
6
|
|
|
6
7
|
class RapidataResults(dict):
|
|
7
8
|
"""
|
|
@@ -32,7 +33,7 @@ class RapidataResults(dict):
|
|
|
32
33
|
return pd.DataFrame()
|
|
33
34
|
|
|
34
35
|
if self["info"].get("orderType") is None:
|
|
35
|
-
|
|
36
|
+
managed_print("Warning: Results are old and Order type is not specified. Dataframe might be wrong.")
|
|
36
37
|
|
|
37
38
|
# Check for detailed results if split_details is True
|
|
38
39
|
if split_details:
|
|
@@ -8,6 +8,7 @@ from rapidata.rapidata_client.validation.validation_set_manager import (
|
|
|
8
8
|
|
|
9
9
|
from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
|
|
10
10
|
|
|
11
|
+
from rapidata.rapidata_client.logging import logger
|
|
11
12
|
|
|
12
13
|
class RapidataClient:
|
|
13
14
|
"""The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
|
|
@@ -38,6 +39,7 @@ class RapidataClient:
|
|
|
38
39
|
order (RapidataOrderManager): The RapidataOrderManager instance.
|
|
39
40
|
validation (ValidationSetManager): The ValidationSetManager instance.
|
|
40
41
|
"""
|
|
42
|
+
logger.debug("Initializing OpenAPIService")
|
|
41
43
|
self._openapi_service = OpenAPIService(
|
|
42
44
|
client_id=client_id,
|
|
43
45
|
client_secret=client_secret,
|
|
@@ -48,12 +50,15 @@ class RapidataClient:
|
|
|
48
50
|
leeway=leeway,
|
|
49
51
|
)
|
|
50
52
|
|
|
53
|
+
logger.debug("Initializing RapidataOrderManager")
|
|
51
54
|
self.order = RapidataOrderManager(openapi_service=self._openapi_service)
|
|
52
55
|
|
|
56
|
+
logger.debug("Initializing ValidationSetManager")
|
|
53
57
|
self.validation = ValidationSetManager(openapi_service=self._openapi_service)
|
|
54
58
|
|
|
59
|
+
logger.debug("Initializing DemographicManager")
|
|
55
60
|
self._demographic = DemographicManager(openapi_service=self._openapi_service)
|
|
56
|
-
|
|
61
|
+
|
|
57
62
|
def reset_credentials(self):
|
|
58
63
|
"""Reset the credentials saved in the configuration file for the current environment."""
|
|
59
64
|
self._openapi_service.reset_credentials()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
|
|
3
|
+
from rapidata.rapidata_client.selection.retrieval_modes import RetrievalMode
|
|
3
4
|
from rapidata.api_client.models.labeling_selection import (
|
|
4
5
|
LabelingSelection as LabelingSelectionModel,
|
|
5
6
|
)
|
|
@@ -12,10 +13,15 @@ class LabelingSelection(RapidataSelection):
|
|
|
12
13
|
|
|
13
14
|
Args:
|
|
14
15
|
amount (int): The amount of labeling rapids that will be shown per session.
|
|
16
|
+
retrieval_mode (RetrievalMode): The retrieval mode to use. Defaults to "Random".
|
|
17
|
+
max_iterations (int | None): The maximum number an annotator can see the same task. Defaults to None.
|
|
18
|
+
This parameter is only taken into account when using "Shuffled" or "Sequential" retrieval modes.
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(self, amount: int):
|
|
21
|
+
def __init__(self, amount: int, retrieval_mode: RetrievalMode = RetrievalMode.Random, max_iterations: int | None = None):
|
|
18
22
|
self.amount = amount
|
|
23
|
+
self.retrieval_mode = retrieval_mode
|
|
24
|
+
self.max_iterations = max_iterations
|
|
19
25
|
|
|
20
26
|
def _to_model(self) -> Any:
|
|
21
|
-
return LabelingSelectionModel(_t="LabelingSelection", amount=self.amount)
|
|
27
|
+
return LabelingSelectionModel(_t="LabelingSelection", amount=self.amount, retrievalMode=self.retrieval_mode.value, maxIterations=self.max_iterations)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
class RetrievalMode(Enum):
|
|
4
|
+
# Will just randomly shuffle the datapoints. This is the default and will NOT take into account the "max_iterations" parameter.
|
|
5
|
+
Random = "Random"
|
|
6
|
+
# Will shuffle the datapoints randomly for each user. The user will then see the datapoints in that order. This will take into account the "max_iterations" parameter.
|
|
7
|
+
Shuffled = "Shuffled"
|
|
8
|
+
# Will show the datapoints in the order they are in the dataset. This will take into account the "max_iterations" parameter.
|
|
9
|
+
Sequential = "Sequential"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
2
|
+
from rapidata.rapidata_client.logging import managed_print
|
|
2
3
|
|
|
3
4
|
class AlertOnFastResponse(RapidataSetting):
|
|
4
5
|
"""
|
|
@@ -12,7 +13,7 @@ class AlertOnFastResponse(RapidataSetting):
|
|
|
12
13
|
if not isinstance(threshold, int):
|
|
13
14
|
raise ValueError("The alert must be an integer.")
|
|
14
15
|
if threshold < 10:
|
|
15
|
-
|
|
16
|
+
managed_print(f"Warning: Are you sure you want to set the threshold so low ({threshold} milliseconds)?")
|
|
16
17
|
if threshold > 25000:
|
|
17
18
|
raise ValueError("The alert must be less than 25000 milliseconds.")
|
|
18
19
|
if threshold < 0:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
2
|
+
from rapidata.rapidata_client.logging import managed_print, logger
|
|
2
3
|
|
|
3
4
|
class FreeTextMinimumCharacters(RapidataSetting):
|
|
4
5
|
"""
|
|
@@ -12,5 +13,5 @@ class FreeTextMinimumCharacters(RapidataSetting):
|
|
|
12
13
|
if value < 1:
|
|
13
14
|
raise ValueError("The minimum number of characters must be greater than or equal to 1.")
|
|
14
15
|
if value > 40:
|
|
15
|
-
|
|
16
|
+
managed_print(f"Warning: Are you sure you want to set the minimum number of characters at {value}?")
|
|
16
17
|
super().__init__(key="free_text_minimum_characters", value=value)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from rapidata.rapidata_client.validation.rapids.rapids import Rapid
|
|
2
2
|
from rapidata.service.openapi_service import OpenAPIService
|
|
3
|
-
from
|
|
4
|
-
import requests
|
|
3
|
+
from rapidata.rapidata_client.logging import logger
|
|
5
4
|
from rapidata.api_client.models.update_dimensions_model import UpdateDimensionsModel
|
|
5
|
+
from rapidata.rapidata_client.assets._sessions import SessionManager
|
|
6
6
|
|
|
7
7
|
class RapidataValidationSet:
|
|
8
8
|
"""A class for interacting with a Rapidata validation set.
|
|
@@ -20,7 +20,6 @@ class RapidataValidationSet:
|
|
|
20
20
|
self.id = validation_set_id
|
|
21
21
|
self.name = name
|
|
22
22
|
self.__openapi_service = openapi_service
|
|
23
|
-
self.__session = self._get_session()
|
|
24
23
|
|
|
25
24
|
def add_rapid(self, rapid: Rapid):
|
|
26
25
|
"""Add a Rapid to the validation set.
|
|
@@ -28,7 +27,7 @@ class RapidataValidationSet:
|
|
|
28
27
|
Args:
|
|
29
28
|
rapid (Rapid): The Rapid to add to the validation set.
|
|
30
29
|
"""
|
|
31
|
-
rapid._add_to_validation_set(self.id, self.__openapi_service
|
|
30
|
+
rapid._add_to_validation_set(self.id, self.__openapi_service)
|
|
32
31
|
return self
|
|
33
32
|
|
|
34
33
|
def update_dimensions(self, dimensions: list[str]):
|
|
@@ -37,39 +36,10 @@ class RapidataValidationSet:
|
|
|
37
36
|
Args:
|
|
38
37
|
dimensions (list[str]): The new dimensions of the validation set.
|
|
39
38
|
"""
|
|
39
|
+
logger.debug(f"Updating dimensions for validation set {self.id} to {dimensions}")
|
|
40
40
|
self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions))
|
|
41
41
|
return self
|
|
42
42
|
|
|
43
|
-
def _get_session(self, max_retries: int = 5, max_workers: int = 10) -> requests.Session:
|
|
44
|
-
"""Get a requests session with retry logic.
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
max_retries (int): The maximum number of retries.
|
|
49
|
-
max_workers (int): The maximum number of workers.
|
|
50
|
-
|
|
51
|
-
Returns:
|
|
52
|
-
requests.Session: A requests session with retry logic.
|
|
53
|
-
"""
|
|
54
|
-
session = requests.Session()
|
|
55
|
-
retries = Retry(
|
|
56
|
-
total=max_retries,
|
|
57
|
-
backoff_factor=1,
|
|
58
|
-
status_forcelist=[500, 502, 503, 504],
|
|
59
|
-
allowed_methods=["GET"],
|
|
60
|
-
respect_retry_after_header=True
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
adapter = HTTPAdapter(
|
|
64
|
-
pool_connections=max_workers * 2,
|
|
65
|
-
pool_maxsize=max_workers * 4,
|
|
66
|
-
max_retries=retries
|
|
67
|
-
)
|
|
68
|
-
session.mount('http://', adapter)
|
|
69
|
-
session.mount('https://', adapter)
|
|
70
|
-
|
|
71
|
-
return session
|
|
72
|
-
|
|
73
43
|
def __str__(self):
|
|
74
44
|
return f"name: '{self.name}' id: {self.id}"
|
|
75
45
|
|
|
@@ -19,7 +19,8 @@ from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner
|
|
|
19
19
|
|
|
20
20
|
from rapidata.service.openapi_service import OpenAPIService
|
|
21
21
|
|
|
22
|
-
import
|
|
22
|
+
from rapidata.rapidata_client.logging import logger
|
|
23
|
+
|
|
23
24
|
|
|
24
25
|
class Rapid():
|
|
25
26
|
def __init__(self, asset: MediaAsset | TextAsset | MultiAsset, metadata: Sequence[Metadata], payload: Any, truth: Any, randomCorrectProbability: float, explanation: str | None):
|
|
@@ -29,15 +30,16 @@ class Rapid():
|
|
|
29
30
|
self.truth = truth
|
|
30
31
|
self.randomCorrectProbability = randomCorrectProbability
|
|
31
32
|
self.explanation = explanation
|
|
33
|
+
logger.debug(f"Created Rapid with asset: {self.asset}, metadata: {self.metadata}, payload: {self.payload}, truth: {self.truth}, randomCorrectProbability: {self.randomCorrectProbability}, explanation: {self.explanation}")
|
|
32
34
|
|
|
33
|
-
def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService
|
|
35
|
+
def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService) -> None:
|
|
34
36
|
if isinstance(self.asset, TextAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], TextAsset)):
|
|
35
37
|
openapi_service.validation_api.validation_add_validation_text_rapid_post(
|
|
36
38
|
add_validation_text_rapid_model=self.__to_text_model(validationSetId)
|
|
37
39
|
)
|
|
38
40
|
|
|
39
41
|
elif isinstance(self.asset, MediaAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], MediaAsset)):
|
|
40
|
-
model = self.__to_media_model(validationSetId
|
|
42
|
+
model = self.__to_media_model(validationSetId)
|
|
41
43
|
openapi_service.validation_api.validation_add_validation_rapid_post(
|
|
42
44
|
model=model[0], files=model[1]
|
|
43
45
|
)
|
|
@@ -45,7 +47,7 @@ class Rapid():
|
|
|
45
47
|
else:
|
|
46
48
|
raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
|
|
47
49
|
|
|
48
|
-
def __to_media_model(self, validationSetId: str
|
|
50
|
+
def __to_media_model(self, validationSetId: str) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
|
|
49
51
|
assets: list[MediaAsset] = []
|
|
50
52
|
if isinstance(self.asset, MultiAsset):
|
|
51
53
|
for asset in self.asset.assets:
|
|
@@ -60,9 +62,6 @@ class Rapid():
|
|
|
60
62
|
if isinstance(self.asset, MediaAsset):
|
|
61
63
|
assets = [self.asset]
|
|
62
64
|
|
|
63
|
-
for asset in assets:
|
|
64
|
-
asset.session = session
|
|
65
|
-
|
|
66
65
|
return (AddValidationRapidModel(
|
|
67
66
|
validationSetId=validationSetId,
|
|
68
67
|
payload=AddValidationRapidModelPayload(self.payload),
|