rapidata 2.13.1__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (35) hide show
  1. rapidata/api_client/__init__.py +6 -4
  2. rapidata/api_client/api/__init__.py +1 -0
  3. rapidata/api_client/api/dataset_api.py +265 -0
  4. rapidata/api_client/api/workflow_api.py +298 -1
  5. rapidata/api_client/models/__init__.py +5 -4
  6. rapidata/api_client/models/add_campaign_model.py +3 -3
  7. rapidata/api_client/models/add_validation_rapid_model.py +3 -3
  8. rapidata/api_client/models/add_validation_text_rapid_model.py +3 -3
  9. rapidata/api_client/models/create_datapoint_from_urls_model.py +26 -4
  10. rapidata/api_client/models/create_datapoint_from_urls_model_metadata_inner.py +168 -0
  11. rapidata/api_client/models/create_datapoint_result.py +5 -3
  12. rapidata/api_client/models/datapoint.py +7 -30
  13. rapidata/api_client/models/datapoint_asset.py +40 -40
  14. rapidata/api_client/models/datapoint_metadata_model.py +3 -3
  15. rapidata/api_client/models/datapoint_model.py +3 -3
  16. rapidata/api_client/models/get_compare_workflow_results_result.py +3 -3
  17. rapidata/api_client/models/get_datapoint_by_id_result.py +3 -3
  18. rapidata/api_client/models/get_failed_datapoints_result.py +95 -0
  19. rapidata/api_client/models/get_responses_result.py +95 -0
  20. rapidata/api_client/models/get_simple_workflow_results_result.py +3 -3
  21. rapidata/api_client/models/multi_asset_model.py +3 -3
  22. rapidata/api_client/models/upload_text_sources_to_dataset_model.py +17 -2
  23. rapidata/api_client_README.md +8 -4
  24. rapidata/rapidata_client/assets/_media_asset.py +5 -1
  25. rapidata/rapidata_client/assets/_multi_asset.py +6 -1
  26. rapidata/rapidata_client/filter/country_filter.py +1 -1
  27. rapidata/rapidata_client/order/_rapidata_dataset.py +311 -108
  28. rapidata/rapidata_client/order/rapidata_order_manager.py +64 -6
  29. rapidata/rapidata_client/validation/rapids/rapids.py +4 -5
  30. rapidata/rapidata_client/workflow/__init__.py +1 -0
  31. rapidata/rapidata_client/workflow/_ranking_workflow.py +40 -0
  32. {rapidata-2.13.1.dist-info → rapidata-2.14.0.dist-info}/METADATA +1 -1
  33. {rapidata-2.13.1.dist-info → rapidata-2.14.0.dist-info}/RECORD +35 -31
  34. {rapidata-2.13.1.dist-info → rapidata-2.14.0.dist-info}/LICENSE +0 -0
  35. {rapidata-2.13.1.dist-info → rapidata-2.14.0.dist-info}/WHEEL +0 -0
@@ -1,9 +1,10 @@
1
1
  from itertools import zip_longest
2
2
 
3
3
  from rapidata.api_client.models.datapoint_metadata_model import DatapointMetadataModel
4
- from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import (
5
- DatapointMetadataModelMetadataInner,
4
+ from rapidata.api_client.models.create_datapoint_from_urls_model import (
5
+ CreateDatapointFromUrlsModelMetadataInner,
6
6
  )
7
+ from rapidata.api_client.models.create_datapoint_from_urls_model import CreateDatapointFromUrlsModel
7
8
  from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
8
9
  UploadTextSourcesToDatasetModel,
9
10
  )
@@ -14,12 +15,11 @@ from rapidata.service.openapi_service import OpenAPIService
14
15
  from concurrent.futures import ThreadPoolExecutor, as_completed
15
16
  from tqdm import tqdm
16
17
 
