rapidata 2.3.2__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,8 @@ from PIL import Image
12
12
  from tinytag import TinyTag
13
13
  import tempfile
14
14
  from pydantic import StrictStr, StrictBytes
15
+ from typing import Optional
16
+ import logging
15
17
 
16
18
 
17
19
  class MediaAsset(BaseAsset):
@@ -26,6 +28,7 @@ class MediaAsset(BaseAsset):
26
28
  Raises:
27
29
  FileNotFoundError: If the provided file path does not exist.
28
30
  """
31
+ _logger = logging.getLogger(__name__ + '.MediaAsset')
29
32
 
30
33
  ALLOWED_TYPES = [
31
34
  'image/',
@@ -33,6 +36,28 @@ class MediaAsset(BaseAsset):
33
36
  'video/mp4', # MP4
34
37
  ]
35
38
 
39
+ MIME_TYPES = {
40
+ 'jpg': 'image/jpeg',
41
+ 'jpeg': 'image/jpeg',
42
+ 'png': 'image/png',
43
+ 'gif': 'image/gif',
44
+ 'webp': 'image/webp',
45
+ 'mp3': 'audio/mp3',
46
+ 'mp4': 'video/mp4'
47
+ }
48
+
49
+ FILE_SIGNATURES = {
50
+ b'\xFF\xD8\xFF': 'image/jpeg',
51
+ b'\x89PNG\r\n\x1a\n': 'image/png',
52
+ b'GIF87a': 'image/gif',
53
+ b'GIF89a': 'image/gif',
54
+ b'RIFF': 'image/webp',
55
+ b'ID3': 'audio/mp3',
56
+ b'\xFF\xFB': 'audio/mp3',
57
+ b'\xFF\xF3': 'audio/mp3',
58
+ b'ftyp': 'video/mp4',
59
+ }
60
+
36
61
  def __init__(self, path: str):
37
62
  """
38
63
  Initialize a MediaAsset instance.
@@ -134,38 +159,113 @@ class MediaAsset(BaseAsset):
134
159
  name = name + '.jpg'
135
160
  return name
136
161
 
162
+ def __get_media_type_from_extension(self, url: str) -> Optional[str]:
163
+ """
164
+ Determine media type from URL file extension.
165
+
166
+ Args:
167
+ url: The URL to check
168
+
169
+ Returns:
170
+ Optional[str]: MIME type if valid extension found, None otherwise
171
+ """
172
+ try:
173
+ ext = url.lower().split('?')[0].split('.')[-1]
174
+ return self.MIME_TYPES.get(ext)
175
+ except IndexError:
176
+ return None
177
+
178
+ def __validate_image_content(self, content: bytes) -> bool:
179
+ """
180
+ Validate image content using PIL.
181
+
182
+ Args:
183
+ content: Image bytes to validate
184
+
185
+ Returns:
186
+ bool: True if valid image, False otherwise
187
+ """
188
+ try:
189
+ img = Image.open(BytesIO(content))
190
+ img.verify()
191
+ return True
192
+ except Exception as e:
193
+ self._logger.debug(f"Image validation failed: {str(e)}")
194
+ return False
195
+
196
+ def __get_media_type_from_signature(self, content: bytes) -> Optional[str]:
197
+ """
198
+ Determine media type from file signature.
199
+
200
+ Args:
201
+ content: File content bytes
202
+
203
+ Returns:
204
+ Optional[str]: MIME type if valid signature found, None otherwise
205
+ """
206
+ file_start = content[:32]
207
+ for signature, mime_type in self.FILE_SIGNATURES.items():
208
+ if file_start.startswith(signature) or (signature in file_start[:10]):
209
+ return mime_type
210
+ return None
211
+
137
212
  def __get_media_bytes(self, url: str) -> bytes:
