rapidata 2.13.0__py3-none-any.whl → 2.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +29 -0
- rapidata/api_client/__init__.py +6 -4
- rapidata/api_client/api/__init__.py +1 -0
- rapidata/api_client/api/dataset_api.py +265 -0
- rapidata/api_client/api/workflow_api.py +298 -1
- rapidata/api_client/models/__init__.py +5 -4
- rapidata/api_client/models/add_campaign_model.py +3 -3
- rapidata/api_client/models/add_validation_rapid_model.py +3 -3
- rapidata/api_client/models/add_validation_text_rapid_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_urls_model.py +26 -4
- rapidata/api_client/models/create_datapoint_from_urls_model_metadata_inner.py +168 -0
- rapidata/api_client/models/create_datapoint_result.py +5 -3
- rapidata/api_client/models/datapoint.py +7 -30
- rapidata/api_client/models/datapoint_asset.py +40 -40
- rapidata/api_client/models/datapoint_metadata_model.py +3 -3
- rapidata/api_client/models/datapoint_model.py +3 -3
- rapidata/api_client/models/get_compare_workflow_results_result.py +3 -3
- rapidata/api_client/models/get_datapoint_by_id_result.py +3 -3
- rapidata/api_client/models/get_failed_datapoints_result.py +95 -0
- rapidata/api_client/models/get_responses_result.py +95 -0
- rapidata/api_client/models/get_simple_workflow_results_result.py +3 -3
- rapidata/api_client/models/multi_asset_model.py +3 -3
- rapidata/api_client/models/upload_text_sources_to_dataset_model.py +17 -2
- rapidata/api_client_README.md +8 -4
- rapidata/rapidata_client/assets/_media_asset.py +5 -1
- rapidata/rapidata_client/assets/_multi_asset.py +6 -1
- rapidata/rapidata_client/filter/country_filter.py +1 -1
- rapidata/rapidata_client/order/_rapidata_dataset.py +311 -108
- rapidata/rapidata_client/order/rapidata_order_manager.py +64 -6
- rapidata/rapidata_client/validation/rapids/rapids.py +4 -5
- rapidata/rapidata_client/workflow/__init__.py +1 -0
- rapidata/rapidata_client/workflow/_ranking_workflow.py +40 -0
- {rapidata-2.13.0.dist-info → rapidata-2.14.0.dist-info}/METADATA +1 -1
- {rapidata-2.13.0.dist-info → rapidata-2.14.0.dist-info}/RECORD +36 -32
- {rapidata-2.13.0.dist-info → rapidata-2.14.0.dist-info}/LICENSE +0 -0
- {rapidata-2.13.0.dist-info → rapidata-2.14.0.dist-info}/WHEEL +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from itertools import zip_longest
|
|
2
2
|
|
|
3
3
|
from rapidata.api_client.models.datapoint_metadata_model import DatapointMetadataModel
|
|
4
|
-
from rapidata.api_client.models.
|
|
5
|
-
|
|
4
|
+
from rapidata.api_client.models.create_datapoint_from_urls_model import (
|
|
5
|
+
CreateDatapointFromUrlsModelMetadataInner,
|
|
6
6
|
)
|
|
7
|
+
from rapidata.api_client.models.create_datapoint_from_urls_model import CreateDatapointFromUrlsModel
|
|
7
8
|
from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
|
|
8
9
|
UploadTextSourcesToDatasetModel,
|
|
9
10
|
)
|
|
@@ -14,12 +15,11 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
14
15
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
15
16
|
from tqdm import tqdm
|
|
16
17
|
|
|
17
|
-
from pydantic import
|
|
18
|
-
from typing import
|
|
18
|
+
from pydantic import StrictStr
|
|
19
|
+
from typing import cast, Sequence, Generator
|
|
19
20
|
from logging import Logger
|
|
20
|
-
from requests.adapters import HTTPAdapter, Retry
|
|
21
21
|
import time
|
|
22
|
-
import
|
|
22
|
+
import threading
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def chunk_list(lst: list, chunk_size: int) -> Generator:
|
|
@@ -77,149 +77,352 @@ class RapidataDataset:
|
|
|
77
77
|
future.result() # This will raise any exceptions that occurred during execution
|
|
78
78
|
pbar.update(1)
|
|
79
79
|
|
|
80
|
-
def
|
|
80
|
+
def _process_single_upload(
|
|
81
81
|
self,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
max_retries: int = 5,
|
|
86
|
-
chunk_size: int = 50,
|
|
82
|
+
media_asset: MediaAsset | MultiAsset,
|
|
83
|
+
meta: Metadata | None,
|
|
84
|
+
index: int,
|
|
87
85
|
) -> tuple[list[str], list[str]]:
|
|
88
86
|
"""
|
|
89
|
-
|
|
87
|
+
Process single upload with error tracking.
|
|
90
88
|
|
|
91
89
|
Args:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
chunk_size: Number of items to process in each batch
|
|
90
|
+
media_asset: MediaAsset or MultiAsset to upload
|
|
91
|
+
meta: Optional metadata for the asset
|
|
92
|
+
index: Sort index for the upload
|
|
93
|
+
session: Requests session for HTTP requests
|
|
97
94
|
|
|
98
95
|
Returns:
|
|
99
|
-
tuple[list[str], list[str]]: Lists of successful and failed
|
|
100
|
-
|
|
101
|
-
Raises:
|
|
102
|
-
ValueError: If metadata length doesn't match media_paths length
|
|
96
|
+
tuple[list[str], list[str]]: Lists of successful and failed identifiers
|
|
103
97
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
# Configure session with retry logic
|
|
108
|
-
session = requests.Session()
|
|
109
|
-
retries = Retry(
|
|
110
|
-
total=max_retries,
|
|
111
|
-
backoff_factor=1,
|
|
112
|
-
status_forcelist=[500, 502, 503, 504],
|
|
113
|
-
allowed_methods=["GET"],
|
|
114
|
-
respect_retry_after_header=True
|
|
115
|
-
)
|
|
98
|
+
local_successful: list[str] = []
|
|
99
|
+
local_failed: list[str] = []
|
|
100
|
+
identifiers_to_track: list[str] = []
|
|
116
101
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
local_failed: list[str] = []
|
|
134
|
-
identifiers_to_track: list[str] = []
|
|
135
|
-
|
|
136
|
-
try:
|
|
137
|
-
# Get identifier for this upload (URL or file path)
|
|
138
|
-
if isinstance(media_asset, MediaAsset):
|
|
139
|
-
media_asset.session = session
|
|
140
|
-
assets = [media_asset]
|
|
141
|
-
identifier = media_asset._url if media_asset._url else media_asset.path
|
|
142
|
-
identifiers_to_track = [identifier] if identifier else []
|
|
143
|
-
elif isinstance(media_asset, MultiAsset):
|
|
144
|
-
assets = cast(list[MediaAsset], media_asset.assets)
|
|
145
|
-
for asset in assets:
|
|
146
|
-
asset.session = session
|
|
147
|
-
identifiers_to_track: list[str] = [
|
|
148
|
-
(asset._url if asset._url else cast(str, asset.path))
|
|
149
|
-
for asset in assets
|
|
150
|
-
]
|
|
151
|
-
else:
|
|
152
|
-
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
102
|
+
try:
|
|
103
|
+
# Get identifier for this upload (URL or file path)
|
|
104
|
+
if isinstance(media_asset, MediaAsset):
|
|
105
|
+
assets = [media_asset]
|
|
106
|
+
identifier = media_asset._url if media_asset._url else media_asset.path
|
|
107
|
+
identifiers_to_track = [identifier] if identifier else []
|
|
108
|
+
elif isinstance(media_asset, MultiAsset):
|
|
109
|
+
assets = cast(list[MediaAsset], media_asset.assets)
|
|
110
|
+
identifiers_to_track: list[str] = [
|
|
111
|
+
(asset._url if asset._url else cast(str, asset.path))
|
|
112
|
+
for asset in assets
|
|
113
|
+
]
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
116
|
+
|
|
117
|
+
meta_model = meta.to_model() if meta else None
|
|
153
118
|
|
|
154
|
-
|
|
119
|
+
metadata = [CreateDatapointFromUrlsModelMetadataInner(meta_model)] if meta_model else []
|
|
120
|
+
|
|
121
|
+
local_paths: bool = assets[0].is_local()
|
|
122
|
+
files: list[StrictStr] = []
|
|
123
|
+
for asset in assets:
|
|
124
|
+
if isinstance(asset, MediaAsset):
|
|
125
|
+
files.append(asset.path)
|
|
126
|
+
|
|
127
|
+
if local_paths:
|
|
155
128
|
model = DatapointMetadataModel(
|
|
156
129
|
datasetId=self.dataset_id,
|
|
157
|
-
metadata=
|
|
130
|
+
metadata=metadata,
|
|
158
131
|
sortIndex=index,
|
|
159
132
|
)
|
|
160
|
-
|
|
161
|
-
files: list[tuple[StrictStr, StrictBytes] | StrictStr | StrictBytes] = []
|
|
162
|
-
for asset in assets:
|
|
163
|
-
if isinstance(asset, MediaAsset):
|
|
164
|
-
files.append(asset.to_file())
|
|
165
|
-
|
|
166
133
|
upload_response = self.openapi_service.dataset_api.dataset_create_datapoint_post(
|
|
167
134
|
model=model,
|
|
168
|
-
files=files
|
|
135
|
+
files=files # type: ignore
|
|
169
136
|
)
|
|
137
|
+
else:
|
|
138
|
+
upload_response = self.openapi_service.dataset_api.dataset_dataset_id_datapoints_urls_post(
|
|
139
|
+
dataset_id=self.dataset_id,
|
|
140
|
+
create_datapoint_from_urls_model=CreateDatapointFromUrlsModel(
|
|
141
|
+
urls=files,
|
|
142
|
+
metadata=metadata,
|
|
143
|
+
sortIndex=index
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if upload_response.errors:
|
|
148
|
+
error_msg = f"Error uploading datapoint: {upload_response.errors}"
|
|
149
|
+
self._logger.error(error_msg)
|
|
150
|
+
local_failed.extend(identifiers_to_track)
|
|
151
|
+
raise ValueError(error_msg)
|
|
170
152
|
|
|
171
|
-
|
|
172
|
-
error_msg = f"Error uploading datapoint: {upload_response.errors}"
|
|
173
|
-
self._logger.error(error_msg)
|
|
174
|
-
local_failed.extend(identifiers_to_track)
|
|
175
|
-
raise ValueError(error_msg)
|
|
176
|
-
|
|
177
|
-
local_successful.extend(identifiers_to_track)
|
|
153
|
+
local_successful.extend(identifiers_to_track)
|
|
178
154
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
155
|
+
except Exception as e:
|
|
156
|
+
self._logger.error(f"\nUpload failed for {identifiers_to_track}: {str(e)}") # \n to avoid same line as tqdm
|
|
157
|
+
local_failed.extend(identifiers_to_track)
|
|
182
158
|
|
|
183
|
-
|
|
159
|
+
return local_successful, local_failed
|
|
184
160
|
|
|
161
|
+
def _get_progress_tracker(
|
|
162
|
+
self,
|
|
163
|
+
total_uploads: int,
|
|
164
|
+
stop_event: threading.Event,
|
|
165
|
+
progress_error_event: threading.Event,
|
|
166
|
+
progress_poll_interval: float,
|
|
167
|
+
) -> threading.Thread:
|
|
168
|
+
"""
|
|
169
|
+
Create and return a progress tracking thread that shows actual API progress.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
total_uploads: Total number of uploads to track
|
|
173
|
+
initial_ready: Initial number of ready items
|
|
174
|
+
initial_progress: Initial progress state
|
|
175
|
+
stop_event: Event to signal thread to stop
|
|
176
|
+
progress_error_event: Event to signal an error in progress tracking
|
|
177
|
+
progress_poll_interval: Time between progress checks
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
threading.Thread: The progress tracking thread
|
|
181
|
+
"""
|
|
182
|
+
def progress_tracking_thread():
|
|
183
|
+
try:
|
|
184
|
+
# Initialize progress bar with 0 completions
|
|
185
|
+
with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
|
|
186
|
+
prev_ready = 0
|
|
187
|
+
prev_failed = 0
|
|
188
|
+
stall_count = 0
|
|
189
|
+
last_progress_time = time.time()
|
|
190
|
+
|
|
191
|
+
# We'll wait for all uploads to finish + some extra time
|
|
192
|
+
# for the backend to fully process everything
|
|
193
|
+
all_uploads_complete = threading.Event()
|
|
194
|
+
|
|
195
|
+
while not stop_event.is_set() or not all_uploads_complete.is_set():
|
|
196
|
+
try:
|
|
197
|
+
current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
|
|
198
|
+
|
|
199
|
+
# Calculate items completed since our initialization
|
|
200
|
+
completed_ready = current_progress.ready
|
|
201
|
+
completed_failed = current_progress.failed
|
|
202
|
+
total_completed = completed_ready + completed_failed
|
|
203
|
+
|
|
204
|
+
# Calculate newly completed items since our last check
|
|
205
|
+
new_ready = current_progress.ready - prev_ready
|
|
206
|
+
new_failed = current_progress.failed - prev_failed
|
|
207
|
+
|
|
208
|
+
# Update progress bar position to show actual completed items
|
|
209
|
+
# First reset to match the actual completed count
|
|
210
|
+
pbar.n = total_completed
|
|
211
|
+
pbar.refresh()
|
|
212
|
+
|
|
213
|
+
if new_ready > 0 or new_failed > 0:
|
|
214
|
+
# We saw progress
|
|
215
|
+
stall_count = 0
|
|
216
|
+
last_progress_time = time.time()
|
|
217
|
+
else:
|
|
218
|
+
stall_count += 1
|
|
219
|
+
|
|
220
|
+
# Update our tracking variables
|
|
221
|
+
prev_ready = current_progress.ready
|
|
222
|
+
prev_failed = current_progress.failed or 0
|
|
223
|
+
|
|
224
|
+
# Check if stop_event was set (all uploads submitted)
|
|
225
|
+
if stop_event.is_set():
|
|
226
|
+
elapsed_since_last_progress = time.time() - last_progress_time
|
|
227
|
+
|
|
228
|
+
# If we haven't seen progress for a while after all uploads were submitted
|
|
229
|
+
if elapsed_since_last_progress > 5.0:
|
|
230
|
+
# If we're at 100%, we're done
|
|
231
|
+
if total_completed >= total_uploads:
|
|
232
|
+
all_uploads_complete.set()
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
# If we're not at 100% but it's been a while with no progress
|
|
236
|
+
if stall_count > 5:
|
|
237
|
+
# We've polled several times with no progress, assume we're done
|
|
238
|
+
self._logger.warning(f"\nProgress seems stalled at {total_completed}/{total_uploads}. Please try again.")
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
except Exception as e:
|
|
242
|
+
self._logger.error(f"\nError checking progress: {str(e)}")
|
|
243
|
+
stall_count += 1
|
|
244
|
+
|
|
245
|
+
if stall_count > 10: # Too many consecutive errors
|
|
246
|
+
progress_error_event.set()
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
# Sleep before next poll
|
|
250
|
+
time.sleep(progress_poll_interval)
|
|
251
|
+
|
|
252
|
+
except Exception as e:
|
|
253
|
+
self._logger.error(f"Progress tracking thread error: {str(e)}")
|
|
254
|
+
progress_error_event.set()
|
|
255
|
+
|
|
256
|
+
# Create and return the thread
|
|
257
|
+
progress_thread = threading.Thread(target=progress_tracking_thread)
|
|
258
|
+
progress_thread.daemon = True
|
|
259
|
+
return progress_thread
|
|
185
260
|
|
|
186
|
-
|
|
261
|
+
def _process_uploads_in_chunks(
|
|
262
|
+
self,
|
|
263
|
+
media_paths: list[MediaAsset] | list[MultiAsset],
|
|
264
|
+
metadata: Sequence[Metadata] | None,
|
|
265
|
+
max_workers: int,
|
|
266
|
+
chunk_size: int,
|
|
267
|
+
stop_progress_tracking: threading.Event,
|
|
268
|
+
progress_tracking_error: threading.Event
|
|
269
|
+
) -> tuple[list[str], list[str]]:
|
|
270
|
+
"""
|
|
271
|
+
Process uploads in chunks with a ThreadPoolExecutor.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
media_paths: List of assets to upload
|
|
275
|
+
metadata: Optional sequence of metadata
|
|
276
|
+
session: Requests session for HTTP requests
|
|
277
|
+
max_workers: Maximum number of concurrent workers
|
|
278
|
+
chunk_size: Number of items to process in each batch
|
|
279
|
+
stop_progress_tracking: Event to signal progress tracking to stop
|
|
280
|
+
progress_tracking_error: Event to detect progress tracking errors
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
tuple[list[str], list[str]]: Lists of successful and failed uploads
|
|
284
|
+
"""
|
|
187
285
|
successful_uploads: list[str] = []
|
|
188
286
|
failed_uploads: list[str] = []
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
290
|
+
# Process uploads in chunks to avoid overwhelming the system
|
|
291
|
+
for chunk_idx, chunk in enumerate(chunk_list(media_paths, chunk_size)):
|
|
292
|
+
chunk_metadata = metadata[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] if metadata else None
|
|
293
|
+
|
|
196
294
|
futures = [
|
|
197
295
|
executor.submit(
|
|
198
|
-
|
|
296
|
+
self._process_single_upload,
|
|
199
297
|
media_asset,
|
|
200
298
|
meta,
|
|
201
|
-
index=(chunk_idx * chunk_size + i)
|
|
202
|
-
session=session
|
|
299
|
+
index=(chunk_idx * chunk_size + i)
|
|
203
300
|
)
|
|
204
301
|
for i, (media_asset, meta) in enumerate(zip_longest(chunk, chunk_metadata or []))
|
|
205
302
|
]
|
|
206
|
-
|
|
303
|
+
|
|
304
|
+
# Wait for this chunk to complete before starting the next one
|
|
207
305
|
for future in as_completed(futures):
|
|
306
|
+
if progress_tracking_error.is_set():
|
|
307
|
+
raise RuntimeError("Progress tracking failed, aborting uploads")
|
|
308
|
+
|
|
208
309
|
try:
|
|
209
310
|
chunk_successful, chunk_failed = future.result()
|
|
210
311
|
successful_uploads.extend(chunk_successful)
|
|
211
312
|
failed_uploads.extend(chunk_failed)
|
|
212
313
|
except Exception as e:
|
|
213
314
|
self._logger.error(f"Future execution failed: {str(e)}")
|
|
214
|
-
|
|
215
|
-
|
|
315
|
+
finally:
|
|
316
|
+
# Signal to the progress tracking thread that all uploads have been submitted
|
|
317
|
+
stop_progress_tracking.set()
|
|
318
|
+
|
|
319
|
+
return successful_uploads, failed_uploads
|
|
216
320
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
321
|
+
def _log_final_progress(
|
|
322
|
+
self,
|
|
323
|
+
total_uploads: int,
|
|
324
|
+
progress_poll_interval: float,
|
|
325
|
+
successful_uploads: list[str],
|
|
326
|
+
failed_uploads: list[str]
|
|
327
|
+
) -> None:
|
|
328
|
+
"""
|
|
329
|
+
Log the final progress of the upload operation.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
total_uploads: Total number of uploads
|
|
333
|
+
initial_ready: Initial number of ready items
|
|
334
|
+
initial_progress: Initial progress state
|
|
335
|
+
progress_poll_interval: Time between progress checks
|
|
336
|
+
successful_uploads: List of successful uploads for fallback reporting
|
|
337
|
+
failed_uploads: List of failed uploads for fallback reporting
|
|
338
|
+
"""
|
|
339
|
+
try:
|
|
340
|
+
# Get final progress
|
|
341
|
+
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
|
|
342
|
+
total_ready = final_progress.ready
|
|
343
|
+
total_failed = final_progress.failed
|
|
344
|
+
|
|
345
|
+
# Make sure we account for all uploads
|
|
346
|
+
if total_ready + total_failed < total_uploads:
|
|
347
|
+
# Try one more time after a longer wait
|
|
348
|
+
time.sleep(5 * progress_poll_interval)
|
|
349
|
+
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
|
|
350
|
+
total_ready = final_progress.ready
|
|
351
|
+
total_failed = final_progress.failed
|
|
352
|
+
|
|
353
|
+
success_rate = (total_ready / total_uploads * 100) if total_uploads > 0 else 0
|
|
354
|
+
|
|
355
|
+
self._logger.info(f"Upload complete: {total_ready} ready, {total_uploads-total_ready} failed ({success_rate:.1f}% success rate)")
|
|
356
|
+
print(f"Upload complete, {total_ready} ready, {total_uploads-total_ready} failed ({success_rate:.1f}% success rate)")
|
|
357
|
+
except Exception as e:
|
|
358
|
+
self._logger.error(f"Error getting final progress: {str(e)}")
|
|
359
|
+
self._logger.info(f"Upload summary from local tracking: {len(successful_uploads)} succeeded, {len(failed_uploads)} failed")
|
|
220
360
|
|
|
221
361
|
if failed_uploads:
|
|
222
362
|
print(f"Failed uploads: {failed_uploads}")
|
|
223
363
|
|
|
224
|
-
|
|
364
|
+
def _add_media_from_paths(
|
|
365
|
+
self,
|
|
366
|
+
media_paths: list[MediaAsset] | list[MultiAsset],
|
|
367
|
+
metadata: Sequence[Metadata] | None = None,
|
|
368
|
+
max_workers: int = 5,
|
|
369
|
+
chunk_size: int = 50,
|
|
370
|
+
progress_poll_interval: float = 0.5,
|
|
371
|
+
) -> tuple[list[str], list[str]]:
|
|
372
|
+
"""
|
|
373
|
+
Upload media paths in chunks with managed resources.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
media_paths: List of MediaAsset or MultiAsset objects to upload
|
|
377
|
+
metadata: Optional sequence of metadata matching media_paths length
|
|
378
|
+
max_workers: Maximum number of concurrent upload workers
|
|
379
|
+
chunk_size: Number of items to process in each batch
|
|
380
|
+
progress_poll_interval: Time in seconds between progress checks
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
tuple[list[str], list[str]]: Lists of successful and failed URLs
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
ValueError: If metadata length doesn't match media_paths length
|
|
387
|
+
"""
|
|
388
|
+
if metadata is not None and len(metadata) != len(media_paths):
|
|
389
|
+
raise ValueError("metadata must be None or have the same length as media_paths")
|
|
390
|
+
|
|
391
|
+
# Setup tracking variables
|
|
392
|
+
total_uploads = len(media_paths)
|
|
393
|
+
|
|
394
|
+
# Create thread control events
|
|
395
|
+
stop_progress_tracking = threading.Event()
|
|
396
|
+
progress_tracking_error = threading.Event()
|
|
397
|
+
|
|
398
|
+
# Create and start progress tracking thread
|
|
399
|
+
progress_thread = self._get_progress_tracker(
|
|
400
|
+
total_uploads,
|
|
401
|
+
stop_progress_tracking,
|
|
402
|
+
progress_tracking_error,
|
|
403
|
+
progress_poll_interval
|
|
404
|
+
)
|
|
405
|
+
progress_thread.start()
|
|
406
|
+
|
|
407
|
+
# Process uploads in chunks
|
|
408
|
+
try:
|
|
409
|
+
successful_uploads, failed_uploads = self._process_uploads_in_chunks(
|
|
410
|
+
media_paths,
|
|
411
|
+
metadata,
|
|
412
|
+
max_workers,
|
|
413
|
+
chunk_size,
|
|
414
|
+
stop_progress_tracking,
|
|
415
|
+
progress_tracking_error
|
|
416
|
+
)
|
|
417
|
+
finally:
|
|
418
|
+
progress_thread.join(10) # Add margin to the timeout for tqdm
|
|
419
|
+
|
|
420
|
+
# Log final progress
|
|
421
|
+
self._log_final_progress(
|
|
422
|
+
total_uploads,
|
|
423
|
+
progress_poll_interval,
|
|
424
|
+
successful_uploads,
|
|
425
|
+
failed_uploads
|
|
426
|
+
)
|
|
225
427
|
|
|
428
|
+
return successful_uploads, failed_uploads
|
|
@@ -1,17 +1,14 @@
|
|
|
1
|
-
from typing import Sequence
|
|
1
|
+
from typing import Sequence, Optional
|
|
2
2
|
from urllib3._collections import HTTPHeaderDict
|
|
3
3
|
|
|
4
4
|
from rapidata.service.openapi_service import OpenAPIService
|
|
5
5
|
from rapidata.rapidata_client.assets.data_type_enum import RapidataDataTypes
|
|
6
|
-
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
7
6
|
from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
|
|
8
7
|
from rapidata.rapidata_client.order._rapidata_order_builder import RapidataOrderBuilder
|
|
9
8
|
from rapidata.rapidata_client.metadata import PromptMetadata, SelectWordsMetadata
|
|
10
9
|
from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
|
|
11
10
|
from rapidata.rapidata_client.referee._early_stopping_referee import EarlyStoppingReferee
|
|
12
11
|
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
|
|
13
|
-
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
14
|
-
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
15
12
|
from rapidata.rapidata_client.workflow import (
|
|
16
13
|
Workflow,
|
|
17
14
|
ClassifyWorkflow,
|
|
@@ -20,7 +17,9 @@ from rapidata.rapidata_client.workflow import (
|
|
|
20
17
|
SelectWordsWorkflow,
|
|
21
18
|
LocateWorkflow,
|
|
22
19
|
DrawWorkflow,
|
|
23
|
-
TimestampWorkflow
|
|
20
|
+
TimestampWorkflow,
|
|
21
|
+
RankingWorkflow
|
|
22
|
+
)
|
|
24
23
|
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
25
24
|
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
26
25
|
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
@@ -243,7 +242,66 @@ class RapidataOrderManager:
|
|
|
243
242
|
selections=selections,
|
|
244
243
|
settings=settings
|
|
245
244
|
)
|
|
246
|
-
|
|
245
|
+
|
|
246
|
+
def create_ranking_order(self,
|
|
247
|
+
name: str,
|
|
248
|
+
instruction: str,
|
|
249
|
+
datapoints: list[str],
|
|
250
|
+
responses_per_comparison: int,
|
|
251
|
+
total_comparison_budget: int,
|
|
252
|
+
random_comparisons_ratio: float = 0.5,
|
|
253
|
+
elo_start: int = 1200,
|
|
254
|
+
elo_k_factor: int = 40,
|
|
255
|
+
elo_scaling_factor: int = 400,
|
|
256
|
+
contexts: Optional[list[str]] = None,
|
|
257
|
+
validation_set_id: Optional[str] = None,
|
|
258
|
+
filters: Sequence[RapidataFilter] = [],
|
|
259
|
+
settings: Sequence[RapidataSetting] = [],
|
|
260
|
+
selections: Optional[Sequence[RapidataSelection]] = None) -> RapidataOrder:
|
|
261
|
+
"""
|
|
262
|
+
Create a ranking order.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
name (str): The name of the order.
|
|
266
|
+
instruction (str): The question asked from People when They see two datapoints.
|
|
267
|
+
datapoints (list[str]): A list of datapoints that will participate in the ranking.
|
|
268
|
+
total_comparison_budget (int): The total number of (pairwise-)comparisons that can be made.
|
|
269
|
+
random_comparisons_ratio (float, optional): The fraction of random comparisons in the ranking process.
|
|
270
|
+
The rest will focus on pairing similarly ranked datapoints.
|
|
271
|
+
elo_start (int, optional): The initial ELO rating assigned to each datapoint.
|
|
272
|
+
elo_k_factor (int, optional): The K-factor used for ELO updates.
|
|
273
|
+
elo_scaling_factor (int, optional): The scaling factor used in the ELO calculation.
|
|
274
|
+
responses_per_comparison (int, optional): The number of responses collected per comparison.
|
|
275
|
+
contexts (list[str], optional): The list of contexts for the comparison. Defaults to None.\n
|
|
276
|
+
If provided has to be the same length as datapoints and will be shown in addition to the instruction.
|
|
277
|
+
(Therefore will be different for each datapoint) Will be match up with the datapoints using the list index.
|
|
278
|
+
validation_set_id (str, optional): The ID of the validation set. Defaults to None.\n
|
|
279
|
+
If provided, one validation task will be shown infront of the datapoints that will be labeled.
|
|
280
|
+
filters (Sequence[RapidataFilter], optional): The list of filters for the order. Defaults to []. Decides who the tasks should be shown to.
|
|
281
|
+
settings (Sequence[RapidataSetting], optional): The list of settings for the order. Defaults to []. Decides how the tasks should be shown.
|
|
282
|
+
selections (Sequence[RapidataSelection], optional): The list of selections for the order. Defaults to None. Decides in what order the tasks should be shown.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
assets = [MediaAsset(path=path) for path in datapoints]
|
|
286
|
+
return self._create_general_order(
|
|
287
|
+
name=name,
|
|
288
|
+
workflow=RankingWorkflow(
|
|
289
|
+
criteria=instruction,
|
|
290
|
+
elo_start=elo_start,
|
|
291
|
+
elo_k_factor=elo_k_factor,
|
|
292
|
+
elo_scaling_factor=elo_scaling_factor,
|
|
293
|
+
total_comparison_budget=total_comparison_budget,
|
|
294
|
+
random_comparisons_ratio=random_comparisons_ratio
|
|
295
|
+
),
|
|
296
|
+
assets=assets,
|
|
297
|
+
responses_per_datapoint=responses_per_comparison,
|
|
298
|
+
contexts=contexts,
|
|
299
|
+
validation_set_id=validation_set_id,
|
|
300
|
+
filters=filters,
|
|
301
|
+
selections=selections,
|
|
302
|
+
settings=settings
|
|
303
|
+
)
|
|
304
|
+
|
|
247
305
|
def create_free_text_order(self,
|
|
248
306
|
name: str,
|
|
249
307
|
instruction: str,
|
|
@@ -16,10 +16,9 @@ from rapidata.api_client.models.add_validation_rapid_model_truth import (
|
|
|
16
16
|
AddValidationRapidModelTruth,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from rapidata.api_client.models.
|
|
20
|
-
|
|
19
|
+
from rapidata.api_client.models.create_datapoint_from_urls_model import (
|
|
20
|
+
CreateDatapointFromUrlsModelMetadataInner,
|
|
21
21
|
)
|
|
22
|
-
|
|
23
22
|
from rapidata.service.openapi_service import OpenAPIService
|
|
24
23
|
|
|
25
24
|
import requests
|
|
@@ -71,7 +70,7 @@ class Rapid():
|
|
|
71
70
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
72
71
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
73
72
|
metadata=[
|
|
74
|
-
|
|
73
|
+
CreateDatapointFromUrlsModelMetadataInner(meta.to_model())
|
|
75
74
|
for meta in self.metadata
|
|
76
75
|
],
|
|
77
76
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -98,7 +97,7 @@ class Rapid():
|
|
|
98
97
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
99
98
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
100
99
|
metadata=[
|
|
101
|
-
|
|
100
|
+
CreateDatapointFromUrlsModelMetadataInner(meta.to_model())
|
|
102
101
|
for meta in self.metadata
|
|
103
102
|
],
|
|
104
103
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from rapidata.api_client import CompareWorkflowModelPairMakerConfig, OnlinePairMakerConfigModel, EloConfigModel
|
|
2
|
+
from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
|
|
3
|
+
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
4
|
+
|
|
5
|
+
class RankingWorkflow(Workflow):
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def __init__(self,
|
|
9
|
+
criteria: str,
|
|
10
|
+
total_comparison_budget: int,
|
|
11
|
+
random_comparisons_ratio,
|
|
12
|
+
elo_start: int,
|
|
13
|
+
elo_k_factor: int,
|
|
14
|
+
elo_scaling_factor: int,
|
|
15
|
+
):
|
|
16
|
+
super().__init__(type="CompareWorkflowConfig")
|
|
17
|
+
|
|
18
|
+
self.criteria = criteria
|
|
19
|
+
self.pair_maker_config = CompareWorkflowModelPairMakerConfig(
|
|
20
|
+
OnlinePairMakerConfigModel(
|
|
21
|
+
_t='OnlinePairMaker',
|
|
22
|
+
totalComparisonBudget=total_comparison_budget,
|
|
23
|
+
randomMatchesRatio=random_comparisons_ratio,
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
self.elo_config = EloConfigModel(
|
|
28
|
+
startingElo=elo_start,
|
|
29
|
+
kFactor=elo_k_factor,
|
|
30
|
+
scalingFactor=elo_scaling_factor,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def _to_model(self) -> CompareWorkflowModel:
|
|
34
|
+
|
|
35
|
+
return CompareWorkflowModel(
|
|
36
|
+
_t="CompareWorkflow",
|
|
37
|
+
criteria=self.criteria,
|
|
38
|
+
eloConfig=self.elo_config,
|
|
39
|
+
pairMakerConfig=self.pair_maker_config
|
|
40
|
+
)
|