rapidata 2.3.2__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,139 +1,86 @@
1
+ from pydantic import StrictBytes, StrictStr
1
2
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
2
3
  from rapidata.rapidata_client.metadata import Metadata
3
4
  from typing import Sequence
4
- from rapidata.rapidata_client.validation.rapids.box import Box
5
+ from typing import Any
6
+ from rapidata.api_client.models.add_validation_rapid_model import (
7
+ AddValidationRapidModel,
8
+ )
9
+ from rapidata.api_client.models.add_validation_text_rapid_model import (
10
+ AddValidationTextRapidModel,
11
+ )
12
+ from rapidata.api_client.models.add_validation_rapid_model_payload import (
13
+ AddValidationRapidModelPayload,
14
+ )
15
+ from rapidata.api_client.models.add_validation_rapid_model_truth import (
16
+ AddValidationRapidModelTruth,
17
+ )
5
18
 
6
- class Rapid:
7
- pass
19
+ from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import (
20
+ DatapointMetadataModelMetadataInner,
21
+ )
8
22
 
9
- class ClassificationRapid(Rapid):
10
- """
11
- A classification rapid. Used as a multiple choice question for the labeler to answer.
12
-
13
-
14
- Args:
15
- instruction (str): The instruction how to choose the options.
16
- answer_options (list[str]): The options that the labeler can choose from.
17
- truths (list[str]): The correct answers to the question.
18
- asset (MediaAsset | TextAsset): The asset that the labeler will be labeling.
19
- metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
20
- """
21
23
 
22
- def __init__(self, instruction: str, answer_options: list[str], truths: list[str], asset: MediaAsset | TextAsset, metadata: Sequence[Metadata]):
23
- self.instruction = instruction
24
- self.answer_options = answer_options
25
- self.truths = truths
24
+ class Rapid():
25
+ def __init__(self, asset: MediaAsset | TextAsset | MultiAsset, metadata: Sequence[Metadata], payload: Any, truth: Any, randomCorrectProbability: float, explanation: str | None):
26
26
  self.asset = asset
27
27
  self.metadata = metadata
28
-
29
- class CompareRapid(Rapid):
30
- """
31
- Used as a comparison of two assets for the labeler to compare.
32
-
33
- Args:
34
- instruction (str): The instruction that the labeler will be comparing the assets on.
35
- truth (str): The correct answer to the comparison. (has to be one of the assets)
36
- asset (MultiAsset): The assets that the labeler will be comparing.
37
- metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
38
- """
39
- def __init__(self, instruction: str, truth: str, asset: MultiAsset, metadata: Sequence[Metadata]):
40
- self.instruction = instruction
41
- self.asset = asset
28
+ self.payload = payload
42
29
  self.truth = truth
43
- self.metadata = metadata
30
+ self.randomCorrectProbability = randomCorrectProbability
31
+ self.explanation = explanation
44
32
 
45
- class SelectWordsRapid(Rapid):
46
- """
47
- Used to give the labeler a text and have them select words from it.
48
-
49
- Args:
50
- instruction (str): The instruction for the labeler.
51
- truths (list[int]): The indices of the words that are the correct answers.
52
- asset (MediaAsset): The asset that the labeler will be selecting words from.
53
- sentence (str): The sentence that the labeler will be selecting words from. (split up by spaces)
54
- strict_grading (bool): Whether the grading should be strict or not.
55
- True means that all correct words and no wrong words have to be selected for the rapid to be marked as correct.
56
- False means that at least one correct word and no wrong words have to be selected for the rapid to be marked as correct.
57
- """
58
- def __init__(self, instruction: str, truths: list[int], asset: MediaAsset, sentence: str, required_precision: float, required_completeness: float, metadata: Sequence[Metadata]):
59
- if not isinstance(truths, list):
60
- raise ValueError("The truths must be a list of integers.")
61
- if not all(isinstance(x, int) for x in truths):
62
- raise ValueError("The truths must be a list of integers.")
63
- if required_completeness <= 0 or required_completeness > 1:
64
- raise ValueError("The required completeness must be > 0 and <= 1.")
65
- if required_precision <= 0 or required_precision > 1:
66
- raise ValueError("The required precision must be > 0 and <= 1.")
67
-
68
- self.instruction = instruction
69
- self.truths = truths
70
- self.asset = asset
71
- self.sentence = sentence
72
- self.required_precision = required_precision
73
- self.required_completeness = required_completeness
74
- self.metadata = metadata
33
+ def to_media_model(self, validationSetId: str) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
34
+ assets: list[MediaAsset] = []
35
+ if isinstance(self.asset, MultiAsset):
36
+ for asset in self.asset.assets:
37
+ if isinstance(asset, MediaAsset):
38
+ assets.append(asset)
39
+ else:
40
+ raise TypeError("The asset is a multiasset, but not all assets are MediaAssets")
75
41
 