17
- from pydantic import StrictBytes, StrictStr
18
- from typing import Optional, cast, Sequence, Generator
18
+ from pydantic import StrictStr
19
+ from typing import cast, Sequence, Generator
19
20
  from logging import Logger
20
- from requests.adapters import HTTPAdapter, Retry
21
21
  import time
22
- import requests
22
+ import threading
23
23
 
24
24
 
25
25
  def chunk_list(lst: list, chunk_size: int) -> Generator:
@@ -77,149 +77,352 @@ class RapidataDataset:
77
77
  future.result() # This will raise any exceptions that occurred during execution
78
78
  pbar.update(1)
79
79
 
80
- def _add_media_from_paths(
80
+ def _process_single_upload(
81
81
  self,
82
- media_paths: list[MediaAsset] | list[MultiAsset],
83
- metadata: Sequence[Metadata] | None = None,
84
- max_workers: int = 10,
85
- max_retries: int = 5,
86
- chunk_size: int = 50,
82
+ media_asset: MediaAsset | MultiAsset,
83
+ meta: Metadata | None,
84
+ index: int,
87
85
  ) -> tuple[list[str], list[str]]:
88
86
  """
89
- Upload media paths in chunks with managed resources.
87
+ Process single upload with error tracking.
90
88
 
91
89
  Args:
92
- media_paths: List of MediaAsset or MultiAsset objects to upload
93
- metadata: Optional sequence of metadata matching media_paths length
94
- max_workers: Maximum number of concurrent upload workers
95
- max_retries: Maximum number of retry attempts per failed request
96
- chunk_size: Number of items to process in each batch
90
+ media_asset: MediaAsset or MultiAsset to upload
91
+ meta: Optional metadata for the asset
92
+ index: Sort index for the upload
93
+ session: Requests session for HTTP requests
97
94
 
98
95
  Returns:
99
- tuple[list[str], list[str]]: Lists of successful and failed URLs
100
-
101
- Raises:
102
- ValueError: If metadata length doesn't match media_paths length
96
+ tuple[list[str], list[str]]: Lists of successful and failed identifiers
103
97
  """
104
- if metadata is not None and len(metadata) != len(media_paths):
105
- raise ValueError("metadata must be None or have the same length as media_paths")
106
-
107
- # Configure session with retry logic
108
- session = requests.Session()
109
- retries = Retry(
110
- total=max_retries,
111
- backoff_factor=1,
112
- status_forcelist=[500, 502, 503, 504],
113
- allowed_methods=["GET"],
114
- respect_retry_after_header=True
115
- )
98
+ local_successful: list[str] = []
99
+ local_failed: list[str] = []
100
+ identifiers_to_track: list[str] = []
116
101
 
117
- adapter = HTTPAdapter(
118
- pool_connections=max_workers * 2,
119
- pool_maxsize=max_workers * 4,
120
- max_retries=retries
121
- )
122
- session.mount('http://', adapter)
123
- session.mount('https://', adapter)
124
-
125
- def upload_datapoint(
126
- media_asset: MediaAsset | MultiAsset,
127
- meta: Metadata | None,
128
- index: int,
129
- session: requests.Session
130
- ) -> tuple[list[str], list[str]]:
131
- """Process single upload with error tracking."""
132
- local_successful: list[str] = []
133
- local_failed: list[str] = []
134
- identifiers_to_track: list[str] = []
135
-
136
- try:
137
- # Get identifier for this upload (URL or file path)
138
- if isinstance(media_asset, MediaAsset):
139
- media_asset.session = session
140
- assets = [media_asset]
141
- identifier = media_asset._url if media_asset._url else media_asset.path
142
- identifiers_to_track = [identifier] if identifier else []
143
- elif isinstance(media_asset, MultiAsset):
144
- assets = cast(list[MediaAsset], media_asset.assets)
145
- for asset in assets:
146
- asset.session = session
147
- identifiers_to_track: list[str] = [
148
- (asset._url if asset._url else cast(str, asset.path))
149
- for asset in assets
150
- ]
151
- else:
152
- raise ValueError(f"Unsupported asset type: {type(media_asset)}")
102
+ try:
103
+ # Get identifier for this upload (URL or file path)
104
+ if isinstance(media_asset, MediaAsset):
105
+ assets = [media_asset]
106
+ identifier = media_asset._url if media_asset._url else media_asset.path
107
+ identifiers_to_track = [identifier] if identifier else []
108
+ elif isinstance(media_asset, MultiAsset):
109
+ assets = cast(list[MediaAsset], media_asset.assets)
110
+ identifiers_to_track: list[str] = [
111
+ (asset._url if asset._url else cast(str, asset.path))
112
+ for asset in assets
113
+ ]
114
+ else:
115
+ raise ValueError(f"Unsupported asset type: {type(media_asset)}")
116
+
117
+ meta_model = meta.to_model() if meta else None
153
118
 
154
- meta_model = meta.to_model() if meta else None
119
+ metadata = [CreateDatapointFromUrlsModelMetadataInner(meta_model)] if meta_model else []
120
+
121
+ local_paths: bool = assets[0].is_local()
122
+ files: list[StrictStr] = []
123
+ for asset in assets:
124
+ if isinstance(asset, MediaAsset):
125
+ files.append(asset.path)
126
+
127
+ if local_paths:
155
128
  model = DatapointMetadataModel(
156
129
  datasetId=self.dataset_id,
157
- metadata=([DatapointMetadataModelMetadataInner(meta_model)] if meta_model else []),
130
+ metadata=metadata,
158
131
  sortIndex=index,
159
132
  )
160
-
161
- files: list[tuple[StrictStr, StrictBytes] | StrictStr | StrictBytes] = []
162
- for asset in assets:
163
- if isinstance(asset, MediaAsset):
164
- files.append(asset.to_file())
165
-
166
133
  upload_response = self.openapi_service.dataset_api.dataset_create_datapoint_post(
167
134
  model=model,
168
- files=files
135
+ files=files # type: ignore
169
136
  )
137
+ else:
138
+ upload_response = self.openapi_service.dataset_api.dataset_dataset_id_datapoints_urls_post(
139
+ dataset_id=self.dataset_id,
140
+ create_datapoint_from_urls_model=CreateDatapointFromUrlsModel(
141
+ urls=files,
142
+ metadata=metadata,
143
+ sortIndex=index
144
+ ),
145
+ )
146
+
147
+ if upload_response.errors:
148
+ error_msg = f"Error uploading datapoint: {upload_response.errors}"
149
+ self._logger.error(error_msg)
150
+ local_failed.extend(identifiers_to_track)
151
+ raise ValueError(error_msg)
170
152
 
171
- if upload_response.errors:
172
- error_msg = f"Error uploading datapoint: {upload_response.errors}"
173
- self._logger.error(error_msg)
174
- local_failed.extend(identifiers_to_track)
175
- raise ValueError(error_msg)
176
-
177
- local_successful.extend(identifiers_to_track)
153
+ local_successful.extend(identifiers_to_track)
178
154
 
179
- except Exception as e:
180
- self._logger.error(f"\nUpload failed for {identifiers_to_track}: {str(e)}") # \n to avoid same line as tqdm
181
- local_failed.extend(identifiers_to_track)
155
+ except Exception as e:
156
+ self._logger.error(f"\nUpload failed for {identifiers_to_track}: {str(e)}") # \n to avoid same line as tqdm
157
+ local_failed.extend(identifiers_to_track)
182
158
 
183
- return local_successful, local_failed
159
+ return local_successful, local_failed
184
160
 
161
+ def _get_progress_tracker(
162
+ self,
163
+ total_uploads: int,
164
+ stop_event: threading.Event,
165
+ progress_error_event: threading.Event,
166
+ progress_poll_interval: float,
167
+ ) -> threading.Thread:
168
+ """
169
+ Create and return a progress tracking thread that shows actual API progress.
170
+
171
+ Args:
172
+ total_uploads: Total number of uploads to track
173
+ initial_ready: Initial number of ready items
174
+ initial_progress: Initial progress state
175
+ stop_event: Event to signal thread to stop
176
+ progress_error_event: Event to signal an error in progress tracking
177
+ progress_poll_interval: Time between progress checks
178
+
179
+ Returns:
180
+ threading.Thread: The progress tracking thread
181
+ """
182
+ def progress_tracking_thread():
183
+ try:
184
+ # Initialize progress bar with 0 completions
185
+ with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
186
+ prev_ready = 0
187
+ prev_failed = 0
188
+ stall_count = 0
189
+ last_progress_time = time.time()
190
+
191
+ # We'll wait for all uploads to finish + some extra time
192
+ # for the backend to fully process everything
193
+ all_uploads_complete = threading.Event()
194
+
195
+ while not stop_event.is_set() or not all_uploads_complete.is_set():
196
+ try:
197
+ current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
198
+
199
+ # Calculate items completed since our initialization
200
+ completed_ready = current_progress.ready
201
+ completed_failed = current_progress.failed
202
+ total_completed = completed_ready + completed_failed
203
+
204
+ # Calculate newly completed items since our last check
205
+ new_ready = current_progress.ready - prev_ready
206
+ new_failed = current_progress.failed - prev_failed
207
+
208
+ # Update progress bar position to show actual completed items
209
+ # First reset to match the actual completed count
210
+ pbar.n = total_completed
211
+ pbar.refresh()
212
+
213
+ if new_ready > 0 or new_failed > 0:
214
+ # We saw progress
215
+ stall_count = 0
216
+ last_progress_time = time.time()
217
+ else:
218
+ stall_count += 1
219
+
220
+ # Update our tracking variables
221
+ prev_ready = current_progress.ready
222
+ prev_failed = current_progress.failed or 0
223
+
224
+ # Check if stop_event was set (all uploads submitted)
225
+ if stop_event.is_set():
226
+ elapsed_since_last_progress = time.time() - last_progress_time
227
+
228
+ # If we haven't seen progress for a while after all uploads were submitted
229
+ if elapsed_since_last_progress > 5.0:
230
+ # If we're at 100%, we're done
231
+ if total_completed >= total_uploads:
232
+ all_uploads_complete.set()
233
+ break
234
+
235
+ # If we're not at 100% but it's been a while with no progress
236
+ if stall_count > 5:
237
+ # We've polled several times with no progress, assume we're done
238
+ self._logger.warning(f"\nProgress seems stalled at {total_completed}/{total_uploads}. Please try again.")
239
+ break
240
+
241
+ except Exception as e:
242
+ self._logger.error(f"\nError checking progress: {str(e)}")
243
+ stall_count += 1
244
+
245
+ if stall_count > 10: # Too many consecutive errors
246
+ progress_error_event.set()
247
+ break
248
+
249
+ # Sleep before next poll
250
+ time.sleep(progress_poll_interval)
251
+
252
+ except Exception as e:
253
+ self._logger.error(f"Progress tracking thread error: {str(e)}")
254
+ progress_error_event.set()
255
+
256
+ # Create and return the thread
257
+ progress_thread = threading.Thread(target=progress_tracking_thread)
258
+ progress_thread.daemon = True
259
+ return progress_thread
185
260
 
186
- # Process uploads in chunks
261
+ def _process_uploads_in_chunks(
262
+ self,
263
+ media_paths: list[MediaAsset] | list[MultiAsset],
264
+ metadata: Sequence[Metadata] | None,
265
+ max_workers: int,
266
+ chunk_size: int,
267
+ stop_progress_tracking: threading.Event,
268
+ progress_tracking_error: threading.Event
269
+ ) -> tuple[list[str], list[str]]:
270
+ """
271
+ Process uploads in chunks with a ThreadPoolExecutor.
272
+
273
+ Args:
274
+ media_paths: List of assets to upload
275
+ metadata: Optional sequence of metadata
276
+ session: Requests session for HTTP requests
277
+ max_workers: Maximum number of concurrent workers
278
+ chunk_size: Number of items to process in each batch
279
+ stop_progress_tracking: Event to signal progress tracking to stop
280
+ progress_tracking_error: Event to detect progress tracking errors
281
+
282
+ Returns:
283
+ tuple[list[str], list[str]]: Lists of successful and failed uploads
284
+ """
187
285
  successful_uploads: list[str] = []
188
286
  failed_uploads: list[str] = []
189
- total_uploads = len(media_paths)
190
-
191
- with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
192
- for chunk_idx, chunk in enumerate(chunk_list(media_paths, chunk_size)):
193
- chunk_metadata = metadata[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] if metadata else None
194
-
195
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
287
+
288
+ try:
289
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
290
+ # Process uploads in chunks to avoid overwhelming the system
291
+ for chunk_idx, chunk in enumerate(chunk_list(media_paths, chunk_size)):
292
+ chunk_metadata = metadata[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] if metadata else None
293
+
196
294
  futures = [
197
295
  executor.submit(
198
- upload_datapoint,
296
+ self._process_single_upload,
199
297
  media_asset,
200
298
  meta,
201
- index=(chunk_idx * chunk_size + i),
202
- session=session
299
+ index=(chunk_idx * chunk_size + i)
203
300
  )
204
301
  for i, (media_asset, meta) in enumerate(zip_longest(chunk, chunk_metadata or []))
205
302
  ]
206
-
303
+
304
+ # Wait for this chunk to complete before starting the next one
207
305
  for future in as_completed(futures):
306
+ if progress_tracking_error.is_set():
307
+ raise RuntimeError("Progress tracking failed, aborting uploads")
308
+
208
309
  try:
209
310
  chunk_successful, chunk_failed = future.result()
210
311
  successful_uploads.extend(chunk_successful)
211
312
  failed_uploads.extend(chunk_failed)
212
313
  except Exception as e:
213
314
  self._logger.error(f"Future execution failed: {str(e)}")
214
- finally:
215
- pbar.update(1)
315
+ finally:
316
+ # Signal to the progress tracking thread that all uploads have been submitted
317
+ stop_progress_tracking.set()
318
+
319
+ return successful_uploads, failed_uploads
216
320
 
217
- # Log summary statistics
218
- success_rate = len(successful_uploads) / total_uploads * 100 if total_uploads > 0 else 0
219
- self._logger.info(f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)")
321
+ def _log_final_progress(
322
+ self,
323
+ total_uploads: int,
324
+ progress_poll_interval: float,
325
+ successful_uploads: list[str],
326
+ failed_uploads: list[str]
327
+ ) -> None:
328
+ """
329
+ Log the final progress of the upload operation.
330
+
331
+ Args:
332
+ total_uploads: Total number of uploads
333
+ initial_ready: Initial number of ready items
334
+ initial_progress: Initial progress state
335
+ progress_poll_interval: Time between progress checks
336
+ successful_uploads: List of successful uploads for fallback reporting
337
+ failed_uploads: List of failed uploads for fallback reporting
338
+ """
339
+ try:
340
+ # Get final progress
341
+ final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
342
+ total_ready = final_progress.ready
343
+ total_failed = final_progress.failed
344
+
345
+ # Make sure we account for all uploads
346
+ if total_ready + total_failed < total_uploads:
347
+ # Try one more time after a longer wait
348
+ time.sleep(5 * progress_poll_interval)
349
+ final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.dataset_id)
350
+ total_ready = final_progress.ready
351
+ total_failed = final_progress.failed
352
+
353
+ success_rate = (total_ready / total_uploads * 100) if total_uploads > 0 else 0
354
+
355
+ self._logger.info(f"Upload complete: {total_ready} ready, {total_uploads-total_ready} failed ({success_rate:.1f}% success rate)")
356
+ print(f"Upload complete, {total_ready} ready, {total_uploads-total_ready} failed ({success_rate:.1f}% success rate)")
357
+ except Exception as e:
358
+ self._logger.error(f"Error getting final progress: {str(e)}")
359
+ self._logger.info(f"Upload summary from local tracking: {len(successful_uploads)} succeeded, {len(failed_uploads)} failed")
220
360
 
