rapidata 0.5.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. rapidata/__init__.py +3 -0
  2. rapidata/api_client/__init__.py +3 -1
  3. rapidata/api_client/api/validation_api.py +276 -0
  4. rapidata/api_client/models/__init__.py +3 -1
  5. rapidata/api_client/models/add_campaign_model.py +3 -3
  6. rapidata/api_client/models/add_validation_text_rapid_model.py +118 -0
  7. rapidata/api_client/models/capped_selection.py +108 -0
  8. rapidata/api_client/models/capped_selection_selections_inner.py +198 -0
  9. rapidata/api_client/models/create_order_model.py +3 -3
  10. rapidata/api_client_README.md +4 -1
  11. rapidata/rapidata_client/__init__.py +1 -0
  12. rapidata/rapidata_client/assets/__init__.py +8 -0
  13. rapidata/rapidata_client/assets/base_asset.py +11 -0
  14. rapidata/rapidata_client/assets/media_asset.py +33 -0
  15. rapidata/rapidata_client/assets/multi_asset.py +44 -0
  16. rapidata/rapidata_client/assets/text_asset.py +25 -0
  17. rapidata/rapidata_client/dataset/rapidata_dataset.py +4 -4
  18. rapidata/rapidata_client/dataset/rapidata_validation_set.py +54 -27
  19. rapidata/rapidata_client/dataset/validation_rapid_parts.py +4 -1
  20. rapidata/rapidata_client/dataset/validation_set_builder.py +41 -24
  21. rapidata/rapidata_client/order/rapidata_order_builder.py +98 -33
  22. rapidata/rapidata_client/rapidata_client.py +24 -0
  23. rapidata/rapidata_client/simple_builders/simple_classification_builders.py +122 -0
  24. rapidata/rapidata_client/simple_builders/simple_compare_builders.py +86 -0
  25. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/METADATA +1 -1
  26. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/RECORD +28 -18
  27. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/LICENSE +0 -0
  28. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/WHEEL +0 -0
@@ -6,6 +6,9 @@ from rapidata.api_client.models.compare_truth import CompareTruth
6
6
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
7
7
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
8
8
  from rapidata.api_client.models.transcription_word import TranscriptionWord
9
+ from rapidata.rapidata_client.assets.media_asset import MediaAsset
10
+ from rapidata.rapidata_client.assets.multi_asset import MultiAsset
11
+ from rapidata.rapidata_client.assets.text_asset import TextAsset
9
12
  from rapidata.rapidata_client.dataset.rapidata_validation_set import (
10
13
  RapidataValidationSet,
11
14
  )
@@ -16,21 +19,31 @@ from rapidata.service.openapi_service import OpenAPIService
16
19
 
17
20
  class ValidationSetBuilder:
18
21
  """The ValidationSetBuilder is used to build a validation set.
19
- Give the validation set a name and then add classify, compare or transcription rapid parts to it.
22
+ Give the validation set a name and then add classify, compare, or transcription rapid parts to it.
20
23
  Get a `ValidationSetBuilder` by calling [`rapi.new_validation_set()`](../rapidata_client.md/#rapidata.rapidata_client.rapidata_client.RapidataClient.new_validation_set).
21
24
  """
22
25
 
23
26
  def __init__(self, name: str, openapi_service: OpenAPIService):
27
+ """
28
+ Initialize the ValidationSetBuilder.
29
+
30
+ Args:
31
+ name (str): The name of the validation set.
32
+ openapi_service (OpenAPIService): An instance of OpenAPIService to interact with the API.
33
+ """
24
34
  self.name = name
25
35
  self.openapi_service = openapi_service
26
36
  self.validation_set_id: str | None = None
27
37
  self._rapid_parts: list[ValidatioRapidParts] = []
28
38
 
29
39
  def create(self):
30
- """This creates the validation set by executing all http requests. This should be the last method called on the builder.
40
+ """Create the validation set by executing all HTTP requests. This should be the last method called on the builder.
31
41
 
32
42
  Returns:
33
43
  RapidataValidationSet: A RapidataValidationSet instance.
