orca-sdk 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. orca_sdk-0.1.2/PKG-INFO +97 -0
  2. orca_sdk-0.1.2/README.md +72 -0
  3. orca_sdk-0.1.2/orca_sdk/__init__.py +30 -0
  4. orca_sdk-0.1.2/orca_sdk/_shared/__init__.py +10 -0
  5. orca_sdk-0.1.2/orca_sdk/_shared/metrics.py +393 -0
  6. orca_sdk-0.1.2/orca_sdk/_shared/metrics_test.py +273 -0
  7. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/analysis_ui.py +13 -11
  8. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/analysis_ui_style.css +0 -3
  9. orca_sdk-0.1.2/orca_sdk/_utils/auth.py +61 -0
  10. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/auth_test.py +1 -1
  11. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/data_parsing.py +28 -2
  12. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/data_parsing_test.py +15 -15
  13. orca_sdk-0.1.2/orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk-0.1.2/orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk-0.1.2/orca_sdk/_utils/prediction_result_ui.py +110 -0
  16. orca_sdk-0.1.2/orca_sdk/_utils/tqdm_file_reader.py +12 -0
  17. orca_sdk-0.1.2/orca_sdk/_utils/value_parser.py +45 -0
  18. orca_sdk-0.1.2/orca_sdk/_utils/value_parser_test.py +39 -0
  19. orca_sdk-0.1.2/orca_sdk/classification_model.py +809 -0
  20. orca_sdk-0.1.2/orca_sdk/classification_model_test.py +496 -0
  21. orca_sdk-0.1.2/orca_sdk/client.py +3747 -0
  22. orca_sdk-0.1.2/orca_sdk/conftest.py +262 -0
  23. orca_sdk-0.1.2/orca_sdk/credentials.py +177 -0
  24. orca_sdk-0.1.0/orca_sdk/orca_credentials_test.py → orca_sdk-0.1.2/orca_sdk/credentials_test.py +21 -1
  25. orca_sdk-0.1.2/orca_sdk/datasource.py +524 -0
  26. orca_sdk-0.1.2/orca_sdk/datasource_test.py +337 -0
  27. orca_sdk-0.1.2/orca_sdk/embedding_model.py +690 -0
  28. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/embedding_model_test.py +40 -14
  29. orca_sdk-0.1.2/orca_sdk/job.py +337 -0
  30. orca_sdk-0.1.2/orca_sdk/job_test.py +108 -0
  31. orca_sdk-0.1.2/orca_sdk/memoryset.py +2190 -0
  32. orca_sdk-0.1.2/orca_sdk/memoryset_test.py +510 -0
  33. orca_sdk-0.1.2/orca_sdk/regression_model.py +684 -0
  34. orca_sdk-0.1.2/orca_sdk/regression_model_test.py +369 -0
  35. orca_sdk-0.1.2/orca_sdk/telemetry.py +692 -0
  36. orca_sdk-0.1.2/orca_sdk/telemetry_test.py +119 -0
  37. orca_sdk-0.1.2/pyproject.toml +84 -0
  38. orca_sdk-0.1.0/PKG-INFO +0 -39
  39. orca_sdk-0.1.0/README.md +0 -15
  40. orca_sdk-0.1.0/orca_sdk/__init__.py +0 -19
  41. orca_sdk-0.1.0/orca_sdk/_generated_api_client/__init__.py +0 -3
  42. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/__init__.py +0 -193
  43. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  44. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  45. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  46. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  47. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  48. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  49. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  50. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  51. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  52. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  53. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  54. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  55. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  56. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  57. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  58. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  59. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  60. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  61. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  62. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  63. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  64. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  65. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  66. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  67. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  68. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  69. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  70. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  71. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  72. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  73. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  74. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  75. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  76. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  77. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  78. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  79. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  80. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  81. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  82. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  83. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  84. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  85. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  86. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  87. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  88. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  89. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  90. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  91. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  92. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  93. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  94. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  95. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  96. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  97. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  98. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  99. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  100. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  101. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  102. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  103. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  104. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  105. orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  106. orca_sdk-0.1.0/orca_sdk/_generated_api_client/client.py +0 -216
  107. orca_sdk-0.1.0/orca_sdk/_generated_api_client/errors.py +0 -38
  108. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/__init__.py +0 -159
  109. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  110. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  111. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/base_model.py +0 -55
  112. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  113. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  114. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  115. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/column_info.py +0 -114
  116. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/column_type.py +0 -14
  117. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  118. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  119. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  120. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  121. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  122. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  123. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  124. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  125. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  126. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  127. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  128. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  129. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  130. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  131. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  132. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  133. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  134. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  135. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  136. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  137. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  138. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  139. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  140. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  141. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  142. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  143. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  144. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  145. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  146. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  147. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  148. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  149. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  150. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  151. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  152. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  153. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  154. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  155. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  156. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  157. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  158. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  159. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  160. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  161. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  162. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  163. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  164. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  165. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  166. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  167. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  168. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  169. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  170. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  171. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  172. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  173. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  174. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  175. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  176. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  177. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  178. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task.py +0 -198
  179. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task_status.py +0 -14
  180. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  181. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  182. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  183. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  184. orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  185. orca_sdk-0.1.0/orca_sdk/_generated_api_client/py.typed +0 -1
  186. orca_sdk-0.1.0/orca_sdk/_generated_api_client/types.py +0 -56
  187. orca_sdk-0.1.0/orca_sdk/_utils/__init__.py +0 -0
  188. orca_sdk-0.1.0/orca_sdk/_utils/auth.py +0 -63
  189. orca_sdk-0.1.0/orca_sdk/_utils/prediction_result_ui.py +0 -64
  190. orca_sdk-0.1.0/orca_sdk/_utils/task.py +0 -73
  191. orca_sdk-0.1.0/orca_sdk/classification_model.py +0 -499
  192. orca_sdk-0.1.0/orca_sdk/classification_model_test.py +0 -266
  193. orca_sdk-0.1.0/orca_sdk/conftest.py +0 -117
  194. orca_sdk-0.1.0/orca_sdk/datasource.py +0 -333
  195. orca_sdk-0.1.0/orca_sdk/datasource_test.py +0 -95
  196. orca_sdk-0.1.0/orca_sdk/embedding_model.py +0 -336
  197. orca_sdk-0.1.0/orca_sdk/labeled_memoryset.py +0 -1154
  198. orca_sdk-0.1.0/orca_sdk/labeled_memoryset_test.py +0 -271
  199. orca_sdk-0.1.0/orca_sdk/orca_credentials.py +0 -75
  200. orca_sdk-0.1.0/orca_sdk/telemetry.py +0 -386
  201. orca_sdk-0.1.0/orca_sdk/telemetry_test.py +0 -100
  202. orca_sdk-0.1.0/pyproject.toml +0 -71
  203. {orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth → orca_sdk-0.1.2/orca_sdk/_utils}/__init__.py +0 -0
  204. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/common.py +0 -0
  205. {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/prediction_result_ui.css +0 -0
@@ -0,0 +1,97 @@
1
+ Metadata-Version: 2.4
2
+ Name: orca_sdk
3
+ Version: 0.1.2
4
+ Summary: SDK for interacting with Orca Services
5
+ License-Expression: Apache-2.0
6
+ Author: Orca DB Inc.
7
+ Author-email: dev-rel@orcadb.ai
8
+ Requires-Python: >=3.11,<3.14
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Requires-Dist: datasets (>=3.1.0,<4)
14
+ Requires-Dist: gradio (>=5.44.1,<6)
15
+ Requires-Dist: httpx (>=0.28.1)
16
+ Requires-Dist: httpx-retries (>=0.4.3,<0.5.0)
17
+ Requires-Dist: numpy (>=2.1.0,<3)
18
+ Requires-Dist: pandas (>=2.2.3,<3)
19
+ Requires-Dist: pyarrow (>=18.0.0,<19)
20
+ Requires-Dist: python-dotenv (>=1.1.0)
21
+ Requires-Dist: scikit-learn (>=1.6.1,<2)
22
+ Requires-Dist: torch (>=2.8.0,<3)
23
+ Description-Content-Type: text/markdown
24
+
25
+ <!--
26
+ IMPORTANT NOTE:
27
+ - This file will get rendered in the public facing PyPi page here: https://pypi.org/project/orca_sdk/
28
+ - Only content suitable for public consumption should be placed in this file everything else should go into CONTRIBUTING.md
29
+ -->
30
+
31
+ # OrcaSDK
32
+
33
+ OrcaSDK is a Python library for building and using retrieval-augmented models with [OrcaCloud](https://orcadb.ai). It enables you to create, deploy, and maintain models that can adapt to changing circumstances without retraining by accessing external data called "memories."
34
+
35
+ ## Documentation
36
+
37
+ You can find the documentation for all things Orca at [docs.orcadb.ai](https://docs.orcadb.ai). This includes tutorials, how-to guides, and the full interface reference for OrcaSDK.
38
+
39
+ ## Features
40
+
41
+ - **Labeled Memorysets**: Store and manage labeled examples that your models can use to guide predictions
42
+ - **Classification Models**: Build retrieval-augmented classification models that adapt to new data without retraining
43
+ - **Embedding Models**: Use pre-trained or fine-tuned embedding models to represent your data
44
+ - **Telemetry**: Collect feedback and monitor memory usage to optimize model performance
45
+ - **Datasources**: Easily ingest data from various sources into your memorysets
46
+
47
+ ## Installation
48
+
49
+ OrcaSDK is compatible with Python 3.10 or higher and is available on [PyPI](https://pypi.org/project/orca_sdk/). You can install it with your favorite python package manager:
50
+
51
+ - Pip: `pip install orca_sdk`
52
+ - Conda: `conda install orca_sdk`
53
+ - Poetry: `poetry add orca_sdk`
54
+
55
+ ## Quick Start
56
+
57
+ ```python
58
+ from dotenv import load_dotenv
59
+ from orca_sdk import OrcaCredentials, LabeledMemoryset, ClassificationModel
60
+
61
+ # Load your API key from environment variables
62
+ load_dotenv()
63
+ assert OrcaCredentials.is_authenticated()
64
+
65
+ # Create a labeled memoryset
66
+ memoryset = LabeledMemoryset.from_disk("my_memoryset", "./data.jsonl")
67
+
68
+ # Create a classification model using the memoryset
69
+ model = ClassificationModel("my_model", memoryset)
70
+
71
+ # Make predictions
72
+ prediction = model.predict("my input")
73
+
74
+ # Get Action Recommendation
75
+ action, rationale = prediction.recommend_action()
76
+ print(f"Recommended action: {action}")
77
+ print(f"Rationale: {rationale}")
78
+
79
+ # Generate and add synthetic memory suggestions
80
+ if action == "add_memories":
81
+ suggestions = prediction.generate_memory_suggestions(num_memories=3)
82
+
83
+ # Review suggestions
84
+ for suggestion in suggestions:
85
+ print(f"Suggested: '{suggestion['value']}' -> {suggestion['label']}")
86
+
87
+ # Add suggestions to memoryset
88
+ model.memoryset.insert(suggestions)
89
+ print(f"Added {len(suggestions)} new memories to improve model performance!")
90
+ ```
91
+
92
+ For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart-sdk/).
93
+
94
+ ## Support
95
+
96
+ If you have any questions, please reach out to us at support@orcadb.ai.
97
+
@@ -0,0 +1,72 @@
1
+ <!--
2
+ IMPORTANT NOTE:
3
+ - This file will get rendered in the public facing PyPi page here: https://pypi.org/project/orca_sdk/
4
+ - Only content suitable for public consumption should be placed in this file everything else should go into CONTRIBUTING.md
5
+ -->
6
+
7
+ # OrcaSDK
8
+
9
+ OrcaSDK is a Python library for building and using retrieval-augmented models with [OrcaCloud](https://orcadb.ai). It enables you to create, deploy, and maintain models that can adapt to changing circumstances without retraining by accessing external data called "memories."
10
+
11
+ ## Documentation
12
+
13
+ You can find the documentation for all things Orca at [docs.orcadb.ai](https://docs.orcadb.ai). This includes tutorials, how-to guides, and the full interface reference for OrcaSDK.
14
+
15
+ ## Features
16
+
17
+ - **Labeled Memorysets**: Store and manage labeled examples that your models can use to guide predictions
18
+ - **Classification Models**: Build retrieval-augmented classification models that adapt to new data without retraining
19
+ - **Embedding Models**: Use pre-trained or fine-tuned embedding models to represent your data
20
+ - **Telemetry**: Collect feedback and monitor memory usage to optimize model performance
21
+ - **Datasources**: Easily ingest data from various sources into your memorysets
22
+
23
+ ## Installation
24
+
25
+ OrcaSDK is compatible with Python 3.10 or higher and is available on [PyPI](https://pypi.org/project/orca_sdk/). You can install it with your favorite python package manager:
26
+
27
+ - Pip: `pip install orca_sdk`
28
+ - Conda: `conda install orca_sdk`
29
+ - Poetry: `poetry add orca_sdk`
30
+
31
+ ## Quick Start
32
+
33
+ ```python
34
+ from dotenv import load_dotenv
35
+ from orca_sdk import OrcaCredentials, LabeledMemoryset, ClassificationModel
36
+
37
+ # Load your API key from environment variables
38
+ load_dotenv()
39
+ assert OrcaCredentials.is_authenticated()
40
+
41
+ # Create a labeled memoryset
42
+ memoryset = LabeledMemoryset.from_disk("my_memoryset", "./data.jsonl")
43
+
44
+ # Create a classification model using the memoryset
45
+ model = ClassificationModel("my_model", memoryset)
46
+
47
+ # Make predictions
48
+ prediction = model.predict("my input")
49
+
50
+ # Get Action Recommendation
51
+ action, rationale = prediction.recommend_action()
52
+ print(f"Recommended action: {action}")
53
+ print(f"Rationale: {rationale}")
54
+
55
+ # Generate and add synthetic memory suggestions
56
+ if action == "add_memories":
57
+ suggestions = prediction.generate_memory_suggestions(num_memories=3)
58
+
59
+ # Review suggestions
60
+ for suggestion in suggestions:
61
+ print(f"Suggested: '{suggestion['value']}' -> {suggestion['label']}")
62
+
63
+ # Add suggestions to memoryset
64
+ model.memoryset.insert(suggestions)
65
+ print(f"Added {len(suggestions)} new memories to improve model performance!")
66
+ ```
67
+
68
+ For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart-sdk/).
69
+
70
+ ## Support
71
+
72
+ If you have any questions, please reach out to us at support@orcadb.ai.
@@ -0,0 +1,30 @@
1
+ """
2
+ OrcaSDK is a Python library for building and using retrieval augmented models in the OrcaCloud.
3
+ """
4
+
5
+ from ._utils.common import UNSET, CreateMode, DropMode
6
+ from .classification_model import ClassificationMetrics, ClassificationModel
7
+ from .client import orca_api
8
+ from .credentials import OrcaCredentials
9
+ from .datasource import Datasource
10
+ from .embedding_model import (
11
+ FinetunedEmbeddingModel,
12
+ PretrainedEmbeddingModel,
13
+ PretrainedEmbeddingModelName,
14
+ )
15
+ from .job import Job, Status
16
+ from .memoryset import (
17
+ CascadingEditSuggestion,
18
+ FilterItemTuple,
19
+ LabeledMemory,
20
+ LabeledMemoryLookup,
21
+ LabeledMemoryset,
22
+ ScoredMemory,
23
+ ScoredMemoryLookup,
24
+ ScoredMemoryset,
25
+ )
26
+ from .regression_model import RegressionModel
27
+ from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
28
+
29
+ # only specify things that should show up on the root page of the reference docs because they are in private modules
30
+ __all__ = ["UNSET", "CreateMode", "DropMode"]
@@ -0,0 +1,10 @@
1
+ from .metrics import (
2
+ ClassificationMetrics,
3
+ PRCurve,
4
+ RegressionMetrics,
5
+ ROCCurve,
6
+ calculate_classification_metrics,
7
+ calculate_pr_curve,
8
+ calculate_regression_metrics,
9
+ calculate_roc_curve,
10
+ )
@@ -0,0 +1,393 @@
1
+ """
2
+ This module contains metrics for usage with the Hugging Face Trainer.
3
+
4
+ IMPORTANT:
5
+ - This is a shared file between OrcaLib and the OrcaSDK.
6
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
7
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
8
+
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any, Literal, TypedDict, cast
13
+
14
+ import numpy as np
15
+ import sklearn.metrics
16
+ from numpy.typing import NDArray
17
+
18
+
19
+ # we don't want to depend on scipy or torch in orca_sdk
20
+ def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
21
+ shifted = logits - np.max(logits, axis=axis, keepdims=True)
22
+ exps = np.exp(shifted)
23
+ return exps / np.sum(exps, axis=axis, keepdims=True)
24
+
25
+
26
+ # We don't want to depend on transformers just for the eval_pred type in orca_sdk
27
+ def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
28
+ # convert results from Trainer compute_metrics param for use in calculate_classification_metrics
29
+ logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
30
+ if isinstance(logits, tuple):
31
+ logits = logits[0]
32
+ if not isinstance(logits, np.ndarray):
33
+ raise ValueError("Logits must be a numpy array")
34
+ if not isinstance(references, np.ndarray):
35
+ raise ValueError(
36
+ "Multiple label columns found, use the `label_names` training argument to specify which one to use"
37
+ )
38
+
39
+ return (references, logits)
40
+
41
+
42
+ class PRCurve(TypedDict):
43
+ thresholds: list[float]
44
+ precisions: list[float]
45
+ recalls: list[float]
46
+
47
+
48
+ def calculate_pr_curve(
49
+ references: NDArray[np.int64],
50
+ probabilities: NDArray[np.float32],
51
+ max_length: int = 100,
52
+ ) -> PRCurve:
53
+ if probabilities.ndim == 1:
54
+ probabilities_slice = probabilities
55
+ elif probabilities.ndim == 2:
56
+ probabilities_slice = probabilities[:, 1]
57
+ else:
58
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
59
+
60
+ if len(probabilities_slice) != len(references):
61
+ raise ValueError("Probabilities and references must have the same length")
62
+
63
+ precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
64
+
65
+ # Convert all arrays to float32 immediately after getting them
66
+ precisions = precisions.astype(np.float32)
67
+ recalls = recalls.astype(np.float32)
68
+ thresholds = thresholds.astype(np.float32)
69
+
70
+ # Concatenate with 0 to include the lowest threshold
71
+ thresholds = np.concatenate(([0], thresholds))
72
+
73
+ # Sort by threshold
74
+ sorted_indices = np.argsort(thresholds)
75
+ thresholds = thresholds[sorted_indices]
76
+ precisions = precisions[sorted_indices]
77
+ recalls = recalls[sorted_indices]
78
+
79
+ if len(precisions) > max_length:
80
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
81
+ new_precisions = np.interp(new_thresholds, thresholds, precisions)
82
+ new_recalls = np.interp(new_thresholds, thresholds, recalls)
83
+ thresholds = new_thresholds
84
+ precisions = new_precisions
85
+ recalls = new_recalls
86
+
87
+ return PRCurve(
88
+ thresholds=cast(list[float], thresholds.tolist()),
89
+ precisions=cast(list[float], precisions.tolist()),
90
+ recalls=cast(list[float], recalls.tolist()),
91
+ )
92
+
93
+
94
+ class ROCCurve(TypedDict):
95
+ thresholds: list[float]
96
+ false_positive_rates: list[float]
97
+ true_positive_rates: list[float]
98
+
99
+
100
+ def calculate_roc_curve(
101
+ references: NDArray[np.int64],
102
+ probabilities: NDArray[np.float32],
103
+ max_length: int = 100,
104
+ ) -> ROCCurve:
105
+ if probabilities.ndim == 1:
106
+ probabilities_slice = probabilities
107
+ elif probabilities.ndim == 2:
108
+ probabilities_slice = probabilities[:, 1]
109
+ else:
110
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
111
+
112
+ if len(probabilities_slice) != len(references):
113
+ raise ValueError("Probabilities and references must have the same length")
114
+
115
+ # Convert probabilities to float32 before calling sklearn_roc_curve
116
+ probabilities_slice = probabilities_slice.astype(np.float32)
117
+ fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
118
+
119
+ # Convert all arrays to float32 immediately after getting them
120
+ fpr = fpr.astype(np.float32)
121
+ tpr = tpr.astype(np.float32)
122
+ thresholds = thresholds.astype(np.float32)
123
+
124
+ # We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
125
+ thresholds[0] = 1.0
126
+
127
+ # Sort by threshold
128
+ sorted_indices = np.argsort(thresholds)
129
+ thresholds = thresholds[sorted_indices]
130
+ fpr = fpr[sorted_indices]
131
+ tpr = tpr[sorted_indices]
132
+
133
+ if len(fpr) > max_length:
134
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
135
+ new_fpr = np.interp(new_thresholds, thresholds, fpr)
136
+ new_tpr = np.interp(new_thresholds, thresholds, tpr)
137
+ thresholds = new_thresholds
138
+ fpr = new_fpr
139
+ tpr = new_tpr
140
+
141
+ return ROCCurve(
142
+ false_positive_rates=cast(list[float], fpr.tolist()),
143
+ true_positive_rates=cast(list[float], tpr.tolist()),
144
+ thresholds=cast(list[float], thresholds.tolist()),
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class ClassificationMetrics:
150
+ coverage: float
151
+ """Percentage of predictions that are not none"""
152
+
153
+ f1_score: float
154
+ """F1 score of the predictions"""
155
+
156
+ accuracy: float
157
+ """Accuracy of the predictions"""
158
+
159
+ loss: float | None
160
+ """Cross-entropy loss of the logits"""
161
+
162
+ anomaly_score_mean: float | None = None
163
+ """Mean of anomaly scores across the dataset"""
164
+
165
+ anomaly_score_median: float | None = None
166
+ """Median of anomaly scores across the dataset"""
167
+
168
+ anomaly_score_variance: float | None = None
169
+ """Variance of anomaly scores across the dataset"""
170
+
171
+ roc_auc: float | None = None
172
+ """Receiver operating characteristic area under the curve"""
173
+
174
+ pr_auc: float | None = None
175
+ """Average precision (area under the curve of the precision-recall curve)"""
176
+
177
+ pr_curve: PRCurve | None = None
178
+ """Precision-recall curve"""
179
+
180
+ roc_curve: ROCCurve | None = None
181
+ """Receiver operating characteristic curve"""
182
+
183
+ def __repr__(self) -> str:
184
+ return (
185
+ "ClassificationMetrics({\n"
186
+ + f" accuracy: {self.accuracy:.4f},\n"
187
+ + f" f1_score: {self.f1_score:.4f},\n"
188
+ + (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
189
+ + (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
190
+ + (
191
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
192
+ if self.anomaly_score_mean
193
+ else ""
194
+ )
195
+ + "})"
196
+ )
197
+
198
+
199
+ def calculate_classification_metrics(
200
+ expected_labels: list[int] | NDArray[np.int64],
201
+ logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
202
+ anomaly_scores: list[float] | None = None,
203
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
204
+ multi_class: Literal["ovr", "ovo"] = "ovr",
205
+ include_curves: bool = False,
206
+ ) -> ClassificationMetrics:
207
+ references = np.array(expected_labels)
208
+
209
+ logits = np.array(logits)
210
+ if logits.ndim == 1:
211
+ if (logits > 1).any() or (logits < 0).any():
212
+ raise ValueError("Logits must be between 0 and 1 for binary classification")
213
+ # convert 1D probabilities (binary) to 2D logits
214
+ logits = np.column_stack([1 - logits, logits])
215
+ probabilities = logits # no need to convert to probabilities
216
+ elif logits.ndim == 2:
217
+ if logits.shape[1] < 2:
218
+ raise ValueError("Use a different metric function for regression tasks")
219
+ if not (logits > 0).all():
220
+ # convert logits to probabilities with softmax if necessary
221
+ probabilities = softmax(logits)
222
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
223
+ # convert logits to probabilities through normalization if necessary
224
+ probabilities = logits / logits.sum(-1, keepdims=True)
225
+ else:
226
+ probabilities = logits
227
+ else:
228
+ raise ValueError("Logits must be 1 or 2 dimensional")
229
+
230
+ predictions = np.argmax(probabilities, axis=-1)
231
+ predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
232
+
233
+ num_classes_references = len(set(references))
234
+ num_classes_predictions = len(set(predictions))
235
+ num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
236
+ coverage = 1 - num_none_predictions / len(probabilities)
237
+
238
+ if average is None:
239
+ average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
240
+
241
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
242
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
243
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
244
+
245
+ accuracy = sklearn.metrics.accuracy_score(references, predictions)
246
+ f1 = sklearn.metrics.f1_score(references, predictions, average=average)
247
+ # Ensure sklearn sees the full class set corresponding to probability columns
248
+ # to avoid errors when y_true does not contain all classes.
249
+ loss = (
250
+ sklearn.metrics.log_loss(
251
+ references,
252
+ probabilities,
253
+ labels=list(range(probabilities.shape[1])),
254
+ )
255
+ if num_none_predictions == 0
256
+ else None
257
+ )
258
+
259
+ if num_classes_references == num_classes_predictions and num_none_predictions == 0:
260
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
261
+ if num_classes_references == 2:
262
+ roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
263
+ roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
264
+ pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
265
+ pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
266
+ else:
267
+ roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
268
+ roc_curve = None
269
+ pr_auc = None
270
+ pr_curve = None
271
+ else:
272
+ roc_auc = None
273
+ pr_auc = None
274
+ pr_curve = None
275
+ roc_curve = None
276
+
277
+ return ClassificationMetrics(
278
+ coverage=coverage,
279
+ accuracy=float(accuracy),
280
+ f1_score=float(f1),
281
+ loss=float(loss) if loss is not None else None,
282
+ anomaly_score_mean=anomaly_score_mean,
283
+ anomaly_score_median=anomaly_score_median,
284
+ anomaly_score_variance=anomaly_score_variance,
285
+ roc_auc=float(roc_auc) if roc_auc is not None else None,
286
+ pr_auc=float(pr_auc) if pr_auc is not None else None,
287
+ pr_curve=pr_curve,
288
+ roc_curve=roc_curve,
289
+ )
290
+
291
+
292
+ @dataclass
293
+ class RegressionMetrics:
294
+ coverage: float
295
+ """Percentage of predictions that are not none"""
296
+
297
+ mse: float
298
+ """Mean squared error of the predictions"""
299
+
300
+ rmse: float
301
+ """Root mean squared error of the predictions"""
302
+
303
+ mae: float
304
+ """Mean absolute error of the predictions"""
305
+
306
+ r2: float
307
+ """R-squared score (coefficient of determination) of the predictions"""
308
+
309
+ explained_variance: float
310
+ """Explained variance score of the predictions"""
311
+
312
+ loss: float
313
+ """Mean squared error loss of the predictions"""
314
+
315
+ anomaly_score_mean: float | None = None
316
+ """Mean of anomaly scores across the dataset"""
317
+
318
+ anomaly_score_median: float | None = None
319
+ """Median of anomaly scores across the dataset"""
320
+
321
+ anomaly_score_variance: float | None = None
322
+ """Variance of anomaly scores across the dataset"""
323
+
324
+ def __repr__(self) -> str:
325
+ return (
326
+ "RegressionMetrics({\n"
327
+ + f" mae: {self.mae:.4f},\n"
328
+ + f" rmse: {self.rmse:.4f},\n"
329
+ + f" r2: {self.r2:.4f},\n"
330
+ + (
331
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
332
+ if self.anomaly_score_mean
333
+ else ""
334
+ )
335
+ + "})"
336
+ )
337
+
338
+
339
+ def calculate_regression_metrics(
340
+ expected_scores: NDArray[np.float32] | list[float],
341
+ predicted_scores: NDArray[np.float32] | list[float],
342
+ anomaly_scores: list[float] | None = None,
343
+ ) -> RegressionMetrics:
344
+ """
345
+ Calculate regression metrics for model evaluation.
346
+
347
+ Params:
348
+ references: True target values
349
+ predictions: Predicted values from the model
350
+ anomaly_scores: Optional anomaly scores for each prediction
351
+
352
+ Returns:
353
+ Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
354
+
355
+ Raises:
356
+ ValueError: If predictions and references have different lengths
357
+ """
358
+ references = np.array(expected_scores)
359
+ predictions = np.array(predicted_scores)
360
+
361
+ if len(predictions) != len(references):
362
+ raise ValueError("Predictions and references must have the same length")
363
+
364
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
365
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
366
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
367
+
368
+ none_prediction_mask = np.isnan(predictions)
369
+ num_none_predictions = none_prediction_mask.sum()
370
+ coverage = 1 - num_none_predictions / len(predictions)
371
+ if num_none_predictions > 0:
372
+ references = references[~none_prediction_mask]
373
+ predictions = predictions[~none_prediction_mask]
374
+
375
+ # Calculate core regression metrics
376
+ mse = float(sklearn.metrics.mean_squared_error(references, predictions))
377
+ rmse = float(np.sqrt(mse))
378
+ mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
379
+ r2 = float(sklearn.metrics.r2_score(references, predictions))
380
+ explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
381
+
382
+ return RegressionMetrics(
383
+ coverage=coverage,
384
+ mse=mse,
385
+ rmse=rmse,
386
+ mae=mae,
387
+ r2=r2,
388
+ explained_variance=explained_var,
389
+ loss=mse, # For regression, loss is typically MSE
390
+ anomaly_score_mean=anomaly_score_mean,
391
+ anomaly_score_median=anomaly_score_median,
392
+ anomaly_score_variance=anomaly_score_variance,
393
+ )