orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +4 -0
  2. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  3. orca_sdk/_generated_api_client/models/__init__.py +4 -0
  4. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
  5. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  6. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  7. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  8. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  9. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  10. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
  11. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  12. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  13. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
  15. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  16. orca_sdk/_shared/__init__.py +1 -0
  17. orca_sdk/_shared/metrics.py +195 -0
  18. orca_sdk/_shared/metrics_test.py +169 -0
  19. orca_sdk/_utils/data_parsing.py +31 -2
  20. orca_sdk/_utils/data_parsing_test.py +18 -15
  21. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  22. orca_sdk/classification_model.py +170 -27
  23. orca_sdk/classification_model_test.py +74 -32
  24. orca_sdk/conftest.py +86 -25
  25. orca_sdk/datasource.py +22 -12
  26. orca_sdk/embedding_model_test.py +6 -5
  27. orca_sdk/memoryset.py +78 -0
  28. orca_sdk/memoryset_test.py +197 -123
  29. orca_sdk/telemetry.py +3 -0
  30. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
  31. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
  32. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
@@ -1,5 +1,6 @@
1
1
  from uuid import uuid4
2
2
 
3
+ import numpy as np
3
4
  import pytest
4
5
  from datasets.arrow_dataset import Dataset
5
6
 
@@ -9,45 +10,45 @@ from .embedding_model import PretrainedEmbeddingModel
9
10
  from .memoryset import LabeledMemoryset
10
11
 
11
12
 
12
- def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
13
+ def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
13
14
  assert model is not None
14
15
  assert model.name == "test_model"
15
- assert model.memoryset == memoryset
16
+ assert model.memoryset == readonly_memoryset
16
17
  assert model.num_classes == 2
17
18
  assert model.memory_lookup_count == 3
18
19
 
19
20
 
20
- def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
21
+ def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
21
22
  with pytest.raises(ValueError):
22
- ClassificationModel.create("test_model", memoryset)
23
+ ClassificationModel.create("test_model", readonly_memoryset)
23
24
  with pytest.raises(ValueError):
24
- ClassificationModel.create("test_model", memoryset, if_exists="error")
25
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
25
26
 
26
27
 
27
- def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
28
+ def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
28
29
  with pytest.raises(ValueError):
29
- ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
30
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
30
31
 
31
32
  with pytest.raises(ValueError):
32
- ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
33
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
33
34
 
34
35
  with pytest.raises(ValueError):
35
- ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
36
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
36
37
 
37
38
  with pytest.raises(ValueError):
38
- ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
39
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
39
40
 
40
- new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
41
+ new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
41
42
  assert new_model is not None
42
43
  assert new_model.name == "test_model"
43
- assert new_model.memoryset == memoryset
44
+ assert new_model.memoryset == readonly_memoryset
44
45
  assert new_model.num_classes == 2
45
46
  assert new_model.memory_lookup_count == 3
46
47
 
47
48
 
48
- def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
49
+ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
49
50
  with pytest.raises(ValueError, match="Invalid API key"):
50
- ClassificationModel.create("test_model", memoryset)
51
+ ClassificationModel.create("test_model", readonly_memoryset)
51
52
 
52
53
 
53
54
  def test_get_model(model: ClassificationModel):
@@ -106,8 +107,8 @@ def test_update_model_no_description(model: ClassificationModel):
106
107
  assert model.description is None
107
108
 
108
109
 
109
- def test_delete_model(memoryset: LabeledMemoryset):
110
- ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
110
+ def test_delete_model(readonly_memoryset: LabeledMemoryset):
111
+ ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
111
112
  assert ClassificationModel.open("model_to_delete")
112
113
  ClassificationModel.drop("model_to_delete")
113
114
  with pytest.raises(LookupError):
@@ -132,25 +133,23 @@ def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
132
133
 
133
134
 
134
135
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
135
- memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
136
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
136
137
  ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
137
138
  with pytest.raises(RuntimeError):