44
+
45
+ Raises:
46
+ ValueError: If the validation set creation fails.
34
47
  """
35
48
  result = (
36
49
  self.openapi_service.validation_api.validation_create_validation_set_post(
@@ -52,7 +65,7 @@ class ValidationSetBuilder:
52
65
  payload=rapid_part.payload,
53
66
  truths=rapid_part.truths,
54
67
  metadata=rapid_part.metadata,
55
- media_paths=rapid_part.media_paths,
68
+ asset=rapid_part.asset,
56
69
  randomCorrectProbability=rapid_part.randomCorrectProbability,
57
70
  )
58
71
 
@@ -60,7 +73,7 @@ class ValidationSetBuilder:
60
73
 
61
74
  def add_classify_rapid(
62
75
  self,
63
- media_path: str,
76
+ asset: MediaAsset | TextAsset,
64
77
  question: str,
65
78
  categories: list[str],
66
79
  truths: list[str],
@@ -69,14 +82,17 @@ class ValidationSetBuilder:
69
82
  """Add a classify rapid to the validation set.
70
83
 
71
84
  Args:
72
- media_path (str): The path to the media file.
85
+ asset (MediaAsset | TextAsset): The asset for the rapid.
73
86
  question (str): The question for the rapid.
74
87
  categories (list[str]): The list of categories for the rapid.
75
88
  truths (list[str]): The list of truths for the rapid.
76
- metadata (list[Metadata], optional): The metadata for the rapid.
89
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
77
90
 
78
91
  Returns:
79
92
  ValidationSetBuilder: The ValidationSetBuilder instance.
93
+
94
+ Raises:
95
+ ValueError: If the lengths of categories and truths are inconsistent.
80
96
  """
81
97
  payload = ClassifyPayload(
82
98
  _t="ClassifyPayload", possibleCategories=categories, title=question
@@ -88,11 +104,11 @@ class ValidationSetBuilder:
88
104
  self._rapid_parts.append(
89
105
  ValidatioRapidParts(
90
106
  question=question,
91
- media_paths=media_path,
92
107
  payload=payload,
93
108
  truths=model_truth,
94
109
  metadata=metadata,
95
110
  randomCorrectProbability=len(truths) / len(categories),
111
+ asset=asset,
96
112
  )
97
113
  )
98
114
 
@@ -100,7 +116,7 @@ class ValidationSetBuilder:
100
116
 
101
117
  def add_compare_rapid(
102
118
  self,
103
- media_paths: list[str],
119
+ asset: MultiAsset,
104
120
  question: str,
105
121
  truth: str,
106
122
  metadata: list[Metadata] = [],
@@ -108,35 +124,33 @@ class ValidationSetBuilder:
108
124
  """Add a compare rapid to the validation set.
109
125
 
110
126
  Args:
111
- media_paths (list[str]): The list of media paths for the rapid.
127
+ asset (MultiAsset): The assets for the rapid.
112
128
  question (str): The question for the rapid.
113
- truth (str): The path to the truth file.
114
- metadata (list[Metadata], optional): The metadata for the rapid.
129
+ truth (str): The truth identifier for the rapid.
130
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
115
131
 
116
132
  Returns:
117
133
  ValidationSetBuilder: The ValidationSetBuilder instance.
134
+
135
+ Raises:
136
+ ValueError: If the number of assets is not exactly two.
118
137
  """
119
138
  payload = ComparePayload(_t="ComparePayload", criteria=question)
120
139
  # take only last part of truth path
121
140
  truth = os.path.basename(truth)
122
141
  model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
123
142
 
124
- if len(media_paths) != 2:
143
+ if len(asset) != 2:
125
144
  raise ValueError("Compare rapid requires exactly two media paths")
126
145
 
127
- # check that files exist
128
- for media_path in media_paths:
129
- if not os.path.exists(media_path):
130
- raise FileNotFoundError(f"File not found: {media_path}")
131
-
132
146
  self._rapid_parts.append(
133
147
  ValidatioRapidParts(
134
148
  question=question,
135
- media_paths=media_paths,
136
149
  payload=payload,
137
150
  truths=model_truth,
138
151
  metadata=metadata,
139
- randomCorrectProbability=1 / len(media_paths),
152
+ randomCorrectProbability=1 / len(asset),
153
+ asset=asset,
140
154
  )
141
155
  )
142
156
 
@@ -144,7 +158,7 @@ class ValidationSetBuilder:
144
158
 
145
159
  def add_transcription_rapid(
146
160
  self,
147
- media_path: str,
161
+ asset: MediaAsset | TextAsset,
148
162
  question: str,
149
163
  transcription: list[str],
150
164
  correct_words: list[str],
@@ -154,15 +168,18 @@ class ValidationSetBuilder:
154
168
  """Add a transcription rapid to the validation set.
