orca-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. orca_sdk/__init__.py +19 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +193 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
  22. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
  23. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
  24. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
  25. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  26. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
  27. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
  28. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
  30. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
  34. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  35. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
  36. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
  37. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
  38. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
  39. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
  40. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
  41. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
  42. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
  43. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
  45. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
  46. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
  47. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
  48. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
  49. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
  50. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
  51. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
  52. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  53. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
  54. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
  56. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
  58. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
  59. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
  60. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
  62. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
  63. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
  64. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
  65. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
  66. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
  67. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
  68. orca_sdk/_generated_api_client/client.py +216 -0
  69. orca_sdk/_generated_api_client/errors.py +38 -0
  70. orca_sdk/_generated_api_client/models/__init__.py +159 -0
  71. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
  72. orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
  73. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  74. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  75. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
  76. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  77. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  78. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  79. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  80. orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
  81. orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
  82. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
  83. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  84. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  85. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  86. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  87. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  88. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  89. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  90. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  91. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  92. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  93. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  94. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
  95. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  96. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  97. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  98. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  99. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  100. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  101. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  102. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  103. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  104. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  105. orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
  106. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
  107. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  108. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
  109. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  110. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  111. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  112. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  113. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
  114. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  115. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  116. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  117. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  118. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  119. orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
  120. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
  121. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
  122. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  123. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  124. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  125. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  126. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  127. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  128. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  129. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
  130. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  131. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  132. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  133. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  134. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  135. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  136. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
  137. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  138. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  139. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  140. orca_sdk/_generated_api_client/models/task.py +198 -0
  141. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  142. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  143. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  144. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  145. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  146. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  147. orca_sdk/_generated_api_client/py.typed +1 -0
  148. orca_sdk/_generated_api_client/types.py +56 -0
  149. orca_sdk/_utils/__init__.py +0 -0
  150. orca_sdk/_utils/analysis_ui.py +194 -0
  151. orca_sdk/_utils/analysis_ui_style.css +54 -0
  152. orca_sdk/_utils/auth.py +63 -0
  153. orca_sdk/_utils/auth_test.py +31 -0
  154. orca_sdk/_utils/common.py +37 -0
  155. orca_sdk/_utils/data_parsing.py +99 -0
  156. orca_sdk/_utils/data_parsing_test.py +244 -0
  157. orca_sdk/_utils/prediction_result_ui.css +18 -0
  158. orca_sdk/_utils/prediction_result_ui.py +64 -0
  159. orca_sdk/_utils/task.py +73 -0
  160. orca_sdk/classification_model.py +499 -0
  161. orca_sdk/classification_model_test.py +266 -0
  162. orca_sdk/conftest.py +117 -0
  163. orca_sdk/datasource.py +333 -0
  164. orca_sdk/datasource_test.py +95 -0
  165. orca_sdk/embedding_model.py +336 -0
  166. orca_sdk/embedding_model_test.py +173 -0
  167. orca_sdk/labeled_memoryset.py +1154 -0
  168. orca_sdk/labeled_memoryset_test.py +271 -0
  169. orca_sdk/orca_credentials.py +75 -0
  170. orca_sdk/orca_credentials_test.py +37 -0
  171. orca_sdk/telemetry.py +386 -0
  172. orca_sdk/telemetry_test.py +100 -0
  173. orca_sdk-0.1.0.dist-info/METADATA +39 -0
  174. orca_sdk-0.1.0.dist-info/RECORD +175 -0
  175. orca_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,37 @@