76
- class LocateRapid(Rapid):
77
- """
78
- Used to have the labeler locate a specific object in an image.
79
-
80
- Args:
81
- instruction (str): The instructions on what the labeler should do.
82
- truths (list[Box]): The boxes that the object is located in.
83
- asset (MediaAsset): The image that the labeler is locating the object in.
84
- metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
85
- """
86
- def __init__(self, instruction: str, truths: list[Box], asset: MediaAsset, metadata: Sequence[Metadata]):
87
- self.instruction = instruction
88
- self.asset = asset
89
- self.truths = truths
90
- self.metadata = metadata
42
+ if isinstance(self.asset, TextAsset):
43
+ raise TypeError("The asset must contain Media")
91
44
 
92
- class DrawRapid(Rapid):
93
- """
94
- Used to have the labeler draw a specific object in an image.
95
-
96
- Args:
97
- instruction (str): The instructions on what the labeler should do.
98
- truths (list[Box]): The boxes that the object is located in.
99
- asset (MediaAsset): The image that the labeler is drawing the object in.
100
- metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
101
- """
102
- def __init__(self, instruction: str, truths: list[Box], asset: MediaAsset, metadata: Sequence[Metadata]):
103
- self.instruction = instruction
104
- self.asset = asset
105
- self.truths = truths
106
- self.metadata = metadata
45
+ if isinstance(self.asset, MediaAsset):
46
+ assets = [self.asset]
107
47
 
108
- class TimestampRapid(Rapid):
109
- """
110
- Used to have the labeler timestamp a video or audio file.
111
-
112
- Args:
113
- instruction (str): The instruction for the labeler.
114
- truths (list[tuple[int, int]]): The possible accepted timestamps intervals for the labeler (in miliseconds).
115
- The first element of the tuple is the start of the interval and the second element is the end of the interval.
116
- asset (MediaAsset): The asset that the labeler is timestamping.
117
- metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
118
- """
119
- def __init__(self, instruction: str, truths: list[tuple[int, int]], asset: MediaAsset, metadata: Sequence[Metadata]):
120
- if not asset.get_duration():
121
- raise ValueError("The datapoints must have a duration. (e.g. video or audio)")
122
-
123
- if not isinstance(truths, list):
124
- raise ValueError("The truths must be a list of tuples.")
48
+ return (AddValidationRapidModel(
49
+ validationSetId=validationSetId,
50
+ payload=AddValidationRapidModelPayload(self.payload),
51
+ truth=AddValidationRapidModelTruth(self.truth),
52
+ metadata=[
53
+ DatapointMetadataModelMetadataInner(meta._to_model())
54
+ for meta in self.metadata
55
+ ],
56
+ randomCorrectProbability=self.randomCorrectProbability,
57
+ explanation=self.explanation
58
+ ), [asset.to_file() for asset in assets])
125
59
 
126
- for truth in truths:
127
- if len(truth) != 2 or not all(isinstance(x, int) for x in truth):
128
- raise ValueError("The truths per datapoint must be a tuple of exactly two integers.")
129
- if truth[0] >= truth[1]:
130
- raise ValueError("The start of the interval must be smaller than the end of the interval.")
131
- if truth[0] < 0:
132
- raise ValueError("The start of the interval must be greater than or equal to 0.")
133
- if truth[1] > asset.get_duration():
134
- raise ValueError("The end of the interval can not be greater than the duration of the datapoint.")
135
-
136
- self.instruction = instruction
137
- self.truths = truths
138
- self.asset = asset
139
- self.metadata = metadata
60
+ def to_text_model(self, validationSetId: str) -> AddValidationTextRapidModel:
61
+ texts: list[str] = []
62
+ if isinstance(self.asset, MultiAsset):
63
+ for asset in self.asset.assets:
64
+ if isinstance(asset, TextAsset):
65
+ texts.append(asset.text)
66
+ else:
67
+ raise TypeError("The asset is a multiasset, but not all assets are TextAssets")
68
+
69
+ if isinstance(self.asset, MediaAsset):
70
+ raise TypeError("The asset must contain Text")
71
+
72
+ if isinstance(self.asset, TextAsset):
73
+ texts = [self.asset.text]
74
+
75
+ return AddValidationTextRapidModel(
76
+ validationSetId=validationSetId,
77
+ payload=AddValidationRapidModelPayload(self.payload),
78
+ truth=AddValidationRapidModelTruth(self.truth),
79
+ metadata=[
80
+ DatapointMetadataModelMetadataInner(meta._to_model())
81
+ for meta in self.metadata
82
+ ],
83
+ randomCorrectProbability=self.randomCorrectProbability,
84
+ texts=texts,
85
+ explanation=self.explanation
86
+ )
@@ -1,17 +1,14 @@
1
+ import os
2
+ from rapidata.api_client import AttachCategoryTruth, BoundingBoxTruth, BoxShape, ClassifyPayload, ComparePayload, CompareTruth, LinePayload, LocateBoxTruth, LocatePayload, ScrubPayload, ScrubRange, ScrubTruth, TranscriptionPayload, TranscriptionTruth, TranscriptionWord
1
3
  from rapidata.rapidata_client.assets.data_type_enum import RapidataDataTypes