155
169
 
156
170
  Args:
157
- media_path (str): The path to the media file.
171
+ asset (MediaAsset | TextAsset): The asset for the rapid.
158
172
  question (str): The question for the rapid.
159
173
  transcription (list[str]): The transcription for the rapid.
160
174
  correct_words (list[str]): The list of correct words for the rapid.
161
- strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
162
- metadata (list[Metadata], optional): The metadata for the rapid.
175
+ strict_grading (bool | None, optional): The strict grading flag for the rapid. Defaults to None.
176
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
163
177
 
164
178
  Returns:
165
179
  ValidationSetBuilder: The ValidationSetBuilder instance.
180
+
181
+ Raises:
182
+ ValueError: If a correct word is not found in the transcription.
166
183
  """
167
184
  transcription_words = [
168
185
  TranscriptionWord(word=word, wordIndex=i)
@@ -190,7 +207,7 @@ class ValidationSetBuilder:
190
207
  self._rapid_parts.append(
191
208
  ValidatioRapidParts(
192
209
  question=question,
193
- media_paths=media_path,
210
+ asset=asset,
194
211
  payload=payload,
195
212
  truths=model_truth,
196
213
  metadata=metadata,
@@ -1,11 +1,11 @@
1
1
  from rapidata.api_client.models.aggregator_type import AggregatorType
2
+ from rapidata.api_client.models.capped_selection_selections_inner import (
3
+ CappedSelectionSelectionsInner,
4
+ )
2
5
  from rapidata.api_client.models.create_order_model import CreateOrderModel
3
6
  from rapidata.api_client.models.create_order_model_referee import (
4
7
  CreateOrderModelReferee,
5
8
  )
6
- from rapidata.api_client.models.create_order_model_selections_inner import (
7
- CreateOrderModelSelectionsInner,
8
- )
9
9
  from rapidata.api_client.models.create_order_model_user_filters_inner import (
10
10
  CreateOrderModelUserFiltersInner,
11
11
  )
@@ -41,6 +41,13 @@ class RapidataOrderBuilder:
41
41
  openapi_service: OpenAPIService,
42
42
  name: str,
43
43
  ):
44
+ """
45
+ Initialize the RapidataOrderBuilder.
46
+
47
+ Args:
48
+ openapi_service (OpenAPIService): The OpenAPIService instance.
49
+ name (str): The name of the order.
50
+ """
44
51
  self._name = name
45
52
  self._openapi_service = openapi_service
46
53
  self._workflow: Workflow | None = None
@@ -54,8 +61,19 @@ class RapidataOrderBuilder:
54
61
  self._selections: list[Selection] = []
55
62
  self._rapids_per_bag: int = 2
56
63
  self._priority: int = 50
64
+ self._texts: list[str] | None = None
65
+ self._media_paths: list[str | list[str]] = []
57
66
 
58
67
  def _to_model(self) -> CreateOrderModel:
68
+ """
69
+ Convert the builder configuration to a CreateOrderModel.
70
+
71
+ Raises:
72
+ ValueError: If no workflow is provided.
73
+
74
+ Returns:
75
+ CreateOrderModel: The model representing the order configuration.
76
+ """
59
77
  if self._workflow is None:
60
78
  raise ValueError("You must provide a workflow to create an order.")
61
79
 
@@ -86,27 +104,34 @@ class RapidataOrderBuilder:
86
104
  else None
87
105
  ),
88
106
  selections=[
89
- CreateOrderModelSelectionsInner(selection.to_model())
107
+ CappedSelectionSelectionsInner(selection.to_model())
90
108
  for selection in self._selections
91
109
  ],
92
110
  priority=self._priority,
93
111
  )
94
112
 
95
- def create(self, submit=True, max_workers=10) -> RapidataOrder:
96
- """Actually makes the API calls to create the order based on how the order builder was configured.
113
+ def create(self, submit: bool = True, max_workers: int = 10) -> RapidataOrder:
114
+ """
115
+ Create the Rapidata order by making the necessary API calls based on the builder's configuration.
97
116
 
98
117
  Args:
99
- submit (bool, optional): Whether to submit the order. Defaults to True.
118
+ submit (bool, optional): Whether to submit the order upon creation. Defaults to True.
119
+ max_workers (int, optional): The maximum number of worker threads for processing media paths. Defaults to 10.
120
+
121
+ Raises:
122
+ ValueError: If both media paths and texts are provided, or if neither is provided.
123
+ AssertionError: If the workflow is a CompareWorkflow and media paths are not in pairs.
100
124
 
101
125
  Returns:
102
126
  RapidataOrder: The created RapidataOrder instance.
103
-
104
- Raises:
105
- ValueError: If no workflow is provided.
106
127
  """