138
213
  """
139
- Downloads media files from URL and validates type and duration.
214
+ Downloads and validates media files from URL.
140
215
 
141
216
  Args:
142
217
  url: URL of the media file
143
-
218
+
144
219
  Returns:
145
- bytes: Media data
220
+ bytes: Validated media content
146
221
 
147
222
  Raises:
148
- ValueError: If media type is unsupported or duration exceeds limit
223
+ ValueError: If media type is unsupported or content validation fails
149
224
  requests.exceptions.RequestException: If download fails
150
225
  """
151
- response = requests.get(url, stream=False) # Don't stream, we need full file
152
- response.raise_for_status()
226
+ try:
227
+ response = requests.get(url, stream=False)
228
+ response.raise_for_status()
229
+ except requests.exceptions.RequestException as e:
230
+ self._logger.error(f"Failed to download media from {url}: {str(e)}")
231
+ raise
153
232
 
233
+ content = response.content
154
234
  content_type = response.headers.get('content-type', '').lower()
155
-
156
- # Validate content type
157
- if not any(content_type.startswith(t) for t in self.ALLOWED_TYPES):
158
- raise ValueError(
159
- f'URL does not point to an allowed media type.\n'
160
- f'Content-Type: {content_type}\n'
161
- f'Allowed types: {self.ALLOWED_TYPES}'
162
- )
163
-
164
- content = BytesIO(response.content)
165
- return content.getvalue()
235
+
236
+ # Case 1: Content-type is already allowed
237
+ if any(content_type.startswith(t) for t in self.ALLOWED_TYPES):
238
+ self._logger.debug(f"Content-type {content_type} is allowed")
239
+ return content
240
+
241
+ # Case 2: Try to validate based on extension
242
+ mime_type = self.__get_media_type_from_extension(url)
243
+ if mime_type and mime_type.startswith(tuple(self.ALLOWED_TYPES)):
244
+ self._logger.debug(f"Found valid mime type from extension: {mime_type}")
245
+ return content
246
+
247
+ # Case 3: Try to validate based on file signature
248
+ mime_type = self.__get_media_type_from_signature(content)
249
+ if mime_type and mime_type.startswith(tuple(self.ALLOWED_TYPES)):
250
+ self._logger.debug(f"Found valid mime type from signature: {mime_type}")
251
+ return content
252
+
253
+ # Case 4: Last resort - try direct image validation
254
+ if self.__validate_image_content(content):
255
+ self._logger.debug("Content validated as image through direct validation")
256
+ return content
257
+
258
+ # If we get here, validation failed
259
+ error_msg = (
260
+ f'Could not validate media type from content.\n'
261
+ f'Content-Type: {content_type}\n'
262
+ f'URL extension: {url.split("?")[0].split(".")[-1]}\n'
263
+ f'Allowed types: {self.ALLOWED_TYPES}'
264
+ )
265
+ self._logger.error(error_msg)
266
+ raise ValueError(error_msg)
166
267
 
167
268
  def to_file(self) -> StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes: # types for autogenerated models
168
- files = []
169
269
  if isinstance(self.path, str):
170
270
  return self.path
171
271
  else: # isinstance(self.path, bytes)
