rapidata 1.4.3__py3-none-any.whl → 1.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

rapidata/__init__.py CHANGED
@@ -10,7 +10,7 @@ from .rapidata_client import (
10
10
  ConditionalValidationSelection,
11
11
  CappedSelection,
12
12
  NaiveReferee,
13
- ClassifyEarlyStoppingReferee,
13
+ EarlyStoppingReferee,
14
14
  PrivateTextMetadata,
15
15
  PublicTextMetadata,
16
16
  PromptMetadata,
@@ -12,7 +12,7 @@ from .selection import (
12
12
  ConditionalValidationSelection,
13
13
  CappedSelection,
14
14
  )
15
- from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
15
+ from .referee import NaiveReferee, EarlyStoppingReferee
16
16
  from .metadata import (
17
17
  PrivateTextMetadata,
18
18
  PublicTextMetadata,
@@ -1,4 +1,4 @@
1
- import time
1
+ from time import sleep
2
2
  from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
3
3
  from rapidata.service.openapi_service import OpenAPIService
4
4
  import json
@@ -61,11 +61,11 @@ class RapidataOrder:
61
61
  :param refresh_rate: How often to refresh the progress bar, in seconds.
62
62
  :type refresh_rate: float
63
63
  """
64
- total_rapids = self._get_total_rapids()
64
+ total_rapids = self._get_workflow_progress().total
65
65
  with tqdm(total=total_rapids, desc="Processing order", unit="rapids") as pbar:
66
66
  completed_rapids = 0
67
67
  while True:
68
- current_completed = self._get_completed_rapids()
68
+ current_completed = self._get_workflow_progress().completed
69
69
  if current_completed > completed_rapids:
70
70
  pbar.update(current_completed - completed_rapids)
71
71
  completed_rapids = current_completed
@@ -73,7 +73,7 @@ class RapidataOrder:
73
73
  if completed_rapids >= total_rapids:
74
74
  break
75
75
 
76
- time.sleep(refresh_rate)
76
+ sleep(refresh_rate)
77
77
 
78
78
  def _get_workflow_id(self):
79
79
  if self._workflow_id:
@@ -86,31 +86,38 @@ class RapidataOrder:
86
86
  self._workflow_id = cast(WorkflowArtifactModel, pipeline.artifacts["workflow-artifact"].actual_instance).workflow_id
87
87
  break
88
88
  except Exception:
89
- time.sleep(2)
89
+ sleep(2)
90
90
  if not self._workflow_id:
91
- raise Exception("Order has not started yet. Please wait for a few seconds and try again.")
91
+ raise Exception("Order has not started yet. Please start it or wait for a few seconds and try again.")
92
92
  return self._workflow_id
93
-
94
- def _get_total_rapids(self):
95
- workflow_id = self._get_workflow_id()
96
- return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).total
97
-
98
- def _get_completed_rapids(self):
93
+
94
+ def _get_workflow_progress(self):
99
95
  workflow_id = self._get_workflow_id()
100
- return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).completed
96
+ progress = None
97
+ for _ in range(2):
98
+ try:
99
+ progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
100
+ break
101
+ except Exception:
102
+ sleep(5)
101
103
 
102
- def get_progress_percentage(self):
103
- workflow_id = self._get_workflow_id()
104
- progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
105
- return progress.completion_percentage
104
+ if not progress:
105
+ raise Exception(f"Failed to get progress. Please try again in a few seconds.")
106
+
107
+ return progress
106
108
 
109
+
107
110
  def get_results(self):
108
111
  """
109
- Gets the results of the order.
112
+ Gets the results of the order.
113
+ If the order is still processing, this method will block until the order is completed and then return the results.
110
114
 
111
115
  :return: The results of the order.
112
116
  :rtype: dict
113
117
  """
118
+ while self.get_status().state == "Processing":
119
+ sleep(5)
120
+
114
121
  try:
115
122
  # Get the raw result string
116
123
  result_str = self.openapi_service.order_api.order_result_get(id=self.order_id)
@@ -28,8 +28,8 @@ class RapidataClient:
28
28
  self,
29
29
  client_id: str,
30
30
  client_secret: str,
31
- endpoint: str = "https://api.app.rapidata.ai",
32
- token_url: str = "https://api.app.rapidata.ai/connect/token",
31
+ endpoint: str = "https://auth.rapidata.ai",
32
+ token_url: str = "https://auth.rapidata.ai/connect/token",
33
33
  oauth_scope: str = "openid",
34
34
  cert_path: str | None = None,
35
35
  ):
@@ -1,3 +1,3 @@
1
1
  from .base_referee import Referee
2
2
  from .naive_referee import NaiveReferee #as MaxVoteReferee
3
- from .classify_early_stopping_referee import ClassifyEarlyStoppingReferee
3
+ from .early_stopping_referee import EarlyStoppingReferee
@@ -5,7 +5,7 @@ from rapidata.api_client.models.early_stopping_referee_model import (
5
5
  )
6
6
 
7
7
 
8
- class ClassifyEarlyStoppingReferee(Referee):
8
+ class EarlyStoppingReferee(Referee):
9
9
  """A referee that stops the task when confidence in the winning category exceeds a threshold.
10
10
 
11
11
  This referee implements an early stopping mechanism for classification tasks.
@@ -15,6 +15,8 @@ class ClassifyEarlyStoppingReferee(Referee):
15
15
  The threshold behaves logarithmically, meaning small increments (e.g., from 0.99
16
16
  to 0.999) can significantly impact the stopping criteria.
17
17
 
18
+ This referee is supported for the classification and compare tasks.
19
+
18
20
  Attributes:
19
21
  threshold (float): The confidence threshold for early stopping.
20
22
  max_vote_count (int): The maximum number of votes allowed before stopping.
@@ -3,31 +3,31 @@ from rapidata.rapidata_client.referee.base_referee import Referee
3
3
 
4
4
 
5
5
  class NaiveReferee(Referee):
6
- """A simple referee that completes a task after a fixed number of guesses.
6
+ """A simple referee that completes a task after a fixed number of responses.
7
7
 
8
8
  This referee implements a straightforward approach to task completion,
9
9
  where the task is considered finished after a predetermined number of
10
- guesses have been made, regardless of the content or quality of those guesses.
10
+ responses have been made, regardless of the content or quality of those responses.
11
11
 
12
12
  Attributes:
13
- required_guesses (int): The number of guesses required to complete the task.
13
+ responses (int): The number of responses required to complete the task.
14
14
  """
15
15
 
16
- def __init__(self, required_guesses: int = 10):
16
+ def __init__(self, responses: int = 10):
17
17
  """Initialize the NaiveReferee.
18
18
 
19
19
  Args:
20
- required_guesses (int, optional): The number of guesses required
21
- to complete the task. Defaults to 10.
20
+ responses (int, optional): The number of responses required
21
+ to complete the task. Defaults to 10. This is per media item.
22
22
  """
23
23
  super().__init__()
24
- self.required_guesses = required_guesses
24
+ self.responses = responses
25
25
 
26
26
  def to_dict(self):
27
27
  return {
28
28
  "_t": "NaiveRefereeConfig",
29
- "guessesRequired": self.required_guesses,
29
+ "guessesRequired": self.responses,
30
30
  }
31
31
 
32
32
  def to_model(self):
33
- return NaiveRefereeModel(_t="NaiveReferee", totalVotes=self.required_guesses)
33
+ return NaiveRefereeModel(_t="NaiveReferee", totalVotes=self.responses)
@@ -1,7 +1,7 @@
1
1
  from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
2
2
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
3
3
  from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
4
- from rapidata.rapidata_client.referee.classify_early_stopping_referee import ClassifyEarlyStoppingReferee
4
+ from rapidata.rapidata_client.referee.early_stopping_referee import EarlyStoppingReferee
5
5
  from rapidata.rapidata_client.selection.base_selection import Selection
6
6
  from rapidata.rapidata_client.workflow.classify_workflow import ClassifyWorkflow
7
7
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
@@ -43,13 +43,13 @@ class ClassificationOrderBuilder:
43
43
 
44
44
  def create(self, submit: bool = True, max_upload_workers: int = 10):
45
45
  if self._probability_threshold and self._responses_required:
46
- referee = ClassifyEarlyStoppingReferee(
46
+ referee = EarlyStoppingReferee(
47
47
  max_vote_count=self._responses_required,
48
48
  threshold=self._probability_threshold
49
49
  )
50
50
 
51
51
  else:
52
- referee = NaiveReferee(required_guesses=self._responses_required)
52
+ referee = NaiveReferee(responses=self._responses_required)
53
53
 
54
54
  assets = [MediaAsset(path=media_path) for media_path in self._media_paths]
55
55
 
@@ -2,7 +2,7 @@ from rapidata.service.openapi_service import OpenAPIService
2
2
  from rapidata.rapidata_client.metadata import Metadata
3
3
  from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
4
4
  from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
5
- from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
5
+ from rapidata.rapidata_client.referee import NaiveReferee, EarlyStoppingReferee
6
6
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
7
7
  from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
8
8
  from rapidata.rapidata_client.selection.base_selection import Selection
@@ -18,6 +18,7 @@ class CompareOrderBuilder:
18
18
  self._responses_required = 10
19
19
  self._metadata = None
20
20
  self._validation_set_id = None
21
+ self._probability_threshold = None
21
22
 
22
23
  def responses(self, responses_required: int) -> 'CompareOrderBuilder':
23
24
  """Set the number of resoonses required per matchup/pairing for the comparison order."""
@@ -34,7 +35,20 @@ class CompareOrderBuilder:
34
35
  self._validation_set_id = validation_set_id
35
36
  return self
36
37
 
38
+ def probability_threshold(self, probability_threshold: float) -> 'CompareOrderBuilder':
39
+ """Set the probability threshold for early stopping."""
40
+ self._probability_threshold = probability_threshold
41
+ return self
42
+
37
43
  def create(self, submit: bool = True, max_upload_workers: int = 10):
44
+ if self._probability_threshold and self._responses_required:
45
+ referee = EarlyStoppingReferee(
46
+ max_vote_count=self._responses_required,
47
+ threshold=self._probability_threshold
48
+ )
49
+
50
+ else:
51
+ referee = NaiveReferee(responses=self._responses_required)
38
52
  selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
39
53
  if self._validation_set_id
40
54
  else [LabelingSelection(amount=3)])
@@ -46,7 +60,7 @@ class CompareOrderBuilder:
46
60
  criteria=self._criteria
47
61
  )
48
62
  )
49
- .referee(NaiveReferee(required_guesses=self._responses_required))
63
+ .referee(referee)
50
64
  .media(media_paths, metadata=self._metadata) # type: ignore
51
65
  .selections(selection)
52
66
  .create(submit=submit, max_workers=max_upload_workers))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rapidata
3
- Version: 1.4.3
3
+ Version: 1.4.5
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG
@@ -1,4 +1,4 @@
1
- rapidata/__init__.py,sha256=vjmq4p45annpd9K5QbD_0RdcLIkDnNFkNb55ZVws56A,587
1
+ rapidata/__init__.py,sha256=xkV2cgV_RBZCW8CNxWFMlf4EFJUWXcJnqfzQtJsG6NA,579
2
2
  rapidata/api_client/__init__.py,sha256=kVvFCI9LJojctHrtdas3VH3EPpbuT8S2OrERxSew5v4,23913
3
3
  rapidata/api_client/api/__init__.py,sha256=S0oVoAVMys10M-Z1SqirMdnHMYSHH3Lz6iph1CfILc0,1004
4
4
  rapidata/api_client/api/campaign_api.py,sha256=DxPFqt9F6c9OpXu_Uxhsrib2NVwnbcZFa3Vkrj7cIuA,40474
@@ -310,7 +310,7 @@ rapidata/api_client/models/workflow_split_model_filter_configs_inner.py,sha256=1
310
310
  rapidata/api_client/models/workflow_state.py,sha256=5LAK1se76RCoozeVB6oxMPb8p_5bhLZJqn7q5fFQWis,850
311
311
  rapidata/api_client/rest.py,sha256=zmCIFQC2l1t-KZcq-TgEm3vco3y_LK6vRm3Q07K-xRI,9423
312
312
  rapidata/api_client_README.md,sha256=58aoLkfxLpoYpBu7uOQkdfPB5wwrkW1cix_LZ1GLQQQ,37571
313
- rapidata/rapidata_client/__init__.py,sha256=mX25RNyeNqboIzCMz_Lm3SnMTXM8LJ_jpXuIRraafww,803
313
+ rapidata/rapidata_client/__init__.py,sha256=2SsfOqLpkTh6b0wTKwUaBboxwuU9_SvP4TK-eUoocGo,795
314
314
  rapidata/rapidata_client/assets/__init__.py,sha256=T-XKvMSkmyI8iYLUYDdZ3LrrSInHsGMUY_Tz77hhnlE,240
315
315
  rapidata/rapidata_client/assets/base_asset.py,sha256=B2YWH1NgaeYUYHDW3OPpHM_bqawHbH4EjnRCE2BYwiM,298
316
316
  rapidata/rapidata_client/assets/media_asset.py,sha256=4xU1k2abdHGwbkJAYNOZYyOPB__5VBRDvRjklegFufQ,887
@@ -341,13 +341,13 @@ rapidata/rapidata_client/metadata/prompt_metadata.py,sha256=_FypjKWrC3iKUO_G2CVw
341
341
  rapidata/rapidata_client/metadata/public_text_metadata.py,sha256=LTiBQHs6izxQ6-C84d6Pf7lL4ENTDgg__HZnDKvzvMc,511
342
342
  rapidata/rapidata_client/metadata/transcription_metadata.py,sha256=THtDEVCON4UlcXHmXrjilaOLHys4TrktUOPGWnXaCcc,631
343
343
  rapidata/rapidata_client/order/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
344
- rapidata/rapidata_client/order/rapidata_order.py,sha256=u0VhNgnbT0dn3AJiITqjtqn_OWgdwUcXrgU4iblPXZI,4663
344
+ rapidata/rapidata_client/order/rapidata_order.py,sha256=IMozRlrYK_yqGLk0GQbLNYL_RRb7gUcSgDe6NUjOEEE,4812
345
345
  rapidata/rapidata_client/order/rapidata_order_builder.py,sha256=QvICzduLAuAkf8qFKxHV3zAag838WnV9lEzWdD4dxI0,15926
346
- rapidata/rapidata_client/rapidata_client.py,sha256=UKFDbck3TOcSIRM31gHpl7vpSCHd3jPrjiTODfBPH44,6956
347
- rapidata/rapidata_client/referee/__init__.py,sha256=Ow9MQsONhF4sX2wFK9jbvSBrpcJgtq3OglIQMkBUdIY,167
346
+ rapidata/rapidata_client/rapidata_client.py,sha256=lXOubU-d1AWX28fhFPdLgIM1p6CCLOGJPJGQNehEqX0,6950
347
+ rapidata/rapidata_client/referee/__init__.py,sha256=E1VODxTjoQRnxzdgMh3aRlDLouxe1nWuvozEHXD2gq4,150
348
348
  rapidata/rapidata_client/referee/base_referee.py,sha256=bMy7cw0a-pGNbFu6u_1_Jplu0A483Ubj4oDQzh8vu8k,493
349
- rapidata/rapidata_client/referee/classify_early_stopping_referee.py,sha256=B5wsqKM3_Oc1TU_MFGiIyiXjwK1LcmaVjhzLdaL8Cgw,1797
350
- rapidata/rapidata_client/referee/naive_referee.py,sha256=KWMLSc73gOdM8YT_ciFhfN7J4eKgtOFphBG9tIra9g0,1179
349
+ rapidata/rapidata_client/referee/early_stopping_referee.py,sha256=Dg2Kk7OiLBtS3kknsLxyJIlS27xmPvsikFR6g4xlbTE,1862
350
+ rapidata/rapidata_client/referee/naive_referee.py,sha256=uOmrzCfpDopm6ww8u4Am5vz5qGNM7621HNJ-QIeH9wM,1164
351
351
  rapidata/rapidata_client/selection/__init__.py,sha256=RFdVUeo2bjCL1gPIL6HXzbpOBqY9kYedkNgb2_ANNLs,321
352
352
  rapidata/rapidata_client/selection/base_selection.py,sha256=Y3HkROPm4I4HLNiR0HuHKpvk236KkRlsoDxQATm_chY,138
353
353
  rapidata/rapidata_client/selection/capped_selection.py,sha256=ikEIT1sKbwWSK8zVqabT-se5LwsEbFs9dkMbO3u_ERk,807
@@ -356,8 +356,8 @@ rapidata/rapidata_client/selection/demographic_selection.py,sha256=DnMbLhzRItZ0t
356
356
  rapidata/rapidata_client/selection/labeling_selection.py,sha256=cqDMQEXfQGMmgIvPgGOYgIGaXflV_J7LZsGOsakLXqo,425
357
357
  rapidata/rapidata_client/selection/validation_selection.py,sha256=HswzD2SvZZWisNLoGj--0sT_TIK8crYp3xGGndo6aLY,523
358
358
  rapidata/rapidata_client/simple_builders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
359
- rapidata/rapidata_client/simple_builders/simple_classification_builders.py,sha256=kaa8DNQJ3WOd_LdXItNkENMCID5v5WThbAEQVBNoNyc,5552
360
- rapidata/rapidata_client/simple_builders/simple_compare_builders.py,sha256=MOb_QwwjaERRks82LemtG46ovoMLNQmNLN_A7JpZK0M,4240
359
+ rapidata/rapidata_client/simple_builders/simple_classification_builders.py,sha256=gTb1OzJZjNW48ltrp0v9tcDjLee9_ZiLw78zYziVu_g,5520
360
+ rapidata/rapidata_client/simple_builders/simple_compare_builders.py,sha256=GmNIUoAkoAHp0OnWLOgeI5-wnDiycuye0o1CUa9c3Q0,4808
361
361
  rapidata/rapidata_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
362
362
  rapidata/rapidata_client/utils/utils.py,sha256=Fl99gCnh_HnieIp099xEvEv4g2kEIKiFcUp0G2iz6x8,815
363
363
  rapidata/rapidata_client/workflow/__init__.py,sha256=xWuzAhBnbcUFfWcgYrzj8ZYLSOXyFtgfepgMrf0hNhU,290
@@ -372,7 +372,7 @@ rapidata/service/local_file_service.py,sha256=pgorvlWcx52Uh3cEG6VrdMK_t__7dacQ_5
372
372
  rapidata/service/openapi_service.py,sha256=pGcOCttKZW0PVCSM7Kfehe5loh7CxmmDDbu4UJbamnI,2770
373
373
  rapidata/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
374
374
  rapidata/utils/image_utils.py,sha256=TldO3eJWG8IhfJjm5MfNGO0mEDm1mQTsRoA0HLU1Uxs,404
375
- rapidata-1.4.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
376
- rapidata-1.4.3.dist-info/METADATA,sha256=AgMjhm057wT3bZqvDtZZstb1jHm8o3OKxMhiiVL18og,1012
377
- rapidata-1.4.3.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
378
- rapidata-1.4.3.dist-info/RECORD,,
375
+ rapidata-1.4.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
376
+ rapidata-1.4.5.dist-info/METADATA,sha256=4tGW6aFOk8Zw8_fi7hFzdIWKoimwxRwy0VLmpVCpSW0,1012
377
+ rapidata-1.4.5.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
378
+ rapidata-1.4.5.dist-info/RECORD,,