107
128
  order_model = self._to_model()
108
- if isinstance(self._workflow, CompareWorkflow): # temp fix, will be handeled by backend in the future
109
- assert all([len(path) == 2 for path in self._media_paths]), "The media paths must come in pairs for comparison tasks."
129
+ if isinstance(
130
+ self._workflow, CompareWorkflow
131
+ ): # Temporary fix; will be handled by backend in the future
132
+ assert all(
133
+ [len(path) == 2 for path in self._media_paths]
134
+ ), "The media paths must come in pairs for comparison tasks."
110
135
 
111
136
  result = self._openapi_service.order_api.order_create_post(
112
137
  create_order_model=order_model
@@ -120,14 +145,32 @@ class RapidataOrderBuilder:
120
145
  openapi_service=self._openapi_service,
121
146
  )
122
147
 
123
- order.dataset.add_media_from_paths(self._media_paths, self._metadata, max_workers)
148
+ if self._media_paths and self._texts:
149
+ raise ValueError(
150
+ "You cannot provide both media paths and texts to the same order."
151
+ )
152
+
153
+ if not self._media_paths and not self._texts:
154
+ raise ValueError(
155
+ "You must provide either media paths or texts to the order."
156
+ )
157
+
158
+ if self._texts:
159
+ order.dataset.add_texts(self._texts)
160
+
161
+ if self._media_paths:
162
+ order.dataset.add_media_from_paths(
163
+ self._media_paths, self._metadata, max_workers
164
+ )
165
+
124
166
  if submit:
125
167
  order.submit()
126
168
 
127
169
  return order
128
170
 
129
- def workflow(self, workflow: Workflow):
130
- """Set the workflow for the order.
171
+ def workflow(self, workflow: Workflow) -> "RapidataOrderBuilder":
172
+ """
173
+ Set the workflow for the order.
131
174
 
132
175
  Args:
133
176
  workflow (Workflow): The workflow to be set.
@@ -138,8 +181,9 @@ class RapidataOrderBuilder:
138
181
  self._workflow = workflow
139
182
  return self
140
183
 
141
- def referee(self, referee: Referee):
142
- """Set the referee for the order.
184
+ def referee(self, referee: Referee) -> "RapidataOrderBuilder":
185
+ """
186
+ Set the referee for the order.
143
187
 
144
188
  Args:
145
189
  referee (Referee): The referee to be set.
@@ -154,8 +198,9 @@ class RapidataOrderBuilder:
154
198
  self,
155
199
  media_paths: list[str | list[str]],
156
200
  metadata: list[Metadata] | None = None,
157
- ):
158
- """Set the media assets for the order.
201
+ ) -> "RapidataOrderBuilder":
202
+ """
203
+ Set the media assets for the order.
159
204
 
160
205
  Args:
161
206
  media_paths (list[str | list[str]]): The paths of the media assets to be set.
@@ -168,8 +213,22 @@ class RapidataOrderBuilder:
168
213
  self._metadata = metadata
169
214
  return self
170
215
 
171
- def feature_flags(self, feature_flags: FeatureFlags):
172
- """Set the feature flags for the order.
216
+ def texts(self, texts: list[str]) -> "RapidataOrderBuilder":
217
+ """
218
+ Set the TextAssets for the order.
219
+
220
+ Args:
221
+ texts (list[str]): The texts to be set.
222
+
223
+ Returns:
224
+ RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
225
+ """
226
+ self._texts = texts
227
+ return self
228
+
229
+ def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
230
+ """
231
+ Set the feature flags for the order.
173
232
 
174
233
  Args:
175
234
  feature_flags (FeatureFlags): The feature flags to be set.
@@ -180,8 +239,9 @@ class RapidataOrderBuilder:
180
239
  self._feature_flags = feature_flags
181
240
  return self
182
241
 
