orca-sdk 0.0.91__tar.gz → 0.0.92__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/PKG-INFO +3 -1
  2. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
  3. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
  4. orca_sdk-0.0.92/orca_sdk/_shared/__init__.py +1 -0
  5. orca_sdk-0.0.92/orca_sdk/_shared/metrics.py +195 -0
  6. orca_sdk-0.0.92/orca_sdk/_shared/metrics_test.py +169 -0
  7. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/classification_model.py +144 -19
  8. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/classification_model_test.py +49 -22
  9. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/telemetry.py +3 -0
  10. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/pyproject.toml +3 -1
  11. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/README.md +0 -0
  12. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/__init__.py +0 -0
  13. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/__init__.py +0 -0
  14. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/__init__.py +0 -0
  15. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  16. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -0
  17. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -0
  18. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -0
  19. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -0
  20. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -0
  21. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  22. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -0
  23. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -0
  24. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -0
  25. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -0
  26. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -0
  27. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -0
  28. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -0
  29. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -0
  30. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -0
  31. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/update_model_classification_model_name_or_id_patch.py +0 -0
  32. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  33. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -0
  34. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -0
  35. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -0
  36. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -0
  37. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -0
  38. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -0
  39. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -0
  40. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  41. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -0
  42. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -0
  43. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  44. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -0
  45. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -0
  46. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -0
  47. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -0
  48. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -0
  49. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  50. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -0
  51. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -0
  52. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -0
  53. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -0
  54. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -0
  55. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -0
  56. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -0
  57. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -0
  58. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -0
  59. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -0
  60. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -0
  61. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -0
  62. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -0
  63. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -0
  64. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -0
  65. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -0
  66. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -0
  67. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -0
  68. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -0
  69. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -0
  70. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  71. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -0
  72. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -0
  73. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -0
  74. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  75. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -0
  76. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -0
  77. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -0
  78. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  79. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -0
  80. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -0
  81. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -0
  82. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -0
  83. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -0
  84. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -0
  85. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -0
  86. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -0
  87. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -0
  88. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -0
  89. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/client.py +0 -0
  90. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/errors.py +0 -0
  91. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/__init__.py +0 -0
  92. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -0
  93. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -0
  94. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -0
  95. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/base_model.py +0 -0
  96. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -0
  97. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -0
  98. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -0
  99. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -0
  100. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/column_info.py +0 -0
  101. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/column_type.py +0 -0
  102. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -0
  103. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -0
  104. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -0
  105. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -0
  106. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -0
  107. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -0
  108. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -0
  109. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -0
  110. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -0
  111. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -0
  112. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embed_request.py +0 -0
  113. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -0
  114. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -0
  115. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -0
  116. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -0
  117. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -0
  118. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/evaluation_request.py +0 -0
  119. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/evaluation_response.py +0 -0
  120. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -0
  121. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/feedback_type.py +0 -0
  122. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item.py +0 -0
  123. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -0
  124. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -0
  125. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_op.py +0 -0
  126. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -0
  127. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -0
  128. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -0
  129. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/get_memories_request.py +0 -0
  130. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -0
  131. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -0
  132. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -0
  133. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -0
  134. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
  135. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory.py +0 -0
  136. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -0
  137. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -0
  138. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -0
  139. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -0
  140. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -0
  141. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -0
  142. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -0
  143. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -0
  144. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -0
  145. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -0
  146. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -0
  147. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memoryset_update.py +0 -0
  148. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/list_memories_request.py +0 -0
  149. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -0
  150. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/lookup_request.py +0 -0
  151. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -0
  152. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memory_metrics.py +0 -0
  153. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -0
  154. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -0
  155. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -0
  156. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -0
  157. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -0
  158. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -0
  159. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -0
  160. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -0
  161. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -0
  162. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -0
  163. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -0
  164. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -0
  165. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -0
  166. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -0
  167. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -0
  168. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -0
  169. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -0
  170. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -0
  171. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -0
  172. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/paginated_labeled_memory_with_feedback_metrics.py +0 -0
  173. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/precision_recall_curve.py +0 -0
  174. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -0
  175. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -0
  176. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -0
  177. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -0
  178. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_request.py +0 -0
  179. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -0
  180. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -0
  181. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -0
  182. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -0
  183. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_head_type.py +0 -0
  184. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -0
  185. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_model_update.py +0 -0
  186. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/roc_curve.py +0 -0
  187. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -0
  188. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task.py +0 -0
  189. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task_status.py +0 -0
  190. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task_status_info.py +0 -0
  191. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -0
  192. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -0
  193. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -0
  194. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -0
  195. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -0
  196. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -0
  197. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -0
  198. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -0
  199. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -0
  200. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/py.typed +0 -0
  201. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/types.py +0 -0
  202. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/__init__.py +0 -0
  203. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/analysis_ui.py +0 -0
  204. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/analysis_ui_style.css +0 -0
  205. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/auth.py +0 -0
  206. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/auth_test.py +0 -0
  207. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/common.py +0 -0
  208. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/data_parsing.py +0 -0
  209. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/data_parsing_test.py +0 -0
  210. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/prediction_result_ui.css +0 -0
  211. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/prediction_result_ui.py +0 -0
  212. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/task.py +0 -0
  213. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/value_parser.py +0 -0
  214. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/value_parser_test.py +0 -0
  215. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/conftest.py +0 -0
  216. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/credentials.py +0 -0
  217. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/credentials_test.py +0 -0
  218. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/datasource.py +0 -0
  219. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/datasource_test.py +0 -0
  220. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/embedding_model.py +0 -0
  221. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/embedding_model_test.py +0 -0
  222. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/memoryset.py +0 -0
  223. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/memoryset_test.py +0 -0
  224. {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/telemetry_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: orca_sdk
3
- Version: 0.0.91
3
+ Version: 0.0.92
4
4
  Summary: SDK for interacting with Orca Services
5
5
  License: Apache-2.0
6
6
  Author: Orca DB Inc.
@@ -20,7 +20,9 @@ Requires-Dist: pandas (>=2.2.3,<3.0.0)
20
20
  Requires-Dist: pyarrow (>=18.0.0,<19.0.0)
21
21
  Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
22
22
  Requires-Dist: python-dotenv (>=1.1.0,<2.0.0)
23
+ Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
23
24
  Requires-Dist: torch (>=2.5.1,<3.0.0)
25
+ Requires-Dist: transformers (>=4.51.3,<5.0.0)
24
26
  Description-Content-Type: text/markdown
25
27
 
26
28
  <!--
@@ -10,7 +10,7 @@ The main change is:
10
10
 
11
11
  # flake8: noqa: C901
12
12
 
13
- from typing import Any, Type, TypeVar, Union, cast
13
+ from typing import Any, List, Type, TypeVar, Union, cast
14
14
 
15
15
  from attrs import define as _attrs_define
16
16
  from attrs import field as _attrs_field
@@ -28,6 +28,7 @@ class BaseLabelPredictionResult:
28
28
  anomaly_score (Union[None, float]):
29
29
  label (int):
30
30
  label_name (Union[None, str]):
31
+ logits (List[float]):
31
32
  """
32
33
 
33
34
  prediction_id: Union[None, str]
@@ -35,6 +36,7 @@ class BaseLabelPredictionResult:
35
36
  anomaly_score: Union[None, float]
36
37
  label: int
37
38
  label_name: Union[None, str]
39
+ logits: List[float]
38
40
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
39
41
 
40
42
  def to_dict(self) -> dict[str, Any]:
@@ -51,6 +53,8 @@ class BaseLabelPredictionResult:
51
53
  label_name: Union[None, str]
52
54
  label_name = self.label_name
53
55
 
56
+ logits = self.logits
57
+
54
58
  field_dict: dict[str, Any] = {}
55
59
  field_dict.update(self.additional_properties)
56
60
  field_dict.update(
@@ -60,6 +64,7 @@ class BaseLabelPredictionResult:
60
64
  "anomaly_score": anomaly_score,
61
65
  "label": label,
62
66
  "label_name": label_name,
67
+ "logits": logits,
63
68
  }
64
69
  )
65
70
 
@@ -94,12 +99,15 @@ class BaseLabelPredictionResult:
94
99
 
95
100
  label_name = _parse_label_name(d.pop("label_name"))
96
101
 
102
+ logits = cast(List[float], d.pop("logits"))
103
+
97
104
  base_label_prediction_result = cls(
98
105
  prediction_id=prediction_id,
99
106
  confidence=confidence,
100
107
  anomaly_score=anomaly_score,
101
108
  label=label,
102
109
  label_name=label_name,
110
+ logits=logits,
103
111
  )
104
112
 
105
113
  base_label_prediction_result.additional_properties = d
@@ -43,6 +43,7 @@ class LabeledMemorysetMetadata:
43
43
  label_names (List[str]):
44
44
  created_at (datetime.datetime):
45
45
  updated_at (datetime.datetime):
46
+ memories_updated_at (datetime.datetime):
46
47
  insertion_task_id (str):
47
48
  insertion_status (TaskStatus): Status of task in the task queue
48
49
  metrics (MemorysetMetrics):
@@ -59,6 +60,7 @@ class LabeledMemorysetMetadata:
59
60
  label_names: List[str]
60
61
  created_at: datetime.datetime
61
62
  updated_at: datetime.datetime
63
+ memories_updated_at: datetime.datetime
62
64
  insertion_task_id: str
63
65
  insertion_status: TaskStatus
64
66
  metrics: "MemorysetMetrics"
@@ -97,6 +99,8 @@ class LabeledMemorysetMetadata:
97
99
 
98
100
  updated_at = self.updated_at.isoformat()
99
101
 
102
+ memories_updated_at = self.memories_updated_at.isoformat()
103
+
100
104
  insertion_task_id = self.insertion_task_id
101
105
 
102
106
  insertion_status = (
@@ -120,6 +124,7 @@ class LabeledMemorysetMetadata:
120
124
  "label_names": label_names,
121
125
  "created_at": created_at,
122
126
  "updated_at": updated_at,
127
+ "memories_updated_at": memories_updated_at,
123
128
  "insertion_task_id": insertion_task_id,
124
129
  "insertion_status": insertion_status,
125
130
  "metrics": metrics,
@@ -180,6 +185,8 @@ class LabeledMemorysetMetadata:
180
185
 
181
186
  updated_at = isoparse(d.pop("updated_at"))
182
187
 
188
+ memories_updated_at = isoparse(d.pop("memories_updated_at"))
189
+
183
190
  insertion_task_id = d.pop("insertion_task_id")
184
191
 
185
192
  insertion_status = TaskStatus(d.pop("insertion_status"))
@@ -198,6 +205,7 @@ class LabeledMemorysetMetadata:
198
205
  label_names=label_names,
199
206
  created_at=created_at,
200
207
  updated_at=updated_at,
208
+ memories_updated_at=memories_updated_at,
201
209
  insertion_task_id=insertion_task_id,
202
210
  insertion_status=insertion_status,
203
211
  metrics=metrics,
@@ -0,0 +1 @@
1
+ from .metrics import calculate_pr_curve, calculate_roc_curve, compute_classifier_metrics
@@ -0,0 +1,195 @@
1
+ """
2
+ This module contains metrics for usage with the Hugging Face Trainer.
3
+
4
+ IMPORTANT:
5
+ - This is a shared file between OrcaLib and the Orca SDK.
6
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
7
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
8
+
9
+ """
10
+
11
+ from typing import Literal, Tuple, TypedDict
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+ from scipy.special import softmax
16
+ from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
17
+ from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
18
+ from sklearn.metrics import roc_auc_score
19
+ from sklearn.metrics import roc_curve as sklearn_roc_curve
20
+ from transformers.trainer_utils import EvalPrediction
21
+
22
+
23
+ class ClassificationMetrics(TypedDict):
24
+ accuracy: float
25
+ f1_score: float
26
+ roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
27
+ pr_auc: float | None # precision-recall area under the curve (only for binary classification)
28
+ log_loss: float # cross-entropy loss for probabilities
29
+
30
+
31
+ def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
32
+ """
33
+ Compute standard metrics for classifier with Hugging Face Trainer.
34
+
35
+ Args:
36
+ eval_pred: The predictions containing logits and expected labels as given by the Trainer.
37
+
38
+ Returns:
39
+ A dictionary containing the accuracy, f1 score, and ROC AUC score.
40
+ """
41
+ logits, references = eval_pred
42
+ if isinstance(logits, tuple):
43
+ logits = logits[0]
44
+ if not isinstance(logits, np.ndarray):
45
+ raise ValueError("Logits must be a numpy array")
46
+ if not isinstance(references, np.ndarray):
47
+ raise ValueError(
48
+ "Multiple label columns found, use the `label_names` training argument to specify which one to use"
49
+ )
50
+
51
+ if not (logits > 0).all():
52
+ # convert logits to probabilities with softmax if necessary
53
+ probabilities = softmax(logits)
54
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
55
+ # convert logits to probabilities through normalization if necessary
56
+ probabilities = logits / logits.sum(-1, keepdims=True)
57
+ else:
58
+ probabilities = logits
59
+
60
+ return classification_scores(references, probabilities)
61
+
62
+
63
+ def classification_scores(
64
+ references: NDArray[np.int64],
65
+ probabilities: NDArray[np.float32],
66
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
67
+ multi_class: Literal["ovr", "ovo"] = "ovr",
68
+ ) -> ClassificationMetrics:
69
+ if probabilities.ndim == 1:
70
+ # convert 1D probabilities (binary) to 2D logits
71
+ probabilities = np.column_stack([1 - probabilities, probabilities])
72
+ elif probabilities.ndim == 2:
73
+ if probabilities.shape[1] < 2:
74
+ raise ValueError("Use a different metric function for regression tasks")
75
+ else:
76
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
77
+
78
+ predictions = np.argmax(probabilities, axis=-1)
79
+
80
+ num_classes_references = len(set(references))
81
+ num_classes_predictions = len(set(predictions))
82
+
83
+ if average is None:
84
+ average = "binary" if num_classes_references == 2 else "weighted"
85
+
86
+ accuracy = accuracy_score(references, predictions)
87
+ f1 = f1_score(references, predictions, average=average)
88
+ loss = log_loss(references, probabilities)
89
+
90
+ if num_classes_references == num_classes_predictions:
91
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
92
+ if num_classes_references == 2:
93
+ roc_auc = roc_auc_score(references, probabilities[:, 1])
94
+ precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
95
+ pr_auc = auc(recalls, precisions)
96
+ else:
97
+ roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
98
+ pr_auc = None
99
+ else:
100
+ roc_auc = None
101
+ pr_auc = None
102
+
103
+ return {
104
+ "accuracy": float(accuracy),
105
+ "f1_score": float(f1),
106
+ "roc_auc": float(roc_auc) if roc_auc is not None else None,
107
+ "pr_auc": float(pr_auc) if pr_auc is not None else None,
108
+ "log_loss": float(loss),
109
+ }
110
+
111
+
112
+ def calculate_pr_curve(
113
+ references: NDArray[np.int64],
114
+ probabilities: NDArray[np.float32],
115
+ max_length: int = 100,
116
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
117
+ if probabilities.ndim == 1:
118
+ probabilities_slice = probabilities
119
+ elif probabilities.ndim == 2:
120
+ probabilities_slice = probabilities[:, 1]
121
+ else:
122
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
123
+
124
+ if len(probabilities_slice) != len(references):
125
+ raise ValueError("Probabilities and references must have the same length")
126
+
127
+ precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
128
+
129
+ # Convert all arrays to float32 immediately after getting them
130
+ precisions = precisions.astype(np.float32)
131
+ recalls = recalls.astype(np.float32)
132
+ thresholds = thresholds.astype(np.float32)
133
+
134
+ # Concatenate with 0 to include the lowest threshold
135
+ thresholds = np.concatenate(([0], thresholds))
136
+
137
+ # Sort by threshold
138
+ sorted_indices = np.argsort(thresholds)
139
+ thresholds = thresholds[sorted_indices]
140
+ precisions = precisions[sorted_indices]
141
+ recalls = recalls[sorted_indices]
142
+
143
+ if len(precisions) > max_length:
144
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
145
+ new_precisions = np.interp(new_thresholds, thresholds, precisions)
146
+ new_recalls = np.interp(new_thresholds, thresholds, recalls)
147
+ thresholds = new_thresholds
148
+ precisions = new_precisions
149
+ recalls = new_recalls
150
+
151
+ return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
152
+
153
+
154
+ def calculate_roc_curve(
155
+ references: NDArray[np.int64],
156
+ probabilities: NDArray[np.float32],
157
+ max_length: int = 100,
158
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
159
+ if probabilities.ndim == 1:
160
+ probabilities_slice = probabilities
161
+ elif probabilities.ndim == 2:
162
+ probabilities_slice = probabilities[:, 1]
163
+ else:
164
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
165
+
166
+ if len(probabilities_slice) != len(references):
167
+ raise ValueError("Probabilities and references must have the same length")
168
+
169
+ # Convert probabilities to float32 before calling sklearn_roc_curve
170
+ probabilities_slice = probabilities_slice.astype(np.float32)
171
+ fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
172
+
173
+ # Convert all arrays to float32 immediately after getting them
174
+ fpr = fpr.astype(np.float32)
175
+ tpr = tpr.astype(np.float32)
176
+ thresholds = thresholds.astype(np.float32)
177
+
178
+ # We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
179
+ thresholds[0] = 1.0
180
+
181
+ # Sort by threshold
182
+ sorted_indices = np.argsort(thresholds)
183
+ thresholds = thresholds[sorted_indices]
184
+ fpr = fpr[sorted_indices]
185
+ tpr = tpr[sorted_indices]
186
+
187
+ if len(fpr) > max_length:
188
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
189
+ new_fpr = np.interp(new_thresholds, thresholds, fpr)
190
+ new_tpr = np.interp(new_thresholds, thresholds, tpr)
191
+ thresholds = new_thresholds
192
+ fpr = new_fpr
193
+ tpr = new_tpr
194
+
195
+ return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
@@ -0,0 +1,169 @@
1
+ """
2
+ IMPORTANT:
3
+ - This is a shared file between OrcaLib and the Orca SDK.
4
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
5
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
+ """
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ import pytest
12
+
13
+ from .metrics import (
14
+ EvalPrediction,
15
+ calculate_pr_curve,
16
+ calculate_roc_curve,
17
+ classification_scores,
18
+ compute_classifier_metrics,
19
+ softmax,
20
+ )
21
+
22
+
23
+ def test_binary_metrics():
24
+ y_true = np.array([0, 1, 1, 0, 1])
25
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
+
27
+ metrics = classification_scores(y_true, y_score)
28
+
29
+ assert metrics["accuracy"] == 0.8
30
+ assert metrics["f1_score"] == 0.8
31
+ assert metrics["roc_auc"] is not None
32
+ assert metrics["roc_auc"] > 0.8
33
+ assert metrics["roc_auc"] < 1.0
34
+ assert metrics["pr_auc"] is not None
35
+ assert metrics["pr_auc"] > 0.8
36
+ assert metrics["pr_auc"] < 1.0
37
+ assert metrics["log_loss"] is not None
38
+ assert metrics["log_loss"] > 0.0
39
+
40
+
41
+ def test_multiclass_metrics_with_2_classes():
42
+ y_true = np.array([0, 1, 1, 0, 1])
43
+ y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
+
45
+ metrics = classification_scores(y_true, y_score)
46
+
47
+ assert metrics["accuracy"] == 0.8
48
+ assert metrics["f1_score"] == 0.8
49
+ assert metrics["roc_auc"] is not None
50
+ assert metrics["roc_auc"] > 0.8
51
+ assert metrics["roc_auc"] < 1.0
52
+ assert metrics["pr_auc"] is not None
53
+ assert metrics["pr_auc"] > 0.8
54
+ assert metrics["pr_auc"] < 1.0
55
+ assert metrics["log_loss"] is not None
56
+ assert metrics["log_loss"] > 0.0
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "average, multiclass",
61
+ [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
+ )
63
+ def test_multiclass_metrics_with_3_classes(
64
+ average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
+ ):
66
+ y_true = np.array([0, 1, 1, 0, 2])
67
+ y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
+
69
+ metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
70
+
71
+ assert metrics["accuracy"] == 1.0
72
+ assert metrics["f1_score"] == 1.0
73
+ assert metrics["roc_auc"] is not None
74
+ assert metrics["roc_auc"] > 0.8
75
+ assert metrics["pr_auc"] is None
76
+ assert metrics["log_loss"] is not None
77
+ assert metrics["log_loss"] > 0.0
78
+
79
+
80
+ def test_does_not_modify_logits_unless_necessary():
81
+ logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
+ references = np.array([0, 1, 0, 1])
83
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
84
+ assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
85
+
86
+
87
+ def test_normalizes_logits_if_necessary():
88
+ logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
89
+ references = np.array([0, 1, 0, 1])
90
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
91
+ assert (
92
+ metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
93
+ )
94
+
95
+
96
+ def test_softmaxes_logits_if_necessary():
97
+ logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
+ references = np.array([0, 1, 0, 1])
99
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
100
+ assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
101
+
102
+
103
+ def test_precision_recall_curve():
104
+ y_true = np.array([0, 1, 1, 0, 1])
105
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
106
+
107
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
108
+ assert precision is not None
109
+ assert recall is not None
110
+ assert thresholds is not None
111
+
112
+ assert len(precision) == len(recall) == len(thresholds) == 6
113
+ assert precision[0] == 0.6
114
+ assert recall[0] == 1.0
115
+ assert precision[-1] == 1.0
116
+ assert recall[-1] == 0.0
117
+
118
+ # test that thresholds are sorted
119
+ assert np.all(np.diff(thresholds) >= 0)
120
+
121
+
122
+ def test_roc_curve():
123
+ y_true = np.array([0, 1, 1, 0, 1])
124
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
125
+
126
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
127
+ assert fpr is not None
128
+ assert tpr is not None
129
+ assert thresholds is not None
130
+
131
+ assert len(fpr) == len(tpr) == len(thresholds) == 6
132
+ assert fpr[0] == 1.0
133
+ assert tpr[0] == 1.0
134
+ assert fpr[-1] == 0.0
135
+ assert tpr[-1] == 0.0
136
+
137
+ # test that thresholds are sorted
138
+ assert np.all(np.diff(thresholds) >= 0)
139
+
140
+
141
+ def test_precision_recall_curve_max_length():
142
+ y_true = np.array([0, 1, 1, 0, 1])
143
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
144
+
145
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
146
+ assert len(precision) == len(recall) == len(thresholds) == 5
147
+
148
+ assert precision[0] == 0.6
149
+ assert recall[0] == 1.0
150
+ assert precision[-1] == 1.0
151
+ assert recall[-1] == 0.0
152
+
153
+ # test that thresholds are sorted
154
+ assert np.all(np.diff(thresholds) >= 0)
155
+
156
+
157
+ def test_roc_curve_max_length():
158
+ y_true = np.array([0, 1, 1, 0, 1])
159
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
160
+
161
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
162
+ assert len(fpr) == len(tpr) == len(thresholds) == 5
163
+ assert fpr[0] == 1.0
164
+ assert tpr[0] == 1.0
165
+ assert fpr[-1] == 0.0
166
+ assert tpr[-1] == 0.0
167
+
168
+ # test that thresholds are sorted
169
+ assert np.all(np.diff(thresholds) >= 0)