orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +8 -0
  2. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
  3. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  4. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  5. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  6. orca_sdk/_generated_api_client/models/__init__.py +10 -0
  7. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  8. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  9. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  10. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  11. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  12. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  13. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  15. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  16. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  17. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  18. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  19. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  20. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  21. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  22. orca_sdk/_utils/data_parsing.py +31 -2
  23. orca_sdk/_utils/data_parsing_test.py +18 -15
  24. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  25. orca_sdk/classification_model.py +32 -12
  26. orca_sdk/classification_model_test.py +95 -34
  27. orca_sdk/conftest.py +87 -25
  28. orca_sdk/datasource.py +56 -12
  29. orca_sdk/datasource_test.py +9 -0
  30. orca_sdk/embedding_model_test.py +6 -5
  31. orca_sdk/memoryset.py +78 -0
  32. orca_sdk/memoryset_test.py +199 -123
  33. orca_sdk/telemetry.py +5 -3
  34. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
  35. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
  36. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import pickle
3
4
  import tempfile
4
5
  from collections import namedtuple
@@ -14,6 +15,8 @@ from torch.utils.data import Dataset as TorchDataset
14
15
  from ..conftest import SAMPLE_DATA
15
16
  from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
16
17
 
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
19
+
17
20
 
18
21
  class PytorchDictDataset(TorchDataset):
19
22
  def __init__(self):
@@ -29,11 +32,11 @@ class PytorchDictDataset(TorchDataset):
29
32
  def test_hf_dataset_from_torch_dict():
30
33
  # Given a Pytorch dataset that returns a dictionary for each item
31
34
  dataset = PytorchDictDataset()
32
- hf_dataset = hf_dataset_from_torch(dataset)
35
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
33
36
  # Then the HF dataset should be created successfully
34
37
  assert isinstance(hf_dataset, Dataset)
35
38
  assert len(hf_dataset) == len(dataset)
36
- assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
39
+ assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
37
40
 
38
41
 
39
42
  class PytorchTupleDataset(TorchDataset):
@@ -41,7 +44,7 @@ class PytorchTupleDataset(TorchDataset):
41
44
  self.data = SAMPLE_DATA
42
45
 
43
46
  def __getitem__(self, i):
44
- return self.data[i]["text"], self.data[i]["label"]
47
+ return self.data[i]["value"], self.data[i]["label"]
45
48
 
46
49
  def __len__(self):
47
50
  return len(self.data)
@@ -51,11 +54,11 @@ def test_hf_dataset_from_torch_tuple():
51
54
  # Given a Pytorch dataset that returns a tuple for each item
52
55
  dataset = PytorchTupleDataset()
53
56
  # And the correct number of column names passed in
54
- hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
57
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
55
58
  # Then the HF dataset should be created successfully
56
59
  assert isinstance(hf_dataset, Dataset)
57
60
  assert len(hf_dataset) == len(dataset)
58
- assert hf_dataset.column_names == ["text", "label"]
61
+ assert hf_dataset.column_names == ["value", "label"]
59
62
 
60
63
 
61
64
  def test_hf_dataset_from_torch_tuple_error():
@@ -63,7 +66,7 @@ def test_hf_dataset_from_torch_tuple_error():
63
66
  dataset = PytorchTupleDataset()
64
67
  # Then the HF dataset should raise an error if no column names are passed in
65
68
  with pytest.raises(DatasetGenerationError):
66
- hf_dataset_from_torch(dataset)
69
+ hf_dataset_from_torch(dataset, ignore_cache=True)
67
70
 
68
71
 
69
72
  def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
@@ -71,7 +74,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
71
74
  dataset = PytorchTupleDataset()
72
75
  # Then the HF dataset should raise an error if not enough column names are passed in
73
76
  with pytest.raises(DatasetGenerationError):
74
- hf_dataset_from_torch(dataset, column_names=["value"])
77
+ hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
75
78
 
76
79
 
77
80
  DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
@@ -82,7 +85,7 @@ class PytorchNamedTupleDataset(TorchDataset):
82
85
  self.data = SAMPLE_DATA
83
86
 
84
87
  def __getitem__(self, i):
85
- return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
88
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
86
89
 
87
90
  def __len__(self):