2
- from rapidata.rapidata_client.validation.rapids.rapids import (
3
- ClassificationRapid,
4
- CompareRapid,
5
- SelectWordsRapid,
6
- LocateRapid,
7
- DrawRapid,
8
- TimestampRapid)
9
4
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
10
5
  from rapidata.rapidata_client.metadata import Metadata
11
6
  from rapidata.rapidata_client.validation.rapids.box import Box
12
7
 
13
8
  from typing import Sequence
14
9
 
10
+ from rapidata.rapidata_client.validation.rapids.rapids import Rapid
11
+
15
12
  class RapidsManager:
16
13
  """
17
14
  Can be used to build different types of rapids. That can then be added to Validation sets
@@ -26,7 +23,8 @@ class RapidsManager:
26
23
  truths: list[str],
27
24
  data_type: str = RapidataDataTypes.MEDIA,
28
25
  metadata: Sequence[Metadata] = [],
29
- ) -> ClassificationRapid:
26
+ explanation: str | None = None,
27
+ ) -> Rapid:
30
28
  """Build a classification rapid
31
29
 
32
30
  Args:
@@ -45,13 +43,24 @@ class RapidsManager:
45
43
  else:
46
44
  raise ValueError(f"Unsupported data type: {data_type}")
47
45
 
48
- return ClassificationRapid(
49
- instruction=instruction,
50
- answer_options=answer_options,
46
+ if not all(truth in answer_options for truth in truths):
47
+ raise ValueError("Truths must be part of the answer options")
48
+
49
+ payload = ClassifyPayload(
50
+ _t="ClassifyPayload", possibleCategories=answer_options, title=instruction
51
+ )
52
+ model_truth = AttachCategoryTruth(
53
+ correctCategories=truths, _t="AttachCategoryTruth"
54
+ )
55
+
56
+ return Rapid(
51
57
  asset=asset,
52
- truths=truths,
53
58
  metadata=metadata,
54
- )
59
+ explanation=explanation,
60
+ payload=payload,
61
+ truth=model_truth,
62
+ randomCorrectProbability=len(truths) / len(answer_options)
63
+ )
55
64
 
56
65
  def compare_rapid(self,
57
66
  instruction: str,
@@ -59,7 +68,8 @@ class RapidsManager:
59
68
  datapoint: list[str],
60
69
  data_type: str = RapidataDataTypes.MEDIA,
61
70
  metadata: Sequence[Metadata] = [],
62
- ) -> CompareRapid:
71
+ explanation: str | None = None,
72
+ ) -> Rapid:
63
73
  """Build a compare rapid
64
74
 
65
75
  Args:
@@ -78,12 +88,23 @@ class RapidsManager:
78
88
  raise ValueError(f"Unsupported data type: {data_type}")
79
89
 
80
90
  asset = MultiAsset(assets)
91
+
92
+ payload = ComparePayload(_t="ComparePayload", criteria=instruction)
93
+ # take only last part of truth path
94
+ truth = os.path.basename(truth)
95
+ model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
81
96
 
82
- return CompareRapid(
83
- instruction=instruction,
97
+ if len(asset) != 2:
98
+ raise ValueError("Compare rapid requires exactly two media paths")
99
+
100
+
101
+ return Rapid(
84
102
  asset=asset,
85
- truth=truth,
103
+ truth=model_truth,
86
104
  metadata=metadata,
105
+ payload=payload,
106
+ explanation=explanation,
107
+ randomCorrectProbability= 1 / len(asset.assets)
87
108
  )
88
109
 
89
110
  def select_words_rapid(self,
@@ -94,7 +115,8 @@ class RapidsManager:
94
115
  required_precision: float = 1,
95
116
  required_completeness: float = 1,
96
117
  metadata: Sequence[Metadata] = [],
97
- ) -> SelectWordsRapid:
118
+ explanation: str | None = None,
119
+ ) -> Rapid:
98
120
  """Build a select words rapid
