rapidata 2.42.1__py3-none-any.whl → 2.42.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +1 -1
- rapidata/api_client/__init__.py +8 -1
- rapidata/api_client/api/__init__.py +1 -0
- rapidata/api_client/api/grouped_ranking_workflow_api.py +318 -0
- rapidata/api_client/models/__init__.py +7 -1
- rapidata/api_client/models/asset_metadata.py +2 -8
- rapidata/api_client/models/create_datapoint_model.py +10 -3
- rapidata/api_client/models/create_order_model_workflow.py +23 -9
- rapidata/api_client/models/file_asset.py +1 -3
- rapidata/api_client/models/file_asset_metadata_value.py +1 -3
- rapidata/api_client/models/get_grouped_ranking_workflow_results_model.py +106 -0
- rapidata/api_client/models/get_grouped_ranking_workflow_results_result.py +97 -0
- rapidata/api_client/models/get_grouped_ranking_workflow_results_result_paged_result.py +105 -0
- rapidata/api_client/models/get_workflow_by_id_result_workflow.py +23 -9
- rapidata/api_client/models/grouped_ranking_workflow_config.py +143 -0
- rapidata/api_client/models/grouped_ranking_workflow_model.py +135 -0
- rapidata/api_client/models/grouped_ranking_workflow_model1.py +121 -0
- rapidata/api_client/models/multi_asset.py +4 -4
- rapidata/api_client/models/multi_asset_assets_inner.py +170 -0
- rapidata/api_client/models/null_asset.py +1 -3
- rapidata/api_client/models/text_asset.py +1 -3
- rapidata/api_client/models/workflow_config_artifact_model_workflow_config.py +23 -9
- rapidata/api_client_README.md +8 -1
- rapidata/rapidata_client/order/_rapidata_order_builder.py +11 -12
- rapidata/rapidata_client/order/rapidata_order.py +16 -16
- rapidata/rapidata_client/order/rapidata_order_manager.py +16 -28
- rapidata/rapidata_client/validation/rapids/rapids.py +2 -2
- rapidata/rapidata_client/validation/validation_set_manager.py +3 -3
- rapidata/rapidata_client/workflow/_base_workflow.py +7 -0
- rapidata/rapidata_client/workflow/_classify_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_compare_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_draw_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_evaluation_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_free_text_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_locate_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_ranking_workflow.py +45 -3
- rapidata/rapidata_client/workflow/_select_words_workflow.py +3 -0
- rapidata/rapidata_client/workflow/_timestamp_workflow.py +3 -0
- {rapidata-2.42.1.dist-info → rapidata-2.42.3.dist-info}/METADATA +1 -1
- {rapidata-2.42.1.dist-info → rapidata-2.42.3.dist-info}/RECORD +42 -34
- {rapidata-2.42.1.dist-info → rapidata-2.42.3.dist-info}/WHEEL +0 -0
- {rapidata-2.42.1.dist-info → rapidata-2.42.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -53,14 +53,14 @@ class RapidataOrder:
|
|
|
53
53
|
self.id = order_id
|
|
54
54
|
self.name = name
|
|
55
55
|
self.__created_at: datetime | None = None
|
|
56
|
-
self.
|
|
56
|
+
self._openapi_service = openapi_service
|
|
57
57
|
self.__workflow_id: str = ""
|
|
58
58
|
self.__campaign_id: str = ""
|
|
59
59
|
self.__pipeline_id: str = ""
|
|
60
60
|
self._max_retries = 10
|
|
61
61
|
self._retry_delay = 2
|
|
62
62
|
self.order_details_page = (
|
|
63
|
-
f"https://app.{self.
|
|
63
|
+
f"https://app.{self._openapi_service.environment}/order/detail/{self.id}"
|
|
64
64
|
)
|
|
65
65
|
logger.debug("RapidataOrder initialized")
|
|
66
66
|
|
|
@@ -68,7 +68,7 @@ class RapidataOrder:
|
|
|
68
68
|
def created_at(self) -> datetime:
|
|
69
69
|
"""Returns the creation date of the order."""
|
|
70
70
|
if not self.__created_at:
|
|
71
|
-
self.__created_at = self.
|
|
71
|
+
self.__created_at = self._openapi_service.order_api.order_order_id_get(
|
|
72
72
|
self.id
|
|
73
73
|
).order_date
|
|
74
74
|
return self.__created_at
|
|
@@ -77,7 +77,7 @@ class RapidataOrder:
|
|
|
77
77
|
"""Runs the order to start collecting responses."""
|
|
78
78
|
with tracer.start_as_current_span("RapidataOrder.run"):
|
|
79
79
|
logger.info("Starting order '%s'", self)
|
|
80
|
-
self.
|
|
80
|
+
self._openapi_service.order_api.order_order_id_submit_post(
|
|
81
81
|
self.id, SubmitOrderModel(ignoreFailedDatapoints=True)
|
|
82
82
|
)
|
|
83
83
|
logger.debug("Order '%s' has been started.", self)
|
|
@@ -90,7 +90,7 @@ class RapidataOrder:
|
|
|
90
90
|
"""Pauses the order."""
|
|
91
91
|
with tracer.start_as_current_span("RapidataOrder.pause"):
|
|
92
92
|
logger.info("Pausing order '%s'", self)
|
|
93
|
-
self.
|
|
93
|
+
self._openapi_service.order_api.order_order_id_pause_post(self.id)
|
|
94
94
|
logger.debug("Order '%s' has been paused.", self)
|
|
95
95
|
managed_print(f"Order '{self}' has been paused.")
|
|
96
96
|
|
|
@@ -98,7 +98,7 @@ class RapidataOrder:
|
|
|
98
98
|
"""Unpauses/resumes the order."""
|
|
99
99
|
with tracer.start_as_current_span("RapidataOrder.unpause"):
|
|
100
100
|
logger.info("Unpausing order '%s'", self)
|
|
101
|
-
self.
|
|
101
|
+
self._openapi_service.order_api.order_order_id_resume_post(self.id)
|
|
102
102
|
logger.debug("Order '%s' has been unpaused.", self)
|
|
103
103
|
managed_print(f"Order '{self}' has been unpaused.")
|
|
104
104
|
|
|
@@ -106,7 +106,7 @@ class RapidataOrder:
|
|
|
106
106
|
"""Deletes the order."""
|
|
107
107
|
with tracer.start_as_current_span("RapidataOrder.delete"):
|
|
108
108
|
logger.info("Deleting order '%s'", self)
|
|
109
|
-
self.
|
|
109
|
+
self._openapi_service.order_api.order_order_id_delete(self.id)
|
|
110
110
|
logger.debug("Order '%s' has been deleted.", self)
|
|
111
111
|
managed_print(f"Order '{self}' has been deleted.")
|
|
112
112
|
|
|
@@ -125,7 +125,7 @@ class RapidataOrder:
|
|
|
125
125
|
Failed: The order has failed.
|
|
126
126
|
"""
|
|
127
127
|
with tracer.start_as_current_span("RapidataOrder.get_status"):
|
|
128
|
-
return self.
|
|
128
|
+
return self._openapi_service.order_api.order_order_id_get(self.id).state
|
|
129
129
|
|
|
130
130
|
def display_progress_bar(self, refresh_rate: int = 5) -> None:
|
|
131
131
|
"""
|
|
@@ -180,7 +180,7 @@ class RapidataOrder:
|
|
|
180
180
|
try:
|
|
181
181
|
with suppress_rapidata_error_logging():
|
|
182
182
|
workflow_id = self.__get_workflow_id()
|
|
183
|
-
progress = self.
|
|
183
|
+
progress = self._openapi_service.workflow_api.workflow_workflow_id_progress_get(
|
|
184
184
|
workflow_id
|
|
185
185
|
)
|
|
186
186
|
break
|
|
@@ -223,7 +223,7 @@ class RapidataOrder:
|
|
|
223
223
|
try:
|
|
224
224
|
return RapidataResults(
|
|
225
225
|
json.loads(
|
|
226
|
-
self.
|
|
226
|
+
self._openapi_service.order_api.order_order_id_download_results_get(
|
|
227
227
|
order_id=self.id
|
|
228
228
|
)
|
|
229
229
|
)
|
|
@@ -260,13 +260,13 @@ class RapidataOrder:
|
|
|
260
260
|
logger.info("Opening order preview in browser...")
|
|
261
261
|
if self.get_status() == OrderState.CREATED:
|
|
262
262
|
logger.info("Order is still in state created. Setting it to preview.")
|
|
263
|
-
self.
|
|
263
|
+
self._openapi_service.order_api.order_order_id_preview_post(
|
|
264
264
|
self.id, PreviewOrderModel(ignoreFailedDatapoints=True)
|
|
265
265
|
)
|
|
266
266
|
logger.info("Order is now in preview state.")
|
|
267
267
|
|
|
268
268
|
campaign_id = self.__get_campaign_id()
|
|
269
|
-
auth_url = f"https://app.{self.
|
|
269
|
+
auth_url = f"https://app.{self._openapi_service.environment}/order/detail/{self.id}/preview?campaignId={campaign_id}"
|
|
270
270
|
could_open_browser = webbrowser.open(auth_url)
|
|
271
271
|
if not could_open_browser:
|
|
272
272
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
@@ -282,7 +282,7 @@ class RapidataOrder:
|
|
|
282
282
|
for _ in range(self._max_retries):
|
|
283
283
|
try:
|
|
284
284
|
self.__pipeline_id = (
|
|
285
|
-
self.
|
|
285
|
+
self._openapi_service.order_api.order_order_id_get(
|
|
286
286
|
self.id
|
|
287
287
|
).pipeline_id
|
|
288
288
|
)
|
|
@@ -312,7 +312,7 @@ class RapidataOrder:
|
|
|
312
312
|
pipeline_id = self.__get_pipeline_id()
|
|
313
313
|
for _ in range(self._max_retries):
|
|
314
314
|
try:
|
|
315
|
-
pipeline = self.
|
|
315
|
+
pipeline = self._openapi_service.pipeline_api.pipeline_pipeline_id_get(
|
|
316
316
|
pipeline_id
|
|
317
317
|
)
|
|
318
318
|
self.__workflow_id = cast(
|
|
@@ -332,14 +332,14 @@ class RapidataOrder:
|
|
|
332
332
|
"""Internal method to fetch preliminary results."""
|
|
333
333
|
try:
|
|
334
334
|
pipeline_id = self.__get_pipeline_id()
|
|
335
|
-
download_id = self.
|
|
335
|
+
download_id = self._openapi_service.pipeline_api.pipeline_pipeline_id_preliminary_download_post(
|
|
336
336
|
pipeline_id, PreliminaryDownloadModel(sendEmail=False)
|
|
337
337
|
).download_id
|
|
338
338
|
|
|
339
339
|
elapsed = 0
|
|
340
340
|
timeout = 60
|
|
341
341
|
while elapsed < timeout:
|
|
342
|
-
preliminary_results = self.
|
|
342
|
+
preliminary_results = self._openapi_service.pipeline_api.pipeline_preliminary_download_preliminary_download_id_get(
|
|
343
343
|
preliminary_download_id=download_id
|
|
344
344
|
)
|
|
345
345
|
if preliminary_results:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Sequence, Optional, Literal
|
|
1
|
+
from typing import Sequence, Optional, Literal, get_args
|
|
2
2
|
from itertools import zip_longest
|
|
3
3
|
|
|
4
4
|
from rapidata.rapidata_client.config.tracer import tracer
|
|
@@ -41,7 +41,7 @@ from rapidata.api_client.models.filter import Filter
|
|
|
41
41
|
from rapidata.api_client.models.filter_operator import FilterOperator
|
|
42
42
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
43
43
|
from rapidata.api_client.models.sort_direction import SortDirection
|
|
44
|
-
|
|
44
|
+
from rapidata.rapidata_client.order._rapidata_order_builder import StickyStateLiteral
|
|
45
45
|
|
|
46
46
|
from tqdm import tqdm
|
|
47
47
|
|
|
@@ -61,7 +61,7 @@ class RapidataOrderManager:
|
|
|
61
61
|
self.settings = RapidataSettings
|
|
62
62
|
self.selections = RapidataSelections
|
|
63
63
|
self.__priority: int | None = None
|
|
64
|
-
self.__sticky_state:
|
|
64
|
+
self.__sticky_state: StickyStateLiteral | None = None
|
|
65
65
|
self.__asset_uploader = AssetUploader(openapi_service)
|
|
66
66
|
logger.debug("RapidataOrderManager initialized")
|
|
67
67
|
|
|
@@ -172,21 +172,21 @@ class RapidataOrderManager:
|
|
|
172
172
|
logger.debug("Order created: %s", order)
|
|
173
173
|
return order
|
|
174
174
|
|
|
175
|
-
def _set_priority(self, priority: int):
|
|
176
|
-
if not isinstance(priority, int):
|
|
177
|
-
raise TypeError("Priority must be an integer")
|
|
175
|
+
def _set_priority(self, priority: int | None):
|
|
176
|
+
if priority is not None and not isinstance(priority, int):
|
|
177
|
+
raise TypeError("Priority must be an integer or None")
|
|
178
178
|
|
|
179
|
-
if priority < 0:
|
|
180
|
-
raise ValueError("Priority must be greater than 0")
|
|
179
|
+
if priority is not None and priority < 0:
|
|
180
|
+
raise ValueError("Priority must be greater than 0 or None")
|
|
181
181
|
|
|
182
182
|
self.__priority = priority
|
|
183
183
|
|
|
184
|
-
def _set_sticky_state(
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
if sticky_state not
|
|
184
|
+
def _set_sticky_state(self, sticky_state: StickyStateLiteral | None):
|
|
185
|
+
sticky_state_valid_values = get_args(StickyStateLiteral)
|
|
186
|
+
|
|
187
|
+
if sticky_state is not None and sticky_state not in sticky_state_valid_values:
|
|
188
188
|
raise ValueError(
|
|
189
|
-
"Sticky state must be one of
|
|
189
|
+
f"Sticky state must be one of {sticky_state_valid_values} or None"
|
|
190
190
|
)
|
|
191
191
|
|
|
192
192
|
self.__sticky_state = sticky_state
|
|
@@ -392,27 +392,15 @@ class RapidataOrderManager:
|
|
|
392
392
|
if len(datapoints) < 2:
|
|
393
393
|
raise ValueError("At least two datapoints are required")
|
|
394
394
|
|
|
395
|
-
metadatas: list[Metadata] = []
|
|
396
|
-
if context:
|
|
397
|
-
if not isinstance(context, str) or context == "":
|
|
398
|
-
raise ValueError("Context must be a non-empty string")
|
|
399
|
-
metadatas.append(PromptMetadata(context))
|
|
400
|
-
if media_context:
|
|
401
|
-
if not isinstance(media_context, str) or media_context == "":
|
|
402
|
-
raise ValueError("Media context must be a non-empty string")
|
|
403
|
-
metadatas.append(
|
|
404
|
-
MediaAssetMetadata(
|
|
405
|
-
self.__asset_uploader.upload_asset(media_context)
|
|
406
|
-
)
|
|
407
|
-
)
|
|
408
|
-
|
|
409
395
|
return self._create_general_order(
|
|
410
396
|
name=name,
|
|
411
397
|
workflow=RankingWorkflow(
|
|
412
398
|
criteria=instruction,
|
|
413
399
|
total_comparison_budget=total_comparison_budget,
|
|
414
400
|
random_comparisons_ratio=random_comparisons_ratio,
|
|
415
|
-
|
|
401
|
+
context=context,
|
|
402
|
+
media_context=media_context,
|
|
403
|
+
file_uploader=self.__asset_uploader,
|
|
416
404
|
),
|
|
417
405
|
assets=datapoints,
|
|
418
406
|
data_type=data_type,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
2
|
-
from typing import Literal,
|
|
2
|
+
from typing import Literal, Any, Sequence
|
|
3
3
|
from pydantic import BaseModel, model_validator, ConfigDict
|
|
4
4
|
|
|
5
5
|
|
|
@@ -20,7 +20,7 @@ class Rapid(BaseModel):
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
@model_validator(mode="after")
|
|
23
|
-
def check_sentence_and_context(self) ->
|
|
23
|
+
def check_sentence_and_context(self) -> "Rapid":
|
|
24
24
|
if isinstance(self.sentence, str) and isinstance(self.context, str):
|
|
25
25
|
raise ValueError(
|
|
26
26
|
"Both 'sentence' and 'context' cannot be strings at the same time."
|
|
@@ -61,7 +61,7 @@ class ValidationSetManager:
|
|
|
61
61
|
"ValidationSetManager._create_order_validation_set"
|
|
62
62
|
):
|
|
63
63
|
rapids: list[Rapid] = []
|
|
64
|
-
for datapoint in datapoints:
|
|
64
|
+
for datapoint in workflow._format_datapoints(datapoints):
|
|
65
65
|
rapids.append(
|
|
66
66
|
Rapid(
|
|
67
67
|
asset=datapoint.asset,
|
|
@@ -672,13 +672,13 @@ class ValidationSetManager:
|
|
|
672
672
|
)
|
|
673
673
|
|
|
674
674
|
def find_validation_sets(
|
|
675
|
-
self, name: str = "", amount: int =
|
|
675
|
+
self, name: str = "", amount: int = 10
|
|
676
676
|
) -> list[RapidataValidationSet]:
|
|
677
677
|
"""Find validation sets by name.
|
|
678
678
|
|
|
679
679
|
Args:
|
|
680
680
|
name (str, optional): The name to search for. Defaults to "" to match with any set.
|
|
681
|
-
amount (int, optional): The amount of validation sets to return. Defaults to
|
|
681
|
+
amount (int, optional): The amount of validation sets to return. Defaults to 10.
|
|
682
682
|
|
|
683
683
|
Returns:
|
|
684
684
|
list[RapidataValidationSet]: The list of validation sets.
|
|
@@ -44,12 +44,19 @@ class Workflow(ABC):
|
|
|
44
44
|
):
|
|
45
45
|
pass
|
|
46
46
|
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def _get_instruction(self) -> str:
|
|
49
|
+
pass
|
|
50
|
+
|
|
47
51
|
@abstractmethod
|
|
48
52
|
def _to_model(
|
|
49
53
|
self,
|
|
50
54
|
) -> SimpleWorkflowModel | CompareWorkflowModel | EvaluationWorkflowModel:
|
|
51
55
|
pass
|
|
52
56
|
|
|
57
|
+
def _format_datapoints(self, datapoints: list[Datapoint]) -> list[Datapoint]:
|
|
58
|
+
return datapoints
|
|
59
|
+
|
|
53
60
|
def __str__(self) -> str:
|
|
54
61
|
return self._type
|
|
55
62
|
|
|
@@ -35,6 +35,9 @@ class ClassifyWorkflow(Workflow):
|
|
|
35
35
|
self._instruction = instruction
|
|
36
36
|
self._answer_options = answer_options
|
|
37
37
|
|
|
38
|
+
def _get_instruction(self) -> str:
|
|
39
|
+
return self._instruction
|
|
40
|
+
|
|
38
41
|
def _to_dict(self) -> dict[str, Any]:
|
|
39
42
|
return {
|
|
40
43
|
**super()._to_dict(),
|
|
@@ -16,6 +16,9 @@ class DrawWorkflow(Workflow):
|
|
|
16
16
|
super().__init__(type="SimpleWorkflowConfig")
|
|
17
17
|
self._target = target
|
|
18
18
|
|
|
19
|
+
def _get_instruction(self) -> str:
|
|
20
|
+
return self._target
|
|
21
|
+
|
|
19
22
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
20
23
|
blueprint = LineRapidBlueprint(_t="LineBlueprint", target=self._target)
|
|
21
24
|
|
|
@@ -22,6 +22,9 @@ class EvaluationWorkflow(Workflow):
|
|
|
22
22
|
self.validation_set_id = validation_set_id
|
|
23
23
|
self.should_accept_incorrect = should_accept_incorrect
|
|
24
24
|
|
|
25
|
+
def _get_instruction(self) -> str:
|
|
26
|
+
return ""
|
|
27
|
+
|
|
25
28
|
def _to_model(self):
|
|
26
29
|
return EvaluationWorkflowModel(
|
|
27
30
|
_t="EvaluationWorkflow",
|
|
@@ -33,6 +33,9 @@ class FreeTextWorkflow(Workflow):
|
|
|
33
33
|
self._instruction = instruction
|
|
34
34
|
self._validation_system_prompt = validation_system_prompt
|
|
35
35
|
|
|
36
|
+
def _get_instruction(self) -> str:
|
|
37
|
+
return self._instruction
|
|
38
|
+
|
|
36
39
|
def _to_dict(self) -> dict[str, Any]:
|
|
37
40
|
return {
|
|
38
41
|
**super()._to_dict(),
|
|
@@ -16,6 +16,9 @@ class LocateWorkflow(Workflow):
|
|
|
16
16
|
super().__init__(type="SimpleWorkflowConfig")
|
|
17
17
|
self._target = target
|
|
18
18
|
|
|
19
|
+
def _get_instruction(self) -> str:
|
|
20
|
+
return self._target
|
|
21
|
+
|
|
19
22
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
20
23
|
blueprint = LocateRapidBlueprint(_t="LocateBlueprint", target=self._target)
|
|
21
24
|
|
|
@@ -8,10 +8,17 @@ from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
|
8
8
|
from rapidata.api_client import ComparePayload
|
|
9
9
|
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
10
10
|
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
11
|
-
from rapidata.rapidata_client.datapoints.metadata import
|
|
11
|
+
from rapidata.rapidata_client.datapoints.metadata import (
|
|
12
|
+
MediaAssetMetadata,
|
|
13
|
+
PromptMetadata,
|
|
14
|
+
)
|
|
12
15
|
from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner import (
|
|
13
16
|
CreateDatapointFromFilesModelMetadataInner,
|
|
14
17
|
)
|
|
18
|
+
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
|
|
19
|
+
import itertools
|
|
20
|
+
import random
|
|
21
|
+
from typing import cast
|
|
15
22
|
|
|
16
23
|
|
|
17
24
|
class RankingWorkflow(Workflow):
|
|
@@ -25,11 +32,25 @@ class RankingWorkflow(Workflow):
|
|
|
25
32
|
elo_start: int = 1200,
|
|
26
33
|
elo_k_factor: int = 40,
|
|
27
34
|
elo_scaling_factor: int = 400,
|
|
28
|
-
|
|
35
|
+
media_context: str | None = None,
|
|
36
|
+
context: str | None = None,
|
|
37
|
+
file_uploader: AssetUploader | None = None,
|
|
29
38
|
):
|
|
30
39
|
super().__init__(type="CompareWorkflowConfig")
|
|
31
40
|
|
|
32
|
-
self.
|
|
41
|
+
self.media_context = media_context
|
|
42
|
+
self.context = context
|
|
43
|
+
|
|
44
|
+
self.metadatas = []
|
|
45
|
+
if media_context:
|
|
46
|
+
assert (
|
|
47
|
+
file_uploader is not None
|
|
48
|
+
), "File uploader is required if media_context is provided"
|
|
49
|
+
self.metadatas.append(
|
|
50
|
+
MediaAssetMetadata(file_uploader.upload_asset(media_context))
|
|
51
|
+
)
|
|
52
|
+
if context:
|
|
53
|
+
self.metadatas.append(PromptMetadata(context))
|
|
33
54
|
|
|
34
55
|
self.criteria = criteria
|
|
35
56
|
self.total_comparison_budget = total_comparison_budget
|
|
@@ -52,6 +73,9 @@ class RankingWorkflow(Workflow):
|
|
|
52
73
|
scalingFactor=elo_scaling_factor,
|
|
53
74
|
)
|
|
54
75
|
|
|
76
|
+
def _get_instruction(self) -> str:
|
|
77
|
+
return self.criteria
|
|
78
|
+
|
|
55
79
|
def _to_model(self) -> CompareWorkflowModel:
|
|
56
80
|
|
|
57
81
|
return CompareWorkflowModel(
|
|
@@ -71,6 +95,24 @@ class RankingWorkflow(Workflow):
|
|
|
71
95
|
criteria=self.criteria,
|
|
72
96
|
)
|
|
73
97
|
|
|
98
|
+
def _format_datapoints(self, datapoints: list[Datapoint]) -> list[Datapoint]:
|
|
99
|
+
if len(datapoints) < 3:
|
|
100
|
+
raise ValueError("RankingWorkflow requires at least three datapoints")
|
|
101
|
+
desired_length = len(datapoints)
|
|
102
|
+
assets = [datapoint.asset for datapoint in datapoints]
|
|
103
|
+
pairs = list(map(list, itertools.combinations(assets, 2)))
|
|
104
|
+
sampled_pairs = random.sample(pairs, desired_length)
|
|
105
|
+
formatted_datapoints = [
|
|
106
|
+
Datapoint(
|
|
107
|
+
asset=cast(list[str], pair),
|
|
108
|
+
data_type=datapoints[0].data_type,
|
|
109
|
+
context=self.context,
|
|
110
|
+
media_context=self.media_context,
|
|
111
|
+
)
|
|
112
|
+
for pair in sampled_pairs
|
|
113
|
+
]
|
|
114
|
+
return formatted_datapoints
|
|
115
|
+
|
|
74
116
|
def __str__(self) -> str:
|
|
75
117
|
return (
|
|
76
118
|
f"RankingWorkflow(criteria='{self.criteria}', metadatas={self.metadatas})"
|
|
@@ -31,6 +31,9 @@ class SelectWordsWorkflow(Workflow):
|
|
|
31
31
|
super().__init__(type="SimpleWorkflowConfig")
|
|
32
32
|
self._instruction = instruction
|
|
33
33
|
|
|
34
|
+
def _get_instruction(self) -> str:
|
|
35
|
+
return self._instruction
|
|
36
|
+
|
|
34
37
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
35
38
|
blueprint = TranscriptionRapidBlueprint(
|
|
36
39
|
_t="TranscriptionBlueprint", title=self._instruction
|
|
@@ -29,6 +29,9 @@ class TimestampWorkflow(Workflow):
|
|
|
29
29
|
super().__init__(type="SimpleWorkflowConfig")
|
|
30
30
|
self._instruction = instruction
|
|
31
31
|
|
|
32
|
+
def _get_instruction(self) -> str:
|
|
33
|
+
return self._instruction
|
|
34
|
+
|
|
32
35
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
33
36
|
blueprint = ScrubRapidBlueprint(_t="ScrubBlueprint", target=self._instruction)
|
|
34
37
|
|