@@ -8,7 +8,16 @@ from rapidata.rapidata_client.selection import (
8
8
  class RapidataSelections:
9
9
  """RapidataSelections Classes
10
10
 
11
- Selections are used to define what type of tasks and in what order they are shown to the user.
11
+ Selections are used to define what type of tasks and in what order they are shown to the user.
12
+ All Tasks are called a "Session". A session can contain multiple tasks of different types.
13
+
14
+ Example:
15
+ ```python
16
+ selections=[ValidationSelection("your-validation-set-id", 1),
17
+ LabelingSelection(2)]
18
+ ```
19
+
20
+ The above example will create a session with a validation task followed by two labeling tasks.
12
21
 
13
22
  Attributes:
14
23
  labeling (LabelingSelection): The LabelingSelection instance.
@@ -16,6 +25,7 @@ class RapidataSelections:
16
25
  conditional_validation (ConditionalValidationSelection): The ConditionalValidationSelection instance.
17
26
  demographic (DemographicSelection): The DemographicSelection instance.
18
27
  capped (CappedSelection): The CappedSelection instance."""
28
+
19
29
  labeling = LabelingSelection
20
30
  validation = ValidationSelection
21
31
  conditional_validation = ConditionalValidationSelection
@@ -1,50 +1,3 @@
1
- import os
2
- from typing import Any
3
- from rapidata.api_client.models.add_validation_rapid_model import (
4
- AddValidationRapidModel,
5
- )
6
- from rapidata.api_client.models.add_validation_text_rapid_model import (
7
- AddValidationTextRapidModel,
8
- )
9
- from rapidata.api_client.models.add_validation_rapid_model_payload import (
10
- AddValidationRapidModelPayload,
11
- )
12
- from rapidata.api_client.models.add_validation_rapid_model_truth import (
13
- AddValidationRapidModelTruth,
14
- )
15
- from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
16
- from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload
17
- from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
18
- from rapidata.api_client.models.classify_payload import ClassifyPayload
19
- from rapidata.api_client.models.compare_payload import ComparePayload
20
- from rapidata.api_client.models.compare_truth import CompareTruth
21
- from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import (
22
- DatapointMetadataModelMetadataInner,
23
- )
24
- from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth
25
- from rapidata.api_client.models.free_text_payload import FreeTextPayload
26
- from rapidata.api_client.models.line_payload import LinePayload
27
- from rapidata.api_client.models.line_truth import LineTruth
28
- from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
29
- from rapidata.api_client.models.locate_payload import LocatePayload
30
- from rapidata.api_client.models.named_entity_payload import NamedEntityPayload
31
- from rapidata.api_client.models.named_entity_truth import NamedEntityTruth
32
- from rapidata.api_client.models.polygon_payload import PolygonPayload
33
- from rapidata.api_client.models.polygon_truth import PolygonTruth
34
- from rapidata.api_client.models.transcription_payload import TranscriptionPayload
35
- from rapidata.api_client.models.transcription_truth import TranscriptionTruth
36
- from rapidata.api_client.models.transcription_word import TranscriptionWord
37
- from rapidata.api_client.models.scrub_payload import ScrubPayload
38
- from rapidata.api_client.models.scrub_truth import ScrubTruth
39
- from rapidata.rapidata_client.assets._media_asset import MediaAsset
40
- from rapidata.rapidata_client.assets._multi_asset import MultiAsset
41
- from rapidata.rapidata_client.assets._text_asset import TextAsset
42
- from rapidata.rapidata_client.metadata._base_metadata import Metadata
43
- from rapidata.service.openapi_service import OpenAPIService
44
-
45
- from typing import Sequence
46
-
47
-
48
1
  class RapidataValidationSet:
49
2
  """A class for interacting with a Rapidata validation set.
50
3
 
@@ -57,257 +10,12 @@ class RapidataValidationSet:
57
10
  name (str): The name of the validation set.
58
11
  """
59
12
 
60
- def __init__(self, validation_set_id, openapi_service: OpenAPIService, name: str):
13
+ def __init__(self, validation_set_id, name: str):
61
14
  self.id = validation_set_id
62
15
  self.name = name
63
- self.__openapi_service = openapi_service
64
-
65
- def __upload_files(self, model: AddValidationRapidModel, assets: list[MediaAsset]):
66
- """Upload a file to the validation set.
67
-
68
- Args:
69
- assets: list[(MediaAsset)]: The asset to upload.
70
- """
71
- files = []
72
- for asset in assets:
73
- files.append(asset.to_file())
74
-
75
- self.__openapi_service.validation_api.validation_add_validation_rapid_post(
76
- model=model, files=files
77
- )
78
-
79
- def _add_general_validation_rapid(
80
- self,
81
- payload: (
82
- BoundingBoxPayload
83
- | ClassifyPayload
84
- | ComparePayload
85
- | FreeTextPayload
86
- | LinePayload
87
- | LocatePayload
88
- | NamedEntityPayload
89
- | PolygonPayload
90
- | TranscriptionPayload
91
- | ScrubPayload
92
- ),
93
- truths: (
94
- AttachCategoryTruth
95
- | BoundingBoxTruth
96
- | CompareTruth
97
- | EmptyValidationTruth
98
- | LineTruth
99
- | LocateBoxTruth
100
- | NamedEntityTruth
101
- | PolygonTruth
102
- | TranscriptionTruth
103
- | ScrubTruth
104
- ),
105
- metadata: Sequence[Metadata],
106
- asset: MediaAsset | TextAsset | MultiAsset,
107
- randomCorrectProbability: float,
108
- ) -> None:
109
- """Add a validation rapid to the validation set.
110
-
111
- Args:
112
- payload: The payload for the rapid.
113
- truths: The truths for the rapid.
114
- metadata (list[Metadata]): The metadata for the rapid.
115
- asset: The asset(s) for the rapid.
116
- randomCorrectProbability (float): The random correct probability for the rapid.
117
-
118
- Returns:
119
- None
120
-
121
- Raises:
122
- ValueError: If an invalid asset type is provided.
123
- """
124
-
125
- model = AddValidationRapidModel(
126
- validationSetId=self.id,
127
- payload=AddValidationRapidModelPayload(payload),
128
- truth=AddValidationRapidModelTruth(truths),
129
- metadata=[
130
- DatapointMetadataModelMetadataInner(meta._to_model())
131
- for meta in metadata
132
- ],
133
- randomCorrectProbability=randomCorrectProbability,
134
- )
135
- if isinstance(asset, MediaAsset):
136
- self.__upload_files(model=model, assets=[asset])
137
-
138
- elif isinstance(asset, TextAsset):
139
- model = AddValidationTextRapidModel(
140
- validationSetId=self.id,
141
- payload=AddValidationRapidModelPayload(payload),
142
- truth=AddValidationRapidModelTruth(truths),
143
- metadata=[
144
- DatapointMetadataModelMetadataInner(meta._to_model())
145
- for meta in metadata
146
- ],
147
- randomCorrectProbability=randomCorrectProbability,
148
- texts=[asset.text],
149
- )
150
- self.__openapi_service.validation_api.validation_add_validation_text_rapid_post(
151
- add_validation_text_rapid_model=model
152
- )
153
-
154
- elif isinstance(asset, MultiAsset):
155
- files = [a for a in asset if isinstance(a, MediaAsset)]
156
- texts = [a.text for a in asset if isinstance(a, TextAsset)]
157
- if files:
158
- self.__upload_files(model=model, assets=files)
159
- if texts:
160
- model = AddValidationTextRapidModel(
161
- validationSetId=self.id,
162
- payload=AddValidationRapidModelPayload(payload),
163
- truth=AddValidationRapidModelTruth(truths),
164
- metadata=[
165
- DatapointMetadataModelMetadataInner(meta._to_model())
166
- for meta in metadata
167
- ],
168
- randomCorrectProbability=randomCorrectProbability,
169
- texts=texts,
170
- )
171
- self.__openapi_service.validation_api.validation_add_validation_text_rapid_post(
172
- add_validation_text_rapid_model=model
173
- )
174
-
175
- else:
176
- raise ValueError("Invalid asset type")
177
-
178
- def _add_classify_rapid(
179
- self,
180
- asset: MediaAsset | TextAsset,
181
- instruction: str,
182
- categories: list[str],
183
- truths: list[str],
184
- metadata: Sequence[Metadata] = [],
185
- ) -> None:
186
- """Add a classify rapid to the validation set.
187
-
188
- Args:
189
- asset (MediaAsset | TextAsset): The asset for the rapid.
190
- instruction (str): The instruction for the rapid.
191
- categories (list[str]): The list of categories for the rapid.
192
- truths (list[str]): The list of truths for the rapid.
193
- metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
194
-
195
- Returns:
196
- None
197
- """
198
- payload = ClassifyPayload(
199
- _t="ClassifyPayload", possibleCategories=categories, title=instruction
200
- )
201
- model_truth = AttachCategoryTruth(
202
- correctCategories=truths, _t="AttachCategoryTruth"
203
- )
204
-
205
- self._add_general_validation_rapid(
206
- payload=payload,
207
- truths=model_truth,
208
- metadata=metadata,
209
- asset=asset,
210
- randomCorrectProbability=len(truths) / len(categories),
211
- )
212
-
213
- def _add_compare_rapid(
214
- self,
215
- asset: MultiAsset,
216
- instruction: str,
217
- truth: str,
218
- metadata: Sequence[Metadata] = [],
219
- ) -> None:
220
- """Add a compare rapid to the validation set.
221
-
222
- Args:
223
- asset (MultiAsset): The assets for the rapid.
224
- instruction (str): The instruction for the rapid.
225
- truth (str): The path to the truth file.
226
- metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
227
-
228
- Returns:
229
- None
230
-
231
- Raises:
232
- ValueError: If the number of assets is not exactly two.
233
- """
234
- payload = ComparePayload(_t="ComparePayload", criteria=instruction)
235
- # take only last part of truth path
236
- truth = os.path.basename(truth)
237
- model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
238
-
239
- if len(asset) != 2:
240
- raise ValueError("Compare rapid requires exactly two media paths")
241
-
242
- self._add_general_validation_rapid(
243
- payload=payload,
244
- truths=model_truth,
245
- metadata=metadata,
246
- asset=asset,
247
- randomCorrectProbability=1 / len(asset),
248
- )
249
-
250
- def _add_transcription_rapid(
251
- self,
252
- asset: MediaAsset | TextAsset,
253
- instruction: str,
254
- text: list[str],
255
- correct_words: list[str],
256
- required_precision: float = 1,
257
- required_completeness: float = 1,
258
- metadata: Sequence[Metadata] = [],
259
- ) -> None:
260
- """Add a transcription rapid to the validation set.
261
-
262
- Args:
263
- asset (MediaAsset | TextAsset): The asset for the rapid.
264
- instruction (str): The instruction for the rapid.
265
- text (list[str]): The text for the rapid.
266
- correct_words (list[str]): The list of correct words for the rapid.
267
- required_precision (float, optional): The required precision for the rapid. Defaults to 1.
268
- required_completeness (float, optional): The required completeness for the rapid. Defaults to 1.
269
- metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
270
-
271
- Returns:
272
- None
273
-
274
- Raises:
275
- ValueError: If a correct word is not found in the transcription.
276
- """
277
- transcription_words = [
278
- TranscriptionWord(word=word, wordIndex=i)
279
- for i, word in enumerate(text)
280
- ]
281
-
282
- correct_transcription_words = []
283
- for word in correct_words:
284
- if word not in text:
285
- raise ValueError(f"Correct word '{word}' not found in transcription")
286
- correct_transcription_words.append(
287
- TranscriptionWord(word=word, wordIndex=text.index(word))
288
- )
289
-
290
- payload = TranscriptionPayload(
291
- _t="TranscriptionPayload", title=instruction, transcription=transcription_words
292
- )
293
-
294
- model_truth = TranscriptionTruth(
295
- _t="TranscriptionTruth",
296
- correctWords=correct_transcription_words,
297
- requiredPrecision=required_precision,
298
- requiredCompleteness=required_completeness,
299
- )
300
-
301
- self._add_general_validation_rapid(
302
- payload=payload,
303
- truths=model_truth,
304
- metadata=metadata,
305
- asset=asset,
306
- randomCorrectProbability=len(correct_words) / len(text),
307
- )
308
16
 
309
17
  def __str__(self):
310
18
  return f"name: '{self.name}' id: {self.id}"
311
-
19
+
312
20
  def __repr__(self):
313
21
  return f"name: '{self.name}' id: {self.id}"