rapidata 1.10.1__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (142) hide show
  1. rapidata/__init__.py +22 -17
  2. rapidata/api_client/__init__.py +16 -5
  3. rapidata/api_client/api/coco_api.py +14 -29
  4. rapidata/api_client/api/dataset_api.py +6 -6
  5. rapidata/api_client/api/identity_api.py +3 -3
  6. rapidata/api_client/api/pipeline_api.py +1008 -95
  7. rapidata/api_client/api/rapid_api.py +6 -6
  8. rapidata/api_client/api/validation_api.py +12 -42
  9. rapidata/api_client/models/__init__.py +16 -5
  10. rapidata/api_client/models/add_campaign_model.py +5 -5
  11. rapidata/api_client/models/add_validation_text_rapid_model.py +1 -1
  12. rapidata/api_client/models/age_group.py +5 -4
  13. rapidata/api_client/models/base_error.py +1 -4
  14. rapidata/api_client/models/compare_workflow_config.py +9 -24
  15. rapidata/api_client/models/compare_workflow_config_model.py +9 -29
  16. rapidata/api_client/models/compare_workflow_config_model_pair_maker_config.py +140 -0
  17. rapidata/api_client/models/compare_workflow_config_pair_maker_config.py +140 -0
  18. rapidata/api_client/models/compare_workflow_model.py +7 -3
  19. rapidata/api_client/models/compare_workflow_model1.py +7 -3
  20. rapidata/api_client/models/compare_workflow_model1_pair_maker_information.py +140 -0
  21. rapidata/api_client/models/compare_workflow_model_pair_maker_config.py +140 -0
  22. rapidata/api_client/models/create_order_model.py +4 -2
  23. rapidata/api_client/models/create_order_model_user_filters_inner.py +25 -11
  24. rapidata/api_client/models/custom_user_filter_model.py +98 -0
  25. rapidata/api_client/models/file_asset_model_metadata_inner.py +8 -22
  26. rapidata/api_client/models/get_classify_workflow_result_overview_result.py +144 -0
  27. rapidata/api_client/models/get_pipeline_by_id_result.py +13 -3
  28. rapidata/api_client/models/identity_read_bridge_token_get202_response.py +140 -0
  29. rapidata/api_client/models/not_available_yet_result.py +96 -0
  30. rapidata/api_client/models/online_pair_maker_config.py +98 -0
  31. rapidata/api_client/models/online_pair_maker_config_model.py +98 -0
  32. rapidata/api_client/models/online_pair_maker_information.py +100 -0
  33. rapidata/api_client/models/pipeline_id_workflow_put_request.py +140 -0
  34. rapidata/api_client/models/pre_arranged_pair_maker_config.py +100 -0
  35. rapidata/api_client/models/pre_arranged_pair_maker_config_model.py +96 -0
  36. rapidata/api_client/models/pre_arranged_pair_maker_information.py +102 -0
  37. rapidata/api_client/models/read_bridge_token_keys_result.py +11 -2
  38. rapidata/api_client/models/simple_workflow_config.py +7 -26
  39. rapidata/api_client/models/simple_workflow_config_model.py +4 -28
  40. rapidata/api_client/models/simple_workflow_get_result_overview_get200_response.py +16 -16
  41. rapidata/api_client/models/simple_workflow_model1.py +3 -3
  42. rapidata/api_client/models/update_campaign_model.py +99 -0
  43. rapidata/api_client/models/validation_import_post_request_blueprint.py +1 -1
  44. rapidata/api_client_README.md +21 -7
  45. rapidata/rapidata_client/__init__.py +20 -10
  46. rapidata/rapidata_client/assets/__init__.py +5 -4
  47. rapidata/rapidata_client/assets/{media_asset.py → _media_asset.py} +32 -11
  48. rapidata/rapidata_client/assets/{multi_asset.py → _multi_asset.py} +1 -1
  49. rapidata/rapidata_client/assets/{text_asset.py → _text_asset.py} +1 -1
  50. rapidata/rapidata_client/assets/data_type_enum.py +7 -0
  51. rapidata/rapidata_client/filter/__init__.py +2 -1
  52. rapidata/rapidata_client/filter/_base_filter.py +10 -0
  53. rapidata/rapidata_client/filter/age_filter.py +12 -5
  54. rapidata/rapidata_client/filter/campaign_filter.py +12 -3
  55. rapidata/rapidata_client/filter/country_filter.py +10 -3
  56. rapidata/rapidata_client/filter/custom_filter.py +29 -0
  57. rapidata/rapidata_client/filter/gender_filter.py +12 -5
  58. rapidata/rapidata_client/filter/language_filter.py +14 -3
  59. rapidata/rapidata_client/filter/models/age_group.py +26 -0
  60. rapidata/rapidata_client/filter/models/gender.py +19 -0
  61. rapidata/rapidata_client/filter/rapidata_filters.py +35 -0
  62. rapidata/rapidata_client/filter/user_score_filter.py +20 -4
  63. rapidata/rapidata_client/metadata/__init__.py +5 -5
  64. rapidata/rapidata_client/metadata/{base_metadata.py → _base_metadata.py} +2 -1
  65. rapidata/rapidata_client/metadata/{private_text_metadata.py → _private_text_metadata.py} +2 -2
  66. rapidata/rapidata_client/metadata/{prompt_metadata.py → _prompt_metadata.py} +3 -2
  67. rapidata/rapidata_client/metadata/{public_text_metadata.py → _public_text_metadata.py} +2 -2
  68. rapidata/rapidata_client/metadata/{select_words_metadata.py → _select_words_metadata.py} +3 -2
  69. rapidata/rapidata_client/{dataset/rapidata_dataset.py → order/_rapidata_dataset.py} +7 -8
  70. rapidata/rapidata_client/order/_rapidata_order_builder.py +365 -0
  71. rapidata/rapidata_client/order/rapidata_order.py +50 -32
  72. rapidata/rapidata_client/order/rapidata_order_manager.py +461 -0
  73. rapidata/rapidata_client/rapidata_client.py +12 -201
  74. rapidata/rapidata_client/referee/__init__.py +3 -3
  75. rapidata/rapidata_client/referee/{base_referee.py → _base_referee.py} +3 -3
  76. rapidata/rapidata_client/referee/{early_stopping_referee.py → _early_stopping_referee.py} +14 -11
  77. rapidata/rapidata_client/referee/{naive_referee.py → _naive_referee.py} +9 -9
  78. rapidata/rapidata_client/selection/__init__.py +1 -1
  79. rapidata/rapidata_client/{filter/base_filter.py → selection/_base_selection.py} +2 -2
  80. rapidata/rapidata_client/selection/capped_selection.py +15 -5
  81. rapidata/rapidata_client/selection/conditional_validation_selection.py +17 -4
  82. rapidata/rapidata_client/selection/demographic_selection.py +18 -7
  83. rapidata/rapidata_client/selection/labeling_selection.py +10 -3
  84. rapidata/rapidata_client/selection/rapidata_selections.py +21 -0
  85. rapidata/rapidata_client/selection/validation_selection.py +11 -4
  86. rapidata/rapidata_client/settings/__init__.py +9 -2
  87. rapidata/rapidata_client/settings/_rapidata_setting.py +11 -0
  88. rapidata/rapidata_client/settings/alert_on_fast_response.py +21 -0
  89. rapidata/rapidata_client/settings/custom_setting.py +16 -0
  90. rapidata/rapidata_client/settings/free_text_minimum_characters.py +16 -0
  91. rapidata/rapidata_client/settings/models/__init__.py +1 -0
  92. rapidata/rapidata_client/settings/models/translation_behaviour_options.py +14 -0
  93. rapidata/rapidata_client/settings/no_shuffle.py +16 -0
  94. rapidata/rapidata_client/settings/play_video_until_the_end.py +16 -0
  95. rapidata/rapidata_client/settings/rapidata_settings.py +31 -0
  96. rapidata/rapidata_client/settings/translation_behaviour.py +18 -0
  97. rapidata/rapidata_client/validation/__init__.py +1 -0
  98. rapidata/rapidata_client/{dataset/validation_rapid_parts.py → validation/_validation_rapid_parts.py} +7 -6
  99. rapidata/rapidata_client/validation/_validation_set_builder.py +371 -0
  100. rapidata/rapidata_client/{dataset → validation}/rapidata_validation_set.py +54 -50
  101. rapidata/rapidata_client/validation/rapids/__init__.py +1 -0
  102. rapidata/rapidata_client/validation/rapids/box.py +17 -0
  103. rapidata/rapidata_client/validation/rapids/rapids.py +94 -0
  104. rapidata/rapidata_client/validation/rapids/rapids_manager.py +163 -0
  105. rapidata/rapidata_client/validation/validation_set_manager.py +338 -0
  106. rapidata/rapidata_client/workflow/__init__.py +8 -6
  107. rapidata/rapidata_client/workflow/_base_workflow.py +25 -0
  108. rapidata/rapidata_client/workflow/{classify_workflow.py → _classify_workflow.py} +6 -6
  109. rapidata/rapidata_client/workflow/{compare_workflow.py → _compare_workflow.py} +10 -16
  110. rapidata/rapidata_client/workflow/_draw_workflow.py +22 -0
  111. rapidata/rapidata_client/workflow/_evaluation_workflow.py +26 -0
  112. rapidata/rapidata_client/workflow/{free_text_workflow.py → _free_text_workflow.py} +10 -16
  113. rapidata/rapidata_client/workflow/_locate_workflow.py +22 -0
  114. rapidata/rapidata_client/workflow/{select_words_workflow.py → _select_words_workflow.py} +2 -8
  115. rapidata/service/credential_manager.py +11 -1
  116. rapidata/service/openapi_service.py +23 -4
  117. {rapidata-1.10.1.dist-info → rapidata-2.1.0.dist-info}/METADATA +2 -1
  118. {rapidata-1.10.1.dist-info → rapidata-2.1.0.dist-info}/RECORD +122 -96
  119. rapidata/constants.py +0 -1
  120. rapidata/rapidata_client/dataset/rapid_builders/__init__.py +0 -4
  121. rapidata/rapidata_client/dataset/rapid_builders/base_rapid_builder.py +0 -33
  122. rapidata/rapidata_client/dataset/rapid_builders/classify_rapid_builders.py +0 -166
  123. rapidata/rapidata_client/dataset/rapid_builders/compare_rapid_builders.py +0 -145
  124. rapidata/rapidata_client/dataset/rapid_builders/rapids.py +0 -33
  125. rapidata/rapidata_client/dataset/rapid_builders/select_words_rapid_builders.py +0 -124
  126. rapidata/rapidata_client/dataset/validation_set_builder.py +0 -336
  127. rapidata/rapidata_client/order/order_builder.py +0 -25
  128. rapidata/rapidata_client/order/rapidata_order_builder.py +0 -463
  129. rapidata/rapidata_client/selection/base_selection.py +0 -9
  130. rapidata/rapidata_client/settings/feature_flags.py +0 -125
  131. rapidata/rapidata_client/settings/settings.py +0 -124
  132. rapidata/rapidata_client/simple_builders/__init__.py +0 -0
  133. rapidata/rapidata_client/simple_builders/simple_classification_builders.py +0 -271
  134. rapidata/rapidata_client/simple_builders/simple_compare_builders.py +0 -267
  135. rapidata/rapidata_client/simple_builders/simple_free_text_builders.py +0 -192
  136. rapidata/rapidata_client/simple_builders/simple_select_words_builders.py +0 -196
  137. rapidata/rapidata_client/workflow/base_workflow.py +0 -42
  138. rapidata/rapidata_client/workflow/evaluation_workflow.py +0 -15
  139. /rapidata/rapidata_client/assets/{base_asset.py → _base_asset.py} +0 -0
  140. /rapidata/rapidata_client/{dataset → filter/models}/__init__.py +0 -0
  141. {rapidata-1.10.1.dist-info → rapidata-2.1.0.dist-info}/LICENSE +0 -0
  142. {rapidata-1.10.1.dist-info → rapidata-2.1.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,14 @@
1
+ from enum import Enum
2
+
3
+ class TranslationBehaviourOptions(Enum):
4
+ """The options for the translation behaviour setting.
5
+
6
+ Attributes:
7
+ BOTH: Show both the original and the translated text.
8
+ May clutter the screen if the options are too long.
9
+ ONLY_ORIGINAL: Show only the original text.
10
+ ONLY_TRANSLATED: Show only the translated text."""
11
+
12
+ BOTH = "both"
13
+ ONLY_ORIGINAL = "only original"
14
+ ONLY_TRANSLATED = "only translated"
@@ -0,0 +1,16 @@
1
+ from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
2
+
3
+ class NoShuffle(RapidataSetting):
4
+ """
5
+ Only for classify tasks. If true, the order of the categories will be the same.
6
+
7
+ If this is not added to the order, it shuffling will be active.
8
+
9
+ Args:
10
+ value (bool, optional): Whether to disable shuffling. Defaults to True for function call.
11
+ """
12
+ def __init__(self, value: bool = True):
13
+ if not isinstance(value, bool):
14
+ raise ValueError("The value must be a boolean.")
15
+
16
+ super().__init__(key="no_shuffle", value=value)
@@ -0,0 +1,16 @@
1
+ from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
2
+
3
+ class PlayVideoUntilTheEnd(RapidataSetting):
4
+ """
5
+ Allows users to only answer once the video has finished playing.
6
+ The additional time gets added on top of the video duration. Can be negative to allow answers before the video ends.
7
+
8
+ Args:
9
+ additional_time (int, optional): Additional time in milliseconds. Defaults to 0.
10
+ """
11
+
12
+ def __init__(self, additional_time: int = 0):
13
+ if additional_time < -25000 or additional_time > 25000:
14
+ raise ValueError("The additional time must be between -25000 and 25000.")
15
+
16
+ super().__init__(key="alert_on_fast_response_add_media_duration", value=additional_time)
@@ -0,0 +1,31 @@
1
+ from rapidata.rapidata_client.settings import (
2
+ AlertOnFastResponse,
3
+ TranslationBehaviour,
4
+ FreeTextMinimumCharacters,
5
+ NoShuffle,
6
+ PlayVideoUntilTheEnd,
7
+ CustomSetting,
8
+ )
9
+
10
+ class RapidataSettings:
11
+ """
12
+ Container class for all setting factory functions
13
+
14
+ Settings can be added to an order to determine the behaviour of the task.
15
+
16
+ Attributes:
17
+ alert_on_fast_response (AlertOnFastResponse): The AlertOnFastResponse instance.
18
+ translation_behaviour (TranslationBehaviour): The TranslationBehaviour instance.
19
+ free_text_minimum_characters (FreeTextMinimumCharacters): The FreeTextMinimumCharacters instance.
20
+ no_shuffle (NoShuffle): The NoShuffle instance.
21
+ play_video_until_the_end (PlayVideoUntilTheEnd): The PlayVideoUntilTheEnd instance.
22
+ custom_setting (CustomSetting): The CustomSetting instance.
23
+ """
24
+
25
+ alert_on_fast_response = AlertOnFastResponse
26
+ translation_behaviour = TranslationBehaviour
27
+ free_text_minimum_characters = FreeTextMinimumCharacters
28
+ no_shuffle = NoShuffle
29
+ play_video_until_the_end = PlayVideoUntilTheEnd
30
+ custom_setting = CustomSetting
31
+
@@ -0,0 +1,18 @@
1
+ from rapidata.rapidata_client.settings.models.translation_behaviour_options import TranslationBehaviourOptions
2
+ from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
3
+
4
+ class TranslationBehaviour(RapidataSetting):
5
+ """
6
+ Defines what's the behaviour of the translation in the UI.
7
+ Will not translate text datapoints or sentences.
8
+
9
+ Args:
10
+ value (TranslationBehaviourOptions): The translation behaviour.
11
+ """
12
+
13
+ def __init__(self, value: TranslationBehaviourOptions):
14
+ if not isinstance(value, TranslationBehaviourOptions):
15
+ raise ValueError("The value must be a TranslationBehaviourOptions.")
16
+
17
+ super().__init__(key="translation_behaviour", value=value)
18
+
@@ -0,0 +1 @@
1
+ from .rapids import Box
@@ -19,15 +19,16 @@ from rapidata.api_client.models.polygon_payload import PolygonPayload
19
19
  from rapidata.api_client.models.polygon_truth import PolygonTruth
20
20
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
21
21
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
22
- from rapidata.rapidata_client.assets.media_asset import MediaAsset
23
- from rapidata.rapidata_client.assets.multi_asset import MultiAsset
24
- from rapidata.rapidata_client.assets.text_asset import TextAsset
25
- from rapidata.rapidata_client.metadata.base_metadata import Metadata
22
+ from rapidata.rapidata_client.assets._media_asset import MediaAsset
23
+ from rapidata.rapidata_client.assets._multi_asset import MultiAsset
24
+ from rapidata.rapidata_client.assets._text_asset import TextAsset
25
+ from rapidata.rapidata_client.metadata._base_metadata import Metadata
26
+ from typing import Sequence
26
27
 
27
28
 
28
29
  @dataclass
29
30
  class ValidatioRapidParts:
30
- question: str
31
+ instruction: str
31
32
  asset: MediaAsset | TextAsset | MultiAsset
32
33
  payload: (
33
34
  BoundingBoxPayload
@@ -51,5 +52,5 @@ class ValidatioRapidParts:
51
52
  | PolygonTruth
52
53
  | TranscriptionTruth
53
54
  )
54
- metadata: list[Metadata]
55
+ metadata: Sequence[Metadata]
55
56
  randomCorrectProbability: float
@@ -0,0 +1,371 @@
1
+ import os
2
+ from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
3
+ from rapidata.api_client.models.classify_payload import ClassifyPayload
4
+ from rapidata.api_client.models.compare_payload import ComparePayload
5
+ from rapidata.api_client.models.compare_truth import CompareTruth
6
+ from rapidata.api_client.models.transcription_payload import TranscriptionPayload
7
+ from rapidata.api_client.models.transcription_truth import TranscriptionTruth
8
+ from rapidata.api_client.models.transcription_word import TranscriptionWord
9
+ from rapidata.api_client.models.locate_payload import LocatePayload
10
+ from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
11
+ from rapidata.api_client.models.line_payload import LinePayload
12
+ from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
13
+ from rapidata.api_client.models.box_shape import BoxShape
14
+ from rapidata.rapidata_client.validation.rapidata_validation_set import (
15
+ RapidataValidationSet,
16
+ )
17
+ from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
18
+ from rapidata.rapidata_client.validation._validation_rapid_parts import ValidatioRapidParts
19
+ from rapidata.rapidata_client.metadata._base_metadata import Metadata
20
+ from rapidata.service.openapi_service import OpenAPIService
21
+ from rapidata.rapidata_client.validation.rapids.box import Box
22
+
23
+ from rapidata.rapidata_client.validation.rapids.rapids import (
24
+ Rapid,
25
+ ClassificationRapid,
26
+ CompareRapid,
27
+ SelectWordsRapid,
28
+ LocateRapid,
29
+ DrawRapid
30
+ )
31
+ from typing import Sequence
32
+
33
+
34
+ class ValidationSetBuilder:
35
+ """The ValidationSetBuilder is used to build a validation set.
36
+ Give the validation set a name and then add classify, compare, or transcription rapid parts to it.
37
+ Get a `ValidationSetBuilder` by calling [`rapi.new_validation_set()`](../rapidata_client.md/#rapidata.rapidata_client.rapidata_client.RapidataClient.new_validation_set).
38
+
39
+ Args:
40
+ name (str): The name of the validation set.
41
+ openapi_service (OpenAPIService): An instance of OpenAPIService to interact with the API.
42
+ """
43
+
44
+ def __init__(self, name: str, openapi_service: OpenAPIService):
45
+ self.name = name
46
+ self.openapi_service = openapi_service
47
+ self.validation_set_id: str | None = None
48
+ self._rapid_parts: list[ValidatioRapidParts] = []
49
+
50
+ def _submit(self, print_confirmation: bool = True) -> RapidataValidationSet:
51
+ """Create the validation set by executing all HTTP requests. This should be the last method called on the builder.
52
+
53
+ Returns:
54
+ RapidataValidationSet: A RapidataValidationSet instance.
55
+
56
+ Raises:
57
+ ValueError: If the validation set creation fails.
58
+ """
59
+ result = (
60
+ self.openapi_service.validation_api.validation_create_validation_set_post(
61
+ name=self.name
62
+ )
63
+ )
64
+ self.validation_set_id = result.validation_set_id
65
+
66
+ if self.validation_set_id is None:
67
+ raise ValueError("Failed to create validation set")
68
+
69
+ validation_set = RapidataValidationSet(
70
+ validation_set_id=self.validation_set_id,
71
+ openapi_service=self.openapi_service,
72
+ name=self.name,
73
+ )
74
+
75
+ for rapid_part in self._rapid_parts:
76
+ validation_set._add_general_validation_rapid(
77
+ payload=rapid_part.payload,
78
+ truths=rapid_part.truths,
79
+ metadata=rapid_part.metadata,
80
+ asset=rapid_part.asset,
81
+ randomCorrectProbability=rapid_part.randomCorrectProbability,
82
+ )
83
+
84
+ if print_confirmation:
85
+ print(f"Validation set '{self.name}' created with ID {self.validation_set_id}")
86
+
87
+ return validation_set
88
+
89
+ def _add_rapid(self, rapid: Rapid):
90
+ """Add a rapid to the validation set.
91
+ To create the Rapid, use the RapidataClient.rapid_builder instance.
92
+
93
+ Args:
94
+ rapid (Rapid): The rapid to add to the validation set.
95
+ """
96
+ if not isinstance(rapid, Rapid):
97
+ raise ValueError("This method only accepts Rapid instances")
98
+
99
+ elif isinstance(rapid, ClassificationRapid):
100
+ self.__add_classify_rapid(rapid.asset, rapid.instruction, rapid.answer_options, rapid.truths, rapid.metadata)
101
+
102
+ elif isinstance(rapid, CompareRapid):
103
+ self.__add_compare_rapid(rapid.asset, rapid.instruction, rapid.truth, rapid.metadata)
104
+
105
+ elif isinstance(rapid, SelectWordsRapid):
106
+ self.__add_select_words_rapid(rapid.asset, rapid.instruction, rapid.sentence, rapid.truths, rapid.strict_grading)
107
+
108
+ elif isinstance(rapid, LocateRapid):
109
+ self.__add_locate_rapid(rapid.asset, rapid.instruction, rapid.truths)
110
+
111
+ elif isinstance(rapid, DrawRapid):
112
+ self.__add_draw_rapid(rapid.asset, rapid.instruction, rapid.truths)
113
+
114
+ else:
115
+ raise ValueError("Unsupported rapid type")
116
+
117
+ return self
118
+
119
+ def __add_classify_rapid(
120
+ self,
121
+ asset: MediaAsset | TextAsset,
122
+ instruction: str,
123
+ answer_options: list[str],
124
+ truths: list[str],
125
+ metadata: Sequence[Metadata] = [],
126
+ ):
127
+ """Add a classify rapid to the validation set.
128
+
129
+ Args:
130
+ asset (MediaAsset | TextAsset): The asset for the rapid.
131
+ instruction (str): The instruction for the rapid.
132
+ answer_options (list[str]): The list of answer_options for the rapid.
133
+ truths (list[str]): The list of truths for the rapid.
134
+ metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
135
+
136
+ Returns:
137
+ ValidationSetBuilder: The ValidationSetBuilder instance.
138
+
139
+ Raises:
140
+ ValueError: If the lengths of categories and truths are inconsistent.
141
+ """
142
+ if not all(truth in answer_options for truth in truths):
143
+ raise ValueError("Truths must be part of the answer options")
144
+
145
+ payload = ClassifyPayload(
146
+ _t="ClassifyPayload", possibleCategories=answer_options, title=instruction
147
+ )
148
+ model_truth = AttachCategoryTruth(
149
+ correctCategories=truths, _t="AttachCategoryTruth"
150
+ )
151
+
152
+ self._rapid_parts.append(
153
+ ValidatioRapidParts(
154
+ instruction=instruction,
155
+ payload=payload,
156
+ truths=model_truth,
157
+ metadata=metadata,
158
+ randomCorrectProbability=len(truths) / len(answer_options),
159
+ asset=asset,
160
+ )
161
+ )
162
+
163
+ def __add_compare_rapid(
164
+ self,
165
+ asset: MultiAsset,
166
+ instruction: str,
167
+ truth: str,
168
+ metadata: Sequence[Metadata] = [],
169
+ ):
170
+ """Add a compare rapid to the validation set.
171
+
172
+ Args:
173
+ asset (MultiAsset): The assets for the rapid.
174
+ instruction (str): The instruction for the comparison.
175
+ truth (str): The truth identifier for the rapid.
176
+ metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
177
+
178
+ Returns:
179
+ ValidationSetBuilder: The ValidationSetBuilder instance.
180
+
181
+ Raises:
182
+ ValueError: If the number of assets is not exactly two.
183
+ """
184
+ payload = ComparePayload(_t="ComparePayload", criteria=instruction)
185
+ # take only last part of truth path
186
+ truth = os.path.basename(truth)
187
+ model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
188
+
189
+ if len(asset) != 2:
190
+ raise ValueError("Compare rapid requires exactly two media paths")
191
+
192
+ self._rapid_parts.append(
193
+ ValidatioRapidParts(
194
+ instruction=instruction,
195
+ payload=payload,
196
+ truths=model_truth,
197
+ metadata=metadata,
198
+ randomCorrectProbability=1 / len(asset),
199
+ asset=asset,
200
+ )
201
+ )
202
+
203
+ def __add_select_words_rapid(
204
+ self,
205
+ asset: MediaAsset | TextAsset,
206
+ instruction: str,
207
+ select_words: str,
208
+ truths: list[int],
209
+ strict_grading: bool | None = None,
210
+ metadata: Sequence[Metadata] = [],
211
+ ):
212
+ """Add a select words rapid to the validation set.
213
+
214
+ Args:
215
+ asset (MediaAsset | TextAsset): The asset for the rapid.
216
+ instruction (str): The instruction for the rapid.
217
+ select words (list[str]): The select words for the rapid.
218
+ truths (list[int]): The list of indices of the true word selections.
219
+ strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
220
+ metadata (Sequence[Metadata], optional): The metadata for the rapid.
221
+
222
+ Returns:
223
+ ValidationSetBuilder: The ValidationSetBuilder instance.
224
+
225
+ Raises:
226
+ ValueError: If a correct word is not found in the select words.
227
+ """
228
+ transcription_words = [
229
+ TranscriptionWord(word=word, wordIndex=i)
230
+ for i, word in enumerate(select_words.split())
231
+ ]
232
+
233
+ true_words = []
234
+ for idx in truths:
235
+ assert isinstance(idx, int), "truths must be a list of integers"
236
+ if idx > len(transcription_words) - 1:
237
+ raise ValueError(f"Index {idx} is out of bounds")
238
+ true_words.append(transcription_words[idx])
239
+
240
+ payload = TranscriptionPayload(
241
+ _t="TranscriptionPayload", title=instruction, transcription=transcription_words
242
+ )
243
+
244
+ model_truth = TranscriptionTruth(
245
+ _t="TranscriptionTruth",
246
+ correctWords=true_words,
247
+ strictGrading=strict_grading,
248
+ )
249
+
250
+ self._rapid_parts.append(
251
+ ValidatioRapidParts(
252
+ instruction=instruction,
253
+ asset=asset,
254
+ payload=payload,
255
+ truths=model_truth,
256
+ metadata=metadata,
257
+ randomCorrectProbability = 1 / len(transcription_words),
258
+ )
259
+ )
260
+
261
+ def __add_locate_rapid(
262
+ self,
263
+ asset: MediaAsset,
264
+ instruction: str,
265
+ truths: list[Box]
266
+ ):
267
+ """Add a locate rapid to the validation set.
268
+
269
+ Args:
270
+ instruction (str): The instruction for the locate rapid.
271
+ asset (MediaAsset): The asset for the rapid.
272
+ truths (list[Box]): The truths for the rapid.
273
+
274
+ Returns:
275
+ ValidationSetBuilder: The ValidationSetBuilder instance.
276
+ """
277
+ payload = LocatePayload(
278
+ _t="LocatePayload", target=instruction
279
+ )
280
+
281
+ img_dimensions = asset.get_image_dimension()
282
+
283
+ if not img_dimensions:
284
+ raise ValueError("Failed to get image dimensions")
285
+
286
+ model_truth = LocateBoxTruth(
287
+ _t="LocateBoxTruth",
288
+ boundingBoxes=[BoxShape(
289
+ _t="BoxShape",
290
+ xMin=truth.x_min / img_dimensions[0] * 100,
291
+ xMax=truth.x_max / img_dimensions[0] * 100,
292
+ yMax=truth.y_max / img_dimensions[1] * 100,
293
+ yMin=truth.y_min / img_dimensions[1] * 100,
294
+ ) for truth in truths]
295
+ )
296
+
297
+ coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
298
+
299
+ self._rapid_parts.append(
300
+ ValidatioRapidParts(
301
+ instruction=instruction,
302
+ payload=payload,
303
+ truths=model_truth,
304
+ metadata=[],
305
+ randomCorrectProbability=coverage,
306
+ asset=asset,
307
+ )
308
+ )
309
+
310
+ def __add_draw_rapid(
311
+ self,
312
+ asset: MediaAsset,
313
+ instruction: str,
314
+ truths: list[Box]
315
+ ):
316
+ """Add a draw rapid to the validation set.
317
+
318
+ Args:
319
+ instruction (str): The instruction for the draw rapid.
320
+ asset (MediaAsset): The asset for the rapid.
321
+ truths (list[Box]): The truths for the rapid.
322
+
323
+ Returns:
324
+ ValidationSetBuilder: The ValidationSetBuilder instance.
325
+ """
326
+
327
+ payload = LinePayload(
328
+ _t="LinePayload", target=instruction
329
+ )
330
+
331
+ img_dimensions = asset.get_image_dimension()
332
+
333
+ if not img_dimensions:
334
+ raise ValueError("Failed to get image dimensions")
335
+
336
+ model_truth = BoundingBoxTruth(
337
+ _t="BoundingBoxTruth",
338
+ xMax=truths[0].x_max / img_dimensions[0],
339
+ xMin=truths[0].x_min / img_dimensions[0],
340
+ yMax=truths[0].y_max / img_dimensions[1],
341
+ yMin=truths[0].y_min / img_dimensions[1],
342
+ ) # TO BE CHANGED BEFORE MERGING
343
+
344
+ coverage = self._calculate_boxes_coverage(truths, img_dimensions[0], img_dimensions[1])
345
+
346
+ self._rapid_parts.append(
347
+ ValidatioRapidParts(
348
+ instruction=instruction,
349
+ payload=payload,
350
+ truths=model_truth,
351
+ metadata=[],
352
+ randomCorrectProbability=coverage,
353
+ asset=asset,
354
+ )
355
+ )
356
+
357
+
358
+ def _calculate_boxes_coverage(self, boxes: list[Box], image_width: int, image_height: int) -> float:
359
+ if not boxes:
360
+ return 0.0
361
+
362
+ # Convert all coordinates to integers for pixel-wise coverage
363
+ pixels = set()
364
+ for box in boxes:
365
+ for x in range(int(box.x_min), int(box.x_max + 1)):
366
+ for y in range(int(box.y_min), int(box.y_max + 1)):
367
+ if 0 <= x < image_width and 0 <= y < image_height:
368
+ pixels.add((x,y))
369
+
370
+ total_covered = len(pixels)
371
+ return total_covered / (image_width * image_height)