orca-sdk 0.0.97__py3-none-any.whl → 0.0.98__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. orca_sdk/__init__.py +1 -0
  2. orca_sdk/_shared/__init__.py +1 -0
  3. orca_sdk/_utils/analysis_ui.py +5 -5
  4. orca_sdk/_utils/auth.py +23 -33
  5. orca_sdk/_utils/pagination.py +126 -0
  6. orca_sdk/_utils/pagination_test.py +132 -0
  7. orca_sdk/classification_model.py +188 -126
  8. orca_sdk/classification_model_test.py +57 -8
  9. orca_sdk/client.py +3515 -0
  10. orca_sdk/conftest.py +10 -0
  11. orca_sdk/credentials.py +59 -21
  12. orca_sdk/credentials_test.py +20 -0
  13. orca_sdk/datasource.py +42 -76
  14. orca_sdk/embedding_model.py +225 -71
  15. orca_sdk/embedding_model_test.py +27 -36
  16. orca_sdk/job.py +49 -45
  17. orca_sdk/job_test.py +16 -0
  18. orca_sdk/memoryset.py +340 -353
  19. orca_sdk/memoryset_test.py +7 -11
  20. orca_sdk/regression_model.py +120 -111
  21. orca_sdk/regression_model_test.py +15 -0
  22. orca_sdk/telemetry.py +162 -139
  23. {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.98.dist-info}/METADATA +2 -5
  24. orca_sdk-0.0.98.dist-info/RECORD +40 -0
  25. orca_sdk/_generated_api_client/__init__.py +0 -3
  26. orca_sdk/_generated_api_client/api/__init__.py +0 -307
  27. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  29. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  30. orca_sdk/_generated_api_client/api/auth/create_org_plan_auth_org_plan_post.py +0 -168
  31. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  32. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  33. orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +0 -122
  34. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  35. orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +0 -168
  36. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  37. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +0 -170
  38. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
  39. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  40. orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
  41. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
  42. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  43. orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  44. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  46. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
  47. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  48. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +0 -224
  49. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +0 -229
  50. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
  51. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  52. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
  53. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  54. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
  55. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  56. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
  57. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  58. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  59. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  60. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  62. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  66. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  67. orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
  68. orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
  69. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  70. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  71. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  74. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  75. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
  76. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
  77. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  78. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
  80. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
  81. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
  82. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
  83. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
  84. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
  87. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
  88. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
  92. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
  93. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  94. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  95. orca_sdk/_generated_api_client/api/regression_model/create_regression_model_regression_model_post.py +0 -170
  96. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  97. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
  98. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
  99. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  100. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
  101. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
  102. orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
  103. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
  104. orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
  105. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  106. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  107. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  108. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
  109. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -288
  110. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  111. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
  112. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  113. orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
  114. orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +0 -239
  115. orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +0 -192
  116. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  117. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
  118. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  119. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
  120. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
  121. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  122. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  123. orca_sdk/_generated_api_client/client.py +0 -216
  124. orca_sdk/_generated_api_client/errors.py +0 -38
  125. orca_sdk/_generated_api_client/models/__init__.py +0 -345
  126. orca_sdk/_generated_api_client/models/action_recommendation.py +0 -82
  127. orca_sdk/_generated_api_client/models/action_recommendation_action.py +0 -11
  128. orca_sdk/_generated_api_client/models/add_memory_recommendations.py +0 -85
  129. orca_sdk/_generated_api_client/models/add_memory_suggestion.py +0 -79
  130. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
  131. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
  132. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
  133. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
  134. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  135. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
  136. orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +0 -145
  137. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
  138. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
  139. orca_sdk/_generated_api_client/models/class_representatives.py +0 -92
  140. orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
  141. orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
  142. orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -227
  143. orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
  144. orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -210
  145. orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
  146. orca_sdk/_generated_api_client/models/column_info.py +0 -145
  147. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  148. orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -81
  149. orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +0 -8
  150. orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
  151. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
  152. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
  153. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
  154. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
  155. orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -237
  156. orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +0 -101
  157. orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -365
  158. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
  159. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
  160. orca_sdk/_generated_api_client/models/create_org_plan_request.py +0 -73
  161. orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +0 -11
  162. orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -157
  163. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
  164. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  165. orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
  166. orca_sdk/_generated_api_client/models/embed_request.py +0 -155
  167. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -205
  168. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -197
  169. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
  170. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
  171. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  172. orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -123
  173. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
  174. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
  175. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
  176. orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
  177. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  178. orca_sdk/_generated_api_client/models/filter_item.py +0 -239
  179. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
  180. orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +0 -8
  181. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +0 -8
  182. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -22
  183. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  184. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  185. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  186. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  187. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  188. orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
  189. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -81
  190. orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +0 -8
  191. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  192. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
  193. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  194. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
  195. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
  196. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
  197. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  198. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
  199. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  200. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  201. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  202. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  203. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
  204. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  205. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
  206. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  207. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
  208. orca_sdk/_generated_api_client/models/lookup_request.py +0 -101
  209. orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
  210. orca_sdk/_generated_api_client/models/memory_metrics.py +0 -263
  211. orca_sdk/_generated_api_client/models/memory_type.py +0 -9
  212. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -245
  213. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
  214. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
  215. orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +0 -79
  216. orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +0 -138
  217. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
  218. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
  219. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
  220. orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
  221. orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
  222. orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
  223. orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
  224. orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
  225. orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -333
  226. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
  227. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
  228. orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -265
  229. orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
  230. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
  231. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
  232. orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
  233. orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
  234. orca_sdk/_generated_api_client/models/memoryset_update.py +0 -121
  235. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -99
  236. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -23
  237. orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +0 -8
  238. orca_sdk/_generated_api_client/models/org_plan.py +0 -99
  239. orca_sdk/_generated_api_client/models/org_plan_tier.py +0 -11
  240. orca_sdk/_generated_api_client/models/paginated_task.py +0 -108
  241. orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
  242. orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
  243. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  244. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  245. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  246. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  247. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
  248. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
  249. orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -111
  250. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -115
  251. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
  252. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  253. orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
  254. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
  255. orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
  256. orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -191
  257. orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
  258. orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
  259. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
  260. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
  261. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
  262. orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
  263. orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
  264. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
  265. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
  266. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
  267. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
  268. orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
  269. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
  270. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
  271. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  272. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
  273. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -81
  274. orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +0 -8
  275. orca_sdk/_generated_api_client/models/task.py +0 -198
  276. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  277. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  278. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +0 -8
  279. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
  280. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +0 -8
  281. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +0 -8
  282. orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -217
  283. orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
  284. orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
  285. orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -185
  286. orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
  287. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -73
  288. orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +0 -8
  289. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -81
  290. orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +0 -8
  291. orca_sdk/_generated_api_client/models/update_org_plan_request.py +0 -73
  292. orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +0 -11
  293. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
  294. orca_sdk/_generated_api_client/models/validation_error.py +0 -99
  295. orca_sdk/_generated_api_client/py.typed +0 -1
  296. orca_sdk/_generated_api_client/types.py +0 -56
  297. orca_sdk-0.0.97.dist-info/RECORD +0 -309
  298. {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.98.dist-info}/WHEEL +0 -0
