rapidata 2.3.2__py3-none-any.whl → 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +1 -1
- rapidata/api_client/__init__.py +2 -0
- rapidata/api_client/api/rapid_api.py +268 -0
- rapidata/api_client/models/__init__.py +2 -0
- rapidata/api_client/models/add_validation_rapid_model.py +9 -2
- rapidata/api_client/models/add_validation_text_rapid_model.py +9 -2
- rapidata/api_client/models/query_validation_rapids_result.py +7 -3
- rapidata/api_client/models/rapid_issue.py +41 -0
- rapidata/api_client/models/report_model.py +103 -0
- rapidata/api_client_README.md +3 -0
- rapidata/rapidata_client/assets/_media_asset.py +0 -1
- rapidata/rapidata_client/selection/rapidata_selections.py +11 -1
- rapidata/rapidata_client/validation/rapidata_validation_set.py +2 -294
- rapidata/rapidata_client/validation/rapids/rapids.py +72 -125
- rapidata/rapidata_client/validation/rapids/rapids_manager.py +206 -42
- rapidata/rapidata_client/validation/validation_set_manager.py +108 -69
- {rapidata-2.3.2.dist-info → rapidata-2.4.0.dist-info}/METADATA +1 -1
- {rapidata-2.3.2.dist-info → rapidata-2.4.0.dist-info}/RECORD +20 -20
- rapidata/rapidata_client/validation/_validation_rapid_parts.py +0 -61
- rapidata/rapidata_client/validation/_validation_set_builder.py +0 -481
- {rapidata-2.3.2.dist-info → rapidata-2.4.0.dist-info}/LICENSE +0 -0
- {rapidata-2.3.2.dist-info → rapidata-2.4.0.dist-info}/WHEEL +0 -0
|
@@ -1,481 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
|
|
3
|
-
from rapidata.api_client.models.classify_payload import ClassifyPayload
|
|
4
|
-
from rapidata.api_client.models.compare_payload import ComparePayload
|
|
5
|
-
from rapidata.api_client.models.compare_truth import CompareTruth
|
|
6
|
-
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
|
|
7
|
-
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
|
|
8
|
-
from rapidata.api_client.models.transcription_word import TranscriptionWord
|
|
9
|
-
from rapidata.api_client.models.scrub_payload import ScrubPayload
|
|
10
|
-
from rapidata.api_client.models.scrub_truth import ScrubTruth
|
|
11
|
-
from rapidata.api_client.models.scrub_range import ScrubRange
|
|
12
|
-
from rapidata.api_client.models.locate_payload import LocatePayload
|
|
13
|
-
from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
|
|
14
|
-
from rapidata.api_client.models.line_payload import LinePayload
|
|
15
|
-
from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
|
|
16
|
-
from rapidata.api_client.models.box_shape import BoxShape
|
|
17
|
-
from rapidata.rapidata_client.validation.rapidata_validation_set import (
|
|
18
|
-
RapidataValidationSet,
|
|
19
|
-
)
|
|
20
|
-
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
21
|
-
from rapidata.rapidata_client.validation._validation_rapid_parts import ValidatioRapidParts
|
|
22
|
-
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
23
|
-
from rapidata.service.openapi_service import OpenAPIService
|
|
24
|
-
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
25
|
-
|
|
26
|
-
from rapidata.rapidata_client.validation.rapids.rapids import (
|
|
27
|
-
Rapid,
|
|
28
|
-
ClassificationRapid,
|
|
29
|
-
CompareRapid,
|
|
30
|
-
SelectWordsRapid,
|
|
31
|
-
LocateRapid,
|
|
32
|
-
DrawRapid,
|
|
33
|
-
TimestampRapid
|
|
34
|
-
)
|
|
35
|
-
from typing import Sequence
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class ValidationSetBuilder:
|
|
39
|
-
"""The ValidationSetBuilder is used to build a validation set.
|
|
40
|
-
Give the validation set a name and then add classify, compare, or transcription rapid parts to it.
|
|
41
|
-
Get a `ValidationSetBuilder` by calling [`rapi.new_validation_set()`](../rapidata_client.md/#rapidata.rapidata_client.rapidata_client.RapidataClient.new_validation_set).
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
name (str): The name of the validation set.
|
|
45
|
-
openapi_service (OpenAPIService): An instance of OpenAPIService to interact with the API.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def __init__(self, name: str, openapi_service: OpenAPIService):
|
|
49
|
-
self.name = name
|
|
50
|
-
self.openapi_service = openapi_service
|
|
51
|
-
self.validation_set_id: str | None = None
|
|
52
|
-
self._rapid_parts: list[ValidatioRapidParts] = []
|
|
53
|
-
|
|
54
|
-
def _submit(self, print_confirmation: bool = True) -> RapidataValidationSet:
|
|
55
|
-
"""Create the validation set by executing all HTTP requests. This should be the last method called on the builder.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
RapidataValidationSet: A RapidataValidationSet instance.
|
|
59
|
-
|
|
60
|
-
Raises:
|
|
61
|
-
ValueError: If the validation set creation fails.
|
|
62
|
-
"""
|
|
63
|
-
result = (
|
|
64
|
-
self.openapi_service.validation_api.validation_create_validation_set_post(
|
|
65
|
-
name=self.name
|
|
66
|
-
)
|
|
67
|
-
)
|
|
68
|
-
self.validation_set_id = result.validation_set_id
|
|
69
|
-
|
|
70
|
-
if self.validation_set_id is None:
|
|
71
|
-
raise ValueError("Failed to create validation set")
|
|
72
|
-
|
|
73
|
-
validation_set = RapidataValidationSet(
|
|
74
|
-
validation_set_id=self.validation_set_id,
|
|
75
|
-
openapi_service=self.openapi_service,
|
|
76
|
-
name=self.name,
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
for rapid_part in self._rapid_parts:
|
|
80
|
-
validation_set._add_general_validation_rapid(
|
|
81
|
-
payload=rapid_part.payload,
|
|
82
|
-
truths=rapid_part.truths,
|
|
83
|
-
metadata=rapid_part.metadata,
|
|
84
|
-
asset=rapid_part.asset,
|
|
85
|
-
randomCorrectProbability=rapid_part.randomCorrectProbability,
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
if print_confirmation:
|
|
89
|
-
print(f"Validation set '{self.name}' created with ID {self.validation_set_id}")
|
|
90
|
-
|
|
91
|
-
return validation_set
|
|
92
|
-
|
|
93
|
-
def _add_rapid(self, rapid: Rapid):
|
|
94
|
-
"""Add a rapid to the validation set.
|
|
95
|
-
To create the Rapid, use the RapidataClient.rapid_builder instance.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
rapid (Rapid): The rapid to add to the validation set.
|
|
99
|
-
"""
|
|
100
|
-
if not isinstance(rapid, Rapid):
|
|
101
|
-
raise ValueError("This method only accepts Rapid instances")
|
|
102
|
-
|
|
103
|
-
elif isinstance(rapid, ClassificationRapid):
|
|
104
|
-
self.__add_classify_rapid(rapid.asset, rapid.instruction, rapid.answer_options, rapid.truths, rapid.metadata)
|
|
105
|
-
|
|
106
|
-
elif isinstance(rapid, CompareRapid):
|
|
107
|
-
self.__add_compare_rapid(rapid.asset, rapid.instruction, rapid.truth, rapid.metadata)
|
|
108
|
-
|
|
109
|
-
elif isinstance(rapid, SelectWordsRapid):
|
|
110
|
-
self.__add_select_words_rapid(
|
|
111
|
-
rapid.asset,
|
|
112
|
-
rapid.instruction,
|
|
113
|
-
rapid.sentence,
|
|
114
|
-
rapid.truths,
|
|
115
|
-
rapid.required_precision,
|
|
116
|
-
rapid.required_completeness,
|
|
117
|
-
rapid.metadata)
|
|
118
|
-
|
|
119
|
-
elif isinstance(rapid, LocateRapid):
|
|
120
|
-
self.__add_locate_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
|
|
121
|
-
|
|
122
|
-
elif isinstance(rapid, DrawRapid):
|
|
123
|
-
self.__add_draw_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
|
|
124
|
-
|
|
125
|
-
elif isinstance(rapid, TimestampRapid):
|
|
126
|
-
self.__add_timestamp_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
raise ValueError("Unsupported rapid type")
|
|
130
|
-
|
|
131
|
-
return self
|
|
132
|
-
|
|
133
|
-
def __add_classify_rapid(
|
|
134
|
-
self,
|
|
135
|
-
asset: MediaAsset | TextAsset,
|
|
136
|
-
instruction: str,
|
|
137
|
-
answer_options: list[str],
|
|
138
|
-
truths: list[str],
|
|
139
|
-
metadata: Sequence[Metadata] = [],
|
|
140
|
-
):
|
|
141
|
-
"""Add a classify rapid to the validation set.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
145
|
-
instruction (str): The instruction for the rapid.
|
|
146
|
-
answer_options (list[str]): The list of answer_options for the rapid.
|
|
147
|
-
truths (list[str]): The list of truths for the rapid.
|
|
148
|
-
metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
152
|
-
|
|
153
|
-
Raises:
|
|
154
|
-
ValueError: If the lengths of categories and truths are inconsistent.
|
|
155
|
-
"""
|
|
156
|
-
if not all(truth in answer_options for truth in truths):
|
|
157
|
-
raise ValueError("Truths must be part of the answer options")
|
|
158
|
-
|
|
159
|
-
payload = ClassifyPayload(
|
|
160
|
-
_t="ClassifyPayload", possibleCategories=answer_options, title=instruction
|
|
161
|
-
)
|
|
162
|
-
model_truth = AttachCategoryTruth(
|
|
163
|
-
correctCategories=truths, _t="AttachCategoryTruth"
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
self._rapid_parts.append(
|
|
167
|
-
ValidatioRapidParts(
|
|
168
|
-
instruction=instruction,
|
|
169
|
-
payload=payload,
|
|
170
|
-
truths=model_truth,
|
|
171
|
-
metadata=metadata,
|
|
172
|
-
randomCorrectProbability=len(truths) / len(answer_options),
|
|
173
|
-
asset=asset,
|
|
174
|
-
)
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
def __add_compare_rapid(
|
|
178
|
-
self,
|
|
179
|
-
asset: MultiAsset,
|
|
180
|
-
instruction: str,
|
|
181
|
-
truth: str,
|
|
182
|
-
metadata: Sequence[Metadata] = [],
|
|
183
|
-
):
|
|
184
|
-
"""Add a compare rapid to the validation set.
|
|
185
|
-
|
|
186
|
-
Args:
|
|
187
|
-
asset (MultiAsset): The assets for the rapid.
|
|
188
|
-
instruction (str): The instruction for the comparison.
|
|
189
|
-
truth (str): The truth identifier for the rapid.
|
|
190
|
-
metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
191
|
-
|
|
192
|
-
Returns:
|
|
193
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
194
|
-
|
|
195
|
-
Raises:
|
|
196
|
-
ValueError: If the number of assets is not exactly two.
|
|
197
|
-
"""
|
|
198
|
-
payload = ComparePayload(_t="ComparePayload", criteria=instruction)
|
|
199
|
-
# take only last part of truth path
|
|
200
|
-
truth = os.path.basename(truth)
|
|
201
|
-
model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
|
|
202
|
-
|
|
203
|
-
if len(asset) != 2:
|
|
204
|
-
raise ValueError("Compare rapid requires exactly two media paths")
|
|
205
|
-
|
|
206
|
-
self._rapid_parts.append(
|
|
207
|
-
ValidatioRapidParts(
|
|
208
|
-
instruction=instruction,
|
|
209
|
-
payload=payload,
|
|
210
|
-
truths=model_truth,
|
|
211
|
-
metadata=metadata,
|
|
212
|
-
randomCorrectProbability=1 / len(asset),
|
|
213
|
-
asset=asset,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
def __add_select_words_rapid(
|
|
218
|
-
self,
|
|
219
|
-
asset: MediaAsset | TextAsset,
|
|
220
|
-
instruction: str,
|
|
221
|
-
select_words: str,
|
|
222
|
-
truths: list[int],
|
|
223
|
-
required_precision: float,
|
|
224
|
-
required_completeness: float,
|
|
225
|
-
metadata: Sequence[Metadata] = [],
|
|
226
|
-
):
|
|
227
|
-
"""Add a select words rapid to the validation set.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
231
|
-
instruction (str): The instruction for the rapid.
|
|
232
|
-
select words (list[str]): The select words for the rapid.
|
|
233
|
-
truths (list[int]): The list of indices of the true word selections.
|
|
234
|
-
required_precision (float): The required precision for the rapid (minimum ratio of selected words that are correct).
|
|
235
|
-
required_completeness (float): The required completeness for the rapid (miminum ratio of total correct words selected).
|
|
236
|
-
metadata (Sequence[Metadata], optional): The metadata for the rapid.
|
|
237
|
-
|
|
238
|
-
Returns:
|
|
239
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
240
|
-
|
|
241
|
-
Raises:
|
|
242
|
-
ValueError: If a correct word is not found in the select words.
|
|
243
|
-
"""
|
|
244
|
-
transcription_words = [
|
|
245
|
-
TranscriptionWord(word=word, wordIndex=i)
|
|
246
|
-
for i, word in enumerate(select_words.split())
|
|
247
|
-
]
|
|
248
|
-
|
|
249
|
-
true_words = []
|
|
250
|
-
for idx in truths:
|
|
251
|
-
assert isinstance(idx, int), "truths must be a list of integers"
|
|
252
|
-
if idx > len(transcription_words) - 1:
|
|
253
|
-
raise ValueError(f"Index {idx} is out of bounds")
|
|
254
|
-
true_words.append(transcription_words[idx])
|
|
255
|
-
|
|
256
|
-
payload = TranscriptionPayload(
|
|
257
|
-
_t="TranscriptionPayload", title=instruction, transcription=transcription_words
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
model_truth = TranscriptionTruth(
|
|
261
|
-
_t="TranscriptionTruth",
|
|
262
|
-
correctWords=true_words,
|
|
263
|
-
requiredPrecision=required_precision,
|
|
264
|
-
requiredCompleteness=required_completeness,
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
self._rapid_parts.append(
|
|
268
|
-
ValidatioRapidParts(
|
|
269
|
-
instruction=instruction,
|
|
270
|
-
asset=asset,
|
|
271
|
-
payload=payload,
|
|
272
|
-
truths=model_truth,
|
|
273
|
-
metadata=metadata,
|
|
274
|
-
randomCorrectProbability = 1 / len(transcription_words),
|
|
275
|
-
)
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
def __add_locate_rapid(
|
|
279
|
-
self,
|
|
280
|
-
asset: MediaAsset,
|
|
281
|
-
instruction: str,
|
|
282
|
-
truths: list[Box],
|
|
283
|
-
metadata: Sequence[Metadata] = [],
|
|
284
|
-
):
|
|
285
|
-
"""Add a locate rapid to the validation set.
|
|
286
|
-
|
|
287
|
-
Args:
|
|
288
|
-
instruction (str): The instruction for the locate rapid.
|
|
289
|
-
asset (MediaAsset): The asset for the rapid.
|
|
290
|
-
truths (list[Box]): The truths for the rapid.
|
|
291
|
-
|
|
292
|
-
Returns:
|
|
293
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
294
|
-
"""
|
|
295
|
-
payload = LocatePayload(
|
|
296
|
-
_t="LocatePayload", target=instruction
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
img_dimensions = asset.get_image_dimension()
|
|
300
|
-
|
|
301
|
-
if not img_dimensions:
|
|
302
|
-
raise ValueError("Failed to get image dimensions")
|
|
303
|
-
|
|
304
|
-
model_truth = LocateBoxTruth(
|
|
305
|
-
_t="LocateBoxTruth",
|
|
306
|
-
boundingBoxes=[BoxShape(
|
|
307
|
-
_t="BoxShape",
|
|
308
|
-
xMin=truth.x_min / img_dimensions[0] * 100,
|
|
309
|
-
xMax=truth.x_max / img_dimensions[0] * 100,
|
|
310
|
-
yMax=truth.y_max / img_dimensions[1] * 100,
|
|
311
|
-
yMin=truth.y_min / img_dimensions[1] * 100,
|
|
312
|
-
) for truth in truths]
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
|
|
316
|
-
|
|
317
|
-
self._rapid_parts.append(
|
|
318
|
-
ValidatioRapidParts(
|
|
319
|
-
instruction=instruction,
|
|
320
|
-
payload=payload,
|
|
321
|
-
truths=model_truth,
|
|
322
|
-
metadata=metadata,
|
|
323
|
-
randomCorrectProbability=coverage,
|
|
324
|
-
asset=asset,
|
|
325
|
-
)
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
def __add_draw_rapid(
|
|
329
|
-
self,
|
|
330
|
-
asset: MediaAsset,
|
|
331
|
-
instruction: str,
|
|
332
|
-
truths: list[Box],
|
|
333
|
-
metadata: Sequence[Metadata] = [],
|
|
334
|
-
):
|
|
335
|
-
"""Add a draw rapid to the validation set.
|
|
336
|
-
|
|
337
|
-
Args:
|
|
338
|
-
instruction (str): The instruction for the draw rapid.
|
|
339
|
-
asset (MediaAsset): The asset for the rapid.
|
|
340
|
-
truths (list[Box]): The truths for the rapid.
|
|
341
|
-
|
|
342
|
-
Returns:
|
|
343
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
344
|
-
"""
|
|
345
|
-
|
|
346
|
-
payload = LinePayload(
|
|
347
|
-
_t="LinePayload", target=instruction
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
img_dimensions = asset.get_image_dimension()
|
|
351
|
-
|
|
352
|
-
if not img_dimensions:
|
|
353
|
-
raise ValueError("Failed to get image dimensions")
|
|
354
|
-
|
|
355
|
-
model_truth = BoundingBoxTruth(
|
|
356
|
-
_t="BoundingBoxTruth",
|
|
357
|
-
xMax=truths[0].x_max / img_dimensions[0],
|
|
358
|
-
xMin=truths[0].x_min / img_dimensions[0],
|
|
359
|
-
yMax=truths[0].y_max / img_dimensions[1],
|
|
360
|
-
yMin=truths[0].y_min / img_dimensions[1],
|
|
361
|
-
) # TO BE CHANGED BEFORE MERGING
|
|
362
|
-
|
|
363
|
-
coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
|
|
364
|
-
|
|
365
|
-
self._rapid_parts.append(
|
|
366
|
-
ValidatioRapidParts(
|
|
367
|
-
instruction=instruction,
|
|
368
|
-
payload=payload,
|
|
369
|
-
truths=model_truth,
|
|
370
|
-
metadata=metadata,
|
|
371
|
-
randomCorrectProbability=coverage,
|
|
372
|
-
asset=asset,
|
|
373
|
-
)
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
def __add_timestamp_rapid(
|
|
377
|
-
self,
|
|
378
|
-
asset: MediaAsset,
|
|
379
|
-
instruction: str,
|
|
380
|
-
truths: list[tuple[int, int]],
|
|
381
|
-
metadata: Sequence[Metadata] = [],
|
|
382
|
-
):
|
|
383
|
-
"""Add a timestamp rapid to the validation set.
|
|
384
|
-
|
|
385
|
-
Args:
|
|
386
|
-
instruction (str): The instruction for the timestamp rapid.
|
|
387
|
-
asset (MediaAsset): The asset for the rapid.
|
|
388
|
-
truths (list[tuple[int, int]]): The truths for the rapid.
|
|
389
|
-
This is a list of tuples where the first element is the start of the interval and the second element is the end of the interval.
|
|
390
|
-
The intervals are in miliseconds.
|
|
391
|
-
metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
392
|
-
|
|
393
|
-
Returns:
|
|
394
|
-
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
395
|
-
"""
|
|
396
|
-
for truth in truths:
|
|
397
|
-
if len(truth) != 2:
|
|
398
|
-
raise ValueError("The truths per datapoint must be a tuple of exactly two integers.")
|
|
399
|
-
if truth[0] > truth[1]:
|
|
400
|
-
raise ValueError("The start of the interval must be smaller than the end of the interval.")
|
|
401
|
-
|
|
402
|
-
payload = ScrubPayload(
|
|
403
|
-
_t="ScrubPayload",
|
|
404
|
-
target=instruction
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
model_truth = ScrubTruth(
|
|
408
|
-
_t="ScrubTruth",
|
|
409
|
-
validRanges=[ScrubRange(
|
|
410
|
-
start=truth[0],
|
|
411
|
-
end=truth[1]
|
|
412
|
-
) for truth in truths]
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
self._rapid_parts.append(
|
|
416
|
-
ValidatioRapidParts(
|
|
417
|
-
instruction=instruction,
|
|
418
|
-
payload=payload,
|
|
419
|
-
truths=model_truth,
|
|
420
|
-
metadata=metadata,
|
|
421
|
-
randomCorrectProbability=self._calculate_coverage_ratio(asset.get_duration(), truths),
|
|
422
|
-
asset=asset,
|
|
423
|
-
)
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def _calculate_boxes_coverage(self, boxes: list[Box], image_width: int, image_height: int) -> float:
|
|
428
|
-
if not boxes:
|
|
429
|
-
return 0.0
|
|
430
|
-
|
|
431
|
-
# Convert all coordinates to integers for pixel-wise coverage
|
|
432
|
-
pixels = set()
|
|
433
|
-
for box in boxes:
|
|
434
|
-
for x in range(int(box.x_min), int(box.x_max + 1)):
|
|
435
|
-
for y in range(int(box.y_min), int(box.y_max + 1)):
|
|
436
|
-
if 0 <= x < image_width and 0 <= y < image_height:
|
|
437
|
-
pixels.add((x,y))
|
|
438
|
-
|
|
439
|
-
total_covered = len(pixels)
|
|
440
|
-
return total_covered / (image_width * image_height)
|
|
441
|
-
|
|
442
|
-
def _calculate_coverage_ratio(self, total_duration: int, subsections: list[tuple[int, int]]) -> float:
|
|
443
|
-
"""
|
|
444
|
-
Calculate the ratio of total_duration that is covered by subsections, handling overlaps.
|
|
445
|
-
|
|
446
|
-
Args:
|
|
447
|
-
total_duration: The total duration to consider
|
|
448
|
-
subsections: List of tuples containing (start, end) times
|
|
449
|
-
|
|
450
|
-
Returns:
|
|
451
|
-
float: Ratio of coverage (0 to 1)
|
|
452
|
-
"""
|
|
453
|
-
if not subsections:
|
|
454
|
-
return 0.0
|
|
455
|
-
|
|
456
|
-
# Sort subsections by start time and clamp to valid range
|
|
457
|
-
sorted_ranges = sorted(
|
|
458
|
-
(max(0, start), min(end, total_duration))
|
|
459
|
-
for start, end in subsections
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
# Merge overlapping ranges
|
|
463
|
-
merged_ranges = []
|
|
464
|
-
current_range = list(sorted_ranges[0])
|
|
465
|
-
|
|
466
|
-
for next_start, next_end in sorted_ranges[1:]:
|
|
467
|
-
current_start, current_end = current_range
|
|
468
|
-
|
|
469
|
-
# If ranges overlap or are adjacent
|
|
470
|
-
if next_start <= current_end:
|
|
471
|
-
current_range[1] = max(current_end, next_end)
|
|
472
|
-
else:
|
|
473
|
-
merged_ranges.append(current_range)
|
|
474
|
-
current_range = [next_start, next_end]
|
|
475
|
-
|
|
476
|
-
merged_ranges.append(current_range)
|
|
477
|
-
|
|
478
|
-
# Calculate total coverage
|
|
479
|
-
total_coverage = sum(end - start for start, end in merged_ranges)
|
|
480
|
-
|
|
481
|
-
return total_coverage / total_duration
|
|
File without changes
|
|
File without changes
|