88
91
  return len(self.data)
@@ -92,7 +95,7 @@ def test_hf_dataset_from_torch_named_tuple():
92
95
  # Given a Pytorch dataset that returns a namedtuple for each item
93
96
  dataset = PytorchNamedTupleDataset()
94
97
  # And no column names are passed in
95
- hf_dataset = hf_dataset_from_torch(dataset)
98
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
96
99
  # Then the HF dataset should be created successfully
97
100
  assert isinstance(hf_dataset, Dataset)
98
101
  assert len(hf_dataset) == len(dataset)
@@ -110,7 +113,7 @@ class PytorchDataclassDataset(TorchDataset):
110
113
  self.data = SAMPLE_DATA
111
114
 
112
115
  def __getitem__(self, i):
113
- return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
116
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
114
117
 
115
118
  def __len__(self):
116
119
  return len(self.data)
@@ -119,7 +122,7 @@ class PytorchDataclassDataset(TorchDataset):
119
122
  def test_hf_dataset_from_torch_dataclass():
120
123
  # Given a Pytorch dataset that returns a dataclass for each item
121
124
  dataset = PytorchDataclassDataset()
122
- hf_dataset = hf_dataset_from_torch(dataset)
125
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
123
126
  # Then the HF dataset should be created successfully
124
127
  assert isinstance(hf_dataset, Dataset)
125
128
  assert len(hf_dataset) == len(dataset)
@@ -131,7 +134,7 @@ class PytorchInvalidDataset(TorchDataset):
131
134
  self.data = SAMPLE_DATA
132
135
 
133
136
  def __getitem__(self, i):
134
- return [self.data[i]["text"], self.data[i]["label"]]
137
+ return [self.data[i]["value"], self.data[i]["label"]]
135
138
 
136
139
  def __len__(self):
137
140
  return len(self.data)
@@ -142,7 +145,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
142
145
  dataset = PytorchInvalidDataset()
143
146
  # Then the HF dataset should raise an error
144
147
  with pytest.raises(DatasetGenerationError):
145
- hf_dataset_from_torch(dataset)
148
+ hf_dataset_from_torch(dataset, ignore_cache=True)
146
149
 
147
150
 
148
151
  def test_hf_dataset_from_torchdataloader():
@@ -150,10 +153,10 @@ def test_hf_dataset_from_torchdataloader():
150
153
  dataset = PytorchDictDataset()
151
154
 
152
155
  def collate_fn(x: list[dict]):
153
- return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
156
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
154
157
 
155
158
  dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
- hf_dataset = hf_dataset_from_torch(dataloader)
159
+ hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
157
160
  # Then the HF dataset should be created successfully
158
161
  assert isinstance(hf_dataset, Dataset)
159
162
  assert len(hf_dataset) == len(dataset)
@@ -0,0 +1,12 @@
1
+ class TqdmFileReader:
2
+ def __init__(self, file_obj, pbar):
3
+ self.file_obj = file_obj
4
+ self.pbar = pbar
5
+
6
+ def read(self, size=-1):
7
+ data = self.file_obj.read(size)
8
+ self.pbar.update(len(data))
9
+ return data
10
+
11
+ def __getattr__(self, attr):
12
+ return getattr(self.file_obj, attr)
@@ -1,10 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import os
4
5
  from contextlib import contextmanager
5
6
  from datetime import datetime
6
7
  from typing import Any, Generator, Iterable, Literal, cast, overload
7
- from uuid import UUID
8
+ from uuid import UUID, uuid4
9
+
10
+ import numpy as np
8
11
 
9
12
  import numpy as np
10
13
  from datasets import Dataset
@@ -312,7 +315,8 @@ class ClassificationModel:
312
315
  value: list[str],
313
316
  expected_labels: list[int] | None = None,
314
317
  tags: set[str] = set(),
315
- disable_telemetry: bool = False,
318
+ save_telemetry: bool = True,
319
+ save_telemetry_synchronously: bool = False,
316
320
  ) -> list[LabelPrediction]:
317
321
  pass
318
322
 
@@ -322,7 +326,8 @@ class ClassificationModel:
322
326
  value: str,
323
327
  expected_labels: int | None = None,
324
328
  tags: set[str] = set(),
