rapidata 0.5.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rapidata/__init__.py +3 -0
- rapidata/api_client/__init__.py +3 -1
- rapidata/api_client/api/validation_api.py +276 -0
- rapidata/api_client/models/__init__.py +3 -1
- rapidata/api_client/models/add_campaign_model.py +3 -3
- rapidata/api_client/models/add_validation_text_rapid_model.py +118 -0
- rapidata/api_client/models/capped_selection.py +108 -0
- rapidata/api_client/models/capped_selection_selections_inner.py +198 -0
- rapidata/api_client/models/create_order_model.py +3 -3
- rapidata/api_client_README.md +4 -1
- rapidata/rapidata_client/__init__.py +1 -0
- rapidata/rapidata_client/assets/__init__.py +8 -0
- rapidata/rapidata_client/assets/base_asset.py +11 -0
- rapidata/rapidata_client/assets/media_asset.py +33 -0
- rapidata/rapidata_client/assets/multi_asset.py +44 -0
- rapidata/rapidata_client/assets/text_asset.py +25 -0
- rapidata/rapidata_client/dataset/rapidata_dataset.py +4 -4
- rapidata/rapidata_client/dataset/rapidata_validation_set.py +54 -27
- rapidata/rapidata_client/dataset/validation_rapid_parts.py +4 -1
- rapidata/rapidata_client/dataset/validation_set_builder.py +41 -24
- rapidata/rapidata_client/order/rapidata_order_builder.py +98 -33
- rapidata/rapidata_client/rapidata_client.py +24 -0
- rapidata/rapidata_client/simple_builders/simple_classification_builders.py +122 -0
- rapidata/rapidata_client/simple_builders/simple_compare_builders.py +86 -0
- {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/METADATA +1 -1
- {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/RECORD +28 -18
- {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/LICENSE +0 -0
- {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/WHEEL +0 -0
|
@@ -6,6 +6,9 @@ from rapidata.api_client.models.compare_truth import CompareTruth
|
|
|
6
6
|
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
|
|
7
7
|
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
|
|
8
8
|
from rapidata.api_client.models.transcription_word import TranscriptionWord
|
|
9
|
+
from rapidata.rapidata_client.assets.media_asset import MediaAsset
|
|
10
|
+
from rapidata.rapidata_client.assets.multi_asset import MultiAsset
|
|
11
|
+
from rapidata.rapidata_client.assets.text_asset import TextAsset
|
|
9
12
|
from rapidata.rapidata_client.dataset.rapidata_validation_set import (
|
|
10
13
|
RapidataValidationSet,
|
|
11
14
|
)
|
|
@@ -16,21 +19,31 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
16
19
|
|
|
17
20
|
class ValidationSetBuilder:
|
|
18
21
|
"""The ValidationSetBuilder is used to build a validation set.
|
|
19
|
-
Give the validation set a name and then add classify, compare or transcription rapid parts to it.
|
|
22
|
+
Give the validation set a name and then add classify, compare, or transcription rapid parts to it.
|
|
20
23
|
Get a `ValidationSetBuilder` by calling [`rapi.new_validation_set()`](../rapidata_client.md/#rapidata.rapidata_client.rapidata_client.RapidataClient.new_validation_set).
|
|
21
24
|
"""
|
|
22
25
|
|
|
23
26
|
def __init__(self, name: str, openapi_service: OpenAPIService):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the ValidationSetBuilder.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
name (str): The name of the validation set.
|
|
32
|
+
openapi_service (OpenAPIService): An instance of OpenAPIService to interact with the API.
|
|
33
|
+
"""
|
|
24
34
|
self.name = name
|
|
25
35
|
self.openapi_service = openapi_service
|
|
26
36
|
self.validation_set_id: str | None = None
|
|
27
37
|
self._rapid_parts: list[ValidatioRapidParts] = []
|
|
28
38
|
|
|
29
39
|
def create(self):
|
|
30
|
-
"""
|
|
40
|
+
"""Create the validation set by executing all HTTP requests. This should be the last method called on the builder.
|
|
31
41
|
|
|
32
42
|
Returns:
|
|
33
43
|
RapidataValidationSet: A RapidataValidationSet instance.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If the validation set creation fails.
|
|
34
47
|
"""
|
|
35
48
|
result = (
|
|
36
49
|
self.openapi_service.validation_api.validation_create_validation_set_post(
|
|
@@ -52,7 +65,7 @@ class ValidationSetBuilder:
|
|
|
52
65
|
payload=rapid_part.payload,
|
|
53
66
|
truths=rapid_part.truths,
|
|
54
67
|
metadata=rapid_part.metadata,
|
|
55
|
-
|
|
68
|
+
asset=rapid_part.asset,
|
|
56
69
|
randomCorrectProbability=rapid_part.randomCorrectProbability,
|
|
57
70
|
)
|
|
58
71
|
|
|
@@ -60,7 +73,7 @@ class ValidationSetBuilder:
|
|
|
60
73
|
|
|
61
74
|
def add_classify_rapid(
|
|
62
75
|
self,
|
|
63
|
-
|
|
76
|
+
asset: MediaAsset | TextAsset,
|
|
64
77
|
question: str,
|
|
65
78
|
categories: list[str],
|
|
66
79
|
truths: list[str],
|
|
@@ -69,14 +82,17 @@ class ValidationSetBuilder:
|
|
|
69
82
|
"""Add a classify rapid to the validation set.
|
|
70
83
|
|
|
71
84
|
Args:
|
|
72
|
-
|
|
85
|
+
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
73
86
|
question (str): The question for the rapid.
|
|
74
87
|
categories (list[str]): The list of categories for the rapid.
|
|
75
88
|
truths (list[str]): The list of truths for the rapid.
|
|
76
|
-
metadata (list[Metadata], optional): The metadata for the rapid.
|
|
89
|
+
metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
77
90
|
|
|
78
91
|
Returns:
|
|
79
92
|
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If the lengths of categories and truths are inconsistent.
|
|
80
96
|
"""
|
|
81
97
|
payload = ClassifyPayload(
|
|
82
98
|
_t="ClassifyPayload", possibleCategories=categories, title=question
|
|
@@ -88,11 +104,11 @@ class ValidationSetBuilder:
|
|
|
88
104
|
self._rapid_parts.append(
|
|
89
105
|
ValidatioRapidParts(
|
|
90
106
|
question=question,
|
|
91
|
-
media_paths=media_path,
|
|
92
107
|
payload=payload,
|
|
93
108
|
truths=model_truth,
|
|
94
109
|
metadata=metadata,
|
|
95
110
|
randomCorrectProbability=len(truths) / len(categories),
|
|
111
|
+
asset=asset,
|
|
96
112
|
)
|
|
97
113
|
)
|
|
98
114
|
|
|
@@ -100,7 +116,7 @@ class ValidationSetBuilder:
|
|
|
100
116
|
|
|
101
117
|
def add_compare_rapid(
|
|
102
118
|
self,
|
|
103
|
-
|
|
119
|
+
asset: MultiAsset,
|
|
104
120
|
question: str,
|
|
105
121
|
truth: str,
|
|
106
122
|
metadata: list[Metadata] = [],
|
|
@@ -108,35 +124,33 @@ class ValidationSetBuilder:
|
|
|
108
124
|
"""Add a compare rapid to the validation set.
|
|
109
125
|
|
|
110
126
|
Args:
|
|
111
|
-
|
|
127
|
+
asset (MultiAsset): The assets for the rapid.
|
|
112
128
|
question (str): The question for the rapid.
|
|
113
|
-
truth (str): The
|
|
114
|
-
metadata (list[Metadata], optional): The metadata for the rapid.
|
|
129
|
+
truth (str): The truth identifier for the rapid.
|
|
130
|
+
metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
115
131
|
|
|
116
132
|
Returns:
|
|
117
133
|
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValueError: If the number of assets is not exactly two.
|
|
118
137
|
"""
|
|
119
138
|
payload = ComparePayload(_t="ComparePayload", criteria=question)
|
|
120
139
|
# take only last part of truth path
|
|
121
140
|
truth = os.path.basename(truth)
|
|
122
141
|
model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
|
|
123
142
|
|
|
124
|
-
if len(
|
|
143
|
+
if len(asset) != 2:
|
|
125
144
|
raise ValueError("Compare rapid requires exactly two media paths")
|
|
126
145
|
|
|
127
|
-
# check that files exist
|
|
128
|
-
for media_path in media_paths:
|
|
129
|
-
if not os.path.exists(media_path):
|
|
130
|
-
raise FileNotFoundError(f"File not found: {media_path}")
|
|
131
|
-
|
|
132
146
|
self._rapid_parts.append(
|
|
133
147
|
ValidatioRapidParts(
|
|
134
148
|
question=question,
|
|
135
|
-
media_paths=media_paths,
|
|
136
149
|
payload=payload,
|
|
137
150
|
truths=model_truth,
|
|
138
151
|
metadata=metadata,
|
|
139
|
-
randomCorrectProbability=1 / len(
|
|
152
|
+
randomCorrectProbability=1 / len(asset),
|
|
153
|
+
asset=asset,
|
|
140
154
|
)
|
|
141
155
|
)
|
|
142
156
|
|
|
@@ -144,7 +158,7 @@ class ValidationSetBuilder:
|
|
|
144
158
|
|
|
145
159
|
def add_transcription_rapid(
|
|
146
160
|
self,
|
|
147
|
-
|
|
161
|
+
asset: MediaAsset | TextAsset,
|
|
148
162
|
question: str,
|
|
149
163
|
transcription: list[str],
|
|
150
164
|
correct_words: list[str],
|
|
@@ -154,15 +168,18 @@ class ValidationSetBuilder:
|
|
|
154
168
|
"""Add a transcription rapid to the validation set.
|
|
155
169
|
|
|
156
170
|
Args:
|
|
157
|
-
|
|
171
|
+
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
158
172
|
question (str): The question for the rapid.
|
|
159
173
|
transcription (list[str]): The transcription for the rapid.
|
|
160
174
|
correct_words (list[str]): The list of correct words for the rapid.
|
|
161
|
-
strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
|
|
162
|
-
metadata (list[Metadata], optional): The metadata for the rapid.
|
|
175
|
+
strict_grading (bool | None, optional): The strict grading flag for the rapid. Defaults to None.
|
|
176
|
+
metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
163
177
|
|
|
164
178
|
Returns:
|
|
165
179
|
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If a correct word is not found in the transcription.
|
|
166
183
|
"""
|
|
167
184
|
transcription_words = [
|
|
168
185
|
TranscriptionWord(word=word, wordIndex=i)
|
|
@@ -190,7 +207,7 @@ class ValidationSetBuilder:
|
|
|
190
207
|
self._rapid_parts.append(
|
|
191
208
|
ValidatioRapidParts(
|
|
192
209
|
question=question,
|
|
193
|
-
|
|
210
|
+
asset=asset,
|
|
194
211
|
payload=payload,
|
|
195
212
|
truths=model_truth,
|
|
196
213
|
metadata=metadata,
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from rapidata.api_client.models.aggregator_type import AggregatorType
|
|
2
|
+
from rapidata.api_client.models.capped_selection_selections_inner import (
|
|
3
|
+
CappedSelectionSelectionsInner,
|
|
4
|
+
)
|
|
2
5
|
from rapidata.api_client.models.create_order_model import CreateOrderModel
|
|
3
6
|
from rapidata.api_client.models.create_order_model_referee import (
|
|
4
7
|
CreateOrderModelReferee,
|
|
5
8
|
)
|
|
6
|
-
from rapidata.api_client.models.create_order_model_selections_inner import (
|
|
7
|
-
CreateOrderModelSelectionsInner,
|
|
8
|
-
)
|
|
9
9
|
from rapidata.api_client.models.create_order_model_user_filters_inner import (
|
|
10
10
|
CreateOrderModelUserFiltersInner,
|
|
11
11
|
)
|
|
@@ -41,6 +41,13 @@ class RapidataOrderBuilder:
|
|
|
41
41
|
openapi_service: OpenAPIService,
|
|
42
42
|
name: str,
|
|
43
43
|
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize the RapidataOrderBuilder.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
openapi_service (OpenAPIService): The OpenAPIService instance.
|
|
49
|
+
name (str): The name of the order.
|
|
50
|
+
"""
|
|
44
51
|
self._name = name
|
|
45
52
|
self._openapi_service = openapi_service
|
|
46
53
|
self._workflow: Workflow | None = None
|
|
@@ -54,8 +61,19 @@ class RapidataOrderBuilder:
|
|
|
54
61
|
self._selections: list[Selection] = []
|
|
55
62
|
self._rapids_per_bag: int = 2
|
|
56
63
|
self._priority: int = 50
|
|
64
|
+
self._texts: list[str] | None = None
|
|
65
|
+
self._media_paths: list[str | list[str]] = []
|
|
57
66
|
|
|
58
67
|
def _to_model(self) -> CreateOrderModel:
|
|
68
|
+
"""
|
|
69
|
+
Convert the builder configuration to a CreateOrderModel.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If no workflow is provided.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
CreateOrderModel: The model representing the order configuration.
|
|
76
|
+
"""
|
|
59
77
|
if self._workflow is None:
|
|
60
78
|
raise ValueError("You must provide a workflow to create an order.")
|
|
61
79
|
|
|
@@ -86,27 +104,34 @@ class RapidataOrderBuilder:
|
|
|
86
104
|
else None
|
|
87
105
|
),
|
|
88
106
|
selections=[
|
|
89
|
-
|
|
107
|
+
CappedSelectionSelectionsInner(selection.to_model())
|
|
90
108
|
for selection in self._selections
|
|
91
109
|
],
|
|
92
110
|
priority=self._priority,
|
|
93
111
|
)
|
|
94
112
|
|
|
95
|
-
def create(self, submit=True, max_workers=10) -> RapidataOrder:
|
|
96
|
-
"""
|
|
113
|
+
def create(self, submit: bool = True, max_workers: int = 10) -> RapidataOrder:
|
|
114
|
+
"""
|
|
115
|
+
Create the Rapidata order by making the necessary API calls based on the builder's configuration.
|
|
97
116
|
|
|
98
117
|
Args:
|
|
99
|
-
submit (bool, optional): Whether to submit the order. Defaults to True.
|
|
118
|
+
submit (bool, optional): Whether to submit the order upon creation. Defaults to True.
|
|
119
|
+
max_workers (int, optional): The maximum number of worker threads for processing media paths. Defaults to 10.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: If both media paths and texts are provided, or if neither is provided.
|
|
123
|
+
AssertionError: If the workflow is a CompareWorkflow and media paths are not in pairs.
|
|
100
124
|
|
|
101
125
|
Returns:
|
|
102
126
|
RapidataOrder: The created RapidataOrder instance.
|
|
103
|
-
|
|
104
|
-
Raises:
|
|
105
|
-
ValueError: If no workflow is provided.
|
|
106
127
|
"""
|
|
107
128
|
order_model = self._to_model()
|
|
108
|
-
if isinstance(
|
|
109
|
-
|
|
129
|
+
if isinstance(
|
|
130
|
+
self._workflow, CompareWorkflow
|
|
131
|
+
): # Temporary fix; will be handled by backend in the future
|
|
132
|
+
assert all(
|
|
133
|
+
[len(path) == 2 for path in self._media_paths]
|
|
134
|
+
), "The media paths must come in pairs for comparison tasks."
|
|
110
135
|
|
|
111
136
|
result = self._openapi_service.order_api.order_create_post(
|
|
112
137
|
create_order_model=order_model
|
|
@@ -120,14 +145,32 @@ class RapidataOrderBuilder:
|
|
|
120
145
|
openapi_service=self._openapi_service,
|
|
121
146
|
)
|
|
122
147
|
|
|
123
|
-
|
|
148
|
+
if self._media_paths and self._texts:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"You cannot provide both media paths and texts to the same order."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if not self._media_paths and not self._texts:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"You must provide either media paths or texts to the order."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self._texts:
|
|
159
|
+
order.dataset.add_texts(self._texts)
|
|
160
|
+
|
|
161
|
+
if self._media_paths:
|
|
162
|
+
order.dataset.add_media_from_paths(
|
|
163
|
+
self._media_paths, self._metadata, max_workers
|
|
164
|
+
)
|
|
165
|
+
|
|
124
166
|
if submit:
|
|
125
167
|
order.submit()
|
|
126
168
|
|
|
127
169
|
return order
|
|
128
170
|
|
|
129
|
-
def workflow(self, workflow: Workflow):
|
|
130
|
-
"""
|
|
171
|
+
def workflow(self, workflow: Workflow) -> "RapidataOrderBuilder":
|
|
172
|
+
"""
|
|
173
|
+
Set the workflow for the order.
|
|
131
174
|
|
|
132
175
|
Args:
|
|
133
176
|
workflow (Workflow): The workflow to be set.
|
|
@@ -138,8 +181,9 @@ class RapidataOrderBuilder:
|
|
|
138
181
|
self._workflow = workflow
|
|
139
182
|
return self
|
|
140
183
|
|
|
141
|
-
def referee(self, referee: Referee):
|
|
142
|
-
"""
|
|
184
|
+
def referee(self, referee: Referee) -> "RapidataOrderBuilder":
|
|
185
|
+
"""
|
|
186
|
+
Set the referee for the order.
|
|
143
187
|
|
|
144
188
|
Args:
|
|
145
189
|
referee (Referee): The referee to be set.
|
|
@@ -154,8 +198,9 @@ class RapidataOrderBuilder:
|
|
|
154
198
|
self,
|
|
155
199
|
media_paths: list[str | list[str]],
|
|
156
200
|
metadata: list[Metadata] | None = None,
|
|
157
|
-
):
|
|
158
|
-
"""
|
|
201
|
+
) -> "RapidataOrderBuilder":
|
|
202
|
+
"""
|
|
203
|
+
Set the media assets for the order.
|
|
159
204
|
|
|
160
205
|
Args:
|
|
161
206
|
media_paths (list[str | list[str]]): The paths of the media assets to be set.
|
|
@@ -168,8 +213,22 @@ class RapidataOrderBuilder:
|
|
|
168
213
|
self._metadata = metadata
|
|
169
214
|
return self
|
|
170
215
|
|
|
171
|
-
def
|
|
172
|
-
"""
|
|
216
|
+
def texts(self, texts: list[str]) -> "RapidataOrderBuilder":
|
|
217
|
+
"""
|
|
218
|
+
Set the TextAssets for the order.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
texts (list[str]): The texts to be set.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
225
|
+
"""
|
|
226
|
+
self._texts = texts
|
|
227
|
+
return self
|
|
228
|
+
|
|
229
|
+
def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
|
|
230
|
+
"""
|
|
231
|
+
Set the feature flags for the order.
|
|
173
232
|
|
|
174
233
|
Args:
|
|
175
234
|
feature_flags (FeatureFlags): The feature flags to be set.
|
|
@@ -180,8 +239,9 @@ class RapidataOrderBuilder:
|
|
|
180
239
|
self._feature_flags = feature_flags
|
|
181
240
|
return self
|
|
182
241
|
|
|
183
|
-
def country_filter(self, country_codes: list[str]):
|
|
184
|
-
"""
|
|
242
|
+
def country_filter(self, country_codes: list[str]) -> "RapidataOrderBuilder":
|
|
243
|
+
"""
|
|
244
|
+
Set the target country codes for the order.
|
|
185
245
|
|
|
186
246
|
Args:
|
|
187
247
|
country_codes (list[str]): The country codes to be set.
|
|
@@ -192,8 +252,9 @@ class RapidataOrderBuilder:
|
|
|
192
252
|
self._country_codes = country_codes
|
|
193
253
|
return self
|
|
194
254
|
|
|
195
|
-
def aggregator(self, aggregator: AggregatorType):
|
|
196
|
-
"""
|
|
255
|
+
def aggregator(self, aggregator: AggregatorType) -> "RapidataOrderBuilder":
|
|
256
|
+
"""
|
|
257
|
+
Set the aggregator for the order.
|
|
197
258
|
|
|
198
259
|
Args:
|
|
199
260
|
aggregator (AggregatorType): The aggregator to be set.
|
|
@@ -204,8 +265,9 @@ class RapidataOrderBuilder:
|
|
|
204
265
|
self._aggregator = aggregator
|
|
205
266
|
return self
|
|
206
267
|
|
|
207
|
-
def validation_set_id(self, validation_set_id: str):
|
|
208
|
-
"""
|
|
268
|
+
def validation_set_id(self, validation_set_id: str) -> "RapidataOrderBuilder":
|
|
269
|
+
"""
|
|
270
|
+
Set the validation set ID for the order.
|
|
209
271
|
|
|
210
272
|
Args:
|
|
211
273
|
validation_set_id (str): The validation set ID to be set.
|
|
@@ -216,8 +278,9 @@ class RapidataOrderBuilder:
|
|
|
216
278
|
self._validation_set_id = validation_set_id
|
|
217
279
|
return self
|
|
218
280
|
|
|
219
|
-
def rapids_per_bag(self, amount: int):
|
|
220
|
-
"""
|
|
281
|
+
def rapids_per_bag(self, amount: int) -> "RapidataOrderBuilder":
|
|
282
|
+
"""
|
|
283
|
+
Define the number of tasks a user sees in a single session.
|
|
221
284
|
|
|
222
285
|
Args:
|
|
223
286
|
amount (int): The number of tasks a user sees in a single session.
|
|
@@ -230,8 +293,9 @@ class RapidataOrderBuilder:
|
|
|
230
293
|
"""
|
|
231
294
|
raise NotImplementedError("Not implemented yet.")
|
|
232
295
|
|
|
233
|
-
def selections(self, selections: list[Selection]):
|
|
234
|
-
"""
|
|
296
|
+
def selections(self, selections: list[Selection]) -> "RapidataOrderBuilder":
|
|
297
|
+
"""
|
|
298
|
+
Set the selections for the order.
|
|
235
299
|
|
|
236
300
|
Args:
|
|
237
301
|
selections (list[Selection]): The selections to be set.
|
|
@@ -242,8 +306,9 @@ class RapidataOrderBuilder:
|
|
|
242
306
|
self._selections = selections
|
|
243
307
|
return self
|
|
244
308
|
|
|
245
|
-
def priority(self, priority: int):
|
|
246
|
-
"""
|
|
309
|
+
def priority(self, priority: int) -> "RapidataOrderBuilder":
|
|
310
|
+
"""
|
|
311
|
+
Set the priority for the order.
|
|
247
312
|
|
|
248
313
|
Args:
|
|
249
314
|
priority (int): The priority to be set.
|
|
@@ -8,6 +8,8 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
8
8
|
from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
|
|
9
9
|
from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
|
|
10
10
|
|
|
11
|
+
from rapidata.rapidata_client.simple_builders.simple_classification_builders import ClassificationQuestionBuilder
|
|
12
|
+
from rapidata.rapidata_client.simple_builders.simple_compare_builders import CompareCriteriaBuilder
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class RapidataClient:
|
|
@@ -81,6 +83,28 @@ class RapidataClient:
|
|
|
81
83
|
dataset=temp_dataset,
|
|
82
84
|
order_id=order_id,
|
|
83
85
|
openapi_service=self.openapi_service)
|
|
86
|
+
|
|
87
|
+
def create_classify_order(self, name: str) -> ClassificationQuestionBuilder:
|
|
88
|
+
"""Create a new classification order where people are asked to classify an image.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
name (str): The name of the order.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
ClassificationQuestionBuilder: A ClassificationQuestionBuilder instance.
|
|
95
|
+
"""
|
|
96
|
+
return ClassificationQuestionBuilder(name=name, openapi_service=self.openapi_service)
|
|
97
|
+
|
|
98
|
+
def create_compare_order(self, name: str) -> CompareCriteriaBuilder:
|
|
99
|
+
"""Create a new comparison order where people are asked to compare two images.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
name (str): The name of the order.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
CompareQuestionBuilder: A CompareQuestionBuilder instance.
|
|
106
|
+
"""
|
|
107
|
+
return CompareCriteriaBuilder(name=name, openapi_service=self.openapi_service)
|
|
84
108
|
|
|
85
109
|
@property
|
|
86
110
|
def utils(self) -> Utils:
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
|
|
2
|
+
from rapidata.rapidata_client.metadata.base_metadata import Metadata
|
|
3
|
+
from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
|
|
4
|
+
from rapidata.rapidata_client.referee.classify_early_stopping_referee import ClassifyEarlyStoppingReferee
|
|
5
|
+
from rapidata.rapidata_client.selection.base_selection import Selection
|
|
6
|
+
from rapidata.rapidata_client.workflow.classify_workflow import ClassifyWorkflow
|
|
7
|
+
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
8
|
+
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
9
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
10
|
+
|
|
11
|
+
class ClassificationOrderBuilder:
|
|
12
|
+
def __init__(self, name: str, question: str, options: list[str], media_paths: list[str], openapi_service: OpenAPIService):
|
|
13
|
+
self._order_builder = RapidataOrderBuilder(name=name, openapi_service=openapi_service)
|
|
14
|
+
self._question = question
|
|
15
|
+
self._options = options
|
|
16
|
+
self._media_paths = media_paths
|
|
17
|
+
self._responses_required = 10
|
|
18
|
+
self._probability_threshold = None
|
|
19
|
+
self._metadata = None
|
|
20
|
+
self._validation_set_id = None
|
|
21
|
+
|
|
22
|
+
def metadata(self, metadata: list[Metadata]):
|
|
23
|
+
"""Set the metadata for the classification order. Has to be the same lenght as the media paths."""
|
|
24
|
+
self._metadata = metadata
|
|
25
|
+
return self
|
|
26
|
+
|
|
27
|
+
def responses(self, responses_required: int):
|
|
28
|
+
"""Set the number of responses required for the classification order."""
|
|
29
|
+
self._responses_required = responses_required
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
def probability_threshold(self, probability_threshold: float):
|
|
33
|
+
"""Set the probability threshold for early stopping."""
|
|
34
|
+
self._probability_threshold = probability_threshold
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
def validation_set_id(self, validation_set_id: str):
|
|
38
|
+
"""Set the validation set ID for the classification order."""
|
|
39
|
+
self._validation_set_id = validation_set_id
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def create(self, submit: bool = True, max_upload_workers: int = 10):
|
|
43
|
+
if self._probability_threshold and self._responses_required:
|
|
44
|
+
referee = ClassifyEarlyStoppingReferee(
|
|
45
|
+
max_vote_count=self._responses_required,
|
|
46
|
+
threshold=self._probability_threshold
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
referee = NaiveReferee(required_guesses=self._responses_required)
|
|
51
|
+
|
|
52
|
+
selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
|
|
53
|
+
if self._validation_set_id
|
|
54
|
+
else [LabelingSelection(amount=3)])
|
|
55
|
+
|
|
56
|
+
order = (self._order_builder
|
|
57
|
+
.workflow(
|
|
58
|
+
ClassifyWorkflow(
|
|
59
|
+
question=self._question,
|
|
60
|
+
options=self._options
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
.referee(referee)
|
|
64
|
+
.media(self._media_paths, metadata=self._metadata) # type: ignore
|
|
65
|
+
.selections(selection)
|
|
66
|
+
.create(submit=submit, max_workers=max_upload_workers))
|
|
67
|
+
|
|
68
|
+
return order
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ClassificationMediaBuilder:
|
|
72
|
+
def __init__(self, name: str, question: str, options: list[str], openapi_service: OpenAPIService):
|
|
73
|
+
self._openapi_service = openapi_service
|
|
74
|
+
self._name = name
|
|
75
|
+
self._question = question
|
|
76
|
+
self._options = options
|
|
77
|
+
self._media_paths = None
|
|
78
|
+
|
|
79
|
+
def media(self, media_paths: list[str]) -> ClassificationOrderBuilder:
|
|
80
|
+
"""Set the media assets for the classification order by providing the local paths to the files."""
|
|
81
|
+
self._media_paths = media_paths
|
|
82
|
+
return self._build()
|
|
83
|
+
|
|
84
|
+
def _build(self) -> ClassificationOrderBuilder:
|
|
85
|
+
if self._media_paths is None:
|
|
86
|
+
raise ValueError("Media paths are required")
|
|
87
|
+
return ClassificationOrderBuilder(self._name, self._question, self._options, self._media_paths, openapi_service=self._openapi_service)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ClassificationOptionsBuilder:
|
|
91
|
+
def __init__(self, name: str, question: str, openapi_service: OpenAPIService):
|
|
92
|
+
self._openapi_service = openapi_service
|
|
93
|
+
self._name = name
|
|
94
|
+
self._question = question
|
|
95
|
+
self._options = None
|
|
96
|
+
|
|
97
|
+
def options(self, options: list[str]) -> ClassificationMediaBuilder:
|
|
98
|
+
"""Set the answer options for the classification order."""
|
|
99
|
+
self._options = options
|
|
100
|
+
return self._build()
|
|
101
|
+
|
|
102
|
+
def _build(self) -> ClassificationMediaBuilder:
|
|
103
|
+
if self._options is None:
|
|
104
|
+
raise ValueError("Options are required")
|
|
105
|
+
return ClassificationMediaBuilder(self._name, self._question, self._options, self._openapi_service)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ClassificationQuestionBuilder:
|
|
109
|
+
def __init__(self, name: str, openapi_service: OpenAPIService):
|
|
110
|
+
self._openapi_service = openapi_service
|
|
111
|
+
self._name = name
|
|
112
|
+
self._question = None
|
|
113
|
+
|
|
114
|
+
def question(self, question: str) -> ClassificationOptionsBuilder:
|
|
115
|
+
"""Set the question for the classification order."""
|
|
116
|
+
self._question = question
|
|
117
|
+
return self._build()
|
|
118
|
+
|
|
119
|
+
def _build(self) -> ClassificationOptionsBuilder:
|
|
120
|
+
if self._question is None:
|
|
121
|
+
raise ValueError("Question is required")
|
|
122
|
+
return ClassificationOptionsBuilder(self._name, self._question, self._openapi_service)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
2
|
+
from rapidata.rapidata_client.metadata.base_metadata import Metadata
|
|
3
|
+
from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
|
|
4
|
+
from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
|
|
5
|
+
from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
|
|
6
|
+
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
7
|
+
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
8
|
+
from rapidata.rapidata_client.selection.base_selection import Selection
|
|
9
|
+
|
|
10
|
+
class CompareOrderBuilder:
|
|
11
|
+
def __init__(self, name:str, criteria: str, media_paths: list[list[str]], openapi_service: OpenAPIService):
|
|
12
|
+
self._order_builder = RapidataOrderBuilder(name=name, openapi_service=openapi_service)
|
|
13
|
+
self._name = name
|
|
14
|
+
self._criteria = criteria
|
|
15
|
+
self._media_paths = media_paths
|
|
16
|
+
self._responses_required = 10
|
|
17
|
+
self._metadata = None
|
|
18
|
+
self._validation_set_id = None
|
|
19
|
+
|
|
20
|
+
def responses(self, responses_required: int) -> 'CompareOrderBuilder':
|
|
21
|
+
"""Set the number of resoonses required per matchup/pairing for the comparison order."""
|
|
22
|
+
self._responses_required = responses_required
|
|
23
|
+
return self
|
|
24
|
+
|
|
25
|
+
def metadata(self, metadata: list[Metadata]) -> 'CompareOrderBuilder':
|
|
26
|
+
"""Set the metadata for the comparison order. Has to be the same shape as the media paths."""
|
|
27
|
+
self._metadata = metadata
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
def validation_set_id(self, validation_set_id: str) -> 'CompareOrderBuilder':
|
|
31
|
+
"""Set the validation set ID for the comparison order."""
|
|
32
|
+
self._validation_set_id = validation_set_id
|
|
33
|
+
return self
|
|
34
|
+
|
|
35
|
+
def create(self, submit: bool = True, max_upload_workers: int = 10):
|
|
36
|
+
selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
|
|
37
|
+
if self._validation_set_id
|
|
38
|
+
else [LabelingSelection(amount=3)])
|
|
39
|
+
|
|
40
|
+
order = (self._order_builder
|
|
41
|
+
.workflow(
|
|
42
|
+
CompareWorkflow(
|
|
43
|
+
criteria=self._criteria
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
.referee(NaiveReferee(required_guesses=self._responses_required))
|
|
47
|
+
.media(self._media_paths, metadata=self._metadata) # type: ignore
|
|
48
|
+
.selections(selection)
|
|
49
|
+
.create(submit=submit, max_workers=max_upload_workers))
|
|
50
|
+
|
|
51
|
+
return order
|
|
52
|
+
|
|
53
|
+
class CompareMediaBuilder:
|
|
54
|
+
def __init__(self, name: str, criteria: str, openapi_service: OpenAPIService):
|
|
55
|
+
self._openapi_service = openapi_service
|
|
56
|
+
self._name = name
|
|
57
|
+
self._criteria = criteria
|
|
58
|
+
self._media_paths = None
|
|
59
|
+
|
|
60
|
+
def media(self, media_paths: list[list[str]]) -> CompareOrderBuilder:
|
|
61
|
+
"""Set the media assets for the comparison order by providing the local paths to the files."""
|
|
62
|
+
self._media_paths = media_paths
|
|
63
|
+
return self._build()
|
|
64
|
+
|
|
65
|
+
def _build(self) -> CompareOrderBuilder:
|
|
66
|
+
if self._media_paths is None:
|
|
67
|
+
raise ValueError("Media paths are required")
|
|
68
|
+
assert all([len(path) == 2 for path in self._media_paths]), "The media paths must come in pairs for comparison tasks."
|
|
69
|
+
return CompareOrderBuilder(self._name, self._criteria, self._media_paths, self._openapi_service)
|
|
70
|
+
|
|
71
|
+
class CompareCriteriaBuilder:
|
|
72
|
+
def __init__(self, name: str, openapi_service: OpenAPIService):
|
|
73
|
+
self._openapi_service = openapi_service
|
|
74
|
+
self._name = name
|
|
75
|
+
self._criteria = None
|
|
76
|
+
|
|
77
|
+
def criteria(self, criteria: str) -> CompareMediaBuilder:
|
|
78
|
+
"""Set the criteria how the images should be compared."""
|
|
79
|
+
self._criteria = criteria
|
|
80
|
+
return self._build()
|
|
81
|
+
|
|
82
|
+
def _build(self) -> CompareMediaBuilder:
|
|
83
|
+
if self._criteria is None:
|
|
84
|
+
raise ValueError("Criteria are required")
|
|
85
|
+
return CompareMediaBuilder(self._name, self._criteria, self._openapi_service)
|
|
86
|
+
|