rapidata 2.26.0__py3-none-any.whl → 2.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +2 -2
- rapidata/api_client/__init__.py +11 -3
- rapidata/api_client/api/__init__.py +2 -1
- rapidata/api_client/api/client_api.py +0 -257
- rapidata/api_client/api/customer_rapid_api.py +1644 -0
- rapidata/api_client/api/dataset_api.py +358 -1
- rapidata/api_client/api/newsletter_api.py +11 -299
- rapidata/api_client/api/user_rapid_api.py +1385 -0
- rapidata/api_client/api/validation_set_api.py +6 -6
- rapidata/api_client/models/__init__.py +9 -2
- rapidata/api_client/models/add_campaign_model.py +5 -0
- rapidata/api_client/models/add_validation_rapid_model.py +3 -3
- rapidata/api_client/models/add_validation_rapid_model_truth.py +25 -11
- rapidata/api_client/models/add_validation_text_rapid_model.py +3 -3
- rapidata/api_client/models/asset_metadata_model.py +2 -8
- rapidata/api_client/models/compare_result.py +1 -10
- rapidata/api_client/models/compare_workflow_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
- rapidata/api_client/models/create_order_model.py +2 -4
- rapidata/api_client/models/datapoint.py +3 -3
- rapidata/api_client/models/datapoint_metadata_model.py +3 -3
- rapidata/api_client/models/datapoint_model.py +3 -3
- rapidata/api_client/models/dataset_dataset_id_datapoints_post_request_metadata_inner.py +182 -0
- rapidata/api_client/models/file_asset_model.py +1 -3
- rapidata/api_client/models/file_asset_model_metadata_value.py +1 -3
- rapidata/api_client/models/get_compare_workflow_results_result.py +3 -3
- rapidata/api_client/models/get_datapoint_by_id_result.py +3 -3
- rapidata/api_client/models/get_rapid_responses_result.py +5 -5
- rapidata/api_client/models/get_validation_rapids_result.py +12 -3
- rapidata/api_client/models/get_validation_rapids_result_truth.py +25 -11
- rapidata/api_client/models/get_workflow_results_result.py +5 -5
- rapidata/api_client/models/multi_asset_model.py +4 -4
- rapidata/api_client/models/multi_compare_truth.py +96 -0
- rapidata/api_client/models/naive_referee_info.py +96 -0
- rapidata/api_client/models/never_ending_referee_info.py +94 -0
- rapidata/api_client/models/null_asset_model.py +1 -3
- rapidata/api_client/models/probabilistic_attach_category_referee_info.py +98 -0
- rapidata/api_client/models/rapid_model.py +173 -0
- rapidata/api_client/models/rapid_model_paged_result.py +105 -0
- rapidata/api_client/models/rapid_model_referee.py +154 -0
- rapidata/api_client/models/rapid_state.py +1 -0
- rapidata/api_client/models/text_asset_model.py +1 -3
- rapidata/api_client/models/update_access_model.py +1 -1
- rapidata/api_client/models/update_validation_rapid_model.py +3 -8
- rapidata/api_client/models/update_validation_rapid_model_truth.py +26 -12
- rapidata/api_client/models/upload_files_from_s3_bucket_model.py +12 -2
- rapidata/api_client/models/upload_text_sources_to_dataset_model.py +3 -3
- rapidata/api_client_README.md +21 -13
- rapidata/rapidata_client/__init__.py +1 -1
- rapidata/rapidata_client/assets/_multi_asset.py +0 -5
- rapidata/rapidata_client/logging/__init__.py +1 -1
- rapidata/rapidata_client/logging/output_manager.py +2 -2
- rapidata/rapidata_client/order/_rapidata_dataset.py +16 -39
- rapidata/rapidata_client/order/rapidata_order.py +21 -16
- rapidata/rapidata_client/order/rapidata_order_manager.py +2 -3
- rapidata/rapidata_client/validation/rapids/rapids.py +3 -4
- rapidata/rapidata_client/validation/validation_set_manager.py +2 -3
- rapidata/rapidata_client/workflow/_ranking_workflow.py +2 -6
- {rapidata-2.26.0.dist-info → rapidata-2.27.0.dist-info}/METADATA +1 -1
- {rapidata-2.26.0.dist-info → rapidata-2.27.0.dist-info}/RECORD +64 -54
- {rapidata-2.26.0.dist-info → rapidata-2.27.0.dist-info}/LICENSE +0 -0
- {rapidata-2.26.0.dist-info → rapidata-2.27.0.dist-info}/WHEEL +0 -0
|
@@ -1,13 +1,7 @@
|
|
|
1
1
|
from itertools import zip_longest
|
|
2
2
|
|
|
3
|
-
from rapidata.api_client.models.datapoint_metadata_model import DatapointMetadataModel
|
|
4
|
-
from rapidata.api_client.models.create_datapoint_from_urls_model import (
|
|
5
|
-
CreateDatapointFromUrlsModel,
|
|
6
|
-
)
|
|
7
|
-
from rapidata.api_client.models.create_datapoint_from_files_model import CreateDatapointFromFilesModel
|
|
8
|
-
from rapidata.api_client.models.create_datapoint_from_urls_model import CreateDatapointFromUrlsModel
|
|
9
3
|
from rapidata.api_client.models.create_datapoint_from_text_sources_model import CreateDatapointFromTextSourcesModel
|
|
10
|
-
from rapidata.api_client.models.
|
|
4
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
11
5
|
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
12
6
|
from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
|
|
13
7
|
from rapidata.service import LocalFileService
|
|
@@ -15,9 +9,8 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
15
9
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
16
10
|
from tqdm import tqdm
|
|
17
11
|
|
|
18
|
-
from pydantic import StrictStr
|
|
19
12
|
from typing import cast, Sequence, Generator
|
|
20
|
-
from rapidata.rapidata_client.logging import logger
|
|
13
|
+
from rapidata.rapidata_client.logging import logger, RapidataOutputManager
|
|
21
14
|
import time
|
|
22
15
|
import threading
|
|
23
16
|
|
|
@@ -59,7 +52,7 @@ class RapidataDataset:
|
|
|
59
52
|
for meta in metadata_per_datapoint:
|
|
60
53
|
meta_model = meta.to_model() if meta else None
|
|
61
54
|
if meta_model:
|
|
62
|
-
metadata.append(
|
|
55
|
+
metadata.append(DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model))
|
|
63
56
|
|
|
64
57
|
model = CreateDatapointFromTextSourcesModel(
|
|
65
58
|
textSources=texts,
|
|
@@ -76,7 +69,7 @@ class RapidataDataset:
|
|
|
76
69
|
for i, (text_asset, metadata) in enumerate(zip_longest(text_assets, metadata_list or []))
|
|
77
70
|
]
|
|
78
71
|
|
|
79
|
-
with tqdm(total=total_uploads, desc="Uploading text datapoints") as pbar:
|
|
72
|
+
with tqdm(total=total_uploads, desc="Uploading text datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
80
73
|
for future in as_completed(futures):
|
|
81
74
|
future.result() # This will raise any exceptions that occurred during execution
|
|
82
75
|
pbar.update(1)
|
|
@@ -117,39 +110,23 @@ class RapidataDataset:
|
|
|
117
110
|
else:
|
|
118
111
|
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
119
112
|
|
|
120
|
-
|
|
121
|
-
metadata = []
|
|
113
|
+
metadata: list[DatasetDatasetIdDatapointsPostRequestMetadataInner] = []
|
|
122
114
|
if meta_list:
|
|
123
115
|
for meta in meta_list:
|
|
124
116
|
meta_model = meta.to_model() if meta else None
|
|
125
117
|
if meta_model:
|
|
126
|
-
metadata.append(
|
|
118
|
+
metadata.append(DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model))
|
|
127
119
|
|
|
128
|
-
local_paths
|
|
129
|
-
|
|
130
|
-
for asset in assets:
|
|
131
|
-
if isinstance(asset, MediaAsset):
|
|
132
|
-
files.append(asset.path)
|
|
120
|
+
local_paths = [asset.to_file() for asset in assets if asset.is_local()]
|
|
121
|
+
urls = [asset.path for asset in assets if not asset.is_local()]
|
|
133
122
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
model=model,
|
|
142
|
-
files=files # type: ignore
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
upload_response = self.openapi_service.dataset_api.dataset_dataset_id_datapoints_urls_post(
|
|
146
|
-
dataset_id=self.dataset_id,
|
|
147
|
-
create_datapoint_from_urls_model=CreateDatapointFromUrlsModel(
|
|
148
|
-
urls=files,
|
|
149
|
-
metadata=metadata,
|
|
150
|
-
sortIndex=index
|
|
151
|
-
),
|
|
152
|
-
)
|
|
123
|
+
self.openapi_service.dataset_api.dataset_dataset_id_datapoints_post(
|
|
124
|
+
dataset_id=self.dataset_id,
|
|
125
|
+
file=local_paths,
|
|
126
|
+
url=urls,
|
|
127
|
+
metadata=metadata,
|
|
128
|
+
sort_index=index,
|
|
129
|
+
)
|
|
153
130
|
|
|
154
131
|
local_successful.extend(identifiers_to_track)
|
|
155
132
|
|
|
@@ -183,7 +160,7 @@ class RapidataDataset:
|
|
|
183
160
|
def progress_tracking_thread():
|
|
184
161
|
try:
|
|
185
162
|
# Initialize progress bar with 0 completions
|
|
186
|
-
with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
|
|
163
|
+
with tqdm(total=total_uploads, desc="Uploading datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
187
164
|
prev_ready = 0
|
|
188
165
|
prev_failed = 0
|
|
189
166
|
stall_count = 0
|
|
@@ -16,7 +16,7 @@ from rapidata.api_client.models.preliminary_download_model import PreliminaryDow
|
|
|
16
16
|
from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
|
|
17
17
|
from rapidata.rapidata_client.order.rapidata_results import RapidataResults
|
|
18
18
|
from rapidata.service.openapi_service import OpenAPIService
|
|
19
|
-
from rapidata.rapidata_client.logging import logger, managed_print
|
|
19
|
+
from rapidata.rapidata_client.logging import logger, managed_print, RapidataOutputManager
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class RapidataOrder:
|
|
@@ -38,7 +38,7 @@ class RapidataOrder:
|
|
|
38
38
|
order_id: str,
|
|
39
39
|
openapi_service: OpenAPIService,
|
|
40
40
|
):
|
|
41
|
-
self.
|
|
41
|
+
self.id = order_id
|
|
42
42
|
self.name = name
|
|
43
43
|
self.__created_at: datetime | None = None
|
|
44
44
|
self.__openapi_service = openapi_service
|
|
@@ -47,20 +47,25 @@ class RapidataOrder:
|
|
|
47
47
|
self.__pipeline_id: str = ""
|
|
48
48
|
self._max_retries = 10
|
|
49
49
|
self._retry_delay = 2
|
|
50
|
-
self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.
|
|
50
|
+
self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.id}"
|
|
51
51
|
logger.debug("RapidataOrder initialized")
|
|
52
52
|
|
|
53
|
+
@property
|
|
54
|
+
def order_id(self) -> str:
|
|
55
|
+
managed_print(f"order_id is deprecated. Use id instead.")
|
|
56
|
+
return self.id
|
|
57
|
+
|
|
53
58
|
@property
|
|
54
59
|
def created_at(self) -> datetime:
|
|
55
60
|
"""Returns the creation date of the order."""
|
|
56
61
|
if not self.__created_at:
|
|
57
|
-
self.__created_at = self.__openapi_service.order_api.order_order_id_get(self.
|
|
62
|
+
self.__created_at = self.__openapi_service.order_api.order_order_id_get(self.id).order_date
|
|
58
63
|
return self.__created_at
|
|
59
64
|
|
|
60
65
|
def run(self) -> "RapidataOrder":
|
|
61
66
|
"""Runs the order to start collecting responses."""
|
|
62
67
|
logger.info(f"Starting order '{self}'")
|
|
63
|
-
self.__openapi_service.order_api.order_order_id_submit_post(self.
|
|
68
|
+
self.__openapi_service.order_api.order_order_id_submit_post(self.id)
|
|
64
69
|
logger.debug(f"Order '{self}' has been started.")
|
|
65
70
|
managed_print(f"Order '{self.name}' is now viewable under: {self.order_details_page}")
|
|
66
71
|
return self
|
|
@@ -68,21 +73,21 @@ class RapidataOrder:
|
|
|
68
73
|
def pause(self) -> None:
|
|
69
74
|
"""Pauses the order."""
|
|
70
75
|
logger.info(f"Pausing order '{self}'")
|
|
71
|
-
self.__openapi_service.order_api.order_order_id_pause_post(self.
|
|
76
|
+
self.__openapi_service.order_api.order_order_id_pause_post(self.id)
|
|
72
77
|
logger.debug(f"Order '{self}' has been paused.")
|
|
73
78
|
managed_print(f"Order '{self}' has been paused.")
|
|
74
79
|
|
|
75
80
|
def unpause(self) -> None:
|
|
76
81
|
"""Unpauses/resumes the order."""
|
|
77
82
|
logger.info(f"Unpausing order '{self}'")
|
|
78
|
-
self.__openapi_service.order_api.order_resume_post(self.
|
|
83
|
+
self.__openapi_service.order_api.order_resume_post(self.id)
|
|
79
84
|
logger.debug(f"Order '{self}' has been unpaused.")
|
|
80
85
|
managed_print(f"Order '{self}' has been unpaused.")
|
|
81
86
|
|
|
82
87
|
def delete(self) -> None:
|
|
83
88
|
"""Deletes the order."""
|
|
84
89
|
logger.info(f"Deleting order '{self}'")
|
|
85
|
-
self.__openapi_service.order_api.order_order_id_delete(self.
|
|
90
|
+
self.__openapi_service.order_api.order_order_id_delete(self.id)
|
|
86
91
|
logger.debug(f"Order '{self}' has been deleted.")
|
|
87
92
|
managed_print(f"Order '{self}' has been deleted.")
|
|
88
93
|
|
|
@@ -100,7 +105,7 @@ class RapidataOrder:
|
|
|
100
105
|
Completed: The order has been completed.\n
|
|
101
106
|
Failed: The order has failed.
|
|
102
107
|
"""
|
|
103
|
-
return self.__openapi_service.order_api.order_order_id_get(self.
|
|
108
|
+
return self.__openapi_service.order_api.order_order_id_get(self.id).state
|
|
104
109
|
|
|
105
110
|
def display_progress_bar(self, refresh_rate: int=5) -> None:
|
|
106
111
|
"""
|
|
@@ -125,8 +130,8 @@ class RapidataOrder:
|
|
|
125
130
|
"To speed up the process, contact support (info@rapidata.ai).\n"
|
|
126
131
|
"Once started, run this method again to display the progress bar."
|
|
127
132
|
)
|
|
128
|
-
|
|
129
|
-
with tqdm(total=100, desc="Processing order", unit="%", bar_format="{desc}: {percentage:3.0f}%|{bar}| completed [{elapsed}<{remaining}, {rate_fmt}]") as pbar:
|
|
133
|
+
|
|
134
|
+
with tqdm(total=100, desc="Processing order", unit="%", bar_format="{desc}: {percentage:3.0f}%|{bar}| completed [{elapsed}<{remaining}, {rate_fmt}]", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
130
135
|
last_percentage = 0
|
|
131
136
|
while True:
|
|
132
137
|
current_percentage = self._workflow_progress.completion_percentage
|
|
@@ -177,7 +182,7 @@ class RapidataOrder:
|
|
|
177
182
|
sleep(5)
|
|
178
183
|
|
|
179
184
|
try:
|
|
180
|
-
return RapidataResults(json.loads(self.__openapi_service.order_api.order_order_id_download_results_get(order_id=self.
|
|
185
|
+
return RapidataResults(json.loads(self.__openapi_service.order_api.order_order_id_download_results_get(order_id=self.id)))
|
|
181
186
|
except (ApiException, json.JSONDecodeError) as e:
|
|
182
187
|
raise Exception(f"Failed to get order results: {str(e)}") from e
|
|
183
188
|
|
|
@@ -203,7 +208,7 @@ class RapidataOrder:
|
|
|
203
208
|
"""
|
|
204
209
|
logger.info("Opening order preview in browser...")
|
|
205
210
|
campaign_id = self.__get_campaign_id()
|
|
206
|
-
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.
|
|
211
|
+
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.id}/preview?campaignId={campaign_id}"
|
|
207
212
|
could_open_browser = webbrowser.open(auth_url)
|
|
208
213
|
if not could_open_browser:
|
|
209
214
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
@@ -214,7 +219,7 @@ class RapidataOrder:
|
|
|
214
219
|
if not self.__pipeline_id:
|
|
215
220
|
for _ in range(self._max_retries):
|
|
216
221
|
try:
|
|
217
|
-
self.__pipeline_id = self.__openapi_service.order_api.order_order_id_get(self.
|
|
222
|
+
self.__pipeline_id = self.__openapi_service.order_api.order_order_id_get(self.id).pipeline_id
|
|
218
223
|
break
|
|
219
224
|
except Exception:
|
|
220
225
|
sleep(self._retry_delay)
|
|
@@ -272,7 +277,7 @@ class RapidataOrder:
|
|
|
272
277
|
raise Exception(f"Failed to get preliminary results: {str(e)}") from e
|
|
273
278
|
|
|
274
279
|
def __str__(self) -> str:
|
|
275
|
-
return f"RapidataOrder(name='{self.name}', order_id='{self.
|
|
280
|
+
return f"RapidataOrder(name='{self.name}', order_id='{self.id}')"
|
|
276
281
|
|
|
277
282
|
def __repr__(self) -> str:
|
|
278
|
-
return f"RapidataOrder(name='{self.name}', order_id='{self.
|
|
283
|
+
return f"RapidataOrder(name='{self.name}', order_id='{self.id}')"
|
|
@@ -28,9 +28,8 @@ from rapidata.rapidata_client.filter import RapidataFilter
|
|
|
28
28
|
from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters
|
|
29
29
|
from rapidata.rapidata_client.settings import RapidataSettings, RapidataSetting
|
|
30
30
|
from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections
|
|
31
|
-
from rapidata.rapidata_client.logging import logger
|
|
31
|
+
from rapidata.rapidata_client.logging import logger, RapidataOutputManager
|
|
32
32
|
|
|
33
|
-
from rapidata.api_client.exceptions import BadRequestException
|
|
34
33
|
from rapidata.api_client.models.query_model import QueryModel
|
|
35
34
|
from rapidata.api_client.models.page_info import PageInfo
|
|
36
35
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
@@ -604,7 +603,7 @@ class RapidataOrderManager:
|
|
|
604
603
|
|
|
605
604
|
assets = [MediaAsset(path=path) for path in datapoints]
|
|
606
605
|
|
|
607
|
-
for asset in tqdm(assets, desc="Downloading assets and checking duration"):
|
|
606
|
+
for asset in tqdm(assets, desc="Downloading assets and checking duration", disable=RapidataOutputManager.silent_mode):
|
|
608
607
|
if not asset.get_duration():
|
|
609
608
|
raise ValueError("The datapoints for this order must have a duration. (e.g. video or audio)")
|
|
610
609
|
|
|
@@ -15,8 +15,7 @@ from rapidata.api_client.models.add_validation_rapid_model_payload import (
|
|
|
15
15
|
from rapidata.api_client.models.add_validation_rapid_model_truth import (
|
|
16
16
|
AddValidationRapidModelTruth,
|
|
17
17
|
)
|
|
18
|
-
from rapidata.api_client.models.
|
|
19
|
-
|
|
18
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
20
19
|
from rapidata.service.openapi_service import OpenAPIService
|
|
21
20
|
|
|
22
21
|
from rapidata.rapidata_client.logging import logger
|
|
@@ -69,7 +68,7 @@ class Rapid():
|
|
|
69
68
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
70
69
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
71
70
|
metadata=[
|
|
72
|
-
|
|
71
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
73
72
|
for meta in self.metadata
|
|
74
73
|
],
|
|
75
74
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -96,7 +95,7 @@ class Rapid():
|
|
|
96
95
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
97
96
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
98
97
|
metadata=[
|
|
99
|
-
|
|
98
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
100
99
|
for meta in self.metadata
|
|
101
100
|
],
|
|
102
101
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -11,12 +11,11 @@ from rapidata.api_client.models.page_info import PageInfo
|
|
|
11
11
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
12
12
|
from rapidata.api_client.models.filter import Filter
|
|
13
13
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
14
|
-
from rapidata.api_client.exceptions import BadRequestException
|
|
15
14
|
from urllib3._collections import HTTPHeaderDict # type: ignore[import]
|
|
16
15
|
|
|
17
16
|
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
18
17
|
|
|
19
|
-
from rapidata.rapidata_client.logging import logger, managed_print
|
|
18
|
+
from rapidata.rapidata_client.logging import logger, managed_print, RapidataOutputManager
|
|
20
19
|
from tqdm import tqdm
|
|
21
20
|
|
|
22
21
|
|
|
@@ -453,7 +452,7 @@ class ValidationSetManager:
|
|
|
453
452
|
)
|
|
454
453
|
|
|
455
454
|
logger.debug("Adding rapids to validation set")
|
|
456
|
-
for rapid in tqdm(rapids, desc="Uploading validation tasks"):
|
|
455
|
+
for rapid in tqdm(rapids, desc="Uploading validation tasks", disable=RapidataOutputManager.silent_mode):
|
|
457
456
|
validation_set.add_rapid(rapid)
|
|
458
457
|
|
|
459
458
|
managed_print()
|
|
@@ -2,14 +2,10 @@ from rapidata.api_client import CompareWorkflowModelPairMakerConfig, OnlinePairM
|
|
|
2
2
|
from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
|
|
3
3
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
4
4
|
from rapidata.rapidata_client.metadata import PromptMetadata
|
|
5
|
-
from rapidata.api_client.models.
|
|
6
|
-
CreateDatapointFromFilesModelMetadataInner,
|
|
7
|
-
)
|
|
5
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class RankingWorkflow(Workflow):
|
|
11
|
-
|
|
12
|
-
|
|
13
9
|
def __init__(self,
|
|
14
10
|
criteria: str,
|
|
15
11
|
total_comparison_budget: int,
|
|
@@ -21,7 +17,7 @@ class RankingWorkflow(Workflow):
|
|
|
21
17
|
):
|
|
22
18
|
super().__init__(type="CompareWorkflowConfig")
|
|
23
19
|
|
|
24
|
-
self.context = [
|
|
20
|
+
self.context = [DatasetDatasetIdDatapointsPostRequestMetadataInner(
|
|
25
21
|
PromptMetadata(context).to_model())
|
|
26
22
|
] if context else None
|
|
27
23
|
|