325
- disable_telemetry: bool = False,
329
+ save_telemetry: bool = True,
330
+ save_telemetry_synchronously: bool = False,
326
331
  ) -> LabelPrediction:
327
332
  pass
328
333
 
@@ -331,7 +336,8 @@ class ClassificationModel:
331
336
  value: list[str] | str,
332
337
  expected_labels: list[int] | int | None = None,
333
338
  tags: set[str] = set(),
334
- disable_telemetry: bool = False,
339
+ save_telemetry: bool = True,
340
+ save_telemetry_synchronously: bool = False,
335
341
  ) -> list[LabelPrediction] | LabelPrediction:
336
342
  """
337
343
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -340,7 +346,10 @@ class ClassificationModel:
340
346
  value: Value(s) to get predict the labels of
341
347
  expected_labels: Expected label(s) for the given input to record for model evaluation
342
348
  tags: Tags to add to the prediction(s)
343
- disable_telemetry: Whether to disable telemetry for the prediction(s)
349
+ save_telemetry: Whether to enable telemetry for the prediction(s)
350
+ save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
351
+ asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
352
+ may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
344
353
 
345
354
  Returns:
346
355
  Label prediction or list of label predictions
@@ -358,6 +367,13 @@ class ClassificationModel:
358
367
  ]
359
368
  """
360
369
 
370
+ if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
371
+ env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
372
+ logging.info(
373
+ f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
374
+ )
375
+ save_telemetry_synchronously = env_var.lower() == "true"
376
+
361
377
  response = predict_gpu(
362
378
  self.id,
363
379
  body=PredictionRequest(
@@ -366,14 +382,17 @@ class ClassificationModel:
366
382
  expected_labels=(
367
383
  expected_labels
368
384
  if isinstance(expected_labels, list)
369
- else [expected_labels] if expected_labels is not None else None
385
+ else [expected_labels]
386
+ if expected_labels is not None
387
+ else None
370
388
  ),
371
389
  tags=list(tags),
372
- disable_telemetry=disable_telemetry,
390
+ save_telemetry=save_telemetry,
391
+ save_telemetry_synchronously=save_telemetry_synchronously,
373
392
  ),
374
393
  )
375
394
 
376
- if not disable_telemetry and any(p.prediction_id is None for p in response):
395
+ if save_telemetry and any(p.prediction_id is None for p in response):
377
396
  raise RuntimeError("Failed to save prediction to database.")
378
397
 
379
398
  predictions = [
@@ -386,8 +405,9 @@ class ClassificationModel:
386
405
  memoryset=self.memoryset,
387
406
  model=self,
388
407
  logits=prediction.logits,
408
+ input_value=input_value,
389
409
  )
390
- for prediction in response
410
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
391
411
  ]
392
412
  self._last_prediction_was_batch = isinstance(value, list)
393
413
  self._last_prediction = predictions[-1]
@@ -463,7 +483,6 @@ class ClassificationModel:
463
483
  predictions: list[LabelPrediction],
464
484
  expected_labels: list[int],
465
485
  ) -> ClassificationEvaluationResult:
466
-
467
486
  targets_array = np.array(expected_labels)
468
487
  predictions_array = np.array([p.label for p in predictions])
469
488
 
@@ -553,7 +572,8 @@ class ClassificationModel:
553
572
  batch[value_column],
554
573
  expected_labels=batch[label_column],
555
574
  tags=tags,
556
- disable_telemetry=(not record_predictions),
575
+ save_telemetry=record_predictions,
576
+ save_telemetry_synchronously=(not record_predictions),
557
577
  )
558
578
  )
559
579
  expected_labels.extend(batch[label_column])
@@ -581,7 +601,7 @@ class ClassificationModel:
581
601
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
582
602
 
583
603
  Returns:
584
- Dictionary with evaluation metrics
604
+ Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
585
605
 
586
606
  Examples:
587
607
  Evaluate using a Datasource:
@@ -1,3 +1,5 @@
1
+ import logging
2
+ import os
1
3
  from uuid import uuid4
2
4
 
3
5
  import numpy as np
@@ -9,46 +11,51 @@ from .datasource import Datasource
9
11
  from .embedding_model import PretrainedEmbeddingModel
