rapidata 0.1.24__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,8 @@ from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
11
11
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
12
12
  from rapidata.service import LocalFileService
13
13
  from rapidata.service.openapi_service import OpenAPIService
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from tqdm import tqdm
14
16
 
15
17
 
16
18
  class RapidataDataset:
@@ -32,19 +34,18 @@ class RapidataDataset:
32
34
  self,
33
35
  image_paths: list[str | list[str]],
34
36
  metadata: list[Metadata] | None = None,
37
+ max_workers: int = 10,
35
38
  ):
36
39
  if metadata is not None and len(metadata) != len(image_paths):
37
40
  raise ValueError(
38
41
  "metadata must be None or have the same length as image_paths"
39
42
  )
40
43
 
41
- for media_paths_rapid, meta in zip_longest(image_paths, metadata or []):
42
-
44
+ def upload_datapoint(media_paths_rapid: str | list[str], meta: Metadata | None) -> None:
43
45
  if isinstance(media_paths_rapid, list) and not all(
44
46
  os.path.exists(media_path) for media_path in media_paths_rapid
45
47
  ):
46
48
  raise FileNotFoundError(f"File not found: {media_paths_rapid}")
47
-
48
49
  elif isinstance(media_paths_rapid, str) and not os.path.exists(
49
50
  media_paths_rapid
50
51
  ):
@@ -60,4 +61,19 @@ class RapidataDataset:
60
61
  ),
61
62
  )
62
63
 