221
361
  if failed_uploads:
222
362
  print(f"Failed uploads: {failed_uploads}")
223
363
 
224
- return successful_uploads, failed_uploads
364
+ def _add_media_from_paths(
365
+ self,
366
+ media_paths: list[MediaAsset] | list[MultiAsset],
367
+ metadata: Sequence[Metadata] | None = None,
368
+ max_workers: int = 5,
369
+ chunk_size: int = 50,
370
+ progress_poll_interval: float = 0.5,
371
+ ) -> tuple[list[str], list[str]]:
372
+ """
373
+ Upload media paths in chunks with managed resources.
374
+
375
+ Args:
376
+ media_paths: List of MediaAsset or MultiAsset objects to upload
377
+ metadata: Optional sequence of metadata matching media_paths length
378
+ max_workers: Maximum number of concurrent upload workers
379
+ chunk_size: Number of items to process in each batch
380
+ progress_poll_interval: Time in seconds between progress checks
381
+
382
+ Returns:
383
+ tuple[list[str], list[str]]: Lists of successful and failed URLs
384
+
385
+ Raises:
386
+ ValueError: If metadata length doesn't match media_paths length
387
+ """
388
+ if metadata is not None and len(metadata) != len(media_paths):
389
+ raise ValueError("metadata must be None or have the same length as media_paths")
390
+
391
+ # Setup tracking variables
392
+ total_uploads = len(media_paths)
393
+
394
+ # Create thread control events
395
+ stop_progress_tracking = threading.Event()
396
+ progress_tracking_error = threading.Event()
397
+
398
+ # Create and start progress tracking thread
399
+ progress_thread = self._get_progress_tracker(
400
+ total_uploads,
401
+ stop_progress_tracking,
402
+ progress_tracking_error,
403
+ progress_poll_interval
404
+ )
405
+ progress_thread.start()
406
+
407
+ # Process uploads in chunks
408
+ try:
409
+ successful_uploads, failed_uploads = self._process_uploads_in_chunks(
410
+ media_paths,
411
+ metadata,
412
+ max_workers,
413
+ chunk_size,
414
+ stop_progress_tracking,
415
+ progress_tracking_error
416
+ )
417
+ finally:
418
+ progress_thread.join(10) # Add margin to the timeout for tqdm
419
+
420
+ # Log final progress
421
+ self._log_final_progress(
422
+ total_uploads,
423
+ progress_poll_interval,
424
+ successful_uploads,
425
+ failed_uploads
426
+ )
225
427
 
