rapidata 1.10.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +21 -17
- rapidata/api_client/__init__.py +15 -5
- rapidata/api_client/api/coco_api.py +14 -29
- rapidata/api_client/api/dataset_api.py +6 -6
- rapidata/api_client/api/identity_api.py +3 -3
- rapidata/api_client/api/pipeline_api.py +1008 -95
- rapidata/api_client/api/rapid_api.py +6 -6
- rapidata/api_client/api/validation_api.py +12 -42
- rapidata/api_client/models/__init__.py +15 -5
- rapidata/api_client/models/add_campaign_model.py +1 -3
- rapidata/api_client/models/add_validation_text_rapid_model.py +1 -1
- rapidata/api_client/models/age_group.py +5 -4
- rapidata/api_client/models/base_error.py +1 -4
- rapidata/api_client/models/compare_workflow_config.py +9 -24
- rapidata/api_client/models/compare_workflow_config_model.py +9 -29
- rapidata/api_client/models/compare_workflow_config_model_pair_maker_config.py +140 -0
- rapidata/api_client/models/compare_workflow_config_pair_maker_config.py +140 -0
- rapidata/api_client/models/compare_workflow_model.py +7 -3
- rapidata/api_client/models/compare_workflow_model1.py +7 -3
- rapidata/api_client/models/compare_workflow_model1_pair_maker_information.py +140 -0
- rapidata/api_client/models/compare_workflow_model_pair_maker_config.py +140 -0
- rapidata/api_client/models/file_asset_model_metadata_inner.py +8 -22
- rapidata/api_client/models/get_classify_workflow_result_overview_result.py +144 -0
- rapidata/api_client/models/get_pipeline_by_id_result.py +13 -3
- rapidata/api_client/models/identity_read_bridge_token_get202_response.py +140 -0
- rapidata/api_client/models/not_available_yet_result.py +96 -0
- rapidata/api_client/models/online_pair_maker_config.py +98 -0
- rapidata/api_client/models/online_pair_maker_config_model.py +98 -0
- rapidata/api_client/models/online_pair_maker_information.py +100 -0
- rapidata/api_client/models/pipeline_id_workflow_put_request.py +140 -0
- rapidata/api_client/models/pre_arranged_pair_maker_config.py +100 -0
- rapidata/api_client/models/pre_arranged_pair_maker_config_model.py +96 -0
- rapidata/api_client/models/pre_arranged_pair_maker_information.py +102 -0
- rapidata/api_client/models/read_bridge_token_keys_result.py +11 -2
- rapidata/api_client/models/simple_workflow_config.py +7 -26
- rapidata/api_client/models/simple_workflow_config_model.py +4 -28
- rapidata/api_client/models/simple_workflow_get_result_overview_get200_response.py +16 -16
- rapidata/api_client/models/simple_workflow_model1.py +3 -3
- rapidata/api_client/models/update_campaign_model.py +99 -0
- rapidata/api_client/models/validation_import_post_request_blueprint.py +1 -1
- rapidata/api_client_README.md +20 -7
- rapidata/rapidata_client/__init__.py +18 -9
- rapidata/rapidata_client/assets/__init__.py +5 -4
- rapidata/rapidata_client/assets/{media_asset.py → _media_asset.py} +32 -11
- rapidata/rapidata_client/assets/{multi_asset.py → _multi_asset.py} +1 -1
- rapidata/rapidata_client/assets/{text_asset.py → _text_asset.py} +1 -1
- rapidata/rapidata_client/assets/data_type_enum.py +7 -0
- rapidata/rapidata_client/filter/__init__.py +1 -1
- rapidata/rapidata_client/filter/_base_filter.py +10 -0
- rapidata/rapidata_client/filter/age_filter.py +12 -5
- rapidata/rapidata_client/filter/campaign_filter.py +12 -3
- rapidata/rapidata_client/filter/country_filter.py +10 -3
- rapidata/rapidata_client/filter/gender_filter.py +12 -5
- rapidata/rapidata_client/filter/language_filter.py +14 -3
- rapidata/rapidata_client/filter/models/age_group.py +26 -0
- rapidata/rapidata_client/filter/models/gender.py +19 -0
- rapidata/rapidata_client/filter/rapidata_filters.py +31 -0
- rapidata/rapidata_client/filter/user_score_filter.py +20 -4
- rapidata/rapidata_client/metadata/__init__.py +5 -5
- rapidata/rapidata_client/metadata/{base_metadata.py → _base_metadata.py} +2 -1
- rapidata/rapidata_client/metadata/{private_text_metadata.py → _private_text_metadata.py} +2 -2
- rapidata/rapidata_client/metadata/{prompt_metadata.py → _prompt_metadata.py} +3 -2
- rapidata/rapidata_client/metadata/{public_text_metadata.py → _public_text_metadata.py} +2 -2
- rapidata/rapidata_client/metadata/{select_words_metadata.py → _select_words_metadata.py} +3 -2
- rapidata/rapidata_client/{dataset/rapidata_dataset.py → order/_rapidata_dataset.py} +7 -8
- rapidata/rapidata_client/order/_rapidata_order_builder.py +365 -0
- rapidata/rapidata_client/order/rapidata_order.py +49 -31
- rapidata/rapidata_client/order/rapidata_order_manager.py +461 -0
- rapidata/rapidata_client/rapidata_client.py +12 -201
- rapidata/rapidata_client/referee/__init__.py +3 -3
- rapidata/rapidata_client/referee/{base_referee.py → _base_referee.py} +3 -3
- rapidata/rapidata_client/referee/{early_stopping_referee.py → _early_stopping_referee.py} +14 -11
- rapidata/rapidata_client/referee/{naive_referee.py → _naive_referee.py} +9 -9
- rapidata/rapidata_client/selection/__init__.py +1 -1
- rapidata/rapidata_client/{filter/base_filter.py → selection/_base_selection.py} +2 -2
- rapidata/rapidata_client/selection/capped_selection.py +15 -5
- rapidata/rapidata_client/selection/conditional_validation_selection.py +17 -4
- rapidata/rapidata_client/selection/demographic_selection.py +18 -7
- rapidata/rapidata_client/selection/labeling_selection.py +10 -3
- rapidata/rapidata_client/selection/rapidata_selections.py +21 -0
- rapidata/rapidata_client/selection/validation_selection.py +11 -4
- rapidata/rapidata_client/settings/__init__.py +9 -2
- rapidata/rapidata_client/settings/_rapidata_setting.py +11 -0
- rapidata/rapidata_client/settings/alert_on_fast_response.py +21 -0
- rapidata/rapidata_client/settings/custom_setting.py +16 -0
- rapidata/rapidata_client/settings/free_text_minimum_characters.py +16 -0
- rapidata/rapidata_client/settings/models/__init__.py +1 -0
- rapidata/rapidata_client/settings/models/translation_behaviour_options.py +14 -0
- rapidata/rapidata_client/settings/no_shuffle.py +16 -0
- rapidata/rapidata_client/settings/play_video_until_the_end.py +16 -0
- rapidata/rapidata_client/settings/rapidata_settings.py +31 -0
- rapidata/rapidata_client/settings/translation_behaviour.py +18 -0
- rapidata/rapidata_client/validation/__init__.py +1 -0
- rapidata/rapidata_client/{dataset/validation_rapid_parts.py → validation/_validation_rapid_parts.py} +7 -6
- rapidata/rapidata_client/validation/_validation_set_builder.py +371 -0
- rapidata/rapidata_client/{dataset → validation}/rapidata_validation_set.py +54 -50
- rapidata/rapidata_client/validation/rapids/__init__.py +1 -0
- rapidata/rapidata_client/validation/rapids/box.py +17 -0
- rapidata/rapidata_client/validation/rapids/rapids.py +94 -0
- rapidata/rapidata_client/validation/rapids/rapids_manager.py +163 -0
- rapidata/rapidata_client/validation/validation_set_manager.py +335 -0
- rapidata/rapidata_client/workflow/__init__.py +8 -6
- rapidata/rapidata_client/workflow/_base_workflow.py +25 -0
- rapidata/rapidata_client/workflow/{classify_workflow.py → _classify_workflow.py} +6 -6
- rapidata/rapidata_client/workflow/{compare_workflow.py → _compare_workflow.py} +10 -16
- rapidata/rapidata_client/workflow/_draw_workflow.py +22 -0
- rapidata/rapidata_client/workflow/_evaluation_workflow.py +26 -0
- rapidata/rapidata_client/workflow/{free_text_workflow.py → _free_text_workflow.py} +10 -16
- rapidata/rapidata_client/workflow/_locate_workflow.py +22 -0
- rapidata/rapidata_client/workflow/{select_words_workflow.py → _select_words_workflow.py} +2 -8
- rapidata/service/credential_manager.py +11 -1
- rapidata/service/openapi_service.py +23 -4
- {rapidata-1.10.0.dist-info → rapidata-2.0.0.dist-info}/METADATA +2 -1
- {rapidata-1.10.0.dist-info → rapidata-2.0.0.dist-info}/RECORD +118 -94
- rapidata/constants.py +0 -1
- rapidata/rapidata_client/dataset/rapid_builders/__init__.py +0 -4
- rapidata/rapidata_client/dataset/rapid_builders/base_rapid_builder.py +0 -33
- rapidata/rapidata_client/dataset/rapid_builders/classify_rapid_builders.py +0 -166
- rapidata/rapidata_client/dataset/rapid_builders/compare_rapid_builders.py +0 -145
- rapidata/rapidata_client/dataset/rapid_builders/rapids.py +0 -33
- rapidata/rapidata_client/dataset/rapid_builders/select_words_rapid_builders.py +0 -124
- rapidata/rapidata_client/dataset/validation_set_builder.py +0 -336
- rapidata/rapidata_client/order/order_builder.py +0 -25
- rapidata/rapidata_client/order/rapidata_order_builder.py +0 -463
- rapidata/rapidata_client/selection/base_selection.py +0 -9
- rapidata/rapidata_client/settings/feature_flags.py +0 -125
- rapidata/rapidata_client/settings/settings.py +0 -124
- rapidata/rapidata_client/simple_builders/__init__.py +0 -0
- rapidata/rapidata_client/simple_builders/simple_classification_builders.py +0 -271
- rapidata/rapidata_client/simple_builders/simple_compare_builders.py +0 -267
- rapidata/rapidata_client/simple_builders/simple_free_text_builders.py +0 -192
- rapidata/rapidata_client/simple_builders/simple_select_words_builders.py +0 -196
- rapidata/rapidata_client/workflow/base_workflow.py +0 -42
- rapidata/rapidata_client/workflow/evaluation_workflow.py +0 -15
- /rapidata/rapidata_client/assets/{base_asset.py → _base_asset.py} +0 -0
- /rapidata/rapidata_client/{dataset → filter/models}/__init__.py +0 -0
- {rapidata-1.10.0.dist-info → rapidata-2.0.0.dist-info}/LICENSE +0 -0
- {rapidata-1.10.0.dist-info → rapidata-2.0.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
class TranslationBehaviourOptions(Enum):
|
|
4
|
+
"""The options for the translation behaviour setting.
|
|
5
|
+
|
|
6
|
+
Attributes:
|
|
7
|
+
BOTH: Show both the original and the translated text.
|
|
8
|
+
May clutter the screen if the options are too long.
|
|
9
|
+
ONLY_ORIGINAL: Show only the original text.
|
|
10
|
+
ONLY_TRANSLATED: Show only the translated text."""
|
|
11
|
+
|
|
12
|
+
BOTH = "both"
|
|
13
|
+
ONLY_ORIGINAL = "only original"
|
|
14
|
+
ONLY_TRANSLATED = "only translated"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
2
|
+
|
|
3
|
+
class NoShuffle(RapidataSetting):
|
|
4
|
+
"""
|
|
5
|
+
Only for classify tasks. If true, the order of the categories will be the same.
|
|
6
|
+
|
|
7
|
+
If this is not added to the order, it shuffling will be active.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
value (bool, optional): Whether to disable shuffling. Defaults to True for function call.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, value: bool = True):
|
|
13
|
+
if not isinstance(value, bool):
|
|
14
|
+
raise ValueError("The value must be a boolean.")
|
|
15
|
+
|
|
16
|
+
super().__init__(key="no_shuffle", value=value)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
2
|
+
|
|
3
|
+
class PlayVideoUntilTheEnd(RapidataSetting):
|
|
4
|
+
"""
|
|
5
|
+
Allows users to only answer once the video has finished playing.
|
|
6
|
+
The additional time gets added on top of the video duration. Can be negative to allow answers before the video ends.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
additional_time (int, optional): Additional time in milliseconds. Defaults to 0.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, additional_time: int = 0):
|
|
13
|
+
if additional_time < -25000 or additional_time > 25000:
|
|
14
|
+
raise ValueError("The additional time must be between -25000 and 25000.")
|
|
15
|
+
|
|
16
|
+
super().__init__(key="alert_on_fast_response_add_media_duration", value=additional_time)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from rapidata.rapidata_client.settings import (
|
|
2
|
+
AlertOnFastResponse,
|
|
3
|
+
TranslationBehaviour,
|
|
4
|
+
FreeTextMinimumCharacters,
|
|
5
|
+
NoShuffle,
|
|
6
|
+
PlayVideoUntilTheEnd,
|
|
7
|
+
CustomSetting,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
class RapidataSettings:
|
|
11
|
+
"""
|
|
12
|
+
Container class for all setting factory functions
|
|
13
|
+
|
|
14
|
+
Settings can be added to an order to determine the behaviour of the task.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
alert_on_fast_response (AlertOnFastResponse): The AlertOnFastResponse instance.
|
|
18
|
+
translation_behaviour (TranslationBehaviour): The TranslationBehaviour instance.
|
|
19
|
+
free_text_minimum_characters (FreeTextMinimumCharacters): The FreeTextMinimumCharacters instance.
|
|
20
|
+
no_shuffle (NoShuffle): The NoShuffle instance.
|
|
21
|
+
play_video_until_the_end (PlayVideoUntilTheEnd): The PlayVideoUntilTheEnd instance.
|
|
22
|
+
custom_setting (CustomSetting): The CustomSetting instance.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
alert_on_fast_response = AlertOnFastResponse
|
|
26
|
+
translation_behaviour = TranslationBehaviour
|
|
27
|
+
free_text_minimum_characters = FreeTextMinimumCharacters
|
|
28
|
+
no_shuffle = NoShuffle
|
|
29
|
+
play_video_until_the_end = PlayVideoUntilTheEnd
|
|
30
|
+
custom_setting = CustomSetting
|
|
31
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from rapidata.rapidata_client.settings.models.translation_behaviour_options import TranslationBehaviourOptions
|
|
2
|
+
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
3
|
+
|
|
4
|
+
class TranslationBehaviour(RapidataSetting):
|
|
5
|
+
"""
|
|
6
|
+
Defines what's the behaviour of the translation in the UI.
|
|
7
|
+
Will not translate text datapoints or sentences.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
value (TranslationBehaviourOptions): The translation behaviour.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, value: TranslationBehaviourOptions):
|
|
14
|
+
if not isinstance(value, TranslationBehaviourOptions):
|
|
15
|
+
raise ValueError("The value must be a TranslationBehaviourOptions.")
|
|
16
|
+
|
|
17
|
+
super().__init__(key="translation_behaviour", value=value)
|
|
18
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .rapids import Box
|
rapidata/rapidata_client/{dataset/validation_rapid_parts.py → validation/_validation_rapid_parts.py}
RENAMED
|
@@ -19,15 +19,16 @@ from rapidata.api_client.models.polygon_payload import PolygonPayload
|
|
|
19
19
|
from rapidata.api_client.models.polygon_truth import PolygonTruth
|
|
20
20
|
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
|
|
21
21
|
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
|
|
22
|
-
from rapidata.rapidata_client.assets.
|
|
23
|
-
from rapidata.rapidata_client.assets.
|
|
24
|
-
from rapidata.rapidata_client.assets.
|
|
25
|
-
from rapidata.rapidata_client.metadata.
|
|
22
|
+
from rapidata.rapidata_client.assets._media_asset import MediaAsset
|
|
23
|
+
from rapidata.rapidata_client.assets._multi_asset import MultiAsset
|
|
24
|
+
from rapidata.rapidata_client.assets._text_asset import TextAsset
|
|
25
|
+
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
26
|
+
from typing import Sequence
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@dataclass
|
|
29
30
|
class ValidatioRapidParts:
|
|
30
|
-
|
|
31
|
+
instruction: str
|
|
31
32
|
asset: MediaAsset | TextAsset | MultiAsset
|
|
32
33
|
payload: (
|
|
33
34
|
BoundingBoxPayload
|
|
@@ -51,5 +52,5 @@ class ValidatioRapidParts:
|
|
|
51
52
|
| PolygonTruth
|
|
52
53
|
| TranscriptionTruth
|
|
53
54
|
)
|
|
54
|
-
metadata:
|
|
55
|
+
metadata: Sequence[Metadata]
|
|
55
56
|
randomCorrectProbability: float
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
|
|
3
|
+
from rapidata.api_client.models.classify_payload import ClassifyPayload
|
|
4
|
+
from rapidata.api_client.models.compare_payload import ComparePayload
|
|
5
|
+
from rapidata.api_client.models.compare_truth import CompareTruth
|
|
6
|
+
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
|
|
7
|
+
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
|
|
8
|
+
from rapidata.api_client.models.transcription_word import TranscriptionWord
|
|
9
|
+
from rapidata.api_client.models.locate_payload import LocatePayload
|
|
10
|
+
from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
|
|
11
|
+
from rapidata.api_client.models.line_payload import LinePayload
|
|
12
|
+
from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
|
|
13
|
+
from rapidata.api_client.models.box_shape import BoxShape
|
|
14
|
+
from rapidata.rapidata_client.validation.rapidata_validation_set import (
|
|
15
|
+
RapidataValidationSet,
|
|
16
|
+
)
|
|
17
|
+
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
18
|
+
from rapidata.rapidata_client.validation._validation_rapid_parts import ValidatioRapidParts
|
|
19
|
+
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
20
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
21
|
+
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
22
|
+
|
|
23
|
+
from rapidata.rapidata_client.validation.rapids.rapids import (
|
|
24
|
+
Rapid,
|
|
25
|
+
ClassificationRapid,
|
|
26
|
+
CompareRapid,
|
|
27
|
+
SelectWordsRapid,
|
|
28
|
+
LocateRapid,
|
|
29
|
+
DrawRapid
|
|
30
|
+
)
|
|
31
|
+
from typing import Sequence
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ValidationSetBuilder:
|
|
35
|
+
"""The ValidationSetBuilder is used to build a validation set.
|
|
36
|
+
Give the validation set a name and then add classify, compare, or transcription rapid parts to it.
|
|
37
|
+
Get a `ValidationSetBuilder` by calling [`rapi.new_validation_set()`](../rapidata_client.md/#rapidata.rapidata_client.rapidata_client.RapidataClient.new_validation_set).
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
name (str): The name of the validation set.
|
|
41
|
+
openapi_service (OpenAPIService): An instance of OpenAPIService to interact with the API.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, name: str, openapi_service: OpenAPIService):
|
|
45
|
+
self.name = name
|
|
46
|
+
self.openapi_service = openapi_service
|
|
47
|
+
self.validation_set_id: str | None = None
|
|
48
|
+
self._rapid_parts: list[ValidatioRapidParts] = []
|
|
49
|
+
|
|
50
|
+
def _submit(self, print_confirmation: bool = True) -> RapidataValidationSet:
|
|
51
|
+
"""Create the validation set by executing all HTTP requests. This should be the last method called on the builder.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
RapidataValidationSet: A RapidataValidationSet instance.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If the validation set creation fails.
|
|
58
|
+
"""
|
|
59
|
+
result = (
|
|
60
|
+
self.openapi_service.validation_api.validation_create_validation_set_post(
|
|
61
|
+
name=self.name
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
self.validation_set_id = result.validation_set_id
|
|
65
|
+
|
|
66
|
+
if self.validation_set_id is None:
|
|
67
|
+
raise ValueError("Failed to create validation set")
|
|
68
|
+
|
|
69
|
+
validation_set = RapidataValidationSet(
|
|
70
|
+
validation_set_id=self.validation_set_id,
|
|
71
|
+
openapi_service=self.openapi_service,
|
|
72
|
+
name=self.name,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
for rapid_part in self._rapid_parts:
|
|
76
|
+
validation_set._add_general_validation_rapid(
|
|
77
|
+
payload=rapid_part.payload,
|
|
78
|
+
truths=rapid_part.truths,
|
|
79
|
+
metadata=rapid_part.metadata,
|
|
80
|
+
asset=rapid_part.asset,
|
|
81
|
+
randomCorrectProbability=rapid_part.randomCorrectProbability,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if print_confirmation:
|
|
85
|
+
print(f"Validation set '{self.name}' created with ID {self.validation_set_id}")
|
|
86
|
+
|
|
87
|
+
return validation_set
|
|
88
|
+
|
|
89
|
+
def _add_rapid(self, rapid: Rapid):
|
|
90
|
+
"""Add a rapid to the validation set.
|
|
91
|
+
To create the Rapid, use the RapidataClient.rapid_builder instance.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
rapid (Rapid): The rapid to add to the validation set.
|
|
95
|
+
"""
|
|
96
|
+
if not isinstance(rapid, Rapid):
|
|
97
|
+
raise ValueError("This method only accepts Rapid instances")
|
|
98
|
+
|
|
99
|
+
elif isinstance(rapid, ClassificationRapid):
|
|
100
|
+
self.__add_classify_rapid(rapid.asset, rapid.instruction, rapid.answer_options, rapid.truths, rapid.metadata)
|
|
101
|
+
|
|
102
|
+
elif isinstance(rapid, CompareRapid):
|
|
103
|
+
self.__add_compare_rapid(rapid.asset, rapid.instruction, rapid.truth, rapid.metadata)
|
|
104
|
+
|
|
105
|
+
elif isinstance(rapid, SelectWordsRapid):
|
|
106
|
+
self.__add_select_words_rapid(rapid.asset, rapid.instruction, rapid.sentence, rapid.truths, rapid.strict_grading)
|
|
107
|
+
|
|
108
|
+
elif isinstance(rapid, LocateRapid):
|
|
109
|
+
self.__add_locate_rapid(rapid.asset, rapid.instruction, rapid.truths)
|
|
110
|
+
|
|
111
|
+
elif isinstance(rapid, DrawRapid):
|
|
112
|
+
self.__add_draw_rapid(rapid.asset, rapid.instruction, rapid.truths)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError("Unsupported rapid type")
|
|
116
|
+
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
def __add_classify_rapid(
|
|
120
|
+
self,
|
|
121
|
+
asset: MediaAsset | TextAsset,
|
|
122
|
+
instruction: str,
|
|
123
|
+
answer_options: list[str],
|
|
124
|
+
truths: list[str],
|
|
125
|
+
metadata: Sequence[Metadata] = [],
|
|
126
|
+
):
|
|
127
|
+
"""Add a classify rapid to the validation set.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
131
|
+
instruction (str): The instruction for the rapid.
|
|
132
|
+
answer_options (list[str]): The list of answer_options for the rapid.
|
|
133
|
+
truths (list[str]): The list of truths for the rapid.
|
|
134
|
+
metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the lengths of categories and truths are inconsistent.
|
|
141
|
+
"""
|
|
142
|
+
if not all(truth in answer_options for truth in truths):
|
|
143
|
+
raise ValueError("Truths must be part of the answer options")
|
|
144
|
+
|
|
145
|
+
payload = ClassifyPayload(
|
|
146
|
+
_t="ClassifyPayload", possibleCategories=answer_options, title=instruction
|
|
147
|
+
)
|
|
148
|
+
model_truth = AttachCategoryTruth(
|
|
149
|
+
correctCategories=truths, _t="AttachCategoryTruth"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self._rapid_parts.append(
|
|
153
|
+
ValidatioRapidParts(
|
|
154
|
+
instruction=instruction,
|
|
155
|
+
payload=payload,
|
|
156
|
+
truths=model_truth,
|
|
157
|
+
metadata=metadata,
|
|
158
|
+
randomCorrectProbability=len(truths) / len(answer_options),
|
|
159
|
+
asset=asset,
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def __add_compare_rapid(
|
|
164
|
+
self,
|
|
165
|
+
asset: MultiAsset,
|
|
166
|
+
instruction: str,
|
|
167
|
+
truth: str,
|
|
168
|
+
metadata: Sequence[Metadata] = [],
|
|
169
|
+
):
|
|
170
|
+
"""Add a compare rapid to the validation set.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
asset (MultiAsset): The assets for the rapid.
|
|
174
|
+
instruction (str): The instruction for the comparison.
|
|
175
|
+
truth (str): The truth identifier for the rapid.
|
|
176
|
+
metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If the number of assets is not exactly two.
|
|
183
|
+
"""
|
|
184
|
+
payload = ComparePayload(_t="ComparePayload", criteria=instruction)
|
|
185
|
+
# take only last part of truth path
|
|
186
|
+
truth = os.path.basename(truth)
|
|
187
|
+
model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
|
|
188
|
+
|
|
189
|
+
if len(asset) != 2:
|
|
190
|
+
raise ValueError("Compare rapid requires exactly two media paths")
|
|
191
|
+
|
|
192
|
+
self._rapid_parts.append(
|
|
193
|
+
ValidatioRapidParts(
|
|
194
|
+
instruction=instruction,
|
|
195
|
+
payload=payload,
|
|
196
|
+
truths=model_truth,
|
|
197
|
+
metadata=metadata,
|
|
198
|
+
randomCorrectProbability=1 / len(asset),
|
|
199
|
+
asset=asset,
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def __add_select_words_rapid(
|
|
204
|
+
self,
|
|
205
|
+
asset: MediaAsset | TextAsset,
|
|
206
|
+
instruction: str,
|
|
207
|
+
select_words: str,
|
|
208
|
+
truths: list[int],
|
|
209
|
+
strict_grading: bool | None = None,
|
|
210
|
+
metadata: Sequence[Metadata] = [],
|
|
211
|
+
):
|
|
212
|
+
"""Add a select words rapid to the validation set.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
216
|
+
instruction (str): The instruction for the rapid.
|
|
217
|
+
select words (list[str]): The select words for the rapid.
|
|
218
|
+
truths (list[int]): The list of indices of the true word selections.
|
|
219
|
+
strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
|
|
220
|
+
metadata (Sequence[Metadata], optional): The metadata for the rapid.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ValueError: If a correct word is not found in the select words.
|
|
227
|
+
"""
|
|
228
|
+
transcription_words = [
|
|
229
|
+
TranscriptionWord(word=word, wordIndex=i)
|
|
230
|
+
for i, word in enumerate(select_words.split())
|
|
231
|
+
]
|
|
232
|
+
|
|
233
|
+
true_words = []
|
|
234
|
+
for idx in truths:
|
|
235
|
+
assert isinstance(idx, int), "truths must be a list of integers"
|
|
236
|
+
if idx > len(transcription_words) - 1:
|
|
237
|
+
raise ValueError(f"Index {idx} is out of bounds")
|
|
238
|
+
true_words.append(transcription_words[idx])
|
|
239
|
+
|
|
240
|
+
payload = TranscriptionPayload(
|
|
241
|
+
_t="TranscriptionPayload", title=instruction, transcription=transcription_words
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
model_truth = TranscriptionTruth(
|
|
245
|
+
_t="TranscriptionTruth",
|
|
246
|
+
correctWords=true_words,
|
|
247
|
+
strictGrading=strict_grading,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self._rapid_parts.append(
|
|
251
|
+
ValidatioRapidParts(
|
|
252
|
+
instruction=instruction,
|
|
253
|
+
asset=asset,
|
|
254
|
+
payload=payload,
|
|
255
|
+
truths=model_truth,
|
|
256
|
+
metadata=metadata,
|
|
257
|
+
randomCorrectProbability = 1 / len(transcription_words),
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def __add_locate_rapid(
|
|
262
|
+
self,
|
|
263
|
+
asset: MediaAsset,
|
|
264
|
+
instruction: str,
|
|
265
|
+
truths: list[Box]
|
|
266
|
+
):
|
|
267
|
+
"""Add a locate rapid to the validation set.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
instruction (str): The instruction for the locate rapid.
|
|
271
|
+
asset (MediaAsset): The asset for the rapid.
|
|
272
|
+
truths (list[Box]): The truths for the rapid.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
276
|
+
"""
|
|
277
|
+
payload = LocatePayload(
|
|
278
|
+
_t="LocatePayload", target=instruction
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
img_dimensions = asset.get_image_dimension()
|
|
282
|
+
|
|
283
|
+
if not img_dimensions:
|
|
284
|
+
raise ValueError("Failed to get image dimensions")
|
|
285
|
+
|
|
286
|
+
model_truth = LocateBoxTruth(
|
|
287
|
+
_t="LocateBoxTruth",
|
|
288
|
+
boundingBoxes=[BoxShape(
|
|
289
|
+
_t="BoxShape",
|
|
290
|
+
xMin=truth.x_min / img_dimensions[0] * 100,
|
|
291
|
+
xMax=truth.x_max / img_dimensions[0] * 100,
|
|
292
|
+
yMax=truth.y_max / img_dimensions[1] * 100,
|
|
293
|
+
yMin=truth.y_min / img_dimensions[1] * 100,
|
|
294
|
+
) for truth in truths]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
|
|
298
|
+
|
|
299
|
+
self._rapid_parts.append(
|
|
300
|
+
ValidatioRapidParts(
|
|
301
|
+
instruction=instruction,
|
|
302
|
+
payload=payload,
|
|
303
|
+
truths=model_truth,
|
|
304
|
+
metadata=[],
|
|
305
|
+
randomCorrectProbability=coverage,
|
|
306
|
+
asset=asset,
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def __add_draw_rapid(
|
|
311
|
+
self,
|
|
312
|
+
asset: MediaAsset,
|
|
313
|
+
instruction: str,
|
|
314
|
+
truths: list[Box]
|
|
315
|
+
):
|
|
316
|
+
"""Add a draw rapid to the validation set.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
instruction (str): The instruction for the draw rapid.
|
|
320
|
+
asset (MediaAsset): The asset for the rapid.
|
|
321
|
+
truths (list[Box]): The truths for the rapid.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
payload = LinePayload(
|
|
328
|
+
_t="LinePayload", target=instruction
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
img_dimensions = asset.get_image_dimension()
|
|
332
|
+
|
|
333
|
+
if not img_dimensions:
|
|
334
|
+
raise ValueError("Failed to get image dimensions")
|
|
335
|
+
|
|
336
|
+
model_truth = BoundingBoxTruth(
|
|
337
|
+
_t="BoundingBoxTruth",
|
|
338
|
+
xMax=truths[0].x_max / img_dimensions[0],
|
|
339
|
+
xMin=truths[0].x_min / img_dimensions[0],
|
|
340
|
+
yMax=truths[0].y_max / img_dimensions[1],
|
|
341
|
+
yMin=truths[0].y_min / img_dimensions[1],
|
|
342
|
+
) # TO BE CHANGED BEFORE MERGING
|
|
343
|
+
|
|
344
|
+
coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
|
|
345
|
+
|
|
346
|
+
self._rapid_parts.append(
|
|
347
|
+
ValidatioRapidParts(
|
|
348
|
+
instruction=instruction,
|
|
349
|
+
payload=payload,
|
|
350
|
+
truths=model_truth,
|
|
351
|
+
metadata=[],
|
|
352
|
+
randomCorrectProbability=coverage,
|
|
353
|
+
asset=asset,
|
|
354
|
+
)
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _calculate_boxes_coverage(self, boxes: list[Box], image_width: int, image_height: int) -> float:
|
|
359
|
+
if not boxes:
|
|
360
|
+
return 0.0
|
|
361
|
+
|
|
362
|
+
# Convert all coordinates to integers for pixel-wise coverage
|
|
363
|
+
pixels = set()
|
|
364
|
+
for box in boxes:
|
|
365
|
+
for x in range(int(box.x_min), int(box.x_max + 1)):
|
|
366
|
+
for y in range(int(box.y_min), int(box.y_max + 1)):
|
|
367
|
+
if 0 <= x < image_width and 0 <= y < image_height:
|
|
368
|
+
pixels.add((x,y))
|
|
369
|
+
|
|
370
|
+
total_covered = len(pixels)
|
|
371
|
+
return total_covered / (image_width * image_height)
|