99
121
 
100
122
  Args:
@@ -108,23 +130,44 @@ class RapidsManager:
108
130
  """
109
131
 
110
132
  asset = MediaAsset(datapoint)
133
+ transcription_words = [
134
+ TranscriptionWord(word=word, wordIndex=i)
135
+ for i, word in enumerate(sentence)
136
+ ]
137
+
138
+ correct_transcription_words: list[TranscriptionWord] = []
139
+ for index in truths:
140
+ correct_transcription_words.append(
141
+ TranscriptionWord(word=transcription_words[index].word, wordIndex=index)
142
+ )
143
+
144
+ payload = TranscriptionPayload(
145
+ _t="TranscriptionPayload", title=instruction, transcription=transcription_words
146
+ )
147
+
148
+ model_truth = TranscriptionTruth(
149
+ _t="TranscriptionTruth",
150
+ correctWords=correct_transcription_words,
151
+ requiredPrecision=required_precision,
152
+ requiredCompleteness=required_completeness,
153
+ )
111
154
 
112
- return SelectWordsRapid(
113
- instruction=instruction,
114
- truths=truths,
155
+ return Rapid(
156
+ payload=payload,
157
+ truth=model_truth,
115
158
  asset=asset,
116
- sentence=sentence,
117
- required_precision=required_precision,
118
- required_completeness=required_completeness,
119
159
  metadata=metadata,
120
- )
160
+ explanation=explanation,
161
+ randomCorrectProbability= len(correct_transcription_words) / len(transcription_words)
162
+ )
121
163
 
122
164
  def locate_rapid(self,
123
165
  instruction: str,
124
166
  truths: list[Box],
125
167
  datapoint: str,
126
168
  metadata: Sequence[Metadata] = [],
127
- ) -> LocateRapid:
169
+ explanation: str | None = None,
170
+ ) -> Rapid:
128
171
  """Build a locate rapid
129
172
 
130
173
  Args:
@@ -135,12 +178,35 @@ class RapidsManager:
135
178
  """
136
179
 
137
180
  asset = MediaAsset(datapoint)
181
+ payload = LocatePayload(
182
+ _t="LocatePayload", target=instruction
183
+ )
184
+
185
+ img_dimensions = asset.get_image_dimension()
186
+
187
+ if not img_dimensions:
188
+ raise ValueError("Failed to get image dimensions")
189
+
190
+ model_truth = LocateBoxTruth(
191
+ _t="LocateBoxTruth",
192
+ boundingBoxes=[BoxShape(
193
+ _t="BoxShape",
194
+ xMin=truth.x_min / img_dimensions[0] * 100,
195
+ xMax=truth.x_max / img_dimensions[0] * 100,
196
+ yMax=truth.y_max / img_dimensions[1] * 100,
197
+ yMin=truth.y_min / img_dimensions[1] * 100,
198
+ ) for truth in truths]
199
+ )
200
+
201
+ coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
138
202
 
139
- return LocateRapid(
140
- instruction=instruction,
141
- truths=truths,
203
+ return Rapid(
204
+ payload=payload,
205
+ truth=model_truth,
142
206
  asset=asset,
143
207
  metadata=metadata,
208
+ explanation=explanation,
209
+ randomCorrectProbability=coverage
144
210
  )
145
211
 
146
212
  def draw_rapid(self,
@@ -148,7 +214,8 @@ class RapidsManager:
148
214
  truths: list[Box],
149
215
  datapoint: str,
150
216
  metadata: Sequence[Metadata] = [],
151
- ) -> DrawRapid:
217
+ explanation: str | None = None
218
+ ) -> Rapid:
152
219
  """Build a draw rapid
153
220
 
154
221
  Args:
@@ -160,19 +227,42 @@ class RapidsManager:
160
227
 
161
228
  asset = MediaAsset(datapoint)
162
229
 