63
- self.openapi_service.dataset_api.dataset_create_datapoint_post(model=model, files=media_paths_rapid if isinstance(media_paths_rapid, list) else [media_paths_rapid]) # type: ignore
64
+ self.openapi_service.dataset_api.dataset_create_datapoint_post(
65
+ model=model,
66
+ files=media_paths_rapid if isinstance(media_paths_rapid, list) else [media_paths_rapid] # type: ignore
67
+ )
68
+
69
+ total_uploads = len(image_paths)
70
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
71
+ futures = [
72
+ executor.submit(upload_datapoint, media_paths, meta)
73
+ for media_paths, meta in zip_longest(image_paths, metadata or [])
74
+ ]
75
+
76
+ with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
77
+ for future in as_completed(futures):
78
+ future.result() # This will raise any exceptions that occurred during execution
79
+ pbar.update(1)
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import Any
2
3
  from rapidata.api_client.models.add_validation_rapid_model import (
3
4
  AddValidationRapidModel,
@@ -29,6 +30,7 @@ from rapidata.api_client.models.polygon_payload import PolygonPayload
29
30
  from rapidata.api_client.models.polygon_truth import PolygonTruth
30
31
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
31
32
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
33
+ from rapidata.api_client.models.transcription_word import TranscriptionWord
32
34
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
33
35
  from rapidata.service.openapi_service import OpenAPIService
34
36
 
@@ -43,7 +45,7 @@ class RapidataValidationSet:
43
45
  self.id = validation_set_id
44
46
  self.openapi_service = openapi_service
45
47
 
46
- def add_validation_rapid(
48
+ def add_general_validation_rapid(
47
49
  self,
48
50
  payload: (
49
51
  BoundingBoxPayload
@@ -70,15 +72,18 @@ class RapidataValidationSet:
70
72
  metadata: list[Metadata],
71
73
  media_paths: str | list[str],
72
74
  randomCorrectProbability: float,
73
- ):
75
+ ) -> None:
74
76
  """Add a validation rapid to the validation set.
75
77
 
76
78
  Args:
77
- payload (BoundingBoxPayload | ClassifyPayload | ComparePayload | FreeTextPayload | LinePayload | LocatePayload | NamedEntityPayload | PolygonPayload | TranscriptionPayload): The payload for the rapid.
78
- truths (AttachCategoryTruth | BoundingBoxTruth | CompareTruth | EmptyValidationTruth | LineTruth | LocateBoxTruth | NamedEntityTruth | PolygonTruth | TranscriptionTruth): The truths for the rapid.
79
+ payload (Union[BoundingBoxPayload, ClassifyPayload, ComparePayload, FreeTextPayload, LinePayload, LocatePayload, NamedEntityPayload, PolygonPayload, TranscriptionPayload]): The payload for the rapid.
80
+ truths (Union[AttachCategoryTruth, BoundingBoxTruth, CompareTruth, EmptyValidationTruth, LineTruth, LocateBoxTruth, NamedEntityTruth, PolygonTruth, TranscriptionTruth]): The truths for the rapid.
79
81
  metadata (list[Metadata]): The metadata for the rapid.
80
- media_paths (str | list[str]): The media paths for the rapid.
82
+ media_paths (Union[str, list[str]]): The media paths for the rapid.
81
83
  randomCorrectProbability (float): The random correct probability for the rapid.
84
+
85
+ Returns:
86
+ None
82
87
  """
83
88
  model = AddValidationRapidModel(
84
89
  validationSetId=self.id,
@@ -94,3 +99,137 @@ class RapidataValidationSet:
94
99
  self.openapi_service.validation_api.validation_add_validation_rapid_post(
95
100
  model=model, files=media_paths if isinstance(media_paths, list) else [media_paths] # type: ignore
96
101
  )
102
+
103
+ def add_classify_validation_rapid(
104
+ self,
105
+ media_path: str,
106
+ question: str,
107
+ categories: list[str],
108
+ truths: list[str],
109
+ metadata: list[Metadata] = [],
110
+ ) -> None:
111
+ """Add a classify rapid to the validation set.
112
+
113
+ Args:
114
+ media_path (str): The path to the media file.
115
+ question (str): The question for the rapid.
116
+ categories (list[str]): The list of categories for the rapid.
117
+ truths (list[str]): The list of truths for the rapid.
118
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
119
+
120
+ Returns:
121
+ None
122
+ """
123
+ payload = ClassifyPayload(
124
+ _t="ClassifyPayload", possibleCategories=categories, title=question
125
+ )
126
+ model_truth = AttachCategoryTruth(
127
+ correctCategories=truths, _t="AttachCategoryTruth"
128
+ )
129
+
130
+ self.add_general_validation_rapid(
131
+ payload=payload,
132
+ truths=model_truth,
133
+ metadata=metadata,
134
+ media_paths=media_path,
135
+ randomCorrectProbability=len(truths) / len(categories),
136
+ )
137
+
138
+ def add_compare_validation_rapid(
139
+ self,
140
+ media_paths: list[str],
141
+ question: str,
142
+ truth: str,
143
+ metadata: list[Metadata] = [],
144
+ ) -> None:
145
+ """Add a compare rapid to the validation set.
146
+
147
+ Args:
148
+ media_paths (list[str]): The list of media paths for the rapid.
149
+ question (str): The question for the rapid.
150
+ truth (str): The path to the truth file.
151
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
152
+
153
+ Returns:
154
+ None
155
+
156
+ Raises:
157
+ ValueError: If the number of media paths is not exactly two.
158
+ FileNotFoundError: If any of the specified files are not found.
159
+ """
160
+ payload = ComparePayload(_t="ComparePayload", criteria=question)
161
+ # take only last part of truth path
162
+ truth = os.path.basename(truth)
163
+ model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
164
+
165
+ if len(media_paths) != 2:
166
+ raise ValueError("Compare rapid requires exactly two media paths")
167
+
168
+ # check that files exist
169
+ for media_path in media_paths:
170
+ if not os.path.exists(media_path):
171
+ raise FileNotFoundError(f"File not found: {media_path}")
172
+
173
+ self.add_general_validation_rapid(
174
+ payload=payload,
175
+ truths=model_truth,
176
+ metadata=metadata,
177
+ media_paths=media_paths,
178
+ randomCorrectProbability=1 / len(media_paths),
179
+ )
180
+
181
+ def add_transcription_validation_rapid(
182
+ self,
183
+ media_path: str,
184
+ question: str,
185
+ transcription: list[str],
186
+ correct_words: list[str],
187
+ strict_grading: bool | None = None,
188
+ metadata: list[Metadata] = [],
189
+ ) -> None:
190
+ """Add a transcription rapid to the validation set.
191
+
192
+ Args:
193
+ media_path (str): The path to the media file.
194
+ question (str): The question for the rapid.
195
+ transcription (list[str]): The transcription for the rapid.
196
+ correct_words (list[str]): The list of correct words for the rapid.
197
+ strict_grading (Optional[bool], optional): The strict grading for the rapid. Defaults to None.
198
+ metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
199
+
200
+ Returns:
201
+ None
202
+
203
+ Raises:
204
+ ValueError: If a correct word is not found in the transcription.
205
+ """
206
+ transcription_words = [
207
+ TranscriptionWord(word=word, wordIndex=i)
208
+ for i, word in enumerate(transcription)
209
+ ]
210
+
211
+ correct_transcription_words = []
212
+ for word in correct_words:
213
+ if word not in transcription:
214
+ raise ValueError(f"Correct word '{word}' not found in transcription")
215
+ correct_transcription_words.append(
216
+ TranscriptionWord(word=word, wordIndex=transcription.index(word))
217
+ )
218
+
219
+ payload = TranscriptionPayload(
220
+ _t="TranscriptionPayload", title=question, transcription=transcription_words
221
+ )
222
+
223
+ model_truth = TranscriptionTruth(
224
+ _t="TranscriptionTruth",
225
+ correctWords=correct_transcription_words,
226
+ strictGrading=strict_grading,
227
+ )
228
+
229
+ self.add_general_validation_rapid(
230
+ payload=payload,
231
+ truths=model_truth,
232
+ metadata=metadata,
233
+ media_paths=media_path,
234
+ randomCorrectProbability=len(correct_words) / len(transcription),
235
+ )
@@ -48,7 +48,7 @@ class ValidationSetBuilder:
48
48
  )
49
49
 
50
50
  for rapid_part in self._rapid_parts:
51
- validation_set.add_validation_rapid(
51
+ validation_set.add_general_validation_rapid(
52
52
  payload=rapid_part.payload,
53
53
  truths=rapid_part.truths,
54
54
  metadata=rapid_part.metadata,
@@ -87,7 +87,7 @@ class RapidataOrderBuilder:
87
87
  ],
88
88
  )
89
89
 
90
- def create(self, submit=True) -> RapidataOrder:
90
+ def create(self, submit=True, max_workers=10) -> RapidataOrder:
91
91
  """Actually makes the API calls to create the order based on how the order builder was configured.
92
92
 
93
93
  Args:
@@ -113,7 +113,7 @@ class RapidataOrderBuilder:
113
113
  openapi_service=self._openapi_service,
114
114
  )
115
115
 
116
- order.dataset.add_media_from_paths(self._media_paths, self._metadata)
116
+ order.dataset.add_media_from_paths(self._media_paths, self._metadata, max_workers)
117
117
 
118
118
  if submit:
119
119
  order.submit()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rapidata
3
- Version: 0.1.24
3
+ Version: 0.2.0
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG
@@ -222,10 +222,10 @@ rapidata/rapidata_client/config.py,sha256=tQLgN6k_ATOX1GzZh38At2rgBDLStV6rJ6z0vs
222
222
  rapidata/rapidata_client/country_codes/__init__.py,sha256=Y8qeG2IMjvMGvhaPydq0nhwRQHb6dQqilctlEXu0_PE,55
223
223
  rapidata/rapidata_client/country_codes/country_codes.py,sha256=Q0HMX7uHJQDeLCFPP5bq4iYi6pgcDWEcl2ONGhjgoeU,286
224
224
  rapidata/rapidata_client/dataset/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
225
- rapidata/rapidata_client/dataset/rapidata_dataset.py,sha256=J5jlSIbdswhqXHOq8Qf9gtzVF70cImq9sioDG7S9Mxs,2457
226
- rapidata/rapidata_client/dataset/rapidata_validation_set.py,sha256=oQUtAF9ouLWg9AXkHXlnbEsPb8w9zxeKTw2YbSIk3ic,4475
225
+ rapidata/rapidata_client/dataset/rapidata_dataset.py,sha256=QDsl7ZCZxuG02yfEpBSphfSDZh_qHz6m5HUCGwZflfw,3205
226
+ rapidata/rapidata_client/dataset/rapidata_validation_set.py,sha256=YrnUzia9AXgq2z917FtztFxj4fD5EgTWXVPBzLVIujY,9374
227
227
  rapidata/rapidata_client/dataset/validation_rapid_parts.py,sha256=SIeQesEXPPOW5kclxYLNWaKllBXHm7DQKBdMU-GXnfc,2104
228
- rapidata/rapidata_client/dataset/validation_set_builder.py,sha256=_g7acP7lqYMI_5U9q1T6YHuUQ1ZDfZbLJ-QoqsGME4o,7463
228
+ rapidata/rapidata_client/dataset/validation_set_builder.py,sha256=B9D-uNCo_PO0NCUHju_7dsWtz_KcOmvFIsxUgQ67Q2M,7471
229
229
  rapidata/rapidata_client/feature_flags/__init__.py,sha256=BNG_NQ4CrrC61fAWliImr8r581pIvegrkepVVbxcBw8,55
230
230
  rapidata/rapidata_client/feature_flags/feature_flags.py,sha256=hcS9YRzpsPWpZfw-3QwSuf2TaVg-MOHBxY788oNqIW4,3957
231
231
  rapidata/rapidata_client/metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -235,7 +235,7 @@ rapidata/rapidata_client/metadata/prompt_metadata.py,sha256=_FypjKWrC3iKUO_G2CVw
235
235
  rapidata/rapidata_client/metadata/transcription_metadata.py,sha256=THtDEVCON4UlcXHmXrjilaOLHys4TrktUOPGWnXaCcc,631
236
236
  rapidata/rapidata_client/order/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
237
237
  rapidata/rapidata_client/order/rapidata_order.py,sha256=VRDLTPBf2k6UihF0DnWq3nfBLWExfWHzh3T-cJgFF1w,2437
238
- rapidata/rapidata_client/order/rapidata_order_builder.py,sha256=qGfdmO3wdQMDwlW5NTM8x5ll8YmC0UUuUkGtyzGMObU,8105
238
+ rapidata/rapidata_client/order/rapidata_order_builder.py,sha256=Wx7lhThTd6SwJNSbXzhGCsEgMhizjsriGz2zMjbQyEI,8134
239
239
  rapidata/rapidata_client/rapidata_client.py,sha256=z3vz5_GNivnShj7kqii-eUff16rvwSy62zwi8WZqAWo,2776
240
240
  rapidata/rapidata_client/referee/__init__.py,sha256=x0AxGCsR6TlDjfqQ00lB9V7QVS9EZCJzweNEIzx42PI,207
241
241
  rapidata/rapidata_client/referee/base_referee.py,sha256=bMy7cw0a-pGNbFu6u_1_Jplu0A483Ubj4oDQzh8vu8k,493
@@ -259,7 +259,7 @@ rapidata/service/local_file_service.py,sha256=pgorvlWcx52Uh3cEG6VrdMK_t__7dacQ_5
259
259
  rapidata/service/openapi_service.py,sha256=-vrM2jEzQxr9KAerOYkVhpvMEeHwjzRwm9L_VFyzOT0,1537
260
260
  rapidata/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
261
261
  rapidata/utils/image_utils.py,sha256=TldO3eJWG8IhfJjm5MfNGO0mEDm1mQTsRoA0HLU1Uxs,404
262
- rapidata-0.1.24.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
263
- rapidata-0.1.24.dist-info/METADATA,sha256=vrxhs737rJg2Z4F7GLvOgpDhxXc0tJcxs_WPpxL7ZjY,924
264
- rapidata-0.1.24.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
265
- rapidata-0.1.24.dist-info/RECORD,,
262
+ rapidata-0.2.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
263
+ rapidata-0.2.0.dist-info/METADATA,sha256=Oymi81UjDR9fM9f-axaTKUWOCnlExpeA1i1voXC7ruU,923
264
+ rapidata-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
265
+ rapidata-0.2.0.dist-info/RECORD,,