138
139
  LabeledMemoryset.drop(memoryset.id)
139
140
 
140
141
 
141
- def test_evaluate(model):
142
- eval_datasource = Datasource.from_list(
143
- "eval_datasource",
144
- [
145
- {"text": "chicken noodle soup is the best", "label": 1},
146
- {"text": "cats are cute", "label": 0},
147
- {"text": "soup is great for the winter", "label": 0},
148
- {"text": "i love cats", "label": 1},
149
- ],
150
- )
151
- result = model.evaluate(eval_datasource, value_column="text")
142
+ def test_evaluate(model, eval_datasource: Datasource):
143
+ result = model.evaluate(eval_datasource)
152
144
  assert result is not None
153
145
  assert isinstance(result, dict)
146
+ # And anomaly score statistics are present and valid
147
+ assert isinstance(result["anomaly_score_mean"], float)
148
+ assert isinstance(result["anomaly_score_median"], float)
149
+ assert isinstance(result["anomaly_score_variance"], float)
150
+ assert -1.0 <= result["anomaly_score_mean"] <= 1.0
151
+ assert -1.0 <= result["anomaly_score_median"] <= 1.0
152
+ assert -1.0 <= result["anomaly_score_variance"] <= 1.0
154
153
  assert isinstance(result["accuracy"], float)
155
154
  assert isinstance(result["f1_score"], float)
156
155
  assert isinstance(result["loss"], float)
@@ -162,6 +161,40 @@ def test_evaluate(model):
162
161
  assert len(result["roc_curve"]["true_positive_rates"]) == 4
163
162
 
164
163
 
164
+ def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
165
+ result_datasource = model.evaluate(eval_datasource)
166
+
167
+ result_dataset = model.evaluate(eval_dataset)
168
+
169
+ for result in [result_datasource, result_dataset]:
170
+ assert result is not None
171
+ assert isinstance(result, dict)
172
+ assert isinstance(result["accuracy"], float)
173
+ assert isinstance(result["f1_score"], float)
174
+ assert isinstance(result["loss"], float)
175
+ assert np.allclose(result["accuracy"], 0.5)
176
+ assert np.allclose(result["f1_score"], 0.5)
177
+
178
+ assert isinstance(result["precision_recall_curve"]["thresholds"], list)
179
+ assert isinstance(result["precision_recall_curve"]["precisions"], list)
180
+ assert isinstance(result["precision_recall_curve"]["recalls"], list)
181
+ assert isinstance(result["roc_curve"]["thresholds"], list)
182
+ assert isinstance(result["roc_curve"]["false_positive_rates"], list)
183
+ assert isinstance(result["roc_curve"]["true_positive_rates"], list)
184
+
185
+ assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
186
+ assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
187
+ assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
188
+ assert np.allclose(result["roc_curve"]["auc"], 0.625)
189
+
190
+ assert np.allclose(
191
+ result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
192
+ )
193
+ assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
194
+ assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
195
+ assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
196
+
197
+
165
198
  def test_evaluate_with_telemetry(model):
166
199
  samples = [
167
200
  {"text": "chicken noodle soup is the best", "label": 1},
@@ -188,9 +221,16 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
188
221
  assert predictions[1].label_name == label_names[1]
189
222
  assert 0 <= predictions[1].confidence <= 1
190
223
 
224
+ assert predictions[0].logits is not None
225
+ assert predictions[1].logits is not None
226
+ assert len(predictions[0].logits) == 2
227
+ assert len(predictions[1].logits) == 2
228
+ assert predictions[0].logits[0] > predictions[0].logits[1]
229
+ assert predictions[1].logits[0] < predictions[1].logits[1]
230
+
191
231
 
192
232
  def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
193
- predictions = model.predict(["Do you love soup?", "Are cats cute?"], disable_telemetry=True)
233
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
194
234
  assert len(predictions) == 2
195
235
  assert predictions[0].prediction_id is None
196
236
  assert predictions[1].prediction_id is None
@@ -212,9 +252,12 @@ def test_predict_unauthorized(unauthorized, model: ClassificationModel):
212
252
  model.predict(["Do you love soup?", "Are cats cute?"])
213
253
 
214
254
 
215
- def test_predict_constraint_violation(memoryset: LabeledMemoryset):
255
+ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
216
256
  model = ClassificationModel.create(
217
- "test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
257
+ "test_model_lookup_count_too_high",
258
+ readonly_memoryset,
259
+ num_classes=2,
260
+ memory_lookup_count=readonly_memoryset.length + 2,
218
261
  )
219
262
  with pytest.raises(RuntimeError):
220
263
  model.predict("test")
@@ -254,7 +297,6 @@ def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset:
254
297
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
255
298
  "test_memoryset_inverted_labels",
256
299
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
257
- value_column="text",
258
300
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
259
301
  )
