orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,1165 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
7
+
8
+ from datasets import Dataset
9
+
10
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .async_client import OrcaAsyncClient
13
+ from .client import (
14
+ BootstrapClassificationModelMeta,
15
+ BootstrapLabeledMemoryDataResult,
16
+ ClassificationModelMetadata,
17
+ ClassificationPredictionRequest,
18
+ ListPredictionsRequest,
19
+ OrcaClient,
20
+ PredictiveModelUpdate,
21
+ RACHeadType,
22
+ )
23
+ from .datasource import Datasource
24
+ from .job import Job
25
+ from .memoryset import (
26
+ FilterItem,
27
+ FilterItemTuple,
28
+ LabeledMemoryset,
29
+ _is_metric_column,
30
+ _parse_filter_item_from_tuple,
31
+ )
32
+ from .telemetry import (
33
+ ClassificationPrediction,
34
+ TelemetryMode,
35
+ _get_telemetry_config,
36
+ _parse_feedback,
37
+ )
38
+
39
+
40
+ class BootstrappedClassificationModel:
41
+
42
+ datasource: Datasource | None
43
+ memoryset: LabeledMemoryset | None
44
+ classification_model: ClassificationModel | None
45
+ agent_output: BootstrapLabeledMemoryDataResult | None
46
+
47
+ def __init__(self, metadata: BootstrapClassificationModelMeta):
48
+ self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
49
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
50
+ self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
51
+ self.agent_output = metadata["agent_output"]
52
+
53
+ def __repr__(self):
54
+ return (
55
+ "BootstrappedClassificationModel({\n"
56
+ f" datasource: {self.datasource},\n"
57
+ f" memoryset: {self.memoryset},\n"
58
+ f" classification_model: {self.classification_model},\n"
59
+ f" agent_output: {self.agent_output},\n"
60
+ "})"
61
+ )
62
+
63
+
64
+ class ClassificationModel:
65
+ """
66
+ A handle to a classification model in OrcaCloud
67
+
68
+ Attributes:
69
+ id: Unique identifier for the model
70
+ name: Unique name of the model
71
+ description: Optional description of the model
72
+ memoryset: Memoryset that the model uses
73
+ head_type: Classification head type of the model
74
+ num_classes: Number of distinct classes the model can predict
75
+ memory_lookup_count: Number of memories the model uses for each prediction
76
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
77
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
78
+ locked: Whether the model is locked to prevent accidental deletion
79
+ created_at: When the model was created
80
+ """
81
+
82
+ id: str
83
+ name: str
84
+ description: str | None
85
+ memoryset: LabeledMemoryset
86
+ head_type: RACHeadType
87
+ num_classes: int
88
+ memory_lookup_count: int
89
+ weigh_memories: bool | None
90
+ min_memory_weight: float | None
91
+ version: int
92
+ locked: bool
93
+ created_at: datetime
94
+
95
+ def __init__(self, metadata: ClassificationModelMetadata):
96
+ # for internal use only, do not document
97
+ self.id = metadata["id"]
98
+ self.name = metadata["name"]
99
+ self.description = metadata["description"]
100
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
101
+ self.head_type = metadata["head_type"]
102
+ self.num_classes = metadata["num_classes"]
103
+ self.memory_lookup_count = metadata["memory_lookup_count"]
104
+ self.weigh_memories = metadata["weigh_memories"]
105
+ self.min_memory_weight = metadata["min_memory_weight"]
106
+ self.version = metadata["version"]
107
+ self.locked = metadata["locked"]
108
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
109
+
110
+ self._memoryset_override_id: str | None = None
111
+ self._last_prediction: ClassificationPrediction | None = None
112
+ self._last_prediction_was_batch: bool = False
113
+
114
+ def __eq__(self, other) -> bool:
115
+ return isinstance(other, ClassificationModel) and self.id == other.id
116
+
117
+ def __repr__(self):
118
+ memoryset_repr = self.memoryset.__repr__().replace("\n", "\n ")
119
+ return (
120
+ "ClassificationModel({\n"
121
+ f" name: '{self.name}',\n"
122
+ f" head_type: {self.head_type},\n"
123
+ f" num_classes: {self.num_classes},\n"
124
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
125
+ f" memoryset: {memoryset_repr},\n"
126
+ "})"
127
+ )
128
+
129
+ @property
130
+ def last_prediction(self) -> ClassificationPrediction:
131
+ """
132
+ Last prediction made by the model
133
+
134
+ Note:
135
+ If the last prediction was part of a batch prediction, the last prediction from the
136
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
137
+ is raised.
138
+ """
139
+ if self._last_prediction_was_batch:
140
+ logging.warning(
141
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
142
+ )
143
+ if self._last_prediction is None:
144
+ raise LookupError("No prediction has been made yet")
145
+ return self._last_prediction
146
+
147
+ @classmethod
148
+ def create(
149
+ cls,
150
+ name: str,
151
+ memoryset: LabeledMemoryset,
152
+ head_type: RACHeadType = "KNN",
153
+ *,
154
+ description: str | None = None,
155
+ num_classes: int | None = None,
156
+ memory_lookup_count: int | None = None,
157
+ weigh_memories: bool = True,
158
+ min_memory_weight: float | None = None,
159
+ if_exists: CreateMode = "error",
160
+ ) -> ClassificationModel:
161
+ """
162
+ Create a new classification model
163
+
164
+ Params:
165
+ name: Name for the new model (must be unique)
166
+ memoryset: Memoryset to attach the model to
167
+ head_type: Type of model head to use
168
+ num_classes: Number of classes this model can predict, will be inferred from memoryset if not specified
169
+ memory_lookup_count: Number of memories to lookup for each prediction,
170
+ by default the system uses a simple heuristic to choose a number of memories that works well in most cases
171
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
172
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
173
+ if_exists: What to do if a model with the same name already exists, defaults to
174
+ `"error"`. Other option is `"open"` to open the existing model.
175
+ description: Optional description for the model, this will be used in agentic flows,
176
+ so make sure it is concise and describes the purpose of your model.
177
+
178
+ Returns:
179
+ Handle to the new model in the OrcaCloud
180
+
181
+ Raises:
182
+ ValueError: If the model already exists and if_exists is `"error"` or if it is
183
+ `"open"` and the existing model has different attributes.
184
+
185
+ Examples:
186
+ Create a new model using default options:
187
+ >>> model = ClassificationModel.create(
188
+ ... "my_model",
189
+ ... LabeledMemoryset.open("my_memoryset"),
190
+ ... )
191
+
192
+ Create a new model with non-default model head and options:
193
+ >>> model = ClassificationModel.create(
194
+ ... name="my_model",
195
+ ... memoryset=LabeledMemoryset.open("my_memoryset"),
196
+ ... head_type=RACHeadType.MMOE,
197
+ ... num_classes=5,
198
+ ... memory_lookup_count=20,
199
+ ... )
200
+ """
201
+ if cls.exists(name):
202
+ if if_exists == "error":
203
+ raise ValueError(f"Model with name {name} already exists")
204
+ elif if_exists == "open":
205
+ existing = cls.open(name)
206
+ for attribute in {
207
+ "head_type",
208
+ "memory_lookup_count",
209
+ "num_classes",
210
+ "min_memory_weight",
211
+ }:
212
+ local_attribute = locals()[attribute]
213
+ existing_attribute = getattr(existing, attribute)
214
+ if local_attribute is not None and local_attribute != existing_attribute:
215
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
216
+
217
+ # special case for memoryset
218
+ if existing.memoryset.id != memoryset.id:
219
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
220
+
221
+ return existing
222
+
223
+ client = OrcaClient._resolve_client()
224
+ metadata = client.POST(
225
+ "/classification_model",
226
+ json={
227
+ "name": name,
228
+ "memoryset_name_or_id": memoryset.id,
229
+ "head_type": head_type,
230
+ "memory_lookup_count": memory_lookup_count,
231
+ "num_classes": num_classes,
232
+ "weigh_memories": weigh_memories,
233
+ "min_memory_weight": min_memory_weight,
234
+ "description": description,
235
+ },
236
+ )
237
+ return cls(metadata)
238
+
239
+ @classmethod
240
+ def open(cls, name: str) -> ClassificationModel:
241
+ """
242
+ Get a handle to a classification model in the OrcaCloud
243
+
244
+ Params:
245
+ name: Name or unique identifier of the classification model
246
+
247
+ Returns:
248
+ Handle to the existing classification model in the OrcaCloud
249
+
250
+ Raises:
251
+ LookupError: If the classification model does not exist
252
+ """
253
+ client = OrcaClient._resolve_client()
254
+ return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
255
+
256
+ @classmethod
257
+ def exists(cls, name_or_id: str) -> bool:
258
+ """
259
+ Check if a classification model exists in the OrcaCloud
260
+
261
+ Params:
262
+ name_or_id: Name or id of the classification model
263
+
264
+ Returns:
265
+ `True` if the classification model exists, `False` otherwise
266
+ """
267
+ try:
268
+ cls.open(name_or_id)
269
+ return True
270
+ except LookupError:
271
+ return False
272
+
273
+ @classmethod
274
+ def all(cls) -> list[ClassificationModel]:
275
+ """
276
+ Get a list of handles to all classification models in the OrcaCloud
277
+
278
+ Returns:
279
+ List of handles to all classification models in the OrcaCloud
280
+ """
281
+ client = OrcaClient._resolve_client()
282
+ return [cls(metadata) for metadata in client.GET("/classification_model")]
283
+
284
+ @classmethod
285
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
286
+ """
287
+ Delete a classification model from the OrcaCloud
288
+
289
+ Warning:
290
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
291
+
292
+ Params:
293
+ name_or_id: Name or id of the classification model
294
+ if_not_exists: What to do if the classification model does not exist, defaults to `"error"`.
295
+ Other option is `"ignore"` to do nothing if the classification model does not exist.
296
+
297
+ Raises:
298
+ LookupError: If the classification model does not exist and if_not_exists is `"error"`
299
+ """
300
+ try:
301
+ client = OrcaClient._resolve_client()
302
+ client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
303
+ logging.info(f"Deleted model {name_or_id}")
304
+ except LookupError:
305
+ if if_not_exists == "error":
306
+ raise
307
+
308
+ def refresh(self):
309
+ """Refresh the model data from the OrcaCloud"""
310
+ self.__dict__.update(self.open(self.name).__dict__)
311
+
312
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
313
+ """
314
+ Update editable attributes of the model.
315
+
316
+ Note:
317
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
318
+
319
+ Params:
320
+ description: Value to set for the description
321
+ locked: Value to set for the locked status
322
+
323
+ Examples:
324
+ Update the description:
325
+ >>> model.set(description="New description")
326
+
327
+ Remove description:
328
+ >>> model.set(description=None)
329
+
330
+ Lock the model:
331
+ >>> model.set(locked=True)
332
+ """
333
+ update: PredictiveModelUpdate = {}
334
+ if description is not UNSET:
335
+ update["description"] = description
336
+ if locked is not UNSET:
337
+ update["locked"] = locked
338
+ client = OrcaClient._resolve_client()
339
+ client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
340
+ self.refresh()
341
+
342
+ def lock(self) -> None:
343
+ """Lock the model to prevent accidental deletion"""
344
+ self.set(locked=True)
345
+
346
+ def unlock(self) -> None:
347
+ """Unlock the model to allow deletion"""
348
+ self.set(locked=False)
349
+
350
+ @overload
351
+ def predict(
352
+ self,
353
+ value: list[str],
354
+ expected_labels: list[int] | None = None,
355
+ filters: list[FilterItemTuple] = [],
356
+ tags: set[str] | None = None,
357
+ save_telemetry: TelemetryMode = "on",
358
+ prompt: str | None = None,
359
+ use_lookup_cache: bool = True,
360
+ timeout_seconds: int = 10,
361
+ ignore_unlabeled: bool = False,
362
+ partition_id: str | list[str | None] | None = None,
363
+ partition_filter_mode: Literal[
364
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
365
+ ] = "include_global",
366
+ use_gpu: bool = True,
367
+ batch_size: int = 100,
368
+ ) -> list[ClassificationPrediction]:
369
+ pass
370
+
371
+ @overload
372
+ def predict(
373
+ self,
374
+ value: str,
375
+ expected_labels: int | None = None,
376
+ filters: list[FilterItemTuple] = [],
377
+ tags: set[str] | None = None,
378
+ save_telemetry: TelemetryMode = "on",
379
+ prompt: str | None = None,
380
+ use_lookup_cache: bool = True,
381
+ timeout_seconds: int = 10,
382
+ ignore_unlabeled: bool = False,
383
+ partition_id: str | None = None,
384
+ partition_filter_mode: Literal[
385
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
386
+ ] = "include_global",
387
+ use_gpu: bool = True,
388
+ batch_size: int = 100,
389
+ ) -> ClassificationPrediction:
390
+ pass
391
+
392
+ def predict(
393
+ self,
394
+ value: list[str] | str,
395
+ expected_labels: list[int] | list[str] | int | str | None = None,
396
+ filters: list[FilterItemTuple] = [],
397
+ tags: set[str] | None = None,
398
+ save_telemetry: TelemetryMode = "on",
399
+ prompt: str | None = None,
400
+ use_lookup_cache: bool = True,
401
+ timeout_seconds: int = 10,
402
+ ignore_unlabeled: bool = False,
403
+ partition_id: str | None | list[str | None] = None,
404
+ partition_filter_mode: Literal[
405
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
406
+ ] = "include_global",
407
+ use_gpu: bool = True,
408
+ batch_size: int = 100,
409
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
410
+ """
411
+ Predict label(s) for the given input value(s) grounded in similar memories
412
+
413
+ Params:
414
+ value: Value(s) to get predict the labels of
415
+ expected_labels: Expected label(s) for the given input to record for model evaluation
416
+ filters: Optional filters to apply during memory lookup
417
+ tags: Tags to add to the prediction(s)
418
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
419
+ * `"off"`: Do not save telemetry
420
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
421
+ environment variable is set.
422
+ * `"sync"`: Save telemetry synchronously
423
+ * `"async"`: Save telemetry asynchronously
424
+ prompt: Optional prompt to use for instruction-tuned embedding models
425
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
426
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
427
+ ignore_unlabeled: If True, only use labeled memories during lookup.
428
+ If False (default), allow unlabeled memories when necessary.
429
+ partition_id: Optional partition ID(s) to use during memory lookup
430
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
431
+ * `"ignore_partitions"`: Ignore partitions
432
+ * `"include_global"`: Include global memories
433
+ * `"exclude_global"`: Exclude global memories
434
+ * `"only_global"`: Only include global memories
435
+ use_gpu: Whether to use GPU for the prediction (defaults to True)
436
+ batch_size: Number of values to process in a single API call
437
+
438
+ Returns:
439
+ Label prediction or list of label predictions
440
+
441
+ Raises:
442
+ ValueError: If timeout_seconds is not a positive integer
443
+ TimeoutError: If the request times out after the specified duration
444
+
445
+ Examples:
446
+ Predict the label for a single value:
447
+ >>> prediction = model.predict("I am happy", tags={"test"})
448
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
449
+
450
+ Predict the labels for a list of values:
451
+ >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
452
+ [
453
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
454
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
455
+ ]
456
+
457
+ Using a prompt with an instruction-tuned embedding model:
458
+ >>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
459
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
460
+ """
461
+
462
+ if timeout_seconds <= 0:
463
+ raise ValueError("timeout_seconds must be a positive integer")
464
+ if batch_size <= 0 or batch_size > 500:
465
+ raise ValueError("batch_size must be between 1 and 500")
466
+
467
+ parsed_filters = [
468
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
469
+ ]
470
+
471
+ if any(_is_metric_column(filter[0]) for filter in filters):
472
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
473
+
474
+ # Convert to list for batching
475
+ values = value if isinstance(value, list) else [value]
476
+ if isinstance(expected_labels, list) and len(expected_labels) != len(values):
477
+ raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
478
+ if isinstance(partition_id, list) and len(partition_id) != len(values):
479
+ raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
480
+
481
+ if isinstance(expected_labels, int):
482
+ expected_labels = [expected_labels] * len(values)
483
+ elif isinstance(expected_labels, str):
484
+ expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
485
+ elif isinstance(expected_labels, list):
486
+ expected_labels = [
487
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
488
+ for label in expected_labels
489
+ ]
490
+
491
+ if use_gpu:
492
+ endpoint = "/gpu/classification_model/{name_or_id}/prediction"
493
+ else:
494
+ endpoint = "/classification_model/{name_or_id}/prediction"
495
+
496
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
497
+ client = OrcaClient._resolve_client()
498
+
499
+ predictions: list[ClassificationPrediction] = []
500
+ for i in range(0, len(values), batch_size):
501
+ batch_values = values[i : i + batch_size]
502
+ batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
503
+
504
+ request_json: ClassificationPredictionRequest = {
505
+ "input_values": batch_values,
506
+ "memoryset_override_name_or_id": self._memoryset_override_id,
507
+ "expected_labels": batch_expected_labels,
508
+ "tags": list(tags or set()),
509
+ "save_telemetry": telemetry_on,
510
+ "save_telemetry_synchronously": telemetry_sync,
511
+ "filters": cast(list[FilterItem], parsed_filters),
512
+ "prompt": prompt,
513
+ "use_lookup_cache": use_lookup_cache,
514
+ "ignore_unlabeled": ignore_unlabeled,
515
+ "partition_filter_mode": partition_filter_mode,
516
+ }
517
+ if partition_filter_mode != "ignore_partitions":
518
+ request_json["partition_ids"] = (
519
+ partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
520
+ )
521
+
522
+ response = client.POST(
523
+ endpoint,
524
+ params={"name_or_id": self.id},
525
+ json=request_json,
526
+ timeout=timeout_seconds,
527
+ )
528
+
529
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
530
+ raise RuntimeError("Failed to save some prediction to database.")
531
+
532
+ predictions.extend(
533
+ ClassificationPrediction(
534
+ prediction_id=prediction["prediction_id"],
535
+ label=prediction["label"],
536
+ label_name=prediction["label_name"],
537
+ score=None,
538
+ confidence=prediction["confidence"],
539
+ anomaly_score=prediction["anomaly_score"],
540
+ memoryset=self.memoryset,
541
+ model=self,
542
+ logits=prediction["logits"],
543
+ input_value=input_value,
544
+ )
545
+ for prediction, input_value in zip(response, batch_values)
546
+ )
547
+
548
+ self._last_prediction_was_batch = isinstance(value, list)
549
+ self._last_prediction = predictions[-1]
550
+ return predictions if isinstance(value, list) else predictions[0]
551
+
552
+ @overload
553
+ async def apredict(
554
+ self,
555
+ value: list[str],
556
+ expected_labels: list[int] | None = None,
557
+ filters: list[FilterItemTuple] = [],
558
+ tags: set[str] | None = None,
559
+ save_telemetry: TelemetryMode = "on",
560
+ prompt: str | None = None,
561
+ use_lookup_cache: bool = True,
562
+ timeout_seconds: int = 10,
563
+ ignore_unlabeled: bool = False,
564
+ partition_id: str | list[str | None] | None = None,
565
+ partition_filter_mode: Literal[
566
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
567
+ ] = "include_global",
568
+ batch_size: int = 100,
569
+ ) -> list[ClassificationPrediction]:
570
+ pass
571
+
572
+ @overload
573
+ async def apredict(
574
+ self,
575
+ value: str,
576
+ expected_labels: int | None = None,
577
+ filters: list[FilterItemTuple] = [],
578
+ tags: set[str] | None = None,
579
+ save_telemetry: TelemetryMode = "on",
580
+ prompt: str | None = None,
581
+ use_lookup_cache: bool = True,
582
+ timeout_seconds: int = 10,
583
+ ignore_unlabeled: bool = False,
584
+ partition_id: str | None = None,
585
+ partition_filter_mode: Literal[
586
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
587
+ ] = "include_global",
588
+ batch_size: int = 100,
589
+ ) -> ClassificationPrediction:
590
+ pass
591
+
592
+ async def apredict(
593
+ self,
594
+ value: list[str] | str,
595
+ expected_labels: list[int] | list[str] | int | str | None = None,
596
+ filters: list[FilterItemTuple] = [],
597
+ tags: set[str] | None = None,
598
+ save_telemetry: TelemetryMode = "on",
599
+ prompt: str | None = None,
600
+ use_lookup_cache: bool = True,
601
+ timeout_seconds: int = 10,
602
+ ignore_unlabeled: bool = False,
603
+ partition_id: str | None | list[str | None] = None,
604
+ partition_filter_mode: Literal[
605
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
606
+ ] = "include_global",
607
+ batch_size: int = 100,
608
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
609
+ """
610
+ Asynchronously predict label(s) for the given input value(s) grounded in similar memories
611
+
612
+ Params:
613
+ value: Value(s) to get predict the labels of
614
+ expected_labels: Expected label(s) for the given input to record for model evaluation
615
+ filters: Optional filters to apply during memory lookup
616
+ tags: Tags to add to the prediction(s)
617
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
618
+ * `"off"`: Do not save telemetry
619
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
620
+ environment variable is set.
621
+ * `"sync"`: Save telemetry synchronously
622
+ * `"async"`: Save telemetry asynchronously
623
+ prompt: Optional prompt to use for instruction-tuned embedding models
624
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
625
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
626
+ ignore_unlabeled: If True, only use labeled memories during lookup.
627
+ If False (default), allow unlabeled memories when necessary.
628
+ partition_id: Optional partition ID(s) to use during memory lookup
629
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
630
+ * `"ignore_partitions"`: Ignore partitions
631
+ * `"include_global"`: Include global memories
632
+ * `"exclude_global"`: Exclude global memories
633
+ * `"only_global"`: Only include global memories
634
+ batch_size: Number of values to process in a single API call
635
+
636
+ Returns:
637
+ Label prediction or list of label predictions.
638
+
639
+ Raises:
640
+ ValueError: If timeout_seconds is not a positive integer
641
+ TimeoutError: If the request times out after the specified duration
642
+
643
+ Examples:
644
+ Predict the label for a single value:
645
+ >>> prediction = await model.apredict("I am happy", tags={"test"})
646
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
647
+
648
+ Predict the labels for a list of values:
649
+ >>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
650
+ [
651
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
652
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
653
+ ]
654
+
655
+ Using a prompt with an instruction-tuned embedding model:
656
+ >>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
657
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
658
+ """
659
+
660
+ if timeout_seconds <= 0:
661
+ raise ValueError("timeout_seconds must be a positive integer")
662
+ if batch_size <= 0 or batch_size > 500:
663
+ raise ValueError("batch_size must be between 1 and 500")
664
+
665
+ parsed_filters = [
666
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
667
+ ]
668
+
669
+ if any(_is_metric_column(filter[0]) for filter in filters):
670
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
671
+
672
+ # Convert to list for batching
673
+ values = value if isinstance(value, list) else [value]
674
+ if isinstance(expected_labels, list) and len(expected_labels) != len(values):
675
+ raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
676
+ if isinstance(partition_id, list) and len(partition_id) != len(values):
677
+ raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
678
+
679
+ if isinstance(expected_labels, int):
680
+ expected_labels = [expected_labels] * len(values)
681
+ elif isinstance(expected_labels, str):
682
+ expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
683
+ elif isinstance(expected_labels, list):
684
+ expected_labels = [
685
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
686
+ for label in expected_labels
687
+ ]
688
+
689
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
690
+ client = OrcaAsyncClient._resolve_client()
691
+
692
+ predictions: list[ClassificationPrediction] = []
693
+ for i in range(0, len(values), batch_size):
694
+ batch_values = values[i : i + batch_size]
695
+ batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
696
+
697
+ request_json: ClassificationPredictionRequest = {
698
+ "input_values": batch_values,
699
+ "memoryset_override_name_or_id": self._memoryset_override_id,
700
+ "expected_labels": batch_expected_labels,
701
+ "tags": list(tags or set()),
702
+ "save_telemetry": telemetry_on,
703
+ "save_telemetry_synchronously": telemetry_sync,
704
+ "filters": cast(list[FilterItem], parsed_filters),
705
+ "prompt": prompt,
706
+ "use_lookup_cache": use_lookup_cache,
707
+ "ignore_unlabeled": ignore_unlabeled,
708
+ "partition_filter_mode": partition_filter_mode,
709
+ }
710
+ if partition_filter_mode != "ignore_partitions":
711
+ request_json["partition_ids"] = (
712
+ partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
713
+ )
714
+ response = await client.POST(
715
+ "/gpu/classification_model/{name_or_id}/prediction",
716
+ params={"name_or_id": self.id},
717
+ json=request_json,
718
+ timeout=timeout_seconds,
719
+ )
720
+
721
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
722
+ raise RuntimeError("Failed to save some prediction to database.")
723
+
724
+ predictions.extend(
725
+ ClassificationPrediction(
726
+ prediction_id=prediction["prediction_id"],
727
+ label=prediction["label"],
728
+ label_name=prediction["label_name"],
729
+ score=None,
730
+ confidence=prediction["confidence"],
731
+ anomaly_score=prediction["anomaly_score"],
732
+ memoryset=self.memoryset,
733
+ model=self,
734
+ logits=prediction["logits"],
735
+ input_value=input_value,
736
+ )
737
+ for prediction, input_value in zip(response, batch_values)
738
+ )
739
+
740
+ self._last_prediction_was_batch = isinstance(value, list)
741
+ self._last_prediction = predictions[-1]
742
+ return predictions if isinstance(value, list) else predictions[0]
743
+
744
+ def predictions(
745
+ self,
746
+ limit: int | None = None,
747
+ offset: int = 0,
748
+ tag: str | None = None,
749
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
750
+ expected_label_match: bool | None = None,
751
+ batch_size: int = 100,
752
+ ) -> list[ClassificationPrediction]:
753
+ """
754
+ Get a list of predictions made by this model
755
+
756
+ Params:
757
+ limit: Maximum number of predictions to return. If `None`, returns all predictions
758
+ by automatically paginating through results.
759
+ offset: Optional offset of the first prediction to return
760
+ tag: Optional tag to filter predictions by
761
+ sort: Optional list of columns and directions to sort the predictions by.
762
+ Predictions can be sorted by `timestamp` or `confidence`.
763
+ expected_label_match: Optional filter to only include predictions where the expected
764
+ label does (`True`) or doesn't (`False`) match the predicted label
765
+ batch_size: Number of predictions to fetch in a single API call
766
+
767
+ Returns:
768
+ List of label predictions
769
+
770
+ Examples:
771
+ Get all predictions with a specific tag:
772
+ >>> predictions = model.predictions(tag="evaluation")
773
+
774
+ Get the last 3 predictions:
775
+ >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
776
+ [
777
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
778
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
779
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
780
+ ]
781
+
782
+
783
+ Get second most confident prediction:
784
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
785
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
786
+
787
+ Get predictions where the expected label doesn't match the predicted label:
788
+ >>> predictions = model.predictions(expected_label_match=False)
789
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
790
+ """
791
+ if batch_size <= 0 or batch_size > 500:
792
+ raise ValueError("batch_size must be between 1 and 500")
793
+ if limit == 0:
794
+ return []
795
+
796
+ client = OrcaClient._resolve_client()
797
+ all_predictions: list[ClassificationPrediction] = []
798
+
799
+ if limit is not None and limit < batch_size:
800
+ pages = [(offset, limit)]
801
+ else:
802
+ # automatically paginate the requests if necessary
803
+ total = client.POST(
804
+ "/telemetry/prediction/count",
805
+ json={
806
+ "model_id": self.id,
807
+ "tag": tag,
808
+ "expected_label_match": expected_label_match,
809
+ },
810
+ )
811
+ max_limit = max(total - offset, 0)
812
+ limit = min(limit, max_limit) if limit is not None else max_limit
813
+ pages = [(o, min(batch_size, limit - (o - offset))) for o in range(offset, offset + limit, batch_size)]
814
+
815
+ for current_offset, current_limit in pages:
816
+ request_json: ListPredictionsRequest = {
817
+ "model_id": self.id,
818
+ "limit": current_limit,
819
+ "offset": current_offset,
820
+ "tag": tag,
821
+ "expected_label_match": expected_label_match,
822
+ }
823
+ if sort:
824
+ request_json["sort"] = sort
825
+ response = client.POST(
826
+ "/telemetry/prediction",
827
+ json=request_json,
828
+ )
829
+ all_predictions.extend(
830
+ ClassificationPrediction(
831
+ prediction_id=prediction["prediction_id"],
832
+ label=prediction["label"],
833
+ label_name=prediction["label_name"],
834
+ score=None,
835
+ confidence=prediction["confidence"],
836
+ anomaly_score=prediction["anomaly_score"],
837
+ memoryset=self.memoryset,
838
+ model=self,
839
+ telemetry=prediction,
840
+ )
841
+ for prediction in response
842
+ if "label" in prediction
843
+ )
844
+
845
+ return all_predictions
846
+
847
+ def _evaluate_datasource(
848
+ self,
849
+ datasource: Datasource,
850
+ value_column: str,
851
+ label_column: str,
852
+ record_predictions: bool,
853
+ tags: set[str] | None,
854
+ subsample: int | float | None,
855
+ background: bool = False,
856
+ ignore_unlabeled: bool = False,
857
+ partition_column: str | None = None,
858
+ partition_filter_mode: Literal[
859
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
860
+ ] = "include_global",
861
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
862
+ client = OrcaClient._resolve_client()
863
+ response = client.POST(
864
+ "/classification_model/{model_name_or_id}/evaluation",
865
+ params={"model_name_or_id": self.id},
866
+ json={
867
+ "datasource_name_or_id": datasource.id,
868
+ "datasource_label_column": label_column,
869
+ "datasource_value_column": value_column,
870
+ "memoryset_override_name_or_id": self._memoryset_override_id,
871
+ "record_telemetry": record_predictions,
872
+ "telemetry_tags": list(tags) if tags else None,
873
+ "subsample": subsample,
874
+ "ignore_unlabeled": ignore_unlabeled,
875
+ "datasource_partition_column": partition_column,
876
+ "partition_filter_mode": partition_filter_mode,
877
+ },
878
+ )
879
+
880
+ def get_value():
881
+ client = OrcaClient._resolve_client()
882
+ res = client.GET(
883
+ "/classification_model/{model_name_or_id}/evaluation/{job_id}",
884
+ params={"model_name_or_id": self.id, "job_id": response["job_id"]},
885
+ )
886
+ assert res["result"] is not None
887
+ return ClassificationMetrics(
888
+ coverage=res["result"].get("coverage"),
889
+ f1_score=res["result"].get("f1_score"),
890
+ accuracy=res["result"].get("accuracy"),
891
+ loss=res["result"].get("loss"),
892
+ anomaly_score_mean=res["result"].get("anomaly_score_mean"),
893
+ anomaly_score_median=res["result"].get("anomaly_score_median"),
894
+ anomaly_score_variance=res["result"].get("anomaly_score_variance"),
895
+ roc_auc=res["result"].get("roc_auc"),
896
+ pr_auc=res["result"].get("pr_auc"),
897
+ pr_curve=res["result"].get("pr_curve"),
898
+ roc_curve=res["result"].get("roc_curve"),
899
+ )
900
+
901
+ job = Job(response["job_id"], get_value)
902
+ return job if background else job.result()
903
+
904
+ def _evaluate_dataset(
905
+ self,
906
+ dataset: Dataset,
907
+ value_column: str,
908
+ label_column: str,
909
+ record_predictions: bool,
910
+ tags: set[str],
911
+ batch_size: int,
912
+ ignore_unlabeled: bool,
913
+ partition_column: str | None = None,
914
+ partition_filter_mode: Literal[
915
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
916
+ ] = "include_global",
917
+ ) -> ClassificationMetrics:
918
+ if len(dataset) == 0:
919
+ raise ValueError("Evaluation dataset cannot be empty")
920
+
921
+ if any(x is None for x in dataset[label_column]):
922
+ raise ValueError("Evaluation dataset cannot contain None values in the label column")
923
+
924
+ predictions = [
925
+ prediction
926
+ for i in range(0, len(dataset), batch_size)
927
+ for prediction in self.predict(
928
+ dataset[i : i + batch_size][value_column],
929
+ expected_labels=dataset[i : i + batch_size][label_column],
930
+ tags=tags,
931
+ save_telemetry="sync" if record_predictions else "off",
932
+ ignore_unlabeled=ignore_unlabeled,
933
+ partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
934
+ partition_filter_mode=partition_filter_mode,
935
+ )
936
+ ]
937
+
938
+ return calculate_classification_metrics(
939
+ expected_labels=dataset[label_column],
940
+ logits=[p.logits for p in predictions],
941
+ anomaly_scores=[p.anomaly_score for p in predictions],
942
+ include_curves=True,
943
+ include_confusion_matrix=True,
944
+ )
945
+
946
+ @overload
947
+ def evaluate(
948
+ self,
949
+ data: Datasource | Dataset,
950
+ *,
951
+ value_column: str = "value",
952
+ label_column: str = "label",
953
+ partition_column: str | None = None,
954
+ record_predictions: bool = False,
955
+ tags: set[str] = {"evaluation"},
956
+ batch_size: int = 100,
957
+ subsample: int | float | None = None,
958
+ background: Literal[True],
959
+ ignore_unlabeled: bool = False,
960
+ partition_filter_mode: Literal[
961
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
962
+ ] = "include_global",
963
+ ) -> Job[ClassificationMetrics]:
964
+ pass
965
+
966
+ @overload
967
+ def evaluate(
968
+ self,
969
+ data: Datasource | Dataset,
970
+ *,
971
+ value_column: str = "value",
972
+ label_column: str = "label",
973
+ partition_column: str | None = None,
974
+ record_predictions: bool = False,
975
+ tags: set[str] = {"evaluation"},
976
+ batch_size: int = 100,
977
+ subsample: int | float | None = None,
978
+ background: Literal[False] = False,
979
+ ignore_unlabeled: bool = False,
980
+ partition_filter_mode: Literal[
981
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
982
+ ] = "include_global",
983
+ ) -> ClassificationMetrics:
984
+ pass
985
+
986
+ def evaluate(
987
+ self,
988
+ data: Datasource | Dataset,
989
+ *,
990
+ value_column: str = "value",
991
+ label_column: str = "label",
992
+ partition_column: str | None = None,
993
+ record_predictions: bool = False,
994
+ tags: set[str] = {"evaluation"},
995
+ batch_size: int = 100,
996
+ subsample: int | float | None = None,
997
+ background: bool = False,
998
+ ignore_unlabeled: bool = False,
999
+ partition_filter_mode: Literal[
1000
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
1001
+ ] = "include_global",
1002
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
1003
+ """
1004
+ Evaluate the classification model on a given dataset or datasource
1005
+
1006
+ Params:
1007
+ data: Dataset or Datasource to evaluate the model on
1008
+ value_column: Name of the column that contains the input values to the model
1009
+ label_column: Name of the column containing the expected labels
1010
+ partition_column: Optional name of the column that contains the partition IDs
1011
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
1012
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
1013
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
1014
+ subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
1015
+ background: Whether to run the operation in the background and return a job handle
1016
+ ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
1017
+ partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
1018
+ * `"ignore_partitions"`: Ignore partitions
1019
+ * `"include_global"`: Include global memories
1020
+ * `"exclude_global"`: Exclude global memories
1021
+ * `"only_global"`: Only include global memories
1022
+ Returns:
1023
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
1024
+
1025
+ Examples:
1026
+ >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
1027
+ ClassificationMetrics({
1028
+ accuracy: 0.8500,
1029
+ f1_score: 0.8500,
1030
+ roc_auc: 0.8500,
1031
+ pr_auc: 0.8500,
1032
+ anomaly_score: 0.3500 ± 0.0500,
1033
+ })
1034
+ """
1035
+ if isinstance(data, Datasource):
1036
+ return self._evaluate_datasource(
1037
+ datasource=data,
1038
+ value_column=value_column,
1039
+ label_column=label_column,
1040
+ record_predictions=record_predictions,
1041
+ tags=tags,
1042
+ subsample=subsample,
1043
+ background=background,
1044
+ ignore_unlabeled=ignore_unlabeled,
1045
+ partition_column=partition_column,
1046
+ partition_filter_mode=partition_filter_mode,
1047
+ )
1048
+ elif isinstance(data, Dataset):
1049
+ return self._evaluate_dataset(
1050
+ dataset=data,
1051
+ value_column=value_column,
1052
+ label_column=label_column,
1053
+ record_predictions=record_predictions,
1054
+ tags=tags,
1055
+ batch_size=batch_size,
1056
+ ignore_unlabeled=ignore_unlabeled,
1057
+ partition_column=partition_column,
1058
+ partition_filter_mode=partition_filter_mode,
1059
+ )
1060
+ else:
1061
+ raise ValueError(f"Invalid data type: {type(data)}")
1062
+
1063
+ def finetune(self, datasource: Datasource):
1064
+ # do not document until implemented
1065
+ raise NotImplementedError("Finetuning is not supported yet")
1066
+
1067
+ @contextmanager
1068
+ def use_memoryset(self, memoryset_override: LabeledMemoryset) -> Generator[None, None, None]:
1069
+ """
1070
+ Temporarily override the memoryset used by the model for predictions
1071
+
1072
+ Params:
1073
+ memoryset_override: Memoryset to override the default memoryset with
1074
+
1075
+ Examples:
1076
+ >>> with model.use_memoryset(LabeledMemoryset.open("my_other_memoryset")):
1077
+ ... predictions = model.predict("I am happy")
1078
+ """
1079
+ self._memoryset_override_id = memoryset_override.id
1080
+ yield
1081
+ self._memoryset_override_id = None
1082
+
1083
+ @overload
1084
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
1085
+ pass
1086
+
1087
+ @overload
1088
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
1089
+ pass
1090
+
1091
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
1092
+ """
1093
+ Record feedback for a list of predictions.
1094
+
1095
+ We support recording feedback in several categories for each prediction. A
1096
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
1097
+ the first time feedback with a new name is recorded. Categories are global across models.
1098
+ The value type of the category is inferred from the first recorded value. Subsequent
1099
+ feedback for the same category must be of the same type.
1100
+
1101
+ Params:
1102
+ feedback: Feedback to record, this should be dictionaries with the following keys:
1103
+
1104
+ - `category`: Name of the category under which to record the feedback.
1105
+ - `value`: Feedback value to record, should be `True` for positive feedback and
1106
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
1107
+ where negative values indicate negative feedback and positive values indicate
1108
+ positive feedback.
1109
+ - `comment`: Optional comment to record with the feedback.
1110
+
1111
+ Examples:
1112
+ Record whether predictions were correct or incorrect:
1113
+ >>> model.record_feedback({
1114
+ ... "prediction": p.prediction_id,
1115
+ ... "category": "correct",
1116
+ ... "value": p.label == p.expected_label,
1117
+ ... } for p in predictions)
1118
+
1119
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
1120
+ >>> model.record_feedback({
1121
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
1122
+ ... "category": "rating",
1123
+ ... "value": -0.5,
1124
+ ... "comment": "2 stars"
1125
+ ... })
1126
+
1127
+ Raises:
1128
+ ValueError: If the value does not match previous value types for the category, or is a
1129
+ [`float`][float] that is not between `-1.0` and `+1.0`.
1130
+ """
1131
+ client = OrcaClient._resolve_client()
1132
+ client.PUT(
1133
+ "/telemetry/prediction/feedback",
1134
+ json=[
1135
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
1136
+ ],
1137
+ )
1138
+
1139
+ @staticmethod
1140
+ def bootstrap_model(
1141
+ model_description: str,
1142
+ label_names: list[str],
1143
+ initial_examples: list[tuple[str, str]],
1144
+ num_examples_per_label: int,
1145
+ background: bool = False,
1146
+ ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
1147
+ client = OrcaClient._resolve_client()
1148
+ response = client.POST(
1149
+ "/agents/bootstrap_classification_model",
1150
+ json={
1151
+ "model_description": model_description,
1152
+ "label_names": label_names,
1153
+ "initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
1154
+ "num_examples_per_label": num_examples_per_label,
1155
+ },
1156
+ )
1157
+
1158
+ def get_result() -> BootstrappedClassificationModel:
1159
+ client = OrcaClient._resolve_client()
1160
+ res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
1161
+ assert res["result"] is not None
1162
+ return BootstrappedClassificationModel(res["result"])
1163
+
1164
+ job = Job(response["job_id"], get_result)
1165
+ return job if background else job.result()