10
12
  from .memoryset import LabeledMemoryset
11
13
 
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
12
15
 
13
- def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
16
+
17
+ SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
18
+
19
+
20
+ def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
14
21
  assert model is not None
15
22
  assert model.name == "test_model"
16
- assert model.memoryset == memoryset
23
+ assert model.memoryset == readonly_memoryset
17
24
  assert model.num_classes == 2
18
25
  assert model.memory_lookup_count == 3
19
26
 
20
27
 
21
- def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
28
+ def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
22
29
  with pytest.raises(ValueError):
23
- ClassificationModel.create("test_model", memoryset)
30
+ ClassificationModel.create("test_model", readonly_memoryset)
24
31
  with pytest.raises(ValueError):
25
- ClassificationModel.create("test_model", memoryset, if_exists="error")
32
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
26
33
 
27
34
 
28
- def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
35
+ def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
29
36
  with pytest.raises(ValueError):
30
- ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
37
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
31
38
 
32
39
  with pytest.raises(ValueError):
33
- ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
40
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
34
41
 
35
42
  with pytest.raises(ValueError):
36
- ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
43
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
37
44
 
38
45
  with pytest.raises(ValueError):
39
- ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
46
+ ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
40
47
 
41
- new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
48
+ new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
42
49
  assert new_model is not None
43
50
  assert new_model.name == "test_model"
44
- assert new_model.memoryset == memoryset
51
+ assert new_model.memoryset == readonly_memoryset
45
52
  assert new_model.num_classes == 2
46
53
  assert new_model.memory_lookup_count == 3
47
54
 
48
55
 
49
- def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
56
+ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
50
57
  with pytest.raises(ValueError, match="Invalid API key"):
51
- ClassificationModel.create("test_model", memoryset)
58
+ ClassificationModel.create("test_model", readonly_memoryset)
52
59
 
53
60
 
54
61
  def test_get_model(model: ClassificationModel):
@@ -107,8 +114,8 @@ def test_update_model_no_description(model: ClassificationModel):
107
114
  assert model.description is None
108
115
 
109
116
 
110
- def test_delete_model(memoryset: LabeledMemoryset):
111
- ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
117
+ def test_delete_model(readonly_memoryset: LabeledMemoryset):
118
+ ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
112
119
  assert ClassificationModel.open("model_to_delete")
113
120
  ClassificationModel.drop("model_to_delete")
114
121
  with pytest.raises(LookupError):
@@ -133,25 +140,38 @@ def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
133
140
 
134
141
 
135
142
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
136
- memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
143
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
137
144
  ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
138
145
  with pytest.raises(RuntimeError):
139
146
  LabeledMemoryset.drop(memoryset.id)
140
147
 
141
148
 
142
- def test_evaluate_combined(model):
143
- data = [
144
- {"text": "chicken noodle soup is the best", "label": 1},
145
- {"text": "cats are cute", "label": 0},
146
- {"text": "soup is great for the winter", "label": 0},
147
- {"text": "i love cats", "label": 1},
148
- ]
149
-
150
- eval_datasource = Datasource.from_list("eval_datasource", data)
151
- result_datasource = model.evaluate(eval_datasource, value_column="text")
152
-
153
- eval_dataset = Dataset.from_list(data)
154
- result_dataset = model.evaluate(eval_dataset, value_column="text")
149
+ def test_evaluate(model, eval_datasource: Datasource):
150
+ result = model.evaluate(eval_datasource)
151
+ assert result is not None
152
+ assert isinstance(result, dict)
153
+ # And anomaly score statistics are present and valid
154
+ assert isinstance(result["anomaly_score_mean"], float)
155
+ assert isinstance(result["anomaly_score_median"], float)
156
+ assert isinstance(result["anomaly_score_variance"], float)
157
+ assert -1.0 <= result["anomaly_score_mean"] <= 1.0
158
+ assert -1.0 <= result["anomaly_score_median"] <= 1.0
159
+ assert -1.0 <= result["anomaly_score_variance"] <= 1.0
160
+ assert isinstance(result["accuracy"], float)
161
+ assert isinstance(result["f1_score"], float)
162
+ assert isinstance(result["loss"], float)
163
+ assert len(result["precision_recall_curve"]["thresholds"]) == 4
164
+ assert len(result["precision_recall_curve"]["precisions"]) == 4
165
+ assert len(result["precision_recall_curve"]["recalls"]) == 4
166
+ assert len(result["roc_curve"]["thresholds"]) == 4
167
+ assert len(result["roc_curve"]["false_positive_rates"]) == 4
168
+ assert len(result["roc_curve"]["true_positive_rates"]) == 4
169
+
170
+
171
+ def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
172
+ result_datasource = model.evaluate(eval_datasource)
173
+
174
+ result_dataset = model.evaluate(eval_dataset)
155
175
 