183
- def country_filter(self, country_codes: list[str]):
184
- """Set the target country codes for the order.
242
+ def country_filter(self, country_codes: list[str]) -> "RapidataOrderBuilder":
243
+ """
244
+ Set the target country codes for the order.
185
245
 
186
246
  Args:
187
247
  country_codes (list[str]): The country codes to be set.
@@ -192,8 +252,9 @@ class RapidataOrderBuilder:
192
252
  self._country_codes = country_codes
193
253
  return self
194
254
 
195
- def aggregator(self, aggregator: AggregatorType):
196
- """Set the aggregator for the order.
255
+ def aggregator(self, aggregator: AggregatorType) -> "RapidataOrderBuilder":
256
+ """
257
+ Set the aggregator for the order.
197
258
 
198
259
  Args:
199
260
  aggregator (AggregatorType): The aggregator to be set.
@@ -204,8 +265,9 @@ class RapidataOrderBuilder:
204
265
  self._aggregator = aggregator
205
266
  return self
206
267
 
207
- def validation_set_id(self, validation_set_id: str):
208
- """Set the validation set for the order.
268
+ def validation_set_id(self, validation_set_id: str) -> "RapidataOrderBuilder":
269
+ """
270
+ Set the validation set ID for the order.
209
271
 
210
272
  Args:
211
273
  validation_set_id (str): The validation set ID to be set.
@@ -216,8 +278,9 @@ class RapidataOrderBuilder:
216
278
  self._validation_set_id = validation_set_id
217
279
  return self
218
280
 
219
- def rapids_per_bag(self, amount: int):
220
- """Defines the number of tasks a user sees in a single session.
281
+ def rapids_per_bag(self, amount: int) -> "RapidataOrderBuilder":
282
+ """
283
+ Define the number of tasks a user sees in a single session.
221
284
 
222
285
  Args:
223
286
  amount (int): The number of tasks a user sees in a single session.