@@ -1,56 +1,61 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from contextlib import contextmanager
6
5
  from datetime import datetime
7
6
  from typing import Any, Generator, Iterable, Literal, cast, overload
8
- from uuid import UUID
9
7
 
10
8
  from datasets import Dataset
11
9
 
12
- from ._generated_api_client.api import (
13
- create_classification_model,
14
- delete_classification_model,
15
- evaluate_classification_model,
16
- get_classification_model,
17
- get_classification_model_evaluation,
18
- list_classification_models,
19
- list_predictions,
20
- predict_label_gpu,
21
- record_prediction_feedback,
22
- update_classification_model,
23
- )
24
- from ._generated_api_client.models import (
25
- ClassificationEvaluationRequest,
10
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .client import (
13
+ BootstrapClassificationModelMeta,
14
+ BootstrapClassificationModelResult,
26
15
  ClassificationModelMetadata,
27
- ClassificationPredictionRequest,
28
- CreateClassificationModelRequest,
29
- LabelPredictionWithMemoriesAndFeedback,
30
- ListPredictionsRequest,
31
- )
32
- from ._generated_api_client.models import (
33
- PredictionSortItemItemType0 as PredictionSortColumns,
34
- )
35
- from ._generated_api_client.models import (
36
- PredictionSortItemItemType1 as PredictionSortDirection,
37
- )
38
- from ._generated_api_client.models import (
39
16
  PredictiveModelUpdate,
40
17
  RACHeadType,
18
+ orca_api,
41
19
  )
42
- from ._generated_api_client.types import UNSET as CLIENT_UNSET
43
- from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
44
- from ._utils.common import UNSET, CreateMode, DropMode
45
20
  from .datasource import Datasource
46
21
  from .job import Job
47
22
  from .memoryset import (
48
23
  FilterItem,
49
24
  FilterItemTuple,
50
25
  LabeledMemoryset,
26
+ _is_metric_column,
51
27
  _parse_filter_item_from_tuple,
52
28
  )
53
- from .telemetry import ClassificationPrediction, _parse_feedback
29
+ from .telemetry import (
30
+ ClassificationPrediction,
31
+ TelemetryMode,
32
+ _get_telemetry_config,
33
+ _parse_feedback,
34
+ )
35
+
36
+
37
+ class BootstrappedClassificationModel:
38
+
39
+ datasource: Datasource | None
40
+ memoryset: LabeledMemoryset | None
41
+ classification_model: ClassificationModel | None
42
+ agent_output: BootstrapClassificationModelResult | None
43
+
44
+ def __init__(self, metadata: BootstrapClassificationModelMeta):
45
+ self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
46
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
47
+ self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
48
+ self.agent_output = metadata["agent_output"]
49
+
50
+ def __repr__(self):
51
+ return (
52
+ "BootstrappedClassificationModel({\n"
53
+ f" datasource: {self.datasource},\n"
54
+ f" memoryset: {self.memoryset},\n"
55
+ f" classification_model: {self.classification_model},\n"
56
+ f" agent_output: {self.agent_output},\n"
57
+ "})"
58
+ )
54
59
 
55
60
 
56
61
  class ClassificationModel:
@@ -86,18 +91,18 @@ class ClassificationModel:
86
91
 
87
92
  def __init__(self, metadata: ClassificationModelMetadata):
88
93
  # for internal use only, do not document
89
- self.id = metadata.id
90
- self.name = metadata.name
91
- self.description = metadata.description
92
- self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
93
- self.head_type = metadata.head_type
94
- self.num_classes = metadata.num_classes
95
- self.memory_lookup_count = metadata.memory_lookup_count
96
- self.weigh_memories = metadata.weigh_memories
97
- self.min_memory_weight = metadata.min_memory_weight
98
- self.version = metadata.version
99
- self.locked = metadata.locked
100
- self.created_at = metadata.created_at
94
+ self.id = metadata["id"]
95
+ self.name = metadata["name"]
96
+ self.description = metadata["description"]
97
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
98
+ self.head_type = metadata["head_type"]
99
+ self.num_classes = metadata["num_classes"]
100
+ self.memory_lookup_count = metadata["memory_lookup_count"]
101
+ self.weigh_memories = metadata["weigh_memories"]
102
+ self.min_memory_weight = metadata["min_memory_weight"]
103
+ self.version = metadata["version"]
104
+ self.locked = metadata["locked"]
105
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
101
106
 
102
107
  self._memoryset_override_id: str | None = None
103
108
  self._last_prediction: ClassificationPrediction | None = None
@@ -140,7 +145,7 @@ class ClassificationModel:
140
145
  cls,
141
146
  name: str,
142
147
  memoryset: LabeledMemoryset,
143
- head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
148
+ head_type: RACHeadType = "KNN",
144
149
  *,
145
150
  description: str | None = None,
146
151
  num_classes: int | None = None,
@@ -206,17 +211,18 @@ class ClassificationModel:
206
211
 
207
212
  return existing
208
213
 
209
- metadata = create_classification_model(
210
- body=CreateClassificationModelRequest(
211
- name=name,
212
- memoryset_id=memoryset.id,
213
- head_type=RACHeadType(head_type),
214
- memory_lookup_count=memory_lookup_count,
215
- num_classes=num_classes,
216
- weigh_memories=weigh_memories,
217
- min_memory_weight=min_memory_weight,
218
- description=description,
219
- ),
214
+ metadata = orca_api.POST(
215
+ "/classification_model",
216
+ json={
217
+ "name": name,
218
+ "memoryset_name_or_id": memoryset.id,
219
+ "head_type": head_type,
220
+ "memory_lookup_count": memory_lookup_count,
221
+ "num_classes": num_classes,
222
+ "weigh_memories": weigh_memories,
223
+ "min_memory_weight": min_memory_weight,
224
+ "description": description,
225
+ },
220
226
  )
221
227
  return cls(metadata)
222
228
 
@@ -234,7 +240,7 @@ class ClassificationModel:
234
240
  Raises:
235
241
  LookupError: If the classification model does not exist
236
242
  """
237
- return cls(get_classification_model(name))
243
+ return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
238
244
 
239
245
  @classmethod
240
246
  def exists(cls, name_or_id: str) -> bool:
@@ -261,7 +267,7 @@ class ClassificationModel:
261
267
  Returns:
262
268
  List of handles to all classification models in the OrcaCloud
263
269
  """
264
- return [cls(metadata) for metadata in list_classification_models()]
270
+ return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
265
271
 
266
272
  @classmethod
267
273
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -280,7 +286,7 @@ class ClassificationModel:
280
286
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
281
287
  """
282
288
  try:
283
- delete_classification_model(name_or_id)
289
+ orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
284
290
  logging.info(f"Deleted model {name_or_id}")
285
291
  except LookupError:
286
292
  if if_not_exists == "error":
@@ -311,11 +317,12 @@ class ClassificationModel:
311
317
  Lock the model:
312
318
  >>> model.set(locked=True)
313
319
  """
314
- update_data = PredictiveModelUpdate(
315
- description=CLIENT_UNSET if description is UNSET else description,
316
- locked=CLIENT_UNSET if locked is UNSET else locked,
317
- )
318
- update_classification_model(self.id, body=update_data)
320
+ update: PredictiveModelUpdate = {}
321
+ if description is not UNSET:
322
+ update["description"] = description
323
+ if locked is not UNSET:
324
+ update["locked"] = locked
325
+ orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
319
326
  self.refresh()
320
327
 
321
328
  def lock(self) -> None:
@@ -333,7 +340,9 @@ class ClassificationModel:
333
340
  expected_labels: list[int] | None = None,
334
341
  filters: list[FilterItemTuple] = [],
335
342
  tags: set[str] | None = None,
336
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
343
+ save_telemetry: TelemetryMode = "on",
344
+ prompt: str | None = None,
345
+ use_lookup_cache: bool = True,
337
346
  ) -> list[ClassificationPrediction]:
338
347
  pass
339
348
 
@@ -344,17 +353,21 @@ class ClassificationModel:
344
353
  expected_labels: int | None = None,
345
354
  filters: list[FilterItemTuple] = [],
346
355
  tags: set[str] | None = None,
347
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
356
+ save_telemetry: TelemetryMode = "on",
357
+ prompt: str | None = None,
358
+ use_lookup_cache: bool = True,
348
359
  ) -> ClassificationPrediction:
