orca-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +19 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +193 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +159 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +194 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +63 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +499 -0
- orca_sdk/classification_model_test.py +266 -0
- orca_sdk/conftest.py +117 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +95 -0
- orca_sdk/embedding_model.py +336 -0
- orca_sdk/embedding_model_test.py +173 -0
- orca_sdk/labeled_memoryset.py +1154 -0
- orca_sdk/labeled_memoryset_test.py +271 -0
- orca_sdk/orca_credentials.py +75 -0
- orca_sdk/orca_credentials_test.py +37 -0
- orca_sdk/telemetry.py +386 -0
- orca_sdk/telemetry_test.py +100 -0
- orca_sdk-0.1.0.dist-info/METADATA +39 -0
- orca_sdk-0.1.0.dist-info/RECORD +175 -0
- orca_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
CreateMode = Literal["error", "open"]
|
|
4
|
+
"""
|
|
5
|
+
Mode for creating a resource.
|
|
6
|
+
|
|
7
|
+
**Options:**
|
|
8
|
+
|
|
9
|
+
- `"error"`: raise an error if a resource with the same name already exists
|
|
10
|
+
- `"open"`: open the resource with the same name if it exists
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
DropMode = Literal["error", "ignore"]
|
|
14
|
+
"""
|
|
15
|
+
Mode for deleting a resource.
|
|
16
|
+
|
|
17
|
+
**Options:**
|
|
18
|
+
|
|
19
|
+
- `"error"`: raise an error if the resource does not exist
|
|
20
|
+
- `"ignore"`: do nothing if the resource does not exist
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _UnsetSentinel:
|
|
25
|
+
"""See corresponding class in orcalib.pydantic_utils"""
|
|
26
|
+
|
|
27
|
+
def __bool__(self) -> bool:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
def __repr__(self) -> str:
|
|
31
|
+
return "UNSET"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
UNSET: Any = _UnsetSentinel()
|
|
35
|
+
"""
|
|
36
|
+
Default value to indicate that no update should be applied to a field and it should not be set to None
|
|
37
|
+
"""
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from dataclasses import asdict, is_dataclass
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
from torch.utils.data import DataLoader as TorchDataLoader
|
|
8
|
+
from torch.utils.data import Dataset as TorchDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
12
|
+
if isinstance(item, dict):
|
|
13
|
+
return item
|
|
14
|
+
|
|
15
|
+
if isinstance(item, tuple):
|
|
16
|
+
if column_names is not None:
|
|
17
|
+
assert len(item) == len(column_names)
|
|
18
|
+
return {column_names[i]: item[i] for i in range(len(item))}
|
|
19
|
+
elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
|
|
20
|
+
return {field: getattr(item, field) for field in item._fields} # type: ignore
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
|
|
23
|
+
|
|
24
|
+
if is_dataclass(item) and not isinstance(item, type):
|
|
25
|
+
return asdict(item)
|
|
26
|
+
|
|
27
|
+
raise ValueError(f"Cannot parse {type(item)}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
|
|
31
|
+
if isinstance(batch, list):
|
|
32
|
+
return [parse_dict_like(item, column_names) for item in batch]
|
|
33
|
+
|
|
34
|
+
batch = parse_dict_like(batch, column_names)
|
|
35
|
+
keys = list(batch.keys())
|
|
36
|
+
batch_size = len(batch[keys[0]])
|
|
37
|
+
for key in keys:
|
|
38
|
+
if not len(batch[key]) == batch_size:
|
|
39
|
+
raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
|
|
40
|
+
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
|
|
44
|
+
if isinstance(torch_data, TorchDataLoader):
|
|
45
|
+
dataloader = torch_data
|
|
46
|
+
else:
|
|
47
|
+
dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
|
|
48
|
+
|
|
49
|
+
def generator():
|
|
50
|
+
for batch in dataloader:
|
|
51
|
+
yield from parse_batch(batch, column_names=column_names)
|
|
52
|
+
|
|
53
|
+
return cast(Dataset, Dataset.from_generator(generator))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
|
|
57
|
+
"""
|
|
58
|
+
Load a dataset from disk into a HuggingFace Dataset object.
|
|
59
|
+
|
|
60
|
+
Params:
|
|
61
|
+
file_path: Path to the file on disk to create the memoryset from. The file type will
|
|
62
|
+
be inferred from the file extension. The following file types are supported:
|
|
63
|
+
|
|
64
|
+
- .pkl: [`Pickle`][pickle] files containing lists of dictionaries or dictionaries of columns
|
|
65
|
+
- .json/.jsonl: [`JSON`][json] and [`JSON`] Lines files
|
|
66
|
+
- .csv: [`CSV`][csv] files
|
|
67
|
+
- .parquet: [`Parquet`](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile) files
|
|
68
|
+
- dataset directory: Directory containing a saved HuggingFace [`Dataset`][datasets.Dataset]
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A HuggingFace Dataset object containing the loaded data.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
[`ValueError`][ValueError]: If the pickle file contains unsupported data types or if
|
|
75
|
+
loading the dataset fails for any reason.
|
|
76
|
+
"""
|
|
77
|
+
if str(file_path).endswith(".pkl"):
|
|
78
|
+
data = pickle.load(open(file_path, "rb"))
|
|
79
|
+
if isinstance(data, list):
|
|
80
|
+
return Dataset.from_list(data)
|
|
81
|
+
elif isinstance(data, dict):
|
|
82
|
+
return Dataset.from_dict(data)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(f"Unsupported pickle file: {file_path}")
|
|
85
|
+
elif str(file_path).endswith(".json"):
|
|
86
|
+
hf_dataset = Dataset.from_json(file_path)
|
|
87
|
+
elif str(file_path).endswith(".jsonl"):
|
|
88
|
+
hf_dataset = Dataset.from_json(file_path)
|
|
89
|
+
elif str(file_path).endswith(".csv"):
|
|
90
|
+
hf_dataset = Dataset.from_csv(file_path)
|
|
91
|
+
elif str(file_path).endswith(".parquet"):
|
|
92
|
+
hf_dataset = Dataset.from_parquet(file_path)
|
|
93
|
+
else:
|
|
94
|
+
try:
|
|
95
|
+
hf_dataset = Dataset.load_from_disk(file_path)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
raise ValueError(f"Failed to load dataset from disk: {e}")
|
|
98
|
+
|
|
99
|
+
return cast(Dataset, hf_dataset)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import pickle
|
|
3
|
+
import tempfile
|
|
4
|
+
from collections import namedtuple
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import pytest
|
|
9
|
+
from datasets import Dataset
|
|
10
|
+
from datasets.exceptions import DatasetGenerationError
|
|
11
|
+
from torch.utils.data import DataLoader as TorchDataLoader
|
|
12
|
+
from torch.utils.data import Dataset as TorchDataset
|
|
13
|
+
|
|
14
|
+
from ..conftest import SAMPLE_DATA
|
|
15
|
+
from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PytorchDictDataset(TorchDataset):
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.data = SAMPLE_DATA
|
|
21
|
+
|
|
22
|
+
def __getitem__(self, i):
|
|
23
|
+
return self.data[i]
|
|
24
|
+
|
|
25
|
+
def __len__(self):
|
|
26
|
+
return len(self.data)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_hf_dataset_from_torch_dict():
|
|
30
|
+
# Given a Pytorch dataset that returns a dictionary for each item
|
|
31
|
+
dataset = PytorchDictDataset()
|
|
32
|
+
hf_dataset = hf_dataset_from_torch(dataset)
|
|
33
|
+
# Then the HF dataset should be created successfully
|
|
34
|
+
assert isinstance(hf_dataset, Dataset)
|
|
35
|
+
assert len(hf_dataset) == len(dataset)
|
|
36
|
+
assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PytorchTupleDataset(TorchDataset):
|
|
40
|
+
def __init__(self):
|
|
41
|
+
self.data = SAMPLE_DATA
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, i):
|
|
44
|
+
return self.data[i]["text"], self.data[i]["label"]
|
|
45
|
+
|
|
46
|
+
def __len__(self):
|
|
47
|
+
return len(self.data)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_hf_dataset_from_torch_tuple():
|
|
51
|
+
# Given a Pytorch dataset that returns a tuple for each item
|
|
52
|
+
dataset = PytorchTupleDataset()
|
|
53
|
+
# And the correct number of column names passed in
|
|
54
|
+
hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
|
|
55
|
+
# Then the HF dataset should be created successfully
|
|
56
|
+
assert isinstance(hf_dataset, Dataset)
|
|
57
|
+
assert len(hf_dataset) == len(dataset)
|
|
58
|
+
assert hf_dataset.column_names == ["text", "label"]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_hf_dataset_from_torch_tuple_error():
|
|
62
|
+
# Given a Pytorch dataset that returns a tuple for each item
|
|
63
|
+
dataset = PytorchTupleDataset()
|
|
64
|
+
# Then the HF dataset should raise an error if no column names are passed in
|
|
65
|
+
with pytest.raises(DatasetGenerationError):
|
|
66
|
+
hf_dataset_from_torch(dataset)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
70
|
+
# Given a Pytorch dataset that returns a tuple for each item
|
|
71
|
+
dataset = PytorchTupleDataset()
|
|
72
|
+
# Then the HF dataset should raise an error if not enough column names are passed in
|
|
73
|
+
with pytest.raises(DatasetGenerationError):
|
|
74
|
+
hf_dataset_from_torch(dataset, column_names=["value"])
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PytorchNamedTupleDataset(TorchDataset):
|
|
81
|
+
def __init__(self):
|
|
82
|
+
self.data = SAMPLE_DATA
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, i):
|
|
85
|
+
return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
|
|
86
|
+
|
|
87
|
+
def __len__(self):
|
|
88
|
+
return len(self.data)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def test_hf_dataset_from_torch_named_tuple():
|
|
92
|
+
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
93
|
+
dataset = PytorchNamedTupleDataset()
|
|
94
|
+
# And no column names are passed in
|
|
95
|
+
hf_dataset = hf_dataset_from_torch(dataset)
|
|
96
|
+
# Then the HF dataset should be created successfully
|
|
97
|
+
assert isinstance(hf_dataset, Dataset)
|
|
98
|
+
assert len(hf_dataset) == len(dataset)
|
|
99
|
+
assert hf_dataset.column_names == ["value", "label"]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass
|
|
103
|
+
class DatasetItem:
|
|
104
|
+
text: str
|
|
105
|
+
label: int
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class PytorchDataclassDataset(TorchDataset):
|
|
109
|
+
def __init__(self):
|
|
110
|
+
self.data = SAMPLE_DATA
|
|
111
|
+
|
|
112
|
+
def __getitem__(self, i):
|
|
113
|
+
return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
|
|
114
|
+
|
|
115
|
+
def __len__(self):
|
|
116
|
+
return len(self.data)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_hf_dataset_from_torch_dataclass():
|
|
120
|
+
# Given a Pytorch dataset that returns a dataclass for each item
|
|
121
|
+
dataset = PytorchDataclassDataset()
|
|
122
|
+
hf_dataset = hf_dataset_from_torch(dataset)
|
|
123
|
+
# Then the HF dataset should be created successfully
|
|
124
|
+
assert isinstance(hf_dataset, Dataset)
|
|
125
|
+
assert len(hf_dataset) == len(dataset)
|
|
126
|
+
assert hf_dataset.column_names == ["text", "label"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class PytorchInvalidDataset(TorchDataset):
|
|
130
|
+
def __init__(self):
|
|
131
|
+
self.data = SAMPLE_DATA
|
|
132
|
+
|
|
133
|
+
def __getitem__(self, i):
|
|
134
|
+
return [self.data[i]["text"], self.data[i]["label"]]
|
|
135
|
+
|
|
136
|
+
def __len__(self):
|
|
137
|
+
return len(self.data)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_hf_dataset_from_torch_invalid_dataset():
|
|
141
|
+
# Given a Pytorch dataset that returns a list for each item
|
|
142
|
+
dataset = PytorchInvalidDataset()
|
|
143
|
+
# Then the HF dataset should raise an error
|
|
144
|
+
with pytest.raises(DatasetGenerationError):
|
|
145
|
+
hf_dataset_from_torch(dataset)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_hf_dataset_from_torchdataloader():
|
|
149
|
+
# Given a Pytorch dataloader that returns a column-oriented batch of items
|
|
150
|
+
dataset = PytorchDictDataset()
|
|
151
|
+
|
|
152
|
+
def collate_fn(x: list[dict]):
|
|
153
|
+
return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
|
|
154
|
+
|
|
155
|
+
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
156
|
+
hf_dataset = hf_dataset_from_torch(dataloader)
|
|
157
|
+
# Then the HF dataset should be created successfully
|
|
158
|
+
assert isinstance(hf_dataset, Dataset)
|
|
159
|
+
assert len(hf_dataset) == len(dataset)
|
|
160
|
+
assert hf_dataset.column_names == ["value", "label"]
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_hf_dataset_from_disk_pickle_list():
|
|
164
|
+
with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
|
|
165
|
+
# Given a pickle file with test data that is a list
|
|
166
|
+
test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
|
|
167
|
+
with open(temp_file.name, "wb") as f:
|
|
168
|
+
pickle.dump(test_data, f)
|
|
169
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
170
|
+
# Then the HF dataset should be created successfully
|
|
171
|
+
assert isinstance(dataset, Dataset)
|
|
172
|
+
assert len(dataset) == 30
|
|
173
|
+
assert dataset.column_names == ["value", "label"]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_hf_dataset_from_disk_pickle_dict():
|
|
177
|
+
with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
|
|
178
|
+
# Given a pickle file with test data that is a dict
|
|
179
|
+
test_data = {"value": [f"test_{i}" for i in range(30)], "label": [i % 2 for i in range(30)]}
|
|
180
|
+
with open(temp_file.name, "wb") as f:
|
|
181
|
+
pickle.dump(test_data, f)
|
|
182
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
183
|
+
# Then the HF dataset should be created successfully
|
|
184
|
+
assert isinstance(dataset, Dataset)
|
|
185
|
+
assert len(dataset) == 30
|
|
186
|
+
assert dataset.column_names == ["value", "label"]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_hf_dataset_from_disk_json():
|
|
190
|
+
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
|
|
191
|
+
# Given a JSON file with test data
|
|
192
|
+
test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
|
|
193
|
+
with open(temp_file.name, "w") as f:
|
|
194
|
+
json.dump(test_data, f)
|
|
195
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
196
|
+
# Then the HF dataset should be created successfully
|
|
197
|
+
assert isinstance(dataset, Dataset)
|
|
198
|
+
assert len(dataset) == 30
|
|
199
|
+
assert dataset.column_names == ["value", "label"]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_hf_dataset_from_disk_jsonl():
|
|
203
|
+
with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file:
|
|
204
|
+
# Given a JSONL file with test data
|
|
205
|
+
test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
|
|
206
|
+
with open(temp_file.name, "w") as f:
|
|
207
|
+
for item in test_data:
|
|
208
|
+
f.write(json.dumps(item) + "\n")
|
|
209
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
210
|
+
# Then the HF dataset should be created successfully
|
|
211
|
+
assert isinstance(dataset, Dataset)
|
|
212
|
+
assert len(dataset) == 30
|
|
213
|
+
assert dataset.column_names == ["value", "label"]
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_hf_dataset_from_disk_csv():
|
|
217
|
+
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
|
|
218
|
+
# Given a CSV file with test data
|
|
219
|
+
test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
|
|
220
|
+
with open(temp_file.name, "w") as f:
|
|
221
|
+
f.write("value,label\n")
|
|
222
|
+
for item in test_data:
|
|
223
|
+
f.write(f"{item['value']},{item['label']}\n")
|
|
224
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
225
|
+
# Then the HF dataset should be created successfully
|
|
226
|
+
assert isinstance(dataset, Dataset)
|
|
227
|
+
assert len(dataset) == 30
|
|
228
|
+
assert dataset.column_names == ["value", "label"]
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def test_hf_dataset_from_disk_parquet():
|
|
232
|
+
with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
|
|
233
|
+
# Given a Parquet file with test data
|
|
234
|
+
data = {
|
|
235
|
+
"value": [f"test_{i}" for i in range(30)],
|
|
236
|
+
"label": [i % 2 for i in range(30)],
|
|
237
|
+
}
|
|
238
|
+
df = pd.DataFrame(data)
|
|
239
|
+
df.to_parquet(temp_file.name)
|
|
240
|
+
dataset = hf_dataset_from_disk(temp_file.name)
|
|
241
|
+
# Then the HF dataset should be created successfully
|
|
242
|
+
assert isinstance(dataset, Dataset)
|
|
243
|
+
assert len(dataset) == 30
|
|
244
|
+
assert dataset.column_names == ["value", "label"]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
.white {
|
|
2
|
+
background-color: white;
|
|
3
|
+
}
|
|
4
|
+
.success {
|
|
5
|
+
color: gray;
|
|
6
|
+
font-size: 12px;
|
|
7
|
+
height: 24px;
|
|
8
|
+
}
|
|
9
|
+
.html-container:has(.no-padding) {
|
|
10
|
+
padding: 0;
|
|
11
|
+
height: 24px;
|
|
12
|
+
}
|
|
13
|
+
.progress-bar {
|
|
14
|
+
background-color: #2b9a66;
|
|
15
|
+
}
|
|
16
|
+
.progress-level-inner {
|
|
17
|
+
display: none;
|
|
18
|
+
}
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import gradio as gr
|
|
7
|
+
|
|
8
|
+
from ..labeled_memoryset import LabeledMemoryLookup
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..telemetry import LabelPrediction
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def inspect_prediction_result(prediction_result: "LabelPrediction"):
|
|
15
|
+
label_names = prediction_result.memoryset.label_names
|
|
16
|
+
|
|
17
|
+
def update_label(val: str, memory: LabeledMemoryLookup, progress=gr.Progress(track_tqdm=True)):
|
|
18
|
+
progress(0)
|
|
19
|
+
match = re.search(r".*\((\d+)\)$", val)
|
|
20
|
+
if match:
|
|
21
|
+
progress(0.5)
|
|
22
|
+
new_label = int(match.group(1))
|
|
23
|
+
memory.update(label=new_label)
|
|
24
|
+
progress(1)
|
|
25
|
+
return "✅ Changes saved"
|
|
26
|
+
else:
|
|
27
|
+
logging.error(f"Invalid label format: {val}")
|
|
28
|
+
|
|
29
|
+
with gr.Blocks(
|
|
30
|
+
fill_width=True,
|
|
31
|
+
title="Prediction Results",
|
|
32
|
+
css_paths=str(Path(__file__).parent / "prediction_result_ui.css"),
|
|
33
|
+
) as prediction_result_ui:
|
|
34
|
+
gr.Markdown("# Prediction Results")
|
|
35
|
+
gr.Markdown(f"**Input:** {prediction_result.input_value}")
|
|
36
|
+
gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
|
|
37
|
+
gr.Markdown("### Memory Lookups")
|
|
38
|
+
|
|
39
|
+
with gr.Row(equal_height=True, variant="panel"):
|
|
40
|
+
with gr.Column(scale=7):
|
|
41
|
+
gr.Markdown("**Value**")
|
|
42
|
+
with gr.Column(scale=3, min_width=150):
|
|
43
|
+
gr.Markdown("**Label**")
|
|
44
|
+
for i, mem_lookup in enumerate(prediction_result.memory_lookups):
|
|
45
|
+
with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
|
|
46
|
+
with gr.Column(scale=7):
|
|
47
|
+
gr.Markdown(mem_lookup.value, label="Value", height=50)
|
|
48
|
+
with gr.Column(scale=3, min_width=150):
|
|
49
|
+
dropdown = gr.Dropdown(
|
|
50
|
+
choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
|
|
51
|
+
label="Label",
|
|
52
|
+
value=f"{label_names[mem_lookup.label]} ({mem_lookup.label})",
|
|
53
|
+
interactive=True,
|
|
54
|
+
container=False,
|
|
55
|
+
)
|
|
56
|
+
changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
|
|
57
|
+
dropdown.change(
|
|
58
|
+
lambda val, mem_lookup=mem_lookup: update_label(val, mem_lookup),
|
|
59
|
+
inputs=[dropdown],
|
|
60
|
+
outputs=[changes_saved],
|
|
61
|
+
show_progress="full",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
prediction_result_ui.launch()
|
orca_sdk/_utils/task.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from tqdm.auto import tqdm
|
|
4
|
+
|
|
5
|
+
from .._generated_api_client.api import abort_task as _abort_task
|
|
6
|
+
from .._generated_api_client.api import get_task_status_task
|
|
7
|
+
from .._generated_api_client.api import list_tasks as _list_tasks
|
|
8
|
+
from .._generated_api_client.models import Task, TaskStatus, TaskStatusInfo
|
|
9
|
+
|
|
10
|
+
task_config = {
|
|
11
|
+
"retry_interval": 3,
|
|
12
|
+
"show_progress": True,
|
|
13
|
+
"max_wait": 60 * 60,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def set_task_config(
|
|
18
|
+
retry_interval: int | None = None,
|
|
19
|
+
show_progress: bool | None = None,
|
|
20
|
+
max_wait: int | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
if retry_interval is not None:
|
|
23
|
+
task_config["retry_interval"] = retry_interval
|
|
24
|
+
if show_progress is not None:
|
|
25
|
+
task_config["show_progress"] = show_progress
|
|
26
|
+
if max_wait is not None:
|
|
27
|
+
task_config["max_wait"] = max_wait
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def wait_for_task(task_id: str, description: str | None = None, show_progress: bool = True) -> None:
|
|
31
|
+
start_time = time.time()
|
|
32
|
+
pbar = None
|
|
33
|
+
steps_total = None
|
|
34
|
+
show_progress = show_progress and task_config["show_progress"]
|
|
35
|
+
while True:
|
|
36
|
+
task_status = get_task_status_task(task_id)
|
|
37
|
+
|
|
38
|
+
# setup progress bar if steps total is known
|
|
39
|
+
if task_status.steps_total is not None and steps_total is None:
|
|
40
|
+
steps_total = task_status.steps_total
|
|
41
|
+
if not pbar and steps_total is not None and show_progress:
|
|
42
|
+
pbar = tqdm(total=steps_total, desc=description)
|
|
43
|
+
|
|
44
|
+
# return if task is complete
|
|
45
|
+
if task_status.status == TaskStatus.COMPLETED:
|
|
46
|
+
if pbar:
|
|
47
|
+
pbar.update(steps_total - pbar.n)
|
|
48
|
+
pbar.close()
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
# raise error if task failed
|
|
52
|
+
if task_status.status == TaskStatus.FAILED:
|
|
53
|
+
raise RuntimeError(f"Task failed with {task_status.exception}")
|
|
54
|
+
|
|
55
|
+
# raise error if task timed out
|
|
56
|
+
if (time.time() - start_time) > task_config["max_wait"]:
|
|
57
|
+
raise RuntimeError(f"Task {task_id} timed out after {task_config['max_wait']}s")
|
|
58
|
+
|
|
59
|
+
# update progress bar
|
|
60
|
+
if pbar and task_status.steps_completed is not None:
|
|
61
|
+
pbar.update(task_status.steps_completed - pbar.n)
|
|
62
|
+
|
|
63
|
+
# sleep before retrying
|
|
64
|
+
time.sleep(task_config["retry_interval"])
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def abort_task(task_id: str) -> TaskStatusInfo:
|
|
68
|
+
_abort_task(task_id)
|
|
69
|
+
return get_task_status_task(task_id)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def list_tasks() -> list[Task]:
|
|
73
|
+
return _list_tasks()
|