156
176
  for result in [result_datasource, result_dataset]:
157
177
  assert result is not None
@@ -217,7 +237,7 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
217
237
 
218
238
 
219
239
  def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
220
- predictions = model.predict(["Do you love soup?", "Are cats cute?"], disable_telemetry=True)
240
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
221
241
  assert len(predictions) == 2
222
242
  assert predictions[0].prediction_id is None
223
243
  assert predictions[1].prediction_id is None
@@ -239,9 +259,12 @@ def test_predict_unauthorized(unauthorized, model: ClassificationModel):
239
259
  model.predict(["Do you love soup?", "Are cats cute?"])
240
260
 
241
261
 
242
- def test_predict_constraint_violation(memoryset: LabeledMemoryset):
262
+ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
243
263
  model = ClassificationModel.create(
244
- "test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
264
+ "test_model_lookup_count_too_high",
265
+ readonly_memoryset,
266
+ num_classes=2,
267
+ memory_lookup_count=readonly_memoryset.length + 2,
245
268
  )
246
269
  with pytest.raises(RuntimeError):
247
270
  model.predict("test")
@@ -281,7 +304,6 @@ def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset:
281
304
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
282
305
  "test_memoryset_inverted_labels",
283
306
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
284
- value_column="text",
285
307
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
286
308
  )
287
309
  with model.use_memoryset(inverted_labeled_memoryset):
@@ -323,3 +345,42 @@ def test_last_prediction_with_single(model: ClassificationModel):
323
345
  assert model.last_prediction.prediction_id == prediction.prediction_id
324
346
  assert model.last_prediction.input_value == "Do you love soup?"
325
347
  assert model._last_prediction_was_batch is False
348
+
349
+
350
+ @pytest.mark.skipif(
351
+ SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
352
+ )
353
+ def test_explain(writable_memoryset: LabeledMemoryset):
354
+
355
+ writable_memoryset.analyze(
356
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
357
+ lookup_count=3,
358
+ )
359
+
360
+ model = ClassificationModel.create(
361
+ "test_model_for_explain",
362
+ writable_memoryset,
363
+ num_classes=2,
364
+ memory_lookup_count=3,
365
+ description="This is a test model for explain",
366
+ )
367
+
368
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
369
+ assert len(predictions) == 2
370
+
371
+ try:
372
+ explanation = predictions[0].explanation
373
+ print(explanation)
374
+ assert explanation is not None
375
+ assert len(explanation) > 10
376
+ assert "soup" in explanation.lower()
377
+ except Exception as e:
378
+ if "ANTHROPIC_API_KEY" in str(e):
379
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set on server")
380
+ else:
381
+ raise e
382
+ finally:
383
+ try:
384
+ ClassificationModel.drop("test_model_for_explain")
385
+ except Exception as e:
386
+ logging.info(f"Failed to drop test model for explain: {e}")
orca_sdk/conftest.py CHANGED
@@ -17,6 +17,8 @@ logging.basicConfig(level=logging.INFO)
17
17
 
18
18
  os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
19
19
 
20
+ os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
21
+
20
22
 
21
23
  def _create_org_id():
22
24
  # UUID start to identify test data (0xtest...)
@@ -69,22 +71,22 @@ def label_names():
69
71
 
70
72
 