1
+ from typing import Any, Literal
2
+
3
+ CreateMode = Literal["error", "open"]
4
+ """
5
+ Mode for creating a resource.
6
+
7
+ **Options:**
8
+
9
+ - `"error"`: raise an error if a resource with the same name already exists
10
+ - `"open"`: open the resource with the same name if it exists
11
+ """
12
+
13
+ DropMode = Literal["error", "ignore"]
14
+ """
15
+ Mode for deleting a resource.
16
+
17
+ **Options:**
18
+
19
+ - `"error"`: raise an error if the resource does not exist
20
+ - `"ignore"`: do nothing if the resource does not exist
21
+ """
22
+
23
+
24
+ class _UnsetSentinel:
25
+ """See corresponding class in orcalib.pydantic_utils"""
26
+
27
+ def __bool__(self) -> bool:
28
+ return False
29
+
30
+ def __repr__(self) -> str:
31
+ return "UNSET"
32
+
33
+
34
+ UNSET: Any = _UnsetSentinel()
35
+ """
36
+ Default value to indicate that no update should be applied to a field and it should not be set to None
37
+ """
@@ -0,0 +1,99 @@
1
+ import pickle
2
+ from dataclasses import asdict, is_dataclass
3
+ from os import PathLike
4
+ from typing import Any, cast
5
+
6
+ from datasets import Dataset
7
+ from torch.utils.data import DataLoader as TorchDataLoader
8
+ from torch.utils.data import Dataset as TorchDataset
9
+
10
+
11
+ def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
12
+ if isinstance(item, dict):
13
+ return item
14
+
15
+ if isinstance(item, tuple):
16
+ if column_names is not None:
17
+ assert len(item) == len(column_names)
18
+ return {column_names[i]: item[i] for i in range(len(item))}
19
+ elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
20
+ return {field: getattr(item, field) for field in item._fields} # type: ignore
21
+ else:
22
+ raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
23
+
24
+ if is_dataclass(item) and not isinstance(item, type):
25
+ return asdict(item)
26
+
27
+ raise ValueError(f"Cannot parse {type(item)}")
28
+
29
+
30
+ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
31
+ if isinstance(batch, list):
32
+ return [parse_dict_like(item, column_names) for item in batch]
33
+
34
+ batch = parse_dict_like(batch, column_names)
35
+ keys = list(batch.keys())
36
+ batch_size = len(batch[keys[0]])
37
+ for key in keys:
38
+ if not len(batch[key]) == batch_size:
39
+ raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
40
+ return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
41
+
42
+
43
+ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
44
+ if isinstance(torch_data, TorchDataLoader):
45
+ dataloader = torch_data
46
+ else:
47
+ dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
48
+
49
+ def generator():
50
+ for batch in dataloader:
51
+ yield from parse_batch(batch, column_names=column_names)
52
+
53
+ return cast(Dataset, Dataset.from_generator(generator))
54
+
55
+
56
+ def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
57
+ """
58
+ Load a dataset from disk into a HuggingFace Dataset object.
59
+
60
+ Params:
61
+ file_path: Path to the file on disk to create the memoryset from. The file type will
62
+ be inferred from the file extension. The following file types are supported:
63
+
64
+ - .pkl: [`Pickle`][pickle] files containing lists of dictionaries or dictionaries of columns
65
+ - .json/.jsonl: [`JSON`][json] and [`JSON`] Lines files
66
+ - .csv: [`CSV`][csv] files
67
+ - .parquet: [`Parquet`](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile) files
68
+ - dataset directory: Directory containing a saved HuggingFace [`Dataset`][datasets.Dataset]
69
+
70
+ Returns:
71
+ A HuggingFace Dataset object containing the loaded data.
72
+
73
+ Raises:
74
+ [`ValueError`][ValueError]: If the pickle file contains unsupported data types or if
75
+ loading the dataset fails for any reason.
76
+ """
77
+ if str(file_path).endswith(".pkl"):
78
+ data = pickle.load(open(file_path, "rb"))
79
+ if isinstance(data, list):
80
+ return Dataset.from_list(data)
81
+ elif isinstance(data, dict):
82
+ return Dataset.from_dict(data)
83
+ else:
84
+ raise ValueError(f"Unsupported pickle file: {file_path}")
85
+ elif str(file_path).endswith(".json"):
86
+ hf_dataset = Dataset.from_json(file_path)
87
+ elif str(file_path).endswith(".jsonl"):
88
+ hf_dataset = Dataset.from_json(file_path)
89
+ elif str(file_path).endswith(".csv"):
90
+ hf_dataset = Dataset.from_csv(file_path)
91
+ elif str(file_path).endswith(".parquet"):
92
+ hf_dataset = Dataset.from_parquet(file_path)
93
+ else:
94
+ try:
95
+ hf_dataset = Dataset.load_from_disk(file_path)
96
+ except Exception as e:
97
+ raise ValueError(f"Failed to load dataset from disk: {e}")
98
+
99
+ return cast(Dataset, hf_dataset)
@@ -0,0 +1,244 @@
1
+ import json
2
+ import pickle
3
+ import tempfile
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+
7
+ import pandas as pd
8
+ import pytest
9
+ from datasets import Dataset
10
+ from datasets.exceptions import DatasetGenerationError
11
+ from torch.utils.data import DataLoader as TorchDataLoader
12
+ from torch.utils.data import Dataset as TorchDataset
13
+
14
+ from ..conftest import SAMPLE_DATA
15
+ from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
16
+
17
+
18
+ class PytorchDictDataset(TorchDataset):
19
+ def __init__(self):
20
+ self.data = SAMPLE_DATA
21
+
22
+ def __getitem__(self, i):
23
+ return self.data[i]
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+
29
+ def test_hf_dataset_from_torch_dict():
30
+ # Given a Pytorch dataset that returns a dictionary for each item
31
+ dataset = PytorchDictDataset()
32
+ hf_dataset = hf_dataset_from_torch(dataset)
33
+ # Then the HF dataset should be created successfully
34
+ assert isinstance(hf_dataset, Dataset)
35
+ assert len(hf_dataset) == len(dataset)
36
+ assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
37
+
38
+
39
+ class PytorchTupleDataset(TorchDataset):
40
+ def __init__(self):
41
+ self.data = SAMPLE_DATA
42
+
43
+ def __getitem__(self, i):
44
+ return self.data[i]["text"], self.data[i]["label"]
45
+
46
+ def __len__(self):
47
+ return len(self.data)
48
+
49
+
50
+ def test_hf_dataset_from_torch_tuple():
51
+ # Given a Pytorch dataset that returns a tuple for each item
52
+ dataset = PytorchTupleDataset()
53
+ # And the correct number of column names passed in
54
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
55
+ # Then the HF dataset should be created successfully
56
+ assert isinstance(hf_dataset, Dataset)
57
+ assert len(hf_dataset) == len(dataset)
58
+ assert hf_dataset.column_names == ["text", "label"]
59
+
60
+
61
+ def test_hf_dataset_from_torch_tuple_error():
62
+ # Given a Pytorch dataset that returns a tuple for each item
63
+ dataset = PytorchTupleDataset()
64
+ # Then the HF dataset should raise an error if no column names are passed in
65
+ with pytest.raises(DatasetGenerationError):
66
+ hf_dataset_from_torch(dataset)
67
+
68
+
69
+ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
70
+ # Given a Pytorch dataset that returns a tuple for each item
71
+ dataset = PytorchTupleDataset()
72
+ # Then the HF dataset should raise an error if not enough column names are passed in
73
+ with pytest.raises(DatasetGenerationError):
74
+ hf_dataset_from_torch(dataset, column_names=["value"])
75
+
76
+
77
+ DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
78
+
79
+
80
+ class PytorchNamedTupleDataset(TorchDataset):
81
+ def __init__(self):
82
+ self.data = SAMPLE_DATA
83
+
84
+ def __getitem__(self, i):
85
+ return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
86
+
87
+ def __len__(self):
88
+ return len(self.data)
89
+
90
+
91
+ def test_hf_dataset_from_torch_named_tuple():
92
+ # Given a Pytorch dataset that returns a namedtuple for each item
93
+ dataset = PytorchNamedTupleDataset()
94
+ # And no column names are passed in
95
+ hf_dataset = hf_dataset_from_torch(dataset)
96
+ # Then the HF dataset should be created successfully
97
+ assert isinstance(hf_dataset, Dataset)
98
+ assert len(hf_dataset) == len(dataset)
99
+ assert hf_dataset.column_names == ["value", "label"]
100
+
101
+
102
+ @dataclass
103
+ class DatasetItem:
104
+ text: str
105
+ label: int
106
+
107
+
108
+ class PytorchDataclassDataset(TorchDataset):
109
+ def __init__(self):
110
+ self.data = SAMPLE_DATA
111
+
112
+ def __getitem__(self, i):
113
+ return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
114
+
115
+ def __len__(self):
116
+ return len(self.data)
117
+
118
+
119
+ def test_hf_dataset_from_torch_dataclass():
120
+ # Given a Pytorch dataset that returns a dataclass for each item
121
+ dataset = PytorchDataclassDataset()
122
+ hf_dataset = hf_dataset_from_torch(dataset)
123
+ # Then the HF dataset should be created successfully
124
+ assert isinstance(hf_dataset, Dataset)
125
+ assert len(hf_dataset) == len(dataset)
126
+ assert hf_dataset.column_names == ["text", "label"]
127
+
128
+
129
+ class PytorchInvalidDataset(TorchDataset):
130
+ def __init__(self):
131
+ self.data = SAMPLE_DATA
132
+
133
+ def __getitem__(self, i):
134
+ return [self.data[i]["text"], self.data[i]["label"]]
135
+
136
+ def __len__(self):
137
+ return len(self.data)
138
+
139
+
140
+ def test_hf_dataset_from_torch_invalid_dataset():
141
+ # Given a Pytorch dataset that returns a list for each item
142
+ dataset = PytorchInvalidDataset()
143
+ # Then the HF dataset should raise an error
144
+ with pytest.raises(DatasetGenerationError):
145
+ hf_dataset_from_torch(dataset)
146
+
147
+
148
+ def test_hf_dataset_from_torchdataloader():
149
+ # Given a Pytorch dataloader that returns a column-oriented batch of items
150
+ dataset = PytorchDictDataset()
151
+
152
+ def collate_fn(x: list[dict]):
153
+ return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
154
+
155
+ dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
+ hf_dataset = hf_dataset_from_torch(dataloader)
157
+ # Then the HF dataset should be created successfully
158
+ assert isinstance(hf_dataset, Dataset)
159
+ assert len(hf_dataset) == len(dataset)
160
+ assert hf_dataset.column_names == ["value", "label"]
161
+
162
+
163
+ def test_hf_dataset_from_disk_pickle_list():
164
+ with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
165
+ # Given a pickle file with test data that is a list
166
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
167
+ with open(temp_file.name, "wb") as f:
168
+ pickle.dump(test_data, f)
169
+ dataset = hf_dataset_from_disk(temp_file.name)
170
+ # Then the HF dataset should be created successfully
171
+ assert isinstance(dataset, Dataset)
172
+ assert len(dataset) == 30
173
+ assert dataset.column_names == ["value", "label"]
174
+
175
+
176
+ def test_hf_dataset_from_disk_pickle_dict():
177
+ with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
178
+ # Given a pickle file with test data that is a dict
179
+ test_data = {"value": [f"test_{i}" for i in range(30)], "label": [i % 2 for i in range(30)]}
180
+ with open(temp_file.name, "wb") as f:
181
+ pickle.dump(test_data, f)
182
+ dataset = hf_dataset_from_disk(temp_file.name)
183
+ # Then the HF dataset should be created successfully
184
+ assert isinstance(dataset, Dataset)
185
+ assert len(dataset) == 30
186
+ assert dataset.column_names == ["value", "label"]
187
+
188
+
189
+ def test_hf_dataset_from_disk_json():
190
+ with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
191
+ # Given a JSON file with test data
192
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
193
+ with open(temp_file.name, "w") as f:
194
+ json.dump(test_data, f)
195
+ dataset = hf_dataset_from_disk(temp_file.name)
196
+ # Then the HF dataset should be created successfully
197
+ assert isinstance(dataset, Dataset)
198
+ assert len(dataset) == 30
199
+ assert dataset.column_names == ["value", "label"]
200
+
201
+
202
+ def test_hf_dataset_from_disk_jsonl():
203
+ with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file:
204
+ # Given a JSONL file with test data
205
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
206
+ with open(temp_file.name, "w") as f:
207
+ for item in test_data:
208
+ f.write(json.dumps(item) + "\n")
209
+ dataset = hf_dataset_from_disk(temp_file.name)
210
+ # Then the HF dataset should be created successfully
211
+ assert isinstance(dataset, Dataset)
212
+ assert len(dataset) == 30
213
+ assert dataset.column_names == ["value", "label"]
214
+
215
+
216
+ def test_hf_dataset_from_disk_csv():
217
+ with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
218
+ # Given a CSV file with test data
219
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
220
+ with open(temp_file.name, "w") as f:
221
+ f.write("value,label\n")
222
+ for item in test_data:
223
+ f.write(f"{item['value']},{item['label']}\n")
224
+ dataset = hf_dataset_from_disk(temp_file.name)
225
+ # Then the HF dataset should be created successfully
226
+ assert isinstance(dataset, Dataset)
227
+ assert len(dataset) == 30
228
+ assert dataset.column_names == ["value", "label"]
229
+
230
+
231
+ def test_hf_dataset_from_disk_parquet():
232
+ with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
233
+ # Given a Parquet file with test data
234
+ data = {
235
+ "value": [f"test_{i}" for i in range(30)],
236
+ "label": [i % 2 for i in range(30)],
237
+ }
238
+ df = pd.DataFrame(data)
239
+ df.to_parquet(temp_file.name)
240
+ dataset = hf_dataset_from_disk(temp_file.name)
241
+ # Then the HF dataset should be created successfully
242
+ assert isinstance(dataset, Dataset)
243
+ assert len(dataset) == 30
244
+ assert dataset.column_names == ["value", "label"]
@@ -0,0 +1,18 @@
1
+ .white {
2
+ background-color: white;
3
+ }
4
+ .success {
5
+ color: gray;
6
+ font-size: 12px;
7
+ height: 24px;
8
+ }
9
+ .html-container:has(.no-padding) {
10
+ padding: 0;
11
+ height: 24px;
12
+ }
13
+ .progress-bar {
14
+ background-color: #2b9a66;
15
+ }
16
+ .progress-level-inner {
17
+ display: none;
18
+ }
@@ -0,0 +1,64 @@
1
+ import logging
2
+ import re
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import gradio as gr
7
+
8
+ from ..labeled_memoryset import LabeledMemoryLookup
9
+
10
+ if TYPE_CHECKING:
11
+ from ..telemetry import LabelPrediction
12
+
13
+
14
+ def inspect_prediction_result(prediction_result: "LabelPrediction"):
15
+ label_names = prediction_result.memoryset.label_names
16
+
17
+ def update_label(val: str, memory: LabeledMemoryLookup, progress=gr.Progress(track_tqdm=True)):
18
+ progress(0)
19
+ match = re.search(r".*\((\d+)\)$", val)
20
+ if match:
21
+ progress(0.5)
22
+ new_label = int(match.group(1))
23
+ memory.update(label=new_label)
24
+ progress(1)
25
+ return "✅ Changes saved"
26
+ else:
27
+ logging.error(f"Invalid label format: {val}")
28
+
29
+ with gr.Blocks(
30
+ fill_width=True,
31
+ title="Prediction Results",
32
+ css_paths=str(Path(__file__).parent / "prediction_result_ui.css"),
33
+ ) as prediction_result_ui:
34
+ gr.Markdown("# Prediction Results")
35
+ gr.Markdown(f"**Input:** {prediction_result.input_value}")
36
+ gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
37
+ gr.Markdown("### Memory Lookups")
38
+
39
+ with gr.Row(equal_height=True, variant="panel"):
40
+ with gr.Column(scale=7):
41
+ gr.Markdown("**Value**")
42
+ with gr.Column(scale=3, min_width=150):
43
+ gr.Markdown("**Label**")
44
+ for i, mem_lookup in enumerate(prediction_result.memory_lookups):
45
+ with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
46
+ with gr.Column(scale=7):
47
+ gr.Markdown(mem_lookup.value, label="Value", height=50)
48
+ with gr.Column(scale=3, min_width=150):
49
+ dropdown = gr.Dropdown(
50
+ choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
51
+ label="Label",
52
+ value=f"{label_names[mem_lookup.label]} ({mem_lookup.label})",
53
+ interactive=True,
54
+ container=False,
55
+ )
56
+ changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
57
+ dropdown.change(
58
+ lambda val, mem_lookup=mem_lookup: update_label(val, mem_lookup),
59
+ inputs=[dropdown],
60
+ outputs=[changes_saved],
61
+ show_progress="full",
62
+ )
63
+
64
+ prediction_result_ui.launch()
@@ -0,0 +1,73 @@
1
+ import time
2
+
3
+ from tqdm.auto import tqdm
4
+
5
+ from .._generated_api_client.api import abort_task as _abort_task
6
+ from .._generated_api_client.api import get_task_status_task
7
+ from .._generated_api_client.api import list_tasks as _list_tasks
8
+ from .._generated_api_client.models import Task, TaskStatus, TaskStatusInfo
9
+
10
+ task_config = {
11
+ "retry_interval": 3,
12
+ "show_progress": True,
13
+ "max_wait": 60 * 60,
14
+ }
15
+
16
+
17
+ def set_task_config(
18
+ retry_interval: int | None = None,
19
+ show_progress: bool | None = None,
20
+ max_wait: int | None = None,
21
+ ) -> None:
22
+ if retry_interval is not None:
23
+ task_config["retry_interval"] = retry_interval
24
+ if show_progress is not None:
25
+ task_config["show_progress"] = show_progress
26
+ if max_wait is not None:
27
+ task_config["max_wait"] = max_wait
28
+
29
+
30
+ def wait_for_task(task_id: str, description: str | None = None, show_progress: bool = True) -> None:
31
+ start_time = time.time()
32
+ pbar = None
33
+ steps_total = None
34
+ show_progress = show_progress and task_config["show_progress"]
35
+ while True:
36
+ task_status = get_task_status_task(task_id)
37
+
38
+ # setup progress bar if steps total is known
39
+ if task_status.steps_total is not None and steps_total is None:
40
+ steps_total = task_status.steps_total
41
+ if not pbar and steps_total is not None and show_progress:
42
+ pbar = tqdm(total=steps_total, desc=description)
43
+
44
+ # return if task is complete
45
+ if task_status.status == TaskStatus.COMPLETED:
46
+ if pbar:
47
+ pbar.update(steps_total - pbar.n)
48
+ pbar.close()
49
+ return
50
+
51
+ # raise error if task failed
52
+ if task_status.status == TaskStatus.FAILED:
53
+ raise RuntimeError(f"Task failed with {task_status.exception}")
54
+
55
+ # raise error if task timed out
56
+ if (time.time() - start_time) > task_config["max_wait"]:
57
+ raise RuntimeError(f"Task {task_id} timed out after {task_config['max_wait']}s")
58
+
59
+ # update progress bar
60
+ if pbar and task_status.steps_completed is not None:
61
+ pbar.update(task_status.steps_completed - pbar.n)
62
+
63
+ # sleep before retrying
64
+ time.sleep(task_config["retry_interval"])
65
+
66
+
67
+ def abort_task(task_id: str) -> TaskStatusInfo:
68
+ _abort_task(task_id)
69
+ return get_task_status_task(task_id)
70
+
71
+
72
+ def list_tasks() -> list[Task]:
73
+ return _list_tasks()