349
360
  pass
350
361
 
351
362
  def predict(
352
363
  self,
353
364
  value: list[str] | str,
354
- expected_labels: list[int] | int | None = None,
365
+ expected_labels: list[int] | list[str] | int | str | None = None,
355
366
  filters: list[FilterItemTuple] = [],
356
367
  tags: set[str] | None = None,
357
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
368
+ save_telemetry: TelemetryMode = "on",
369
+ prompt: str | None = None,
370
+ use_lookup_cache: bool = True,
358
371
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
359
372
  """
360
373
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -362,6 +375,7 @@ class ClassificationModel:
362
375
  Params:
363
376
  value: Value(s) to get predict the labels of
364
377
  expected_labels: Expected label(s) for the given input to record for model evaluation
378
+ filters: Optional filters to apply during memory lookup
365
379
  tags: Tags to add to the prediction(s)
366
380
  save_telemetry: Whether to save telemetry for the prediction(s). One of
367
381
  * `"off"`: Do not save telemetry
@@ -369,6 +383,7 @@ class ClassificationModel:
369
383
  environment variable is set.
370
384
  * `"sync"`: Save telemetry synchronously
371
385
  * `"async"`: Save telemetry asynchronously
386
+ prompt: Optional prompt to use for instruction-tuned embedding models
372
387
 
373
388
  Returns:
374
389
  Label prediction or list of label predictions
@@ -384,48 +399,60 @@ class ClassificationModel:
384
399
  ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
385
400
  ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
386
401
  ]
402
+
403
+ Using a prompt with an instruction-tuned embedding model:
404
+ >>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
405
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
387
406
  """
388
407
 
389
408
  parsed_filters = [
390
409
  _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
391
410
  ]
392
411
 
393
- if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
412
+ if any(_is_metric_column(filter[0]) for filter in filters):
394
413
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
395
414
 
396
- response = predict_label_gpu(
397
- self.id,
398
- body=ClassificationPredictionRequest(
399
- input_values=value if isinstance(value, list) else [value],
400
- memoryset_override_id=self._memoryset_override_id,
401
- expected_labels=(
402
- expected_labels
403
- if isinstance(expected_labels, list)
404
- else [expected_labels] if expected_labels is not None else None
405
- ),
406
- tags=list(tags or set()),
407
- save_telemetry=save_telemetry != "off",
408
- save_telemetry_synchronously=(
409
- os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
410
- ),
411
- filters=cast(list[FilterItem], parsed_filters),
412
- ),
415
+ if isinstance(expected_labels, int):
416
+ expected_labels = [expected_labels]
417
+ elif isinstance(expected_labels, str):
418
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
419
+ elif isinstance(expected_labels, list):
420
+ expected_labels = [
421
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
422
+ for label in expected_labels
423
+ ]
424
+
425
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
426
+ response = orca_api.POST(
427
+ "/gpu/classification_model/{name_or_id}/prediction",
428
+ params={"name_or_id": self.id},
429
+ json={
430
+ "input_values": value if isinstance(value, list) else [value],
431
+ "memoryset_override_name_or_id": self._memoryset_override_id,
432
+ "expected_labels": expected_labels,
433
+ "tags": list(tags or set()),
434
+ "save_telemetry": telemetry_on,
435
+ "save_telemetry_synchronously": telemetry_sync,
436
+ "filters": cast(list[FilterItem], parsed_filters),
437
+ "prompt": prompt,
438
+ "use_lookup_cache": use_lookup_cache,
439
+ },
413
440
  )
414
441
 
415
- if save_telemetry != "off" and any(p.prediction_id is None for p in response):
442
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
416
443
  raise RuntimeError("Failed to save prediction to database.")
417
444
 
418
445
  predictions = [
419
446
  ClassificationPrediction(
420
- prediction_id=prediction.prediction_id,
421
- label=prediction.label,
422
- label_name=prediction.label_name,
447
+ prediction_id=prediction["prediction_id"],
448
+ label=prediction["label"],
449
+ label_name=prediction["label_name"],
423
450
  score=None,
424
- confidence=prediction.confidence,
425
- anomaly_score=prediction.anomaly_score,
451
+ confidence=prediction["confidence"],
452
+ anomaly_score=prediction["anomaly_score"],
426
453
  memoryset=self.memoryset,
427
454
  model=self,
428
- logits=prediction.logits,
455
+ logits=prediction["logits"],
429
456
  input_value=input_value,
430
457
  )
431
458
  for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
@@ -439,7 +466,7 @@ class ClassificationModel:
439
466
  limit: int = 100,
440
467
  offset: int = 0,
441
468
  tag: str | None = None,
442
- sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
469
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
443
470
  expected_label_match: bool | None = None,
444
471
  ) -> list[ClassificationPrediction]:
445
472
  """
@@ -475,30 +502,31 @@ class ClassificationModel:
475
502
  >>> predictions = model.predictions(expected_label_match=False)
476
503
  [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
477
504
  """
478
- predictions = list_predictions(
479
- body=ListPredictionsRequest(
480
- model_id=self.id,
481
- limit=limit,
482
- offset=offset,
483
- sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
484
- tag=tag,
485
- expected_label_match=expected_label_match,
486
- ),
505
+ predictions = orca_api.POST(
506
+ "/telemetry/prediction",
507
+ json={
508
+ "model_id": self.id,
509
+ "limit": limit,
510
+ "offset": offset,
511
+ "sort": [list(sort_item) for sort_item in sort],
512
+ "tag": tag,
513
+ "expected_label_match": expected_label_match,
514
+ },
487
515
  )
488
516
  return [
489
517
  ClassificationPrediction(
490
- prediction_id=prediction.prediction_id,
491
- label=prediction.label,
492
- label_name=prediction.label_name,
518
+ prediction_id=prediction["prediction_id"],
519
+ label=prediction["label"],
520
+ label_name=prediction["label_name"],
493
521
  score=None,
494
- confidence=prediction.confidence,
495
- anomaly_score=prediction.anomaly_score,
522
+ confidence=prediction["confidence"],
523
+ anomaly_score=prediction["anomaly_score"],
496
524
  memoryset=self.memoryset,
497
525
  model=self,
498
526
  telemetry=prediction,
499
527
  )
500
528
  for prediction in predictions
501
- if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
529
+ if "label" in prediction
502
530
  ]
503
531
 
504
532
  def _evaluate_datasource(
@@ -510,23 +538,28 @@ class ClassificationModel:
510
538
  tags: set[str] | None,
511
539
  background: bool = False,
512
540
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
513
- response = evaluate_classification_model(
514
- self.id,
515
- body=ClassificationEvaluationRequest(
516
- datasource_id=datasource.id,
517
- datasource_label_column=label_column,
518
- datasource_value_column=value_column,
519
- memoryset_override_id=self._memoryset_override_id,
520
- record_telemetry=record_predictions,
521
- telemetry_tags=list(tags) if tags else None,
522
- ),
541
+ response = orca_api.POST(
542
+ "/classification_model/{model_name_or_id}/evaluation",
543
+ params={"model_name_or_id": self.id},
544
+ json={
545
+ "datasource_name_or_id": datasource.id,
546
+ "datasource_label_column": label_column,
547
+ "datasource_value_column": value_column,
548
+ "memoryset_override_name_or_id": self._memoryset_override_id,
549
+ "record_telemetry": record_predictions,
550
+ "telemetry_tags": list(tags) if tags else None,
551
+ },
523
552
  )
524
553
 
525
- job = Job(
526
- response.task_id,
527
- lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
528
- and ClassificationMetrics(**r.to_dict()),
529
- )
554
+ def get_value():
555
+ res = orca_api.GET(
556
+ "/classification_model/{model_name_or_id}/evaluation/{task_id}",
557
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
558
+ )
559
+ assert res["result"] is not None
560
+ return ClassificationMetrics(**res["result"])
561
+
562
+ job = Job(response["task_id"], get_value)
530
563
  return job if background else job.result()
531
564
 
532
565
  def _evaluate_dataset(
@@ -709,8 +742,37 @@ class ClassificationModel:
709
742
  ValueError: If the value does not match previous value types for the category, or is a
710
743
  [`float`][float] that is not between `-1.0` and `+1.0`.
711
744
  """
712
- record_prediction_feedback(
713
- body=[
745
+ orca_api.PUT(
746
+ "/telemetry/prediction/feedback",
747
+ json=[
714
748
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
715
749
  ],
716
750
  )
751
+
752
+ @staticmethod
753
+ def bootstrap_model(
754
+ model_description: str,
755
+ label_names: list[str],
756
+ initial_examples: list[tuple[str, str]],
757
+ num_examples_per_label: int,
758
+ background: bool = False,
759
+ ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
760
+ response = orca_api.POST(
761
+ "/agents/bootstrap_classification_model",
762
+ json={
763
+ "model_description": model_description,
764
+ "label_names": label_names,
765
+ "initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
766
+ "num_examples_per_label": num_examples_per_label,
767
+ },
768
+ )
769
+
770
+ def get_result() -> BootstrappedClassificationModel:
771
+ res = orca_api.GET(
772
+ "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
773
+ )
774
+ assert res["result"] is not None
775
+ return BootstrappedClassificationModel(res["result"])
776
+
777
+ job = Job(response["task_id"], get_result)
778
+ return job if background else job.result()
@@ -326,6 +326,37 @@ def test_predict_with_filters(classification_model: ClassificationModel):
326
326
  assert filtered_prediction.label_name == "cats"
327
327
 
328
328
 
329
+ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
330
+ model = ClassificationModel.create(
331
+ "test_predict_with_memoryset_update",
332
+ writable_memoryset,
333
+ num_classes=2,
334
+ memory_lookup_count=3,
335
+ )
336
+
337
+ prediction = model.predict("Do you love soup?")
338
+ assert prediction.label == 0
339
+ assert prediction.label_name == "soup"
340
+
341
+ # insert new memories
342
+ writable_memoryset.insert(
343
+ [
344
+ {"value": "Do you love soup?", "label": 1, "key": "g1"},
345
+ {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
346
+ {"value": "Do you love crackers?", "label": 1, "key": "g2"},
347
+ {"value": "Do you love broth?", "label": 1, "key": "g2"},
348
+ {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
349
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
350
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
351
+ ],
352
+ )
353
+ prediction = model.predict("Do you love soup?")
354
+ assert prediction.label == 1
355
+ assert prediction.label_name == "cats"
356
+
357
+ ClassificationModel.drop("test_predict_with_memoryset_update")
358
+
359
+
329
360
  def test_last_prediction_with_batch(classification_model: ClassificationModel):
330
361
  predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
331
362
  assert classification_model.last_prediction is not None
@@ -396,6 +427,8 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
396
427
  # Make a prediction with expected label to simulate incorrect prediction
397
428
  prediction = model.predict("Do you love soup?", expected_labels=1)
398
429
 
430
+ memoryset_length = model.memoryset.length
431
+
399
432
  try:
400
433
  # Get action recommendation
401
434
  action, rationale = prediction.recommend_action()
@@ -406,19 +439,22 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
406
439
  assert len(rationale) > 10
407
440
 
408
441
  # Test memory suggestions
409
- memory_suggestions = prediction.generate_memory_suggestions(num_memories=2)
442
+ suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
443
+ memory_suggestions = suggestions_response.suggestions
410
444
 
411
445
  assert memory_suggestions is not None
412
446
  assert len(memory_suggestions) == 2
413
447
 
414
448
  for suggestion in memory_suggestions:
415
- assert isinstance(suggestion, dict)
416
- assert "value" in suggestion
417
- assert "label" in suggestion
418
- assert isinstance(suggestion["value"], str)
419
- assert len(suggestion["value"]) > 0
420
- assert isinstance(suggestion["label"], int)
421
- assert 0 <= suggestion["label"] < len(model.memoryset.label_names)
449
+ assert isinstance(suggestion[0], str)
450
+ assert len(suggestion[0]) > 0
451
+ assert isinstance(suggestion[1], str)
452
+ assert suggestion[1] in model.memoryset.label_names
453
+
454
+ suggestions_response.apply()
455
+
456
+ model.memoryset.refresh()
457
+ assert model.memoryset.length == memoryset_length + 2
422
458
 
423
459
  except Exception as e:
424
460
  if "ANTHROPIC_API_KEY" in str(e):
@@ -427,3 +463,16 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
427
463
  raise e
428
464
  finally:
429
465
  ClassificationModel.drop("test_model_for_action")
466
+
467
+
468
+ def test_predict_with_prompt(classification_model: ClassificationModel):
469
+ """Test that prompt parameter is properly passed through to predictions"""
470
+ # Test with an instruction-supporting embedding model if available
471
+ prediction_with_prompt = classification_model.predict(
472
+ "I love this product!", prompt="Represent this text for sentiment classification:"
473
+ )
474
+ prediction_without_prompt = classification_model.predict("I love this product!")
475
+
476
+ # Both should work and return valid predictions
477
+ assert prediction_with_prompt.label is not None
478
+ assert prediction_without_prompt.label is not None