@@ -230,8 +293,9 @@ class RapidataOrderBuilder:
230
293
  """
231
294
  raise NotImplementedError("Not implemented yet.")
232
295
 
233
- def selections(self, selections: list[Selection]):
234
- """Set the selections for the order.
296
+ def selections(self, selections: list[Selection]) -> "RapidataOrderBuilder":
297
+ """
298
+ Set the selections for the order.
235
299
 
236
300
  Args:
237
301
  selections (list[Selection]): The selections to be set.
@@ -242,8 +306,9 @@ class RapidataOrderBuilder:
242
306
  self._selections = selections
243
307
  return self
244
308
 
245
- def priority(self, priority: int):
246
- """Set the priority for the order.
309
+ def priority(self, priority: int) -> "RapidataOrderBuilder":
310
+ """
311
+ Set the priority for the order.
247
312
 
248
313
  Args:
249
314
  priority (int): The priority to be set.
@@ -8,6 +8,8 @@ from rapidata.service.openapi_service import OpenAPIService
8
8
  from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
9
9
  from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
10
10
 
11
+ from rapidata.rapidata_client.simple_builders.simple_classification_builders import ClassificationQuestionBuilder
12
+ from rapidata.rapidata_client.simple_builders.simple_compare_builders import CompareCriteriaBuilder
11
13
 
12
14
 
13
15
  class RapidataClient:
@@ -81,6 +83,28 @@ class RapidataClient:
81
83
  dataset=temp_dataset,
82
84
  order_id=order_id,
83
85
  openapi_service=self.openapi_service)
86
+
87
+ def create_classify_order(self, name: str) -> ClassificationQuestionBuilder:
88
+ """Create a new classification order where people are asked to classify an image.
89
+
90
+ Args:
91
+ name (str): The name of the order.
92
+
93
+ Returns:
94
+ ClassificationQuestionBuilder: A ClassificationQuestionBuilder instance.
95
+ """
96
+ return ClassificationQuestionBuilder(name=name, openapi_service=self.openapi_service)
97
+
98
+ def create_compare_order(self, name: str) -> CompareCriteriaBuilder:
99
+ """Create a new comparison order where people are asked to compare two images.
100
+
101
+ Args:
102
+ name (str): The name of the order.
103
+
104
+ Returns:
105
+ CompareQuestionBuilder: A CompareQuestionBuilder instance.
106
+ """
107
+ return CompareCriteriaBuilder(name=name, openapi_service=self.openapi_service)
84
108
 
85
109
  @property
86
110
  def utils(self) -> Utils:
@@ -0,0 +1,122 @@
1
+ from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
2
+ from rapidata.rapidata_client.metadata.base_metadata import Metadata
3
+ from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
4
+ from rapidata.rapidata_client.referee.classify_early_stopping_referee import ClassifyEarlyStoppingReferee
5
+ from rapidata.rapidata_client.selection.base_selection import Selection
6
+ from rapidata.rapidata_client.workflow.classify_workflow import ClassifyWorkflow
7
+ from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
8
+ from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
9
+ from rapidata.service.openapi_service import OpenAPIService
10
+
11
+ class ClassificationOrderBuilder:
12
+ def __init__(self, name: str, question: str, options: list[str], media_paths: list[str], openapi_service: OpenAPIService):
13
+ self._order_builder = RapidataOrderBuilder(name=name, openapi_service=openapi_service)
14
+ self._question = question
15
+ self._options = options
16
+ self._media_paths = media_paths
17
+ self._responses_required = 10
18
+ self._probability_threshold = None
19
+ self._metadata = None
20
+ self._validation_set_id = None
21
+
22
+ def metadata(self, metadata: list[Metadata]):
23
+ """Set the metadata for the classification order. Has to be the same lenght as the media paths."""
24
+ self._metadata = metadata
25
+ return self
26
+
27
+ def responses(self, responses_required: int):
28
+ """Set the number of responses required for the classification order."""
29
+ self._responses_required = responses_required
30
+ return self
31
+
32
+ def probability_threshold(self, probability_threshold: float):
33
+ """Set the probability threshold for early stopping."""
34
+ self._probability_threshold = probability_threshold
35
+ return self
36
+
37
+ def validation_set_id(self, validation_set_id: str):
38
+ """Set the validation set ID for the classification order."""
39
+ self._validation_set_id = validation_set_id
40
+ return self
41
+
42
+ def create(self, submit: bool = True, max_upload_workers: int = 10):
43
+ if self._probability_threshold and self._responses_required:
44
+ referee = ClassifyEarlyStoppingReferee(
45
+ max_vote_count=self._responses_required,
46
+ threshold=self._probability_threshold
47
+ )
48
+
49
+ else:
50
+ referee = NaiveReferee(required_guesses=self._responses_required)
51
+
52
+ selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
53
+ if self._validation_set_id
54
+ else [LabelingSelection(amount=3)])
55
+
56
+ order = (self._order_builder
57
+ .workflow(
58
+ ClassifyWorkflow(
59
+ question=self._question,
60
+ options=self._options
61
+ )
62
+ )
63
+ .referee(referee)
64
+ .media(self._media_paths, metadata=self._metadata) # type: ignore
65
+ .selections(selection)
66
+ .create(submit=submit, max_workers=max_upload_workers))
67
+
68
+ return order
69
+
70
+
71
+ class ClassificationMediaBuilder:
72
+ def __init__(self, name: str, question: str, options: list[str], openapi_service: OpenAPIService):
73
+ self._openapi_service = openapi_service
74
+ self._name = name
75
+ self._question = question
76
+ self._options = options
77
+ self._media_paths = None
78
+
79
+ def media(self, media_paths: list[str]) -> ClassificationOrderBuilder:
80
+ """Set the media assets for the classification order by providing the local paths to the files."""
81
+ self._media_paths = media_paths
82
+ return self._build()
83
+
84
+ def _build(self) -> ClassificationOrderBuilder:
85
+ if self._media_paths is None:
86
+ raise ValueError("Media paths are required")
87
+ return ClassificationOrderBuilder(self._name, self._question, self._options, self._media_paths, openapi_service=self._openapi_service)
88
+
89
+
90
+ class ClassificationOptionsBuilder:
91
+ def __init__(self, name: str, question: str, openapi_service: OpenAPIService):
92
+ self._openapi_service = openapi_service
93
+ self._name = name
94
+ self._question = question
95
+ self._options = None
96
+
97
+ def options(self, options: list[str]) -> ClassificationMediaBuilder:
98
+ """Set the answer options for the classification order."""
99
+ self._options = options
100
+ return self._build()
101
+
102
+ def _build(self) -> ClassificationMediaBuilder:
103
+ if self._options is None:
104
+ raise ValueError("Options are required")
105
+ return ClassificationMediaBuilder(self._name, self._question, self._options, self._openapi_service)
106
+
107
+
108
+ class ClassificationQuestionBuilder:
109
+ def __init__(self, name: str, openapi_service: OpenAPIService):
110
+ self._openapi_service = openapi_service
111
+ self._name = name
112
+ self._question = None
113
+
114
+ def question(self, question: str) -> ClassificationOptionsBuilder:
115
+ """Set the question for the classification order."""
116
+ self._question = question
117
+ return self._build()
118
+
119
+ def _build(self) -> ClassificationOptionsBuilder:
120
+ if self._question is None:
121
+ raise ValueError("Question is required")
122
+ return ClassificationOptionsBuilder(self._name, self._question, self._openapi_service)
@@ -0,0 +1,86 @@
1
+ from rapidata.service.openapi_service import OpenAPIService
2
+ from rapidata.rapidata_client.metadata.base_metadata import Metadata
3
+ from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
4
+ from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
5
+ from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
6
+ from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
7
+ from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
8
+ from rapidata.rapidata_client.selection.base_selection import Selection
9
+
10
+ class CompareOrderBuilder:
11
+ def __init__(self, name:str, criteria: str, media_paths: list[list[str]], openapi_service: OpenAPIService):
12
+ self._order_builder = RapidataOrderBuilder(name=name, openapi_service=openapi_service)
13
+ self._name = name
14
+ self._criteria = criteria
15
+ self._media_paths = media_paths
16
+ self._responses_required = 10
17
+ self._metadata = None
18
+ self._validation_set_id = None
19
+
20
+ def responses(self, responses_required: int) -> 'CompareOrderBuilder':
21
+ """Set the number of resoonses required per matchup/pairing for the comparison order."""
22
+ self._responses_required = responses_required
23
+ return self
24
+
25
+ def metadata(self, metadata: list[Metadata]) -> 'CompareOrderBuilder':
26
+ """Set the metadata for the comparison order. Has to be the same shape as the media paths."""
27
+ self._metadata = metadata
28
+ return self
29
+
30
+ def validation_set_id(self, validation_set_id: str) -> 'CompareOrderBuilder':
31
+ """Set the validation set ID for the comparison order."""
32
+ self._validation_set_id = validation_set_id
33
+ return self
34
+
35
+ def create(self, submit: bool = True, max_upload_workers: int = 10):
36
+ selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
37
+ if self._validation_set_id
38
+ else [LabelingSelection(amount=3)])
39
+
40
+ order = (self._order_builder
41
+ .workflow(
42
+ CompareWorkflow(
43
+ criteria=self._criteria
44
+ )
45
+ )
46
+ .referee(NaiveReferee(required_guesses=self._responses_required))
47
+ .media(self._media_paths, metadata=self._metadata) # type: ignore
48
+ .selections(selection)
49
+ .create(submit=submit, max_workers=max_upload_workers))
50
+
51
+ return order
52
+
53
+ class CompareMediaBuilder:
54
+ def __init__(self, name: str, criteria: str, openapi_service: OpenAPIService):
55
+ self._openapi_service = openapi_service
56
+ self._name = name
57
+ self._criteria = criteria
58
+ self._media_paths = None
59
+
60
+ def media(self, media_paths: list[list[str]]) -> CompareOrderBuilder:
61
+ """Set the media assets for the comparison order by providing the local paths to the files."""
62
+ self._media_paths = media_paths
63
+ return self._build()
64
+
65
+ def _build(self) -> CompareOrderBuilder:
66
+ if self._media_paths is None:
67
+ raise ValueError("Media paths are required")
68
+ assert all([len(path) == 2 for path in self._media_paths]), "The media paths must come in pairs for comparison tasks."
69
+ return CompareOrderBuilder(self._name, self._criteria, self._media_paths, self._openapi_service)
70
+
71
+ class CompareCriteriaBuilder:
72
+ def __init__(self, name: str, openapi_service: OpenAPIService):
73
+ self._openapi_service = openapi_service
74
+ self._name = name
75
+ self._criteria = None
76
+
77
+ def criteria(self, criteria: str) -> CompareMediaBuilder:
78
+ """Set the criteria how the images should be compared."""
79
+ self._criteria = criteria
80
+ return self._build()
81
+
82
+ def _build(self) -> CompareMediaBuilder:
83
+ if self._criteria is None:
84
+ raise ValueError("Criteria are required")
85
+ return CompareMediaBuilder(self._name, self._criteria, self._openapi_service)
86
+