71
73
  SAMPLE_DATA = [
72
- {"text": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
73
- {"text": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
74
- {"text": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
75
- {"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
76
- {"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
77
- {"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
78
- {"text": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
79
- {"text": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
80
- {"text": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
81
- {"text": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
82
- {"text": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
83
- {"text": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
84
- {"text": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
85
- {"text": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
86
- {"text": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
87
- {"text": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
74
+ {"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
75
+ {"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
76
+ {"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
77
+ {"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
78
+ {"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
79
+ {"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
80
+ {"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
81
+ {"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
82
+ {"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
83
+ {"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
84
+ {"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
85
+ {"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
86
+ {"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
87
+ {"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
88
+ {"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
89
+ {"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
88
90
  ]
89
91
 
90
92
 
@@ -94,7 +96,7 @@ def hf_dataset(label_names):
94
96
  SAMPLE_DATA,
95
97
  features=Features(
96
98
  {
97
- "text": Value("string"),
99
+ "value": Value("string"),
98
100
  "label": ClassLabel(names=label_names),
99
101
  "key": Value("string"),
100
102
  "score": Value("float"),
@@ -106,23 +108,83 @@ def hf_dataset(label_names):
106
108
 
107
109
  @pytest.fixture(scope="session")
108
110
  def datasource(hf_dataset) -> Datasource:
109
- return Datasource.from_hf_dataset("test_datasource", hf_dataset)
111
+ datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
112
+ return datasource
113
+
114
+
115
+ EVAL_DATASET = [
116
+ {"value": "chicken noodle soup is the best", "label": 1},
117
+ {"value": "cats are cute", "label": 0},
118
+ {"value": "soup is great for the winter", "label": 0},
119
+ {"value": "i love cats", "label": 1},
120
+ ]
110
121
 
111
122
 
112
123
  @pytest.fixture(scope="session")
113
- def memoryset(datasource) -> LabeledMemoryset:
114
- return LabeledMemoryset.create(
115
- "test_memoryset",
124
+ def eval_datasource() -> Datasource:
125
+ eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
126
+ return eval_datasource
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def eval_dataset() -> Dataset:
131
+ eval_dataset = Dataset.from_list(EVAL_DATASET)
132
+ return eval_dataset
133
+
134
+
135
+ @pytest.fixture(scope="session")
136
+ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
137
+ memoryset = LabeledMemoryset.create(
138
+ "test_readonly_memoryset",
116
139
  datasource=datasource,
117
140
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
118
- value_column="text",
119
141
  source_id_column="source_id",
120
142
  max_seq_length_override=32,
121
143
  )
144
+ return memoryset
145
+
146
+
147
+ @pytest.fixture(scope="function")
148
+ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[LabeledMemoryset, None, None]:
149
+ """
150
+ Function-scoped fixture that provides a writable memoryset for tests that mutate state.
151
+
152
+ This fixture creates a fresh `LabeledMemoryset` named 'test_writable_memoryset' before each test.
153
+ After the test, it attempts to restore the memoryset to its initial state by deleting any added entries
154
+ and reinserting sample data — unless the memoryset has been dropped by the test itself, in which case
155
+ it will be recreated on the next invocation.
156
+
157
+ Note: Re-creating the memoryset from scratch is surprisingly more expensive than cleaning it up.
158
+ """
159
+ # It shouldn't be possible for this memoryset to already exist
160
+ memoryset = LabeledMemoryset.create(
161
+ "test_writable_memoryset",
162
+ datasource=datasource,
163
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
164
+ source_id_column="source_id",
165
+ max_seq_length_override=32,
166
+ if_exists="open",
167
+ )
168
+ try:
169
+ yield memoryset
170
+ finally:
171
+ # Restore the memoryset to a clean state for the next test.
172
+ OrcaCredentials.set_api_key(api_key, check_validity=False)
173
+
174
+ if LabeledMemoryset.exists("test_writable_memoryset"):
175
+ memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
176
+
177
+ if memory_ids:
178
+ memoryset.delete(memory_ids)
179
+ memoryset.refresh()
180
+ assert len(memoryset) == 0
181
+ memoryset.insert(SAMPLE_DATA)
182
+ # If the test dropped the memoryset, do nothing — it will be recreated on the next use.
122
183
 
123
184
 
124
185
  @pytest.fixture(scope="session")
125
- def model(memoryset) -> ClassificationModel:
126
- return ClassificationModel.create(
127
- "test_model", memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
186
+ def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
187
+ model = ClassificationModel.create(
188
+ "test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
128
189
  )
190
+ return model