orca-sdk 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,8 @@ IMPORTANT:
8
8
 
9
9
  """
10
10
 
11
- from dataclasses import dataclass
11
+ import logging
12
+ from dataclasses import dataclass, field
12
13
  from typing import Any, Literal, Sequence, TypedDict, cast
13
14
 
14
15
  import numpy as np
@@ -242,6 +243,12 @@ class ClassificationMetrics:
242
243
  roc_curve: ROCCurve | None = None
243
244
  """Receiver operating characteristic curve"""
244
245
 
246
+ confusion_matrix: list[list[int]] | None = None
247
+ """Confusion matrix where confusion_matrix[i][j] is the count of samples with true label i predicted as label j"""
248
+
249
+ warnings: list[str] = field(default_factory=list)
250
+ """Human-readable warnings about skipped or adjusted metrics"""
251
+
245
252
  def __repr__(self) -> str:
246
253
  return (
247
254
  "ClassificationMetrics({\n"
@@ -321,7 +328,9 @@ def calculate_classification_metrics(
321
328
  average: Literal["micro", "macro", "weighted", "binary"] | None = None,
322
329
  multi_class: Literal["ovr", "ovo"] = "ovr",
323
330
  include_curves: bool = False,
331
+ include_confusion_matrix: bool = False,
324
332
  ) -> ClassificationMetrics:
333
+ warnings: list[str] = []
325
334
  references = np.array(expected_labels)
326
335
 
327
336
  # Convert to numpy array, handling None values
@@ -342,6 +351,7 @@ def calculate_classification_metrics(
342
351
  pr_auc=None,
343
352
  pr_curve=None,
344
353
  roc_curve=None,
354
+ confusion_matrix=None,
345
355
  )
346
356
 
347
357
  # Convert logits to probabilities
@@ -351,41 +361,102 @@ def calculate_classification_metrics(
351
361
  predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
352
362
 
353
363
  num_classes_references = len(set(references))
354
- num_classes_predictions = len(set(predictions))
364
+ num_classes_predictions = probabilities.shape[1] # Number of probability columns (model's known classes)
355
365
  num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
356
366
  coverage = 1 - (num_none_predictions / len(probabilities) if len(probabilities) > 0 else 0)
367
+ if num_none_predictions > 0:
368
+ warnings.append(f"Some predictions were missing (coverage={coverage:.3f}); loss and AUC metrics were skipped.")
357
369
 
358
370
  if average is None:
359
371
  average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
360
372
 
361
373
  accuracy = sklearn.metrics.accuracy_score(references, predictions)
362
374
  f1 = sklearn.metrics.f1_score(references, predictions, average=average)
375
+
376
+ # Check for unknown classes early (before log_loss)
377
+ classes_in_references = np.unique(references)
378
+ has_unknown_classes = np.max(classes_in_references) >= num_classes_predictions
379
+ if has_unknown_classes:
380
+ logging.warning(
381
+ f"Test labels contain classes not in the model's predictions. "
382
+ f"Model has {num_classes_predictions} classes (0 - {num_classes_predictions - 1}), "
383
+ f"but test labels contain class {np.max(classes_in_references)}. "
384
+ f"ROC AUC and PR AUC cannot be calculated."
385
+ )
386
+ warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
387
+
363
388
  # Ensure sklearn sees the full class set corresponding to probability columns
364
389
  # to avoid errors when y_true does not contain all classes.
390
+ # Skip log_loss if there are unknown classes (would cause ValueError)
365
391
  loss = (
366
392
  sklearn.metrics.log_loss(
367
393
  references,
368
394
  probabilities,
369
395
  labels=list(range(probabilities.shape[1])),
370
396
  )
371
- if num_none_predictions == 0
397
+ if num_none_predictions == 0 and not has_unknown_classes
372
398
  else None
373
399
  )
374
400
 
375
- if num_classes_references == num_classes_predictions and num_none_predictions == 0:
376
- # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
377
- if num_classes_references == 2:
378
- # Use probabilities[:, 1] which is guaranteed to be 2D
379
- probabilities_positive = probabilities[:, 1]
380
- roc_auc = sklearn.metrics.roc_auc_score(references, probabilities_positive)
381
- roc_curve = calculate_roc_curve(references, probabilities_positive) if include_curves else None
382
- pr_auc = sklearn.metrics.average_precision_score(references, probabilities_positive)
383
- pr_curve = calculate_pr_curve(references, probabilities_positive) if include_curves else None
384
- else:
385
- roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
401
+ # Calculate ROC AUC with filtering for class mismatch
402
+ if num_none_predictions == 0:
403
+ # Check if y_true contains classes not in the model (unknown classes)
404
+ if has_unknown_classes:
405
+ # Unknown classes present - can't calculate meaningful ROC AUC
406
+ logging.warning(
407
+ "Cannot calculate ROC AUC and PR AUC: test labels contain classes not in the model's predictions."
408
+ )
409
+ if "y_true contains classes unknown to the model" not in " ".join(warnings):
410
+ warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
411
+ roc_auc = None
412
+ pr_auc = None
413
+ pr_curve = None
386
414
  roc_curve = None
415
+ elif len(classes_in_references) < 2:
416
+ # Need at least 2 classes for ROC AUC
417
+ logging.warning(
418
+ f"Cannot calculate ROC AUC and PR AUC: need at least 2 classes, but only {len(classes_in_references)} class(es) found in test labels."
419
+ )
420
+ roc_auc = None
387
421
  pr_auc = None
388
422
  pr_curve = None
423
+ roc_curve = None
424
+ warnings.append("ROC AUC requires at least 2 classes; metric was skipped.")
425
+ else:
426
+ # Filter probabilities to only classes present in references
427
+ if len(classes_in_references) < num_classes_predictions:
428
+ # Subset and renormalize probabilities
429
+ probabilities_filtered = probabilities[:, classes_in_references]
430
+ # Safe renormalization: guard against zero denominators
431
+ row_sums = probabilities_filtered.sum(axis=1, keepdims=True)
432
+ probabilities_filtered = probabilities_filtered / np.where(row_sums > 0, row_sums, 1.0)
433
+
434
+ # Remap references to filtered indices
435
+ class_mapping = {cls: idx for idx, cls in enumerate(classes_in_references)}
436
+ references_remapped = np.array([class_mapping[y] for y in references])
437
+ warnings.append(
438
+ f"ROC AUC computed only on classes present in y_true: {classes_in_references.tolist()}."
439
+ )
440
+ else:
441
+ # All classes present, no filtering needed
442
+ probabilities_filtered = probabilities
443
+ references_remapped = references
444
+
445
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
446
+ if len(classes_in_references) == 2:
447
+ # Use probabilities[:, 1] which is guaranteed to be 2D
448
+ probabilities_positive = cast(NDArray[np.float32], probabilities_filtered[:, 1].astype(np.float32))
449
+ roc_auc = sklearn.metrics.roc_auc_score(references_remapped, probabilities_positive)
450
+ roc_curve = calculate_roc_curve(references_remapped, probabilities_positive) if include_curves else None
451
+ pr_auc = sklearn.metrics.average_precision_score(references_remapped, probabilities_positive)
452
+ pr_curve = calculate_pr_curve(references_remapped, probabilities_positive) if include_curves else None
453
+ else:
454
+ roc_auc = sklearn.metrics.roc_auc_score(
455
+ references_remapped, probabilities_filtered, multi_class=multi_class
456
+ )
457
+ roc_curve = None
458
+ pr_auc = None
459
+ pr_curve = None
389
460
  else:
390
461
  roc_auc = None
391
462
  pr_auc = None
@@ -395,6 +466,31 @@ def calculate_classification_metrics(
395
466
  # Calculate anomaly score statistics
396
467
  anomaly_score_mean, anomaly_score_median, anomaly_score_variance = calculate_anomaly_score_stats(anomaly_scores)
397
468
 
469
+ # Calculate confusion matrix if requested
470
+ confusion_matrix: list[list[int]] | None = None
471
+ if include_confusion_matrix:
472
+ # Get the number of classes from the probabilities shape
473
+ num_classes = probabilities.shape[1]
474
+ labels = list(range(num_classes))
475
+ # Filter out NaN predictions (which are set to -1) before computing confusion matrix
476
+ valid_mask = predictions != -1
477
+ num_filtered = (~valid_mask).sum()
478
+ if num_filtered > 0:
479
+ warning_msg = (
480
+ f"Confusion matrix computation: filtered out {num_filtered} samples with NaN predictions "
481
+ f"({num_filtered}/{len(predictions)} = {num_filtered / len(predictions):.1%})"
482
+ )
483
+ logging.warning(warning_msg)
484
+ warnings.append(warning_msg)
485
+
486
+ if np.any(valid_mask):
487
+ # Compute confusion matrix with explicit labels to ensure consistent shape
488
+ cm = sklearn.metrics.confusion_matrix(references[valid_mask], predictions[valid_mask], labels=labels)
489
+ else:
490
+ # No valid predictions; return an all-zero confusion matrix
491
+ cm = np.zeros((num_classes, num_classes), dtype=int)
492
+ confusion_matrix = cast(list[list[int]], cm.tolist())
493
+
398
494
  return ClassificationMetrics(
399
495
  coverage=coverage,
400
496
  accuracy=float(accuracy),
@@ -407,6 +503,8 @@ def calculate_classification_metrics(
407
503
  pr_auc=float(pr_auc) if pr_auc is not None else None,
408
504
  pr_curve=pr_curve,
409
505
  roc_curve=roc_curve,
506
+ confusion_matrix=confusion_matrix,
507
+ warnings=warnings,
410
508
  )
411
509
 
412
510
 
@@ -364,3 +364,207 @@ def test_regression_metrics_all_predictions_nan():
364
364
  assert metrics.anomaly_score_mean is None
365
365
  assert metrics.anomaly_score_median is None
366
366
  assert metrics.anomaly_score_variance is None
367
+
368
+
369
+ def test_roc_auc_handles_missing_classes_in_y_true():
370
+ """Test that ROC AUC is calculated with filtering when test set has fewer classes than model predictions."""
371
+ # Model trained on classes [0, 1, 2], but test set only has [0, 1]
372
+ y_true = np.array([0, 1, 0, 1])
373
+ y_score = np.array(
374
+ [
375
+ [0.7, 0.2, 0.1], # Predicts class 0
376
+ [0.1, 0.8, 0.1], # Predicts class 1
377
+ [0.6, 0.3, 0.1], # Predicts class 0
378
+ [0.2, 0.7, 0.1], # Predicts class 1
379
+ ]
380
+ )
381
+
382
+ metrics = calculate_classification_metrics(y_true, y_score)
383
+
384
+ # Should calculate ROC AUC by filtering to classes [0, 1]
385
+ assert metrics.roc_auc is not None
386
+ assert metrics.roc_auc == 1.0
387
+ assert any("computed only on classes present" in w for w in metrics.warnings)
388
+ # Other metrics should still work
389
+ assert metrics.accuracy == 1.0
390
+ assert metrics.f1_score == 1.0
391
+ assert metrics.loss is not None
392
+
393
+
394
+ def test_roc_auc_with_all_classes_present():
395
+ """Test that ROC AUC works when all classes are present in test set."""
396
+ # Model trained on classes [0, 1, 2], test set has all three
397
+ y_true = np.array([0, 1, 2, 0, 1, 2])
398
+ y_score = np.array(
399
+ [
400
+ [0.9, 0.05, 0.05], # Predicts class 0
401
+ [0.1, 0.8, 0.1], # Predicts class 1
402
+ [0.1, 0.1, 0.8], # Predicts class 2
403
+ [0.7, 0.2, 0.1], # Predicts class 0
404
+ [0.2, 0.7, 0.1], # Predicts class 1
405
+ [0.1, 0.2, 0.7], # Predicts class 2
406
+ ]
407
+ )
408
+
409
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
410
+
411
+ # ROC AUC should be calculated when all classes present
412
+ assert metrics.roc_auc is not None
413
+ assert metrics.accuracy == 1.0
414
+ assert metrics.f1_score == 1.0
415
+
416
+
417
+ def test_roc_auc_handles_subset_of_many_classes():
418
+ """Test ROC AUC where model knows 15 classes, test has 10."""
419
+ # Simulate the actual error scenario from the bug report
420
+ num_model_classes = 15
421
+ num_test_classes = 10
422
+ num_samples = 50
423
+
424
+ # Test set only uses classes 0-9
425
+ y_true = np.random.randint(0, num_test_classes, size=num_samples)
426
+
427
+ # Model produces predictions for all 15 classes
428
+ y_score = np.random.rand(num_samples, num_model_classes)
429
+ y_score = y_score / y_score.sum(axis=1, keepdims=True) # Normalize to probabilities
430
+
431
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
432
+
433
+ # Should calculate ROC AUC by filtering to classes 0-9
434
+ assert metrics.roc_auc is not None
435
+ assert 0.0 <= metrics.roc_auc <= 1.0
436
+ assert any("computed only on classes present" in w for w in metrics.warnings)
437
+ # Other metrics should still work
438
+ assert metrics.accuracy is not None
439
+ assert metrics.f1_score is not None
440
+ assert metrics.loss is not None
441
+
442
+
443
+ def test_roc_auc_handles_unknown_classes_in_y_true():
444
+ """Test that metrics handle when y_true contains classes not in y_score."""
445
+ # Model trained on classes [0, 1, 2], but test set has class 3
446
+ y_true = np.array([0, 1, 2, 3])
447
+ y_score = np.array(
448
+ [
449
+ [0.7, 0.2, 0.1],
450
+ [0.1, 0.8, 0.1],
451
+ [0.1, 0.1, 0.8],
452
+ [0.3, 0.4, 0.3], # Unknown class 3
453
+ ]
454
+ )
455
+
456
+ metrics = calculate_classification_metrics(y_true, y_score)
457
+
458
+ # Should skip ROC AUC and loss when unknown classes present
459
+ assert metrics.roc_auc is None
460
+ assert metrics.loss is None # Loss also skipped to avoid ValueError
461
+ assert any("unknown" in w for w in metrics.warnings)
462
+ # Other metrics should still work (they handle extra classes)
463
+ assert metrics.accuracy is not None
464
+ assert metrics.f1_score is not None
465
+
466
+
467
+ def test_roc_auc_handles_zero_probability_on_present_classes():
468
+ """Test ROC AUC when a sample has zero probability on all present classes (edge case for renormalization)."""
469
+ # Model trained on classes [0, 1, 2, 3], test set only has [0, 1, 2]
470
+ # One sample has ALL probability mass on excluded class 3 (zero on [0, 1, 2])
471
+ y_true = np.array([0, 1, 2, 0, 1, 2])
472
+ y_score = np.array(
473
+ [
474
+ [0.7, 0.2, 0.08, 0.02],
475
+ [0.1, 0.8, 0.08, 0.02],
476
+ [0.1, 0.1, 0.78, 0.02],
477
+ [0.6, 0.3, 0.08, 0.02],
478
+ [0.0, 0.0, 0.0, 1.0], # zero denominator
479
+ [0.1, 0.1, 0.78, 0.02],
480
+ ]
481
+ )
482
+
483
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
484
+
485
+ # Should still calculate ROC AUC despite zero-denominator case
486
+ # The safe renormalization should prevent NaN/inf
487
+ assert metrics.roc_auc is not None
488
+ assert not np.isnan(metrics.roc_auc)
489
+ assert not np.isinf(metrics.roc_auc)
490
+ assert any("computed only on classes present" in w for w in metrics.warnings)
491
+ assert metrics.accuracy is not None
492
+ assert metrics.f1_score is not None
493
+ assert metrics.loss is not None
494
+
495
+
496
+ def test_roc_auc_skipped_for_single_class():
497
+ """Test that ROC AUC is skipped when only one class is present in y_true."""
498
+ # Model trained on classes [0, 1, 2], but test set only has class 0
499
+ y_true = np.array([0, 0, 0, 0])
500
+ y_score = np.array(
501
+ [
502
+ [0.9, 0.05, 0.05],
503
+ [0.8, 0.1, 0.1],
504
+ [0.85, 0.1, 0.05],
505
+ [0.9, 0.05, 0.05],
506
+ ]
507
+ )
508
+
509
+ metrics = calculate_classification_metrics(y_true, y_score)
510
+
511
+ # ROC AUC requires at least 2 classes
512
+ assert metrics.roc_auc is None
513
+ assert metrics.accuracy == 1.0
514
+ assert metrics.loss is not None
515
+ assert any("requires at least 2 classes" in w for w in metrics.warnings)
516
+
517
+
518
+ # Confusion Matrix Tests
519
+ def test_confusion_matrix_binary_classification():
520
+ y_true = np.array([0, 1, 1, 0, 1])
521
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
522
+
523
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
524
+
525
+ assert metrics.confusion_matrix is not None
526
+ expected_cm = sklearn.metrics.confusion_matrix(y_true, [0, 1, 1, 0, 0], labels=[0, 1])
527
+ assert metrics.confusion_matrix == expected_cm.tolist()
528
+
529
+
530
+ def test_confusion_matrix_multiclass():
531
+ y_true = np.array([0, 1, 2, 0, 1, 2])
532
+ y_score = np.array(
533
+ [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9], [0.7, 0.2, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
534
+ )
535
+
536
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
537
+
538
+ assert metrics.confusion_matrix is not None
539
+ # All predictions correct
540
+ assert metrics.confusion_matrix == [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
541
+
542
+
543
+ def test_confusion_matrix_with_misclassifications():
544
+ y_true = np.array([0, 1, 2, 0, 1, 2])
545
+ y_score = np.array(
546
+ [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.1, 0.8, 0.1], [0.2, 0.7, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
547
+ )
548
+
549
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
550
+
551
+ assert metrics.confusion_matrix is not None
552
+ # Class 0: 1 correct (index 0), 1 predicted as class 1 (index 3)
553
+ # Class 1: 2 correct (indices 1, 4)
554
+ # Class 2: 1 predicted as class 1 (index 2), 1 correct (index 5)
555
+ assert metrics.confusion_matrix == [[1, 1, 0], [0, 2, 0], [0, 1, 1]]
556
+
557
+
558
+ def test_confusion_matrix_handles_nan_logits():
559
+ logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
560
+ expected_labels = [0, 1, 0, 1]
561
+ metrics = calculate_classification_metrics(expected_labels, logits, include_confusion_matrix=True)
562
+
563
+ # NaN predictions are set to -1, so they won't match any true label
564
+ # Only the last 2 predictions are valid: pred=[1, 1], true=[0, 1]
565
+ assert metrics.confusion_matrix is not None
566
+ # With NaN handling, predictions become [-1, -1, 1, 1]
567
+ # Only position 3 is correct (true=1, pred=1)
568
+ # Position 2 is wrong (true=0, pred=1)
569
+ assert len(metrics.confusion_matrix) == 2 # 2 classes
570
+ assert len(metrics.confusion_matrix[0]) == 2
orca_sdk/async_client.py CHANGED
@@ -62,10 +62,12 @@ class ActionRecommendation(TypedDict):
62
62
  class AddMemorySuggestion(TypedDict):
63
63
  value: str
64
64
  label_name: str
65
+ similarity: NotRequired[float | None]
65
66
 
66
67
 
67
68
  class AliveResponse(TypedDict):
68
69
  ok: bool
70
+ checks: dict[str, bool]
69
71
 
70
72
 
71
73
  class ApiKeyMetadata(TypedDict):
@@ -292,7 +294,7 @@ class JobStatusInfo(TypedDict):
292
294
  class LabelClassMetrics(TypedDict):
293
295
  label: int | None
294
296
  label_name: NotRequired[str | None]
295
- average_lookup_score: float
297
+ average_lookup_score: float | None
296
298
  memory_count: int
297
299
 
298
300
 
@@ -346,7 +348,7 @@ class MemoryMetrics(TypedDict):
346
348
  cluster: NotRequired[int]
347
349
  embedding_2d: NotRequired[list]
348
350
  anomaly_score: NotRequired[float]
349
- neighbor_label_logits: NotRequired[list[float]]
351
+ neighbor_label_logits: NotRequired[list[float] | None]
350
352
  neighbor_predicted_label: NotRequired[int | None]
351
353
  neighbor_predicted_label_ambiguity: NotRequired[float]
352
354
  neighbor_predicted_label_confidence: NotRequired[float]
@@ -1157,6 +1159,9 @@ class FieldValidationError(TypedDict):
1157
1159
 
1158
1160
  class AddMemoryRecommendations(TypedDict):
1159
1161
  memories: list[AddMemorySuggestion]
1162
+ attempts_used: NotRequired[int]
1163
+ partial: NotRequired[bool]
1164
+ rejection_counts: NotRequired[dict[str, int]]
1160
1165
 
1161
1166
 
1162
1167
  class BootstrapClassificationModelRequest(TypedDict):
@@ -1272,10 +1277,19 @@ class CreateClassificationModelRequest(TypedDict):
1272
1277
  num_classes: NotRequired[int | None]
1273
1278
 
1274
1279
 
1275
- class CreateMemorysetRequest(TypedDict):
1280
+ class CreateMemorysetFromDatasourceRequest(TypedDict):
1276
1281
  name: str
1277
1282
  description: NotRequired[str | None]
1278
1283
  notes: NotRequired[str | None]
1284
+ pretrained_embedding_model_name: NotRequired[PretrainedEmbeddingModelName | None]
1285
+ finetuned_embedding_model_name_or_id: NotRequired[str | None]
1286
+ max_seq_length_override: NotRequired[int | None]
1287
+ label_names: NotRequired[list[str] | None]
1288
+ index_type: NotRequired[Literal["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "DISKANN"]]
1289
+ index_params: NotRequired[dict[str, int | float | str]]
1290
+ prompt: NotRequired[str]
1291
+ hidden: NotRequired[bool]
1292
+ memory_type: NotRequired[MemoryType | None]
1279
1293
  datasource_name_or_id: str
1280
1294
  datasource_label_column: NotRequired[str | None]
1281
1295
  datasource_score_column: NotRequired[str | None]
@@ -1283,6 +1297,14 @@ class CreateMemorysetRequest(TypedDict):
1283
1297
  datasource_source_id_column: NotRequired[str | None]
1284
1298
  datasource_partition_id_column: NotRequired[str | None]
1285
1299
  remove_duplicates: NotRequired[bool]
1300
+ batch_size: NotRequired[int]
1301
+ subsample: NotRequired[int | float | None]
1302
+
1303
+
1304
+ class CreateMemorysetRequest(TypedDict):
1305
+ name: str
1306
+ description: NotRequired[str | None]
1307
+ notes: NotRequired[str | None]
1286
1308
  pretrained_embedding_model_name: NotRequired[PretrainedEmbeddingModelName | None]
1287
1309
  finetuned_embedding_model_name_or_id: NotRequired[str | None]
1288
1310
  max_seq_length_override: NotRequired[int | None]
@@ -1291,9 +1313,7 @@ class CreateMemorysetRequest(TypedDict):
1291
1313
  index_params: NotRequired[dict[str, int | float | str]]
1292
1314
  prompt: NotRequired[str]
1293
1315
  hidden: NotRequired[bool]
1294
- batch_size: NotRequired[int]
1295
- subsample: NotRequired[int | float | None]
1296
- memory_type: NotRequired[MemoryType]
1316
+ memory_type: NotRequired[MemoryType | None]
1297
1317
 
1298
1318
 
1299
1319
  class CreateRegressionModelRequest(TypedDict):
@@ -1653,8 +1673,8 @@ class MemorysetMetadata(TypedDict):
1653
1673
  created_at: str
1654
1674
  updated_at: str
1655
1675
  memories_updated_at: str
1656
- insertion_job_id: str
1657
- insertion_status: JobStatus
1676
+ insertion_job_id: str | None
1677
+ insertion_status: JobStatus | None
1658
1678
  metrics: MemorysetMetrics
1659
1679
  memory_type: MemoryType
1660
1680
  label_names: list[str] | None
@@ -1664,7 +1684,7 @@ class MemorysetMetadata(TypedDict):
1664
1684
  document_prompt_override: str | None
1665
1685
  query_prompt_override: str | None
1666
1686
  hidden: bool
1667
- insertion_task_id: str
1687
+ insertion_task_id: str | None
1668
1688
 
1669
1689
 
1670
1690
  class PaginatedWorkerInfo(TypedDict):
@@ -1735,6 +1755,22 @@ class OrcaAsyncClient(AsyncClient):
1735
1755
  ) -> AliveResponse:
1736
1756
  pass
1737
1757
 
1758
+ @overload
1759
+ async def GET(
1760
+ self,
1761
+ path: Literal["/gpu/check/alive"],
1762
+ *,
1763
+ params: None = None,
1764
+ parse_as: Literal["json"] = "json",
1765
+ headers: HeaderTypes | None = None,
1766
+ cookies: CookieTypes | None = None,
1767
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
1768
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
1769
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
1770
+ extensions: RequestExtensions | None = None,
1771
+ ) -> AliveResponse:
1772
+ pass
1773
+
1738
1774
  @overload
1739
1775
  async def GET(
1740
1776
  self,
@@ -1754,7 +1790,7 @@ class OrcaAsyncClient(AsyncClient):
1754
1790
  @overload
1755
1791
  async def GET(
1756
1792
  self,
1757
- path: Literal["/gpu/check/healthy"],
1793
+ path: Literal["/check/healthy"],
1758
1794
  *,
1759
1795
  params: None = None,
1760
1796
  parse_as: Literal["json"] = "json",
@@ -1770,7 +1806,7 @@ class OrcaAsyncClient(AsyncClient):
1770
1806
  @overload
1771
1807
  async def GET(
1772
1808
  self,
1773
- path: Literal["/check/healthy"],
1809
+ path: Literal["/gpu/check/healthy"],
1774
1810
  *,
1775
1811
  params: None = None,
1776
1812
  parse_as: Literal["json"] = "json",
@@ -2895,6 +2931,26 @@ class OrcaAsyncClient(AsyncClient):
2895
2931
  path: Literal["/memoryset"],
2896
2932
  *,
2897
2933
  params: None = None,
2934
+ json: CreateMemorysetFromDatasourceRequest,
2935
+ data: None = None,
2936
+ files: None = None,
2937
+ content: None = None,
2938
+ parse_as: Literal["json"] = "json",
2939
+ headers: HeaderTypes | None = None,
2940
+ cookies: CookieTypes | None = None,
2941
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
2942
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
2943
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
2944
+ extensions: RequestExtensions | None = None,
2945
+ ) -> MemorysetMetadata:
2946
+ pass
2947
+
2948
+ @overload
2949
+ async def POST(
2950
+ self,
2951
+ path: Literal["/memoryset/empty"],
2952
+ *,
2953
+ params: None = None,
2898
2954
  json: CreateMemorysetRequest,
2899
2955
  data: None = None,
2900
2956
  files: None = None,
@@ -114,13 +114,14 @@ class ClassificationModel:
114
114
  return isinstance(other, ClassificationModel) and self.id == other.id
115
115
 
116
116
  def __repr__(self):
117
+ memoryset_repr = self.memoryset.__repr__().replace("\n", "\n ")
117
118
  return (
118
119
  "ClassificationModel({\n"
119
120
  f" name: '{self.name}',\n"
120
121
  f" head_type: {self.head_type},\n"
121
122
  f" num_classes: {self.num_classes},\n"
122
123
  f" memory_lookup_count: {self.memory_lookup_count},\n"
123
- f" memoryset: LabeledMemoryset.open('{self.memoryset.name}'),\n"
124
+ f" memoryset: {memoryset_repr},\n"
124
125
  "})"
125
126
  )
126
127