rapidata 2.26.1__py3-none-any.whl → 2.27.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +2 -2
- rapidata/api_client/__init__.py +11 -3
- rapidata/api_client/api/__init__.py +2 -1
- rapidata/api_client/api/client_api.py +0 -257
- rapidata/api_client/api/customer_rapid_api.py +1644 -0
- rapidata/api_client/api/dataset_api.py +358 -1
- rapidata/api_client/api/newsletter_api.py +11 -299
- rapidata/api_client/api/user_rapid_api.py +1385 -0
- rapidata/api_client/api/validation_set_api.py +6 -6
- rapidata/api_client/models/__init__.py +9 -2
- rapidata/api_client/models/add_campaign_model.py +5 -0
- rapidata/api_client/models/add_validation_rapid_model.py +3 -3
- rapidata/api_client/models/add_validation_rapid_model_truth.py +25 -11
- rapidata/api_client/models/add_validation_text_rapid_model.py +3 -3
- rapidata/api_client/models/asset_metadata_model.py +2 -8
- rapidata/api_client/models/compare_result.py +1 -10
- rapidata/api_client/models/compare_workflow_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
- rapidata/api_client/models/create_order_model.py +2 -4
- rapidata/api_client/models/datapoint.py +3 -3
- rapidata/api_client/models/datapoint_metadata_model.py +3 -3
- rapidata/api_client/models/datapoint_model.py +3 -3
- rapidata/api_client/models/dataset_dataset_id_datapoints_post_request_metadata_inner.py +182 -0
- rapidata/api_client/models/file_asset_model.py +1 -3
- rapidata/api_client/models/file_asset_model_metadata_value.py +1 -3
- rapidata/api_client/models/get_compare_workflow_results_result.py +3 -3
- rapidata/api_client/models/get_datapoint_by_id_result.py +3 -3
- rapidata/api_client/models/get_rapid_responses_result.py +5 -5
- rapidata/api_client/models/get_validation_rapids_result.py +12 -3
- rapidata/api_client/models/get_validation_rapids_result_truth.py +25 -11
- rapidata/api_client/models/get_workflow_results_result.py +5 -5
- rapidata/api_client/models/multi_asset_model.py +4 -4
- rapidata/api_client/models/multi_compare_truth.py +96 -0
- rapidata/api_client/models/naive_referee_info.py +96 -0
- rapidata/api_client/models/never_ending_referee_info.py +94 -0
- rapidata/api_client/models/null_asset_model.py +1 -3
- rapidata/api_client/models/probabilistic_attach_category_referee_info.py +98 -0
- rapidata/api_client/models/rapid_model.py +173 -0
- rapidata/api_client/models/rapid_model_paged_result.py +105 -0
- rapidata/api_client/models/rapid_model_referee.py +154 -0
- rapidata/api_client/models/rapid_state.py +1 -0
- rapidata/api_client/models/text_asset_model.py +1 -3
- rapidata/api_client/models/update_access_model.py +1 -1
- rapidata/api_client/models/update_validation_rapid_model.py +3 -8
- rapidata/api_client/models/update_validation_rapid_model_truth.py +26 -12
- rapidata/api_client/models/upload_files_from_s3_bucket_model.py +12 -2
- rapidata/api_client/models/upload_text_sources_to_dataset_model.py +3 -3
- rapidata/api_client_README.md +21 -13
- rapidata/rapidata_client/__init__.py +1 -1
- rapidata/rapidata_client/assets/_multi_asset.py +0 -5
- rapidata/rapidata_client/logging/__init__.py +1 -1
- rapidata/rapidata_client/logging/output_manager.py +2 -2
- rapidata/rapidata_client/order/_rapidata_dataset.py +16 -39
- rapidata/rapidata_client/order/rapidata_order.py +20 -15
- rapidata/rapidata_client/order/rapidata_order_manager.py +5 -2
- rapidata/rapidata_client/validation/rapids/rapids.py +3 -4
- rapidata/rapidata_client/validation/validation_set_manager.py +2 -2
- rapidata/rapidata_client/workflow/_ranking_workflow.py +2 -6
- rapidata/service/credential_manager.py +6 -6
- {rapidata-2.26.1.dist-info → rapidata-2.27.1.dist-info}/METADATA +1 -1
- {rapidata-2.26.1.dist-info → rapidata-2.27.1.dist-info}/RECORD +65 -55
- {rapidata-2.26.1.dist-info → rapidata-2.27.1.dist-info}/LICENSE +0 -0
- {rapidata-2.26.1.dist-info → rapidata-2.27.1.dist-info}/WHEEL +0 -0
|
@@ -1,13 +1,7 @@
|
|
|
1
1
|
from itertools import zip_longest
|
|
2
2
|
|
|
3
|
-
from rapidata.api_client.models.datapoint_metadata_model import DatapointMetadataModel
|
|
4
|
-
from rapidata.api_client.models.create_datapoint_from_urls_model import (
|
|
5
|
-
CreateDatapointFromUrlsModel,
|
|
6
|
-
)
|
|
7
|
-
from rapidata.api_client.models.create_datapoint_from_files_model import CreateDatapointFromFilesModel
|
|
8
|
-
from rapidata.api_client.models.create_datapoint_from_urls_model import CreateDatapointFromUrlsModel
|
|
9
3
|
from rapidata.api_client.models.create_datapoint_from_text_sources_model import CreateDatapointFromTextSourcesModel
|
|
10
|
-
from rapidata.api_client.models.
|
|
4
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
11
5
|
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
12
6
|
from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
|
|
13
7
|
from rapidata.service import LocalFileService
|
|
@@ -15,9 +9,8 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
15
9
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
16
10
|
from tqdm import tqdm
|
|
17
11
|
|
|
18
|
-
from pydantic import StrictStr
|
|
19
12
|
from typing import cast, Sequence, Generator
|
|
20
|
-
from rapidata.rapidata_client.logging import logger,
|
|
13
|
+
from rapidata.rapidata_client.logging import logger, RapidataOutputManager
|
|
21
14
|
import time
|
|
22
15
|
import threading
|
|
23
16
|
|
|
@@ -59,7 +52,7 @@ class RapidataDataset:
|
|
|
59
52
|
for meta in metadata_per_datapoint:
|
|
60
53
|
meta_model = meta.to_model() if meta else None
|
|
61
54
|
if meta_model:
|
|
62
|
-
metadata.append(
|
|
55
|
+
metadata.append(DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model))
|
|
63
56
|
|
|
64
57
|
model = CreateDatapointFromTextSourcesModel(
|
|
65
58
|
textSources=texts,
|
|
@@ -76,7 +69,7 @@ class RapidataDataset:
|
|
|
76
69
|
for i, (text_asset, metadata) in enumerate(zip_longest(text_assets, metadata_list or []))
|
|
77
70
|
]
|
|
78
71
|
|
|
79
|
-
with tqdm(total=total_uploads, desc="Uploading text datapoints", disable=
|
|
72
|
+
with tqdm(total=total_uploads, desc="Uploading text datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
80
73
|
for future in as_completed(futures):
|
|
81
74
|
future.result() # This will raise any exceptions that occurred during execution
|
|
82
75
|
pbar.update(1)
|
|
@@ -117,39 +110,23 @@ class RapidataDataset:
|
|
|
117
110
|
else:
|
|
118
111
|
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
119
112
|
|
|
120
|
-
|
|
121
|
-
metadata = []
|
|
113
|
+
metadata: list[DatasetDatasetIdDatapointsPostRequestMetadataInner] = []
|
|
122
114
|
if meta_list:
|
|
123
115
|
for meta in meta_list:
|
|
124
116
|
meta_model = meta.to_model() if meta else None
|
|
125
117
|
if meta_model:
|
|
126
|
-
metadata.append(
|
|
118
|
+
metadata.append(DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model))
|
|
127
119
|
|
|
128
|
-
local_paths
|
|
129
|
-
|
|
130
|
-
for asset in assets:
|
|
131
|
-
if isinstance(asset, MediaAsset):
|
|
132
|
-
files.append(asset.path)
|
|
120
|
+
local_paths = [asset.to_file() for asset in assets if asset.is_local()]
|
|
121
|
+
urls = [asset.path for asset in assets if not asset.is_local()]
|
|
133
122
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
model=model,
|
|
142
|
-
files=files # type: ignore
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
upload_response = self.openapi_service.dataset_api.dataset_dataset_id_datapoints_urls_post(
|
|
146
|
-
dataset_id=self.dataset_id,
|
|
147
|
-
create_datapoint_from_urls_model=CreateDatapointFromUrlsModel(
|
|
148
|
-
urls=files,
|
|
149
|
-
metadata=metadata,
|
|
150
|
-
sortIndex=index
|
|
151
|
-
),
|
|
152
|
-
)
|
|
123
|
+
self.openapi_service.dataset_api.dataset_dataset_id_datapoints_post(
|
|
124
|
+
dataset_id=self.dataset_id,
|
|
125
|
+
file=local_paths,
|
|
126
|
+
url=urls,
|
|
127
|
+
metadata=metadata,
|
|
128
|
+
sort_index=index,
|
|
129
|
+
)
|
|
153
130
|
|
|
154
131
|
local_successful.extend(identifiers_to_track)
|
|
155
132
|
|
|
@@ -183,7 +160,7 @@ class RapidataDataset:
|
|
|
183
160
|
def progress_tracking_thread():
|
|
184
161
|
try:
|
|
185
162
|
# Initialize progress bar with 0 completions
|
|
186
|
-
with tqdm(total=total_uploads, desc="Uploading datapoints", disable=
|
|
163
|
+
with tqdm(total=total_uploads, desc="Uploading datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
187
164
|
prev_ready = 0
|
|
188
165
|
prev_failed = 0
|
|
189
166
|
stall_count = 0
|
|
@@ -16,7 +16,7 @@ from rapidata.api_client.models.preliminary_download_model import PreliminaryDow
|
|
|
16
16
|
from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
|
|
17
17
|
from rapidata.rapidata_client.order.rapidata_results import RapidataResults
|
|
18
18
|
from rapidata.service.openapi_service import OpenAPIService
|
|
19
|
-
from rapidata.rapidata_client.logging import logger, managed_print,
|
|
19
|
+
from rapidata.rapidata_client.logging import logger, managed_print, RapidataOutputManager
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class RapidataOrder:
|
|
@@ -38,7 +38,7 @@ class RapidataOrder:
|
|
|
38
38
|
order_id: str,
|
|
39
39
|
openapi_service: OpenAPIService,
|
|
40
40
|
):
|
|
41
|
-
self.
|
|
41
|
+
self.id = order_id
|
|
42
42
|
self.name = name
|
|
43
43
|
self.__created_at: datetime | None = None
|
|
44
44
|
self.__openapi_service = openapi_service
|
|
@@ -47,20 +47,25 @@ class RapidataOrder:
|
|
|
47
47
|
self.__pipeline_id: str = ""
|
|
48
48
|
self._max_retries = 10
|
|
49
49
|
self._retry_delay = 2
|
|
50
|
-
self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.
|
|
50
|
+
self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.id}"
|
|
51
51
|
logger.debug("RapidataOrder initialized")
|
|
52
52
|
|
|
53
|
+
@property
|
|
54
|
+
def order_id(self) -> str:
|
|
55
|
+
managed_print(f"order_id is deprecated. Use id instead.")
|
|
56
|
+
return self.id
|
|
57
|
+
|
|
53
58
|
@property
|
|
54
59
|
def created_at(self) -> datetime:
|
|
55
60
|
"""Returns the creation date of the order."""
|
|
56
61
|
if not self.__created_at:
|
|
57
|
-
self.__created_at = self.__openapi_service.order_api.order_order_id_get(self.
|
|
62
|
+
self.__created_at = self.__openapi_service.order_api.order_order_id_get(self.id).order_date
|
|
58
63
|
return self.__created_at
|
|
59
64
|
|
|
60
65
|
def run(self) -> "RapidataOrder":
|
|
61
66
|
"""Runs the order to start collecting responses."""
|
|
62
67
|
logger.info(f"Starting order '{self}'")
|
|
63
|
-
self.__openapi_service.order_api.order_order_id_submit_post(self.
|
|
68
|
+
self.__openapi_service.order_api.order_order_id_submit_post(self.id)
|
|
64
69
|
logger.debug(f"Order '{self}' has been started.")
|
|
65
70
|
managed_print(f"Order '{self.name}' is now viewable under: {self.order_details_page}")
|
|
66
71
|
return self
|
|
@@ -68,21 +73,21 @@ class RapidataOrder:
|
|
|
68
73
|
def pause(self) -> None:
|
|
69
74
|
"""Pauses the order."""
|
|
70
75
|
logger.info(f"Pausing order '{self}'")
|
|
71
|
-
self.__openapi_service.order_api.order_order_id_pause_post(self.
|
|
76
|
+
self.__openapi_service.order_api.order_order_id_pause_post(self.id)
|
|
72
77
|
logger.debug(f"Order '{self}' has been paused.")
|
|
73
78
|
managed_print(f"Order '{self}' has been paused.")
|
|
74
79
|
|
|
75
80
|
def unpause(self) -> None:
|
|
76
81
|
"""Unpauses/resumes the order."""
|
|
77
82
|
logger.info(f"Unpausing order '{self}'")
|
|
78
|
-
self.__openapi_service.order_api.order_resume_post(self.
|
|
83
|
+
self.__openapi_service.order_api.order_resume_post(self.id)
|
|
79
84
|
logger.debug(f"Order '{self}' has been unpaused.")
|
|
80
85
|
managed_print(f"Order '{self}' has been unpaused.")
|
|
81
86
|
|
|
82
87
|
def delete(self) -> None:
|
|
83
88
|
"""Deletes the order."""
|
|
84
89
|
logger.info(f"Deleting order '{self}'")
|
|
85
|
-
self.__openapi_service.order_api.order_order_id_delete(self.
|
|
90
|
+
self.__openapi_service.order_api.order_order_id_delete(self.id)
|
|
86
91
|
logger.debug(f"Order '{self}' has been deleted.")
|
|
87
92
|
managed_print(f"Order '{self}' has been deleted.")
|
|
88
93
|
|
|
@@ -100,7 +105,7 @@ class RapidataOrder:
|
|
|
100
105
|
Completed: The order has been completed.\n
|
|
101
106
|
Failed: The order has failed.
|
|
102
107
|
"""
|
|
103
|
-
return self.__openapi_service.order_api.order_order_id_get(self.
|
|
108
|
+
return self.__openapi_service.order_api.order_order_id_get(self.id).state
|
|
104
109
|
|
|
105
110
|
def display_progress_bar(self, refresh_rate: int=5) -> None:
|
|
106
111
|
"""
|
|
@@ -126,7 +131,7 @@ class RapidataOrder:
|
|
|
126
131
|
"Once started, run this method again to display the progress bar."
|
|
127
132
|
)
|
|
128
133
|
|
|
129
|
-
with tqdm(total=100, desc="Processing order", unit="%", bar_format="{desc}: {percentage:3.0f}%|{bar}| completed [{elapsed}<{remaining}, {rate_fmt}]", disable=
|
|
134
|
+
with tqdm(total=100, desc="Processing order", unit="%", bar_format="{desc}: {percentage:3.0f}%|{bar}| completed [{elapsed}<{remaining}, {rate_fmt}]", disable=RapidataOutputManager.silent_mode) as pbar:
|
|
130
135
|
last_percentage = 0
|
|
131
136
|
while True:
|
|
132
137
|
current_percentage = self._workflow_progress.completion_percentage
|
|
@@ -177,7 +182,7 @@ class RapidataOrder:
|
|
|
177
182
|
sleep(5)
|
|
178
183
|
|
|
179
184
|
try:
|
|
180
|
-
return RapidataResults(json.loads(self.__openapi_service.order_api.order_order_id_download_results_get(order_id=self.
|
|
185
|
+
return RapidataResults(json.loads(self.__openapi_service.order_api.order_order_id_download_results_get(order_id=self.id)))
|
|
181
186
|
except (ApiException, json.JSONDecodeError) as e:
|
|
182
187
|
raise Exception(f"Failed to get order results: {str(e)}") from e
|
|
183
188
|
|
|
@@ -203,7 +208,7 @@ class RapidataOrder:
|
|
|
203
208
|
"""
|
|
204
209
|
logger.info("Opening order preview in browser...")
|
|
205
210
|
campaign_id = self.__get_campaign_id()
|
|
206
|
-
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.
|
|
211
|
+
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.id}/preview?campaignId={campaign_id}"
|
|
207
212
|
could_open_browser = webbrowser.open(auth_url)
|
|
208
213
|
if not could_open_browser:
|
|
209
214
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
@@ -214,7 +219,7 @@ class RapidataOrder:
|
|
|
214
219
|
if not self.__pipeline_id:
|
|
215
220
|
for _ in range(self._max_retries):
|
|
216
221
|
try:
|
|
217
|
-
self.__pipeline_id = self.__openapi_service.order_api.order_order_id_get(self.
|
|
222
|
+
self.__pipeline_id = self.__openapi_service.order_api.order_order_id_get(self.id).pipeline_id
|
|
218
223
|
break
|
|
219
224
|
except Exception:
|
|
220
225
|
sleep(self._retry_delay)
|
|
@@ -272,7 +277,7 @@ class RapidataOrder:
|
|
|
272
277
|
raise Exception(f"Failed to get preliminary results: {str(e)}") from e
|
|
273
278
|
|
|
274
279
|
def __str__(self) -> str:
|
|
275
|
-
return f"RapidataOrder(name='{self.name}', order_id='{self.
|
|
280
|
+
return f"RapidataOrder(name='{self.name}', order_id='{self.id}')"
|
|
276
281
|
|
|
277
282
|
def __repr__(self) -> str:
|
|
278
|
-
return f"RapidataOrder(name='{self.name}', order_id='{self.
|
|
283
|
+
return f"RapidataOrder(name='{self.name}', order_id='{self.id}')"
|
|
@@ -28,7 +28,7 @@ from rapidata.rapidata_client.filter import RapidataFilter
|
|
|
28
28
|
from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters
|
|
29
29
|
from rapidata.rapidata_client.settings import RapidataSettings, RapidataSetting
|
|
30
30
|
from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections
|
|
31
|
-
from rapidata.rapidata_client.logging import logger,
|
|
31
|
+
from rapidata.rapidata_client.logging import logger, RapidataOutputManager
|
|
32
32
|
|
|
33
33
|
from rapidata.api_client.models.query_model import QueryModel
|
|
34
34
|
from rapidata.api_client.models.page_info import PageInfo
|
|
@@ -77,6 +77,9 @@ class RapidataOrderManager:
|
|
|
77
77
|
private_notes: list[str] | None = None,
|
|
78
78
|
default_labeling_amount: int = 3
|
|
79
79
|
) -> RapidataOrder:
|
|
80
|
+
|
|
81
|
+
if not assets:
|
|
82
|
+
raise ValueError("No datapoints provided")
|
|
80
83
|
|
|
81
84
|
if contexts and len(contexts) != len(assets):
|
|
82
85
|
raise ValueError("Number of contexts must match number of datapoints")
|
|
@@ -603,7 +606,7 @@ class RapidataOrderManager:
|
|
|
603
606
|
|
|
604
607
|
assets = [MediaAsset(path=path) for path in datapoints]
|
|
605
608
|
|
|
606
|
-
for asset in tqdm(assets, desc="Downloading assets and checking duration", disable=
|
|
609
|
+
for asset in tqdm(assets, desc="Downloading assets and checking duration", disable=RapidataOutputManager.silent_mode):
|
|
607
610
|
if not asset.get_duration():
|
|
608
611
|
raise ValueError("The datapoints for this order must have a duration. (e.g. video or audio)")
|
|
609
612
|
|
|
@@ -15,8 +15,7 @@ from rapidata.api_client.models.add_validation_rapid_model_payload import (
|
|
|
15
15
|
from rapidata.api_client.models.add_validation_rapid_model_truth import (
|
|
16
16
|
AddValidationRapidModelTruth,
|
|
17
17
|
)
|
|
18
|
-
from rapidata.api_client.models.
|
|
19
|
-
|
|
18
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
20
19
|
from rapidata.service.openapi_service import OpenAPIService
|
|
21
20
|
|
|
22
21
|
from rapidata.rapidata_client.logging import logger
|
|
@@ -69,7 +68,7 @@ class Rapid():
|
|
|
69
68
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
70
69
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
71
70
|
metadata=[
|
|
72
|
-
|
|
71
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
73
72
|
for meta in self.metadata
|
|
74
73
|
],
|
|
75
74
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -96,7 +95,7 @@ class Rapid():
|
|
|
96
95
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
97
96
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
98
97
|
metadata=[
|
|
99
|
-
|
|
98
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
100
99
|
for meta in self.metadata
|
|
101
100
|
],
|
|
102
101
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
@@ -15,7 +15,7 @@ from urllib3._collections import HTTPHeaderDict # type: ignore[import]
|
|
|
15
15
|
|
|
16
16
|
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
17
17
|
|
|
18
|
-
from rapidata.rapidata_client.logging import logger, managed_print,
|
|
18
|
+
from rapidata.rapidata_client.logging import logger, managed_print, RapidataOutputManager
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
|
|
21
21
|
|
|
@@ -452,7 +452,7 @@ class ValidationSetManager:
|
|
|
452
452
|
)
|
|
453
453
|
|
|
454
454
|
logger.debug("Adding rapids to validation set")
|
|
455
|
-
for rapid in tqdm(rapids, desc="Uploading validation tasks", disable=
|
|
455
|
+
for rapid in tqdm(rapids, desc="Uploading validation tasks", disable=RapidataOutputManager.silent_mode):
|
|
456
456
|
validation_set.add_rapid(rapid)
|
|
457
457
|
|
|
458
458
|
managed_print()
|
|
@@ -2,14 +2,10 @@ from rapidata.api_client import CompareWorkflowModelPairMakerConfig, OnlinePairM
|
|
|
2
2
|
from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
|
|
3
3
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
4
4
|
from rapidata.rapidata_client.metadata import PromptMetadata
|
|
5
|
-
from rapidata.api_client.models.
|
|
6
|
-
CreateDatapointFromFilesModelMetadataInner,
|
|
7
|
-
)
|
|
5
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class RankingWorkflow(Workflow):
|
|
11
|
-
|
|
12
|
-
|
|
13
9
|
def __init__(self,
|
|
14
10
|
criteria: str,
|
|
15
11
|
total_comparison_budget: int,
|
|
@@ -21,7 +17,7 @@ class RankingWorkflow(Workflow):
|
|
|
21
17
|
):
|
|
22
18
|
super().__init__(type="CompareWorkflowConfig")
|
|
23
19
|
|
|
24
|
-
self.context = [
|
|
20
|
+
self.context = [DatasetDatasetIdDatapointsPostRequestMetadataInner(
|
|
25
21
|
PromptMetadata(context).to_model())
|
|
26
22
|
] if context else None
|
|
27
23
|
|
|
@@ -81,7 +81,9 @@ class CredentialManager:
|
|
|
81
81
|
|
|
82
82
|
# Ensure file is only readable by the user
|
|
83
83
|
os.chmod(self.config_path, 0o600)
|
|
84
|
-
logger.debug(
|
|
84
|
+
logger.debug(
|
|
85
|
+
f"Set permissions for {self.config_path} to read/write for user only."
|
|
86
|
+
)
|
|
85
87
|
|
|
86
88
|
def _store_credential(self, credential: ClientCredential) -> None:
|
|
87
89
|
credentials = self._read_credentials()
|
|
@@ -128,16 +130,14 @@ class CredentialManager:
|
|
|
128
130
|
if self.endpoint in credentials:
|
|
129
131
|
del credentials[self.endpoint]
|
|
130
132
|
self._write_credentials(credentials)
|
|
131
|
-
logger.info(
|
|
132
|
-
f"Credentials for {self.endpoint} have been reset."
|
|
133
|
-
)
|
|
133
|
+
logger.info(f"Credentials for {self.endpoint} have been reset.")
|
|
134
134
|
|
|
135
135
|
def _get_bridge_tokens(self) -> Optional[BridgeToken]:
|
|
136
136
|
"""Get bridge tokens from the identity endpoint."""
|
|
137
137
|
logger.debug("Getting bridge tokens")
|
|
138
138
|
try:
|
|
139
139
|
bridge_endpoint = (
|
|
140
|
-
f"{self.endpoint}/
|
|
140
|
+
f"{self.endpoint}/identity/bridge-token?clientId=rapidata-cli"
|
|
141
141
|
)
|
|
142
142
|
response = requests.post(bridge_endpoint, verify=self.cert_path)
|
|
143
143
|
if not response.ok:
|
|
@@ -152,7 +152,7 @@ class CredentialManager:
|
|
|
152
152
|
|
|
153
153
|
def _poll_read_key(self, read_key: str) -> Optional[str]:
|
|
154
154
|
"""Poll the read key endpoint until we get an access token."""
|
|
155
|
-
read_endpoint = f"{self.endpoint}/
|
|
155
|
+
read_endpoint = f"{self.endpoint}/identity/bridge-token"
|
|
156
156
|
start_time = time.time()
|
|
157
157
|
|
|
158
158
|
while time.time() - start_time < self.poll_timeout:
|