428
+ return successful_uploads, failed_uploads
@@ -1,17 +1,14 @@
1
- from typing import Sequence
1
+ from typing import Sequence, Optional
2
2
  from urllib3._collections import HTTPHeaderDict
3
3
 
4
4
  from rapidata.service.openapi_service import OpenAPIService
5
5
  from rapidata.rapidata_client.assets.data_type_enum import RapidataDataTypes
6
- from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
7
6
  from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
8
7
  from rapidata.rapidata_client.order._rapidata_order_builder import RapidataOrderBuilder
9
8
  from rapidata.rapidata_client.metadata import PromptMetadata, SelectWordsMetadata
10
9
  from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
11
10
  from rapidata.rapidata_client.referee._early_stopping_referee import EarlyStoppingReferee
12
11
  from rapidata.rapidata_client.selection._base_selection import RapidataSelection
13
- from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
14
- from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
15
12
  from rapidata.rapidata_client.workflow import (
16
13
  Workflow,
17
14
  ClassifyWorkflow,
@@ -20,7 +17,9 @@ from rapidata.rapidata_client.workflow import (
20
17
  SelectWordsWorkflow,
21
18
  LocateWorkflow,
22
19
  DrawWorkflow,
23
- TimestampWorkflow)
20
+ TimestampWorkflow,
21
+ RankingWorkflow
22
+ )
24
23
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
25
24
  from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