163
- return DrawRapid(
164
- instruction=instruction,
165
- truths=truths,
166
- asset=asset,
167
- metadata=metadata,
168
- )
169
-
230
+ payload = LinePayload(
231
+ _t="LinePayload", target=instruction
232
+ )
233
+
234
+ img_dimensions = asset.get_image_dimension()
235
+
236
+ if not img_dimensions:
237
+ raise ValueError("Failed to get image dimensions")
238
+
239
+ model_truth = BoundingBoxTruth(
240
+ _t="BoundingBoxTruth",
241
+ xMax=truths[0].x_max / img_dimensions[0],
242
+ xMin=truths[0].x_min / img_dimensions[0],
243
+ yMax=truths[0].y_max / img_dimensions[1],
244
+ yMin=truths[0].y_min / img_dimensions[1],
245
+ )
246
+
247
+ coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
248
+
249
+ return Rapid(
250
+ payload=payload,
251
+ truth=model_truth,
252
+ asset=asset,
253
+ metadata=metadata,
254
+ explanation=explanation,
255
+ randomCorrectProbability=coverage
256
+ )
257
+
258
+
170
259
  def timestamp_rapid(self,
171
260
  instruction: str,
172
261
  truths: list[tuple[int, int]],
173
262
  datapoint: str,
174
- metadata: Sequence[Metadata] = []
175
- ) -> TimestampRapid:
263
+ metadata: Sequence[Metadata] = [],
264
+ explanation: str | None = None
265
+ ) -> Rapid:
176
266
  """Build a timestamp rapid
177
267
 
178
268
  Args:
@@ -184,12 +274,86 @@ class RapidsManager:
184
274
  """
185
275
 
186
276
  asset = MediaAsset(datapoint)
277
+
278
+ for truth in truths:
279
+ if len(truth) != 2:
280
+ raise ValueError("The truths per datapoint must be a tuple of exactly two integers.")
281
+ if truth[0] > truth[1]:
282
+ raise ValueError("The start of the interval must be smaller than the end of the interval.")
283
+
284
+ payload = ScrubPayload(
285
+ _t="ScrubPayload",
286
+ target=instruction
287
+ )
288
+
289
+ model_truth = ScrubTruth(
290
+ _t="ScrubTruth",
291
+ validRanges=[ScrubRange(
292
+ start=truth[0],
293
+ end=truth[1]
294
+ ) for truth in truths]
295
+ )
187
296
 
188
- return TimestampRapid(
189
- instruction=instruction,
190
- truths=truths,
297
+ return Rapid(
298
+ payload=payload,
299
+ truth=model_truth,
191
300
  asset=asset,
192
301
  metadata=metadata,
302
+ explanation=explanation,
303
+ randomCorrectProbability=self._calculate_coverage_ratio(asset.get_duration(), truths),
193
304
  )
194
305
 
306
+ def _calculate_boxes_coverage(self, boxes: list[Box], image_width: int, image_height: int) -> float:
307
+ if not boxes:
308
+ return 0.0
309
+ # Convert all coordinates to integers for pixel-wise coverage
310
+ pixels = set()
311
+ for box in boxes:
312
+ for x in range(int(box.x_min), int(box.x_max + 1)):
313
+ for y in range(int(box.y_min), int(box.y_max + 1)):
314
+ if 0 <= x < image_width and 0 <= y < image_height:
315
+ pixels.add((x,y))
316
+
317
+ total_covered = len(pixels)
318
+ return total_covered / (image_width * image_height)
319
+
320
+ def _calculate_coverage_ratio(self, total_duration: int, subsections: list[tuple[int, int]]) -> float:
321
+ """
322
+ Calculate the ratio of total_duration that is covered by subsections, handling overlaps.
323
+
324
+ Args:
325
+ total_duration: The total duration to consider
326
+ subsections: List of tuples containing (start, end) times
327
+
328
+ Returns:
329
+ float: Ratio of coverage (0 to 1)
330
+ """
331
+ if not subsections:
332
+ return 0.0
333
+
334
+ # Sort subsections by start time and clamp to valid range
335
+ sorted_ranges = sorted(
336
+ (max(0, start), min(end, total_duration))
337
+ for start, end in subsections
338
+ )
339
+
340
+ # Merge overlapping ranges
341
+ merged_ranges = []
342
+ current_range = list(sorted_ranges[0])
343
+
344
+ for next_start, next_end in sorted_ranges[1:]:
345
+ current_start, current_end = current_range
346
+
347
+ # If ranges overlap or are adjacent
348
+ if next_start <= current_end:
349
+ current_range[1] = max(current_end, next_end)
350
+ else:
351
+ merged_ranges.append(current_range)
352
+ current_range = [next_start, next_end]
353
+
354
+ merged_ranges.append(current_range)
355
+
356
+ # Calculate total coverage
357
+ total_coverage = sum(end - start for start, end in merged_ranges)
195
358
 
359
+ return total_coverage / total_duration