260
302
  with model.use_memoryset(inverted_labeled_memoryset):
orca_sdk/conftest.py CHANGED
@@ -17,6 +17,8 @@ logging.basicConfig(level=logging.INFO)
17
17
 
18
18
  os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
19
19
 
20
+ os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
21
+
20
22
 
21
23
  def _create_org_id():
22
24
  # UUID start to identify test data (0xtest...)
@@ -69,22 +71,22 @@ def label_names():
69
71
 
70
72
 
71
73
  SAMPLE_DATA = [
72
- {"text": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
73
- {"text": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
74
- {"text": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
75
- {"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
76
- {"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
77
- {"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
78
- {"text": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
79
- {"text": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
80
- {"text": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
81
- {"text": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
82
- {"text": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
83
- {"text": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
84
- {"text": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
85
- {"text": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
86
- {"text": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
87
- {"text": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
74
+ {"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
75
+ {"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
76
+ {"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
77
+ {"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
78
+ {"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
79
+ {"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
80
+ {"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
81
+ {"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
82
+ {"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
83
+ {"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
84
+ {"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
85
+ {"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
86
+ {"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
87
+ {"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
88
+ {"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
89
+ {"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
88
90
  ]
89
91
 
90
92
 
@@ -94,7 +96,7 @@ def hf_dataset(label_names):
94
96
  SAMPLE_DATA,
95
97
  features=Features(
96
98
  {
97
- "text": Value("string"),
99
+ "value": Value("string"),
98
100
  "label": ClassLabel(names=label_names),
99
101
  "key": Value("string"),
100
102
  "score": Value("float"),
@@ -106,23 +108,82 @@ def hf_dataset(label_names):
106
108
 
107
109
  @pytest.fixture(scope="session")
108
110
  def datasource(hf_dataset) -> Datasource:
109
- return Datasource.from_hf_dataset("test_datasource", hf_dataset)
111
+ datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
112
+ return datasource
113
+
114
+
115
+ EVAL_DATASET = [
116
+ {"value": "chicken noodle soup is the best", "label": 1},
117
+ {"value": "cats are cute", "label": 0},
118
+ {"value": "soup is great for the winter", "label": 0},
119
+ {"value": "i love cats", "label": 1},
120
+ ]
110
121
 
111
122
 
112
123
  @pytest.fixture(scope="session")
113
- def memoryset(datasource) -> LabeledMemoryset:
114
- return LabeledMemoryset.create(
115
- "test_memoryset",
124
+ def eval_datasource() -> Datasource:
125
+ eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
126
+ return eval_datasource
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def eval_dataset() -> Dataset:
131
+ eval_dataset = Dataset.from_list(EVAL_DATASET)
132
+ return eval_dataset
133
+
134
+
135
+ @pytest.fixture(scope="session")
136
+ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
137
+ memoryset = LabeledMemoryset.create(
138
+ "test_readonly_memoryset",
116
139
  datasource=datasource,
117
140
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
118
- value_column="text",
119
141
  source_id_column="source_id",
120
142
  max_seq_length_override=32,
121
143
  )
144
+ return memoryset
145
+
146
+
147
+ @pytest.fixture(scope="function")
148
+ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[LabeledMemoryset, None, None]:
149
+ """
150
+ Function-scoped fixture that provides a writable memoryset for tests that mutate state.
151
+
152
+ This fixture creates a fresh `LabeledMemoryset` named 'test_writable_memoryset' before each test.
153
+ After the test, it attempts to restore the memoryset to its initial state by deleting any added entries
154
+ and reinserting sample data — unless the memoryset has been dropped by the test itself, in which case
155
+ it will be recreated on the next invocation.
156
+
157
+ Note: Re-creating the memoryset from scratch is surprisingly more expensive than cleaning it up.
158
+ """
159
+ # It shouldn't be possible for this memoryset to already exist
160
+ memoryset = LabeledMemoryset.create(
161
+ "test_writable_memoryset",
162
+ datasource=datasource,
163
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
164
+ source_id_column="source_id",
165
+ max_seq_length_override=32,
166
+ if_exists="open",
167
+ )
168
+ try:
169
+ yield memoryset
170
+ finally:
171
+ # Restore the memoryset to a clean state for the next test.
172
+ OrcaCredentials.set_api_key(api_key, check_validity=False)
173
+
174
+ if LabeledMemoryset.exists("test_writable_memoryset"):
175
+ memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
176
+
177
+ if memory_ids:
178
+ memoryset.delete(memory_ids)
179
+ assert len(memoryset) == 0
180
+ memoryset.insert(SAMPLE_DATA)
181
+ # If the test dropped the memoryset, do nothing — it will be recreated on the next use.
122
182
 
123
183
 
124
184
  @pytest.fixture(scope="session")
125
- def model(memoryset) -> ClassificationModel:
126
- return ClassificationModel.create(
127
- "test_model", memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
185
+ def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
186
+ model = ClassificationModel.create(
187
+ "test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
128
188
  )
189
+ return model
orca_sdk/datasource.py CHANGED
@@ -12,6 +12,7 @@ import pyarrow as pa
12
12
  from datasets import Dataset
13
13
  from torch.utils.data import DataLoader as TorchDataLoader
14
14
  from torch.utils.data import Dataset as TorchDataset
15
+ from tqdm.auto import tqdm
15
16
 
16
17
  from ._generated_api_client.api import (
17
18
  delete_datasource,
@@ -25,6 +26,7 @@ from ._generated_api_client.client import get_client
25
26
  from ._generated_api_client.models import ColumnType, DatasourceMetadata
26
27
  from ._utils.common import CreateMode, DropMode
27
28
  from ._utils.data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
29
+ from ._utils.tqdm_file_reader import TqdmFileReader
28
30
 
29
31
 
30
32
  class Datasource:
@@ -113,19 +115,27 @@ class Datasource:
113
115
  with tempfile.TemporaryDirectory() as tmp_dir:
114
116
  dataset.save_to_disk(tmp_dir)
115
117
  files = []
116
- for file_path in Path(tmp_dir).iterdir():
117
- buffered_reader = open(file_path, "rb")
118
- files.append(("files", buffered_reader))
119
-
120
- # Do not use Generated client for this endpoint b/c it does not handle files properly
121
- metadata = parse_create_response(
122
- response=client.get_httpx_client().request(
123
- method="post",
124
- url="/datasource/",
125
- files=files,
126
- data={"name": name, "description": description},
118
+
119
+ # Calculate total size for all files
120
+ file_paths = list(Path(tmp_dir).iterdir())
121
+ total_size = sum(file_path.stat().st_size for file_path in file_paths)
122
+
123
+ with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
124
+ for file_path in file_paths:
125
+ buffered_reader = open(file_path, "rb")
126
+ tqdm_reader = TqdmFileReader(buffered_reader, pbar)
127
+ files.append(("files", (file_path.name, tqdm_reader)))
128
+
129
+ # Do not use Generated client for this endpoint b/c it does not handle files properly
130
+ metadata = parse_create_response(
131
+ response=client.get_httpx_client().request(
132
+ method="post",
133
+ url="/datasource/",
134
+ files=files,
135
+ data={"name": name, "description": description},
136
+ )
127
137
  )
128
- )
138
+
129
139
  return cls(metadata=metadata)
130
140
 
131
141
  @classmethod
@@ -53,7 +53,7 @@ def test_embed_text_unauthenticated(unauthenticated):
53
53
 
54
54
  @pytest.fixture(scope="session")
55
55
  def finetuned_model(datasource) -> FinetunedEmbeddingModel:
56
- return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, value_column="text")
56
+ return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
57
57
 
58
58
 
59
59
  def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
@@ -65,8 +65,10 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
65
65
  assert finetuned_model._status == TaskStatus.COMPLETED
66
66
 
67
67
 
68
- def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
- finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
68
+ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
69
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
70
+ "test_finetuned_model_from_memoryset", readonly_memoryset
71
+ )
70
72
  assert finetuned_model is not None
71
73
  assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
74
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
@@ -109,7 +111,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
109
111
  "test_memoryset_finetuned_model",
110
112
  datasource,
111
113
  embedding_model=finetuned_model,
112
- value_column="text",
113
114
  )
114
115
  assert memoryset is not None
115
116
  assert memoryset.name == "test_memoryset_finetuned_model"
@@ -152,7 +153,7 @@ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: Finetu
152
153
 
153
154
 
154
155
  def test_drop_finetuned_model(datasource: Datasource):
155
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource, value_column="text")
156
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
156
157
  assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
157
158
  FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
158
159
  with pytest.raises(LookupError):
orca_sdk/memoryset.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, cast, overload
7
7
 
8
8
  import pandas as pd
9
9
  import pyarrow as pa
10
+ from attrs import fields
10
11
  from datasets import Dataset
11
12
  from torch.utils.data import DataLoader as TorchDataLoader
12
13
  from torch.utils.data import Dataset as TorchDataset
@@ -29,11 +30,14 @@ from ._generated_api_client.api import (
29
30
  memoryset_lookup_gpu,
30
31
  potential_duplicate_groups,
31
32
  query_memoryset,
33
+ suggest_cascading_edits,
32
34
  update_memories_gpu,
33
35
  update_memory_gpu,
34
36
  update_memoryset,
35
37
  )
36
38
  from ._generated_api_client.models import (
39
+ CascadeEditSuggestionsRequest,
40
+ CascadingEditSuggestion,
37
41
  CloneLabeledMemorysetRequest,
38
42
  CreateLabeledMemorysetRequest,
39
43
  DeleteMemoriesRequest,
@@ -1180,6 +1184,63 @@ class LabeledMemoryset:
1180
1184
  updated_memories = [LabeledMemory(self.id, memory) for memory in response]
1181
1185
  return updated_memories[0] if isinstance(updates, dict) else updated_memories
1182
1186
 
1187
+ def get_cascading_edits_suggestions(
1188
+ self: LabeledMemoryset,
1189
+ memory: LabeledMemory,
1190
+ *,
1191
+ old_label: int,
1192
+ new_label: int,
1193
+ max_neighbors: int = 50,
1194
+ max_validation_neighbors: int = 10,
1195
+ similarity_threshold: float | None = None,
1196
+ only_if_has_old_label: bool = True,
1197
+ exclude_if_new_label: bool = True,
1198
+ suggestion_cooldown_time: float = 3600.0 * 24.0, # 1 day
1199
+ label_confirmation_cooldown_time: float = 3600.0 * 24.0 * 7, # 1 week
1200
+ ) -> list[CascadingEditSuggestion]:
1201
+ """
1202
+ Suggests cascading edits for a given memory based on nearby points with similar labels.
1203
+
1204
+ This function is triggered after a user changes a memory's label. It looks for nearby
1205
+ candidates in embedding space that may be subject to similar relabeling and returns them
1206
+ as suggestions. The system uses scoring heuristics, label filters, and cooldown tracking
1207
+ to reduce noise and improve usability.
1208
+
1209
+ Params:
1210
+ memory: The memory whose label was just changed.
1211
+ old_label: The label this memory used to have.
1212
+ new_label: The label it was changed to.
1213
+ max_neighbors: Maximum number of neighbors to consider.
1214
+ max_validation_neighbors: Maximum number of neighbors to use for label suggestion.
1215
+ similarity_threshold: If set, only include neighbors with a lookup score above this threshold.
1216
+ only_if_has_old_label: If True, only consider neighbors that have the old label.
1217
+ exclude_if_new_label: If True, exclude neighbors that already have the new label.
1218
+ suggestion_cooldown_time: Minimum time (in seconds) since the last suggestion for a neighbor
1219
+ to be considered again.
1220
+ label_confirmation_cooldown_time: Minimum time (in seconds) since a neighbor's label was confirmed
1221
+ to be considered for suggestions.
1222
+ _current_time: Optional override for the current timestamp (useful for testing).
1223
+
1224
+ Returns:
1225
+ A list of CascadingEditSuggestion objects, each containing a neighbor and the suggested new label.
1226
+ """
1227
+
1228
+ return suggest_cascading_edits(
1229
+ name_or_id=self.id,
1230
+ memory_id=memory.memory_id,
1231
+ body=CascadeEditSuggestionsRequest(
1232
+ old_label=old_label,
1233
+ new_label=new_label,
1234
+ max_neighbors=max_neighbors,
1235
+ max_validation_neighbors=max_validation_neighbors,
1236
+ similarity_threshold=similarity_threshold,
1237
+ only_if_has_old_label=only_if_has_old_label,
1238
+ exclude_if_new_label=exclude_if_new_label,
1239
+ suggestion_cooldown_time=suggestion_cooldown_time,
1240
+ label_confirmation_cooldown_time=label_confirmation_cooldown_time,
1241
+ ),
1242
+ )
1243
+
1183
1244
  def delete(self, memory_id: str | Iterable[str]) -> None:
1184
1245
  """
1185
1246
  Delete memories from the memoryset
@@ -1229,6 +1290,9 @@ class LabeledMemoryset:
1229
1290
  Returns:
1230
1291
  dictionary with aggregate metrics for each analysis that was run
1231
1292
 
1293
+ Raises:
1294
+ ValueError: If an invalid analysis name is provided
1295
+
1232
1296
  Examples:
1233
1297
  Run label and duplicate analysis:
1234
1298
  >>> memoryset.analyze("label", {"name": "duplicate", "possible_duplicate_threshold": 0.99})
@@ -1263,12 +1327,26 @@ class LabeledMemoryset:
1263
1327
  Display label analysis to review potential mislabelings:
1264
1328
  >>> memoryset.display_label_analysis()
1265
1329
  """
1330
+
1331
+ # Get valid analysis names from MemorysetAnalysisConfigs
1332
+ valid_analysis_names = {
1333
+ field.name for field in fields(MemorysetAnalysisConfigs) if field.name != "additional_properties"
1334
+ }
1335
+
1266
1336
  configs: dict[str, dict] = {}
1267
1337
  for analysis in analyses:
1268
1338
  if isinstance(analysis, str):
1339
+ error_msg = (
1340
+ f"Invalid analysis name: {analysis}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
1341
+ )
1342
+ if analysis not in valid_analysis_names:
1343
+ raise ValueError(error_msg)
1269
1344
  configs[analysis] = {}
1270
1345
  else:
1271
1346
  name = analysis.pop("name") # type: ignore
1347
+ error_msg = f"Invalid analysis name: {name}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
1348
+ if name not in valid_analysis_names:
1349
+ raise ValueError(error_msg)
1272
1350
  configs[name] = analysis # type: ignore
1273
1351
 
1274
1352
  analysis = analyze_memoryset(