26
25
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
@@ -243,7 +242,66 @@ class RapidataOrderManager:
243
242
  selections=selections,
244
243
  settings=settings
245
244
  )
246
-
245
+
246
+ def create_ranking_order(self,
247
+ name: str,
248
+ instruction: str,
249
+ datapoints: list[str],
250
+ responses_per_comparison: int,
251
+ total_comparison_budget: int,
252
+ random_comparisons_ratio: float = 0.5,
253
+ elo_start: int = 1200,
254
+ elo_k_factor: int = 40,
255
+ elo_scaling_factor: int = 400,
256
+ contexts: Optional[list[str]] = None,
257
+ validation_set_id: Optional[str] = None,
258
+ filters: Sequence[RapidataFilter] = [],
259
+ settings: Sequence[RapidataSetting] = [],
260
+ selections: Optional[Sequence[RapidataSelection]] = None) -> RapidataOrder:
261
+ """
262
+ Create a ranking order.
263
+
264
+ Args:
265
+ name (str): The name of the order.
266
+ instruction (str): The question asked from People when They see two datapoints.
267
+ datapoints (list[str]): A list of datapoints that will participate in the ranking.
268
+ total_comparison_budget (int): The total number of (pairwise-)comparisons that can be made.
269
+ random_comparisons_ratio (float, optional): The fraction of random comparisons in the ranking process.
270
+ The rest will focus on pairing similarly ranked datapoints.
271
+ elo_start (int, optional): The initial ELO rating assigned to each datapoint.
272
+ elo_k_factor (int, optional): The K-factor used for ELO updates.
273
+ elo_scaling_factor (int, optional): The scaling factor used in the ELO calculation.
274
+ responses_per_comparison (int, optional): The number of responses collected per comparison.
275
+ contexts (list[str], optional): The list of contexts for the comparison. Defaults to None.\n
276
+ If provided has to be the same length as datapoints and will be shown in addition to the instruction.
277
+ (Therefore will be different for each datapoint) Will be match up with the datapoints using the list index.
278
+ validation_set_id (str, optional): The ID of the validation set. Defaults to None.\n
279
+ If provided, one validation task will be shown infront of the datapoints that will be labeled.
280
+ filters (Sequence[RapidataFilter], optional): The list of filters for the order. Defaults to []. Decides who the tasks should be shown to.
281
+ settings (Sequence[RapidataSetting], optional): The list of settings for the order. Defaults to []. Decides how the tasks should be shown.
282
+ selections (Sequence[RapidataSelection], optional): The list of selections for the order. Defaults to None. Decides in what order the tasks should be shown.
283
+ """
284
+
285
+ assets = [MediaAsset(path=path) for path in datapoints]
286
+ return self._create_general_order(
287
+ name=name,
288
+ workflow=RankingWorkflow(
289
+ criteria=instruction,
290
+ elo_start=elo_start,
291
+ elo_k_factor=elo_k_factor,
292
+ elo_scaling_factor=elo_scaling_factor,
293
+ total_comparison_budget=total_comparison_budget,
294
+ random_comparisons_ratio=random_comparisons_ratio
295
+ ),
296
+ assets=assets,
297
+ responses_per_datapoint=responses_per_comparison,
298
+ contexts=contexts,
299
+ validation_set_id=validation_set_id,
300
+ filters=filters,
301
+ selections=selections,
302
+ settings=settings
303
+ )
304
+
247
305
  def create_free_text_order(self,
248
306
  name: str,
249
307
  instruction: str,
@@ -16,10 +16,9 @@ from rapidata.api_client.models.add_validation_rapid_model_truth import (
16
16
  AddValidationRapidModelTruth,
17
17
  )
18
18
 
19
- from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import (
20
- DatapointMetadataModelMetadataInner,
19
+ from rapidata.api_client.models.create_datapoint_from_urls_model import (
20
+ CreateDatapointFromUrlsModelMetadataInner,
21
21
  )
22
-
23
22
  from rapidata.service.openapi_service import OpenAPIService
24
23
 
25
24
  import requests
@@ -71,7 +70,7 @@ class Rapid():
71
70
  payload=AddValidationRapidModelPayload(self.payload),
72
71
  truth=AddValidationRapidModelTruth(self.truth),
73
72
  metadata=[
74
- DatapointMetadataModelMetadataInner(meta.to_model())
73
+ CreateDatapointFromUrlsModelMetadataInner(meta.to_model())
75
74
  for meta in self.metadata
76
75
  ],
77
76
  randomCorrectProbability=self.randomCorrectProbability,
@@ -98,7 +97,7 @@ class Rapid():
98
97
  payload=AddValidationRapidModelPayload(self.payload),
99
98
  truth=AddValidationRapidModelTruth(self.truth),
100
99
  metadata=[
101
- DatapointMetadataModelMetadataInner(meta.to_model())
100
+ CreateDatapointFromUrlsModelMetadataInner(meta.to_model())
102
101
  for meta in self.metadata
103
102
  ],
104
103
  randomCorrectProbability=self.randomCorrectProbability,
@@ -7,3 +7,4 @@ from ._free_text_workflow import FreeTextWorkflow
7
7
  from ._select_words_workflow import SelectWordsWorkflow
8
8
  from ._evaluation_workflow import EvaluationWorkflow
9
9
  from ._timestamp_workflow import TimestampWorkflow
10
+ from ._ranking_workflow import RankingWorkflow
@@ -0,0 +1,40 @@
1
+ from rapidata.api_client import CompareWorkflowModelPairMakerConfig, OnlinePairMakerConfigModel, EloConfigModel
2
+ from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
3
+ from rapidata.rapidata_client.workflow._base_workflow import Workflow
4
+
5
+ class RankingWorkflow(Workflow):
6
+
7
+
8
+ def __init__(self,
9
+ criteria: str,
10
+ total_comparison_budget: int,
11
+ random_comparisons_ratio,
12
+ elo_start: int,
13
+ elo_k_factor: int,
14
+ elo_scaling_factor: int,
15
+ ):
16
+ super().__init__(type="CompareWorkflowConfig")
17
+
18
+ self.criteria = criteria
19
+ self.pair_maker_config = CompareWorkflowModelPairMakerConfig(
20
+ OnlinePairMakerConfigModel(
21
+ _t='OnlinePairMaker',
22
+ totalComparisonBudget=total_comparison_budget,
23
+ randomMatchesRatio=random_comparisons_ratio,
24
+ )
25
+ )
26
+
27
+ self.elo_config = EloConfigModel(
28
+ startingElo=elo_start,
29
+ kFactor=elo_k_factor,
30
+ scalingFactor=elo_scaling_factor,
31
+ )
32
+
33
+ def _to_model(self) -> CompareWorkflowModel:
34
+
35
+ return CompareWorkflowModel(
36
+ _t="CompareWorkflow",
37
+ criteria=self.criteria,
38
+ eloConfig=self.elo_config,
39
+ pairMakerConfig=self.pair_maker_config
40
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rapidata
3
- Version: 2.13.1
3
+ Version: 2.14.0
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG