orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,14 +1,17 @@
1
+ import logging
2
+ from typing import get_args
1
3
  from uuid import uuid4
2
4
 
3
5
  import pytest
4
6
 
5
7
  from .datasource import Datasource
6
8
  from .embedding_model import (
9
+ ClassificationMetrics,
7
10
  FinetunedEmbeddingModel,
8
11
  PretrainedEmbeddingModel,
9
12
  PretrainedEmbeddingModelName,
10
- TaskStatus,
11
13
  )
14
+ from .job import Status
12
15
  from .memoryset import LabeledMemoryset
13
16
 
14
17
 
@@ -22,20 +25,24 @@ def test_open_pretrained_model():
22
25
  assert model is PretrainedEmbeddingModel.GTE_BASE
23
26
 
24
27
 
25
- def test_open_pretrained_model_unauthenticated(unauthenticated):
26
- with pytest.raises(ValueError, match="Invalid API key"):
27
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
+ def test_open_pretrained_model_unauthenticated(unauthenticated_client):
29
+ with unauthenticated_client.use():
30
+ with pytest.raises(ValueError, match="Invalid API key"):
31
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
32
 
29
33
 
30
34
  def test_open_pretrained_model_not_found():
31
35
  with pytest.raises(LookupError):
32
- PretrainedEmbeddingModel._get("INVALID_MODEL")
36
+ PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
33
37
 
34
38
 
35
39
  def test_all_pretrained_models():
36
40
  models = PretrainedEmbeddingModel.all()
37
- assert len(models) == len(PretrainedEmbeddingModelName)
38
- assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
41
+ assert len(models) > 1
42
+ if len(models) != len(get_args(PretrainedEmbeddingModelName)):
43
+ logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
44
+ model_names = [m.name for m in models]
45
+ assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
39
46
 
40
47
 
41
48
  def test_embed_text():
@@ -46,9 +53,17 @@ def test_embed_text():
46
53
  assert isinstance(embedding[0], float)
47
54
 
48
55
 
49
- def test_embed_text_unauthenticated(unauthenticated):
50
- with pytest.raises(ValueError, match="Invalid API key"):
51
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
56
+ def test_embed_text_unauthenticated(unauthenticated_client):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
60
+
61
+
62
+ def test_evaluate_pretrained_model(datasource: Datasource):
63
+ metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
64
+ assert metrics is not None
65
+ assert isinstance(metrics, ClassificationMetrics)
66
+ assert metrics.accuracy > 0.5
52
67
 
53
68
 
54
69
  @pytest.fixture(scope="session")
@@ -62,26 +77,25 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
62
77
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
63
78
  assert finetuned_model.embedding_dim == 768
64
79
  assert finetuned_model.max_seq_length == 512
65
- assert finetuned_model._status == TaskStatus.COMPLETED
80
+ assert finetuned_model._status == Status.COMPLETED
66
81
 
67
82
 
68
- def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
- finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
83
+ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
84
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
85
+ "test_finetuned_model_from_memoryset", readonly_memoryset
86
+ )
70
87
  assert finetuned_model is not None
71
88
  assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
89
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
73
90
  assert finetuned_model.embedding_dim == 768
74
91
  assert finetuned_model.max_seq_length == 512
75
- assert finetuned_model._status == TaskStatus.COMPLETED
92
+ assert finetuned_model._status == Status.COMPLETED
76
93
 
77
94
 
78
95
  def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
79
96
  with pytest.raises(ValueError):
80
97
  PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
81
98
 
82
- with pytest.raises(ValueError):
83
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="error")
84
-
85
99
 
86
100
  def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
87
101
  with pytest.raises(ValueError):
@@ -93,12 +107,13 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
93
107
  assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
94
108
  assert new_model.embedding_dim == 768
95
109
  assert new_model.max_seq_length == 512
96
- assert new_model._status == TaskStatus.COMPLETED
110
+ assert new_model._status == Status.COMPLETED
97
111
 
98
112
 
99
- def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
100
- with pytest.raises(ValueError, match="Invalid API key"):
101
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
113
+ def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
114
+ with unauthenticated_client.use():
115
+ with pytest.raises(ValueError, match="Invalid API key"):
116
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
102
117
 
103
118
 
104
119
  def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
@@ -106,7 +121,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
106
121
  "test_memoryset_finetuned_model",
107
122
  datasource,
108
123
  embedding_model=finetuned_model,
109
- value_column="text",
110
124
  )
111
125
  assert memoryset is not None
112
126
  assert memoryset.name == "test_memoryset_finetuned_model"
@@ -139,13 +153,15 @@ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
139
153
  assert any(model.name == finetuned_model.name for model in models)
140
154
 
141
155
 
142
- def test_all_finetuned_models_unauthenticated(unauthenticated):
143
- with pytest.raises(ValueError, match="Invalid API key"):
144
- FinetunedEmbeddingModel.all()
156
+ def test_all_finetuned_models_unauthenticated(unauthenticated_client):
157
+ with unauthenticated_client.use():
158
+ with pytest.raises(ValueError, match="Invalid API key"):
159
+ FinetunedEmbeddingModel.all()
145
160
 
146
161
 
147
- def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
148
- assert finetuned_model not in FinetunedEmbeddingModel.all()
162
+ def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
163
+ with unauthorized_client.use():
164
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
149
165
 
150
166
 
151
167
  def test_drop_finetuned_model(datasource: Datasource):
@@ -156,9 +172,10 @@ def test_drop_finetuned_model(datasource: Datasource):
156
172
  FinetunedEmbeddingModel.open("finetuned_model_to_delete")
157
173
 
158
174
 
159
- def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
160
- with pytest.raises(ValueError, match="Invalid API key"):
161
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
175
+ def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
176
+ with unauthenticated_client.use():
177
+ with pytest.raises(ValueError, match="Invalid API key"):
178
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
162
179
 
163
180
 
164
181
  def test_drop_finetuned_model_not_found():
@@ -168,6 +185,22 @@ def test_drop_finetuned_model_not_found():
168
185
  FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
169
186
 
170
187
 
171
- def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
172
- with pytest.raises(LookupError):
173
- FinetunedEmbeddingModel.drop(finetuned_model.id)
188
+ def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
189
+ with unauthorized_client.use():
190
+ with pytest.raises(LookupError):
191
+ FinetunedEmbeddingModel.drop(finetuned_model.id)
192
+
193
+
194
+ def test_supports_instructions():
195
+ model = PretrainedEmbeddingModel.GTE_BASE
196
+ assert not model.supports_instructions
197
+
198
+ instruction_model = PretrainedEmbeddingModel.BGE_BASE
199
+ assert instruction_model.supports_instructions
200
+
201
+
202
+ def test_use_explicit_instruction_prompt():
203
+ model = PretrainedEmbeddingModel.BGE_BASE
204
+ assert model.supports_instructions
205
+ input = "Hello world"
206
+ assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
orca_sdk/job.py ADDED
@@ -0,0 +1,343 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from datetime import datetime, timedelta
5
+ from enum import Enum
6
+ from typing import Callable, Generic, TypedDict, TypeVar, cast
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ from .client import OrcaClient
11
+
12
+
13
+ class JobConfig(TypedDict):
14
+ refresh_interval: int
15
+ show_progress: bool
16
+ max_wait: int
17
+
18
+
19
+ class Status(Enum):
20
+ """Status of a cloud job in the task queue"""
21
+
22
+ # the INITIALIZED state should never be returned by the API
23
+ INITIALIZED = "INITIALIZED"
24
+ """The job has been initialized"""
25
+
26
+ DISPATCHED = "DISPATCHED"
27
+ """The job has been queued and is waiting to be processed"""
28
+
29
+ WAITING = "WAITING"
30
+ """The job is waiting for dependencies to complete"""
31
+
32
+ PROCESSING = "PROCESSING"
33
+ """The job is being processed"""
34
+
35
+ COMPLETED = "COMPLETED"
36
+ """The job has been completed successfully"""
37
+
38
+ FAILED = "FAILED"
39
+ """The job has failed"""
40
+
41
+ ABORTING = "ABORTING"
42
+ """The job is being aborted"""
43
+
44
+ ABORTED = "ABORTED"
45
+ """The job has been aborted"""
46
+
47
+
48
+ TResult = TypeVar("TResult")
49
+
50
+
51
+ class Job(Generic[TResult]):
52
+ """
53
+ Handle to a job that is run in the OrcaCloud
54
+
55
+ Attributes:
56
+ id: Unique identifier for the job
57
+ type: Type of the job
58
+ status: Current status of the job
59
+ steps_total: Total number of steps in the job, present if the job started processing
60
+ steps_completed: Number of steps completed in the job, present if the job started processing
61
+ completion: Percentage of the job that has been completed, present if the job started processing
62
+ exception: Exception that occurred during the job, present if the status is `FAILED`
63
+ value: Value of the result of the job, present if the status is `COMPLETED`
64
+ created_at: When the job was queued for processing
65
+ updated_at: When the job was last updated
66
+ refreshed_at: When the job status was last refreshed
67
+
68
+ Note:
69
+ Accessing status and related attributes will refresh the job status in the background.
70
+ """
71
+
72
+ id: str
73
+ type: str
74
+ status: Status
75
+ steps_total: int | None
76
+ steps_completed: int | None
77
+ exception: str | None
78
+ value: TResult | None
79
+ updated_at: datetime
80
+ created_at: datetime
81
+ refreshed_at: datetime
82
+
83
+ @property
84
+ def completion(self) -> float:
85
+ """
86
+ Percentage of the job that has been completed, present if the job started processing
87
+ """
88
+ return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
89
+
90
+ # Global configuration for all jobs
91
+ config: JobConfig = {
92
+ "refresh_interval": 3,
93
+ "show_progress": True,
94
+ "max_wait": 60 * 60,
95
+ }
96
+
97
+ def __repr__(self) -> str:
98
+ return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
99
+
100
+ @classmethod
101
+ def set_config(
102
+ cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
103
+ ):
104
+ """
105
+ Set global configuration for running jobs
106
+
107
+ Args:
108
+ refresh_interval: Time to wait between polling the job status in seconds, default is 3
109
+ show_progress: Whether to show a progress bar when calling the wait method, default is True
110
+ max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
111
+ """
112
+ if refresh_interval is not None:
113
+ cls.config["refresh_interval"] = refresh_interval
114
+ if show_progress is not None:
115
+ cls.config["show_progress"] = show_progress
116
+ if max_wait is not None:
117
+ cls.config["max_wait"] = max_wait
118
+
119
+ @classmethod
120
+ def query(
121
+ cls,
122
+ status: Status | list[Status] | None = None,
123
+ type: str | list[str] | None = None,
124
+ limit: int = 100,
125
+ offset: int = 0,
126
+ start: datetime | None = None,
127
+ end: datetime | None = None,
128
+ ) -> list[Job]:
129
+ """
130
+ Query the job queue for jobs matching the given filters
131
+
132
+ Args:
133
+ status: Optional status or list of statuses to filter by
134
+ type: Optional type or list of types to filter by
135
+ limit: Maximum number of jobs to return
136
+ offset: Offset into the list of jobs to return
137
+ start: Optional minimum creation time of the jobs to query for
138
+ end: Optional maximum creation time of the jobs to query for
139
+
140
+ Returns:
141
+ List of jobs matching the given filters
142
+ """
143
+ client = OrcaClient._resolve_client()
144
+ paginated_tasks = client.GET(
145
+ "/task",
146
+ params={
147
+ "status": (
148
+ [s.value for s in status]
149
+ if isinstance(status, list)
150
+ else status.value if isinstance(status, Status) else None
151
+ ),
152
+ "type": type,
153
+ "limit": limit,
154
+ "offset": offset,
155
+ "start_timestamp": start.isoformat() if start is not None else None,
156
+ "end_timestamp": end.isoformat() if end is not None else None,
157
+ },
158
+ )
159
+
160
+ # can't use constructor because it makes an API call, so we construct the objects manually
161
+ return [
162
+ (
163
+ lambda t: (
164
+ obj := cls.__new__(cls),
165
+ setattr(obj, "id", t["id"]),
166
+ setattr(obj, "type", t["type"]),
167
+ setattr(obj, "status", Status(t["status"])),
168
+ setattr(obj, "steps_total", t["steps_total"]),
169
+ setattr(obj, "steps_completed", t["steps_completed"]),
170
+ setattr(obj, "exception", t["exception"]),
171
+ setattr(obj, "value", cast(TResult, t["result"]) if t["result"] is not None else None),
172
+ setattr(obj, "updated_at", datetime.fromisoformat(t["updated_at"])),
173
+ setattr(obj, "created_at", datetime.fromisoformat(t["created_at"])),
174
+ setattr(obj, "refreshed_at", datetime.now()),
175
+ obj,
176
+ )[-1]
177
+ )(t)
178
+ for t in paginated_tasks["items"]
179
+ ]
180
+
181
+ def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
182
+ """
183
+ Create a handle to a job in the job queue
184
+
185
+ Args:
186
+ id: Unique identifier for the job
187
+ get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
188
+ """
189
+ self.id = id
190
+ client = OrcaClient._resolve_client()
191
+ task = client.GET("/task/{task_id}", params={"task_id": id})
192
+
193
+ def default_get_value():
194
+ client = OrcaClient._resolve_client()
195
+ return cast(TResult | None, client.GET("/task/{task_id}", params={"task_id": id})["result"])
196
+
197
+ self._get_value = get_value or default_get_value
198
+ self.type = task["type"]
199
+ self.status = Status(task["status"])
200
+ self.steps_total = task["steps_total"]
201
+ self.steps_completed = task["steps_completed"]
202
+ self.exception = task["exception"]
203
+ self.value = (
204
+ None
205
+ if task["status"] != "COMPLETED"
206
+ else (
207
+ get_value()
208
+ if get_value is not None
209
+ else cast(TResult, task["result"]) if task["result"] is not None else None
210
+ )
211
+ )
212
+ self.updated_at = datetime.fromisoformat(task["updated_at"])
213
+ self.created_at = datetime.fromisoformat(task["created_at"])
214
+ self.refreshed_at = datetime.now()
215
+
216
+ def refresh(self, throttle: float = 0):
217
+ """
218
+ Refresh the status and progress of the job
219
+
220
+ Params:
221
+ throttle: Minimum time in seconds between refreshes
222
+ """
223
+ current_time = datetime.now()
224
+ # Skip refresh if last refresh was too recent
225
+ if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
226
+ return
227
+ self.refreshed_at = current_time
228
+
229
+ client = OrcaClient._resolve_client()
230
+ status_info = client.GET("/task/{task_id}/status", params={"task_id": self.id})
231
+ self.status = Status(status_info["status"])
232
+ if status_info["steps_total"] is not None:
233
+ self.steps_total = status_info["steps_total"]
234
+ if status_info["steps_completed"] is not None:
235
+ self.steps_completed = status_info["steps_completed"]
236
+
237
+ self.exception = status_info["exception"]
238
+ self.updated_at = datetime.fromisoformat(status_info["updated_at"])
239
+
240
+ if status_info["status"] == "COMPLETED":
241
+ self.value = self._get_value()
242
+
243
+ def __getattribute__(self, name: str):
244
+ # if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
245
+ if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
246
+ self.refresh(self.config["refresh_interval"])
247
+ return super().__getattribute__(name)
248
+
249
+ def wait(
250
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
251
+ ) -> None:
252
+ """
253
+ Block until the job is complete
254
+
255
+ Params:
256
+ show_progress: Show a progress bar while waiting for the job to complete
257
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
258
+ max_wait: Maximum time to wait for the job to complete in seconds
259
+
260
+ Note:
261
+ The defaults for the config parameters can be set globally using the
262
+ [`set_config`][orca_sdk.Job.set_config] method.
263
+
264
+ This method will not return the result or raise an exception if the job fails. Call
265
+ [`result`][orca_sdk.Job.result] instead if you want to get the result.
266
+
267
+ Raises:
268
+ RuntimeError: If the job times out
269
+ """
270
+ start_time = time.time()
271
+ show_progress = show_progress if show_progress is not None else self.config["show_progress"]
272
+ refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
273
+ max_wait = max_wait if max_wait is not None else self.config["max_wait"]
274
+ pbar = None
275
+ while True:
276
+ # setup progress bar if steps total is known
277
+ if not pbar and self.steps_total is not None and show_progress:
278
+ desc = " ".join(self.type.split("_")).lower()
279
+ pbar = tqdm(total=self.steps_total, desc=desc)
280
+
281
+ # return if job is complete
282
+ if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
283
+ if pbar:
284
+ pbar.update(self.steps_total - pbar.n)
285
+ pbar.close()
286
+ return
287
+
288
+ # raise error if job timed out
289
+ if (time.time() - start_time) > max_wait:
290
+ raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
291
+
292
+ # update progress bar
293
+ if pbar and self.steps_completed is not None:
294
+ pbar.update(self.steps_completed - pbar.n)
295
+
296
+ # sleep before retrying
297
+ time.sleep(refresh_interval)
298
+
299
+ def result(
300
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
301
+ ) -> TResult:
302
+ """
303
+ Block until the job is complete and return the result value
304
+
305
+ Params:
306
+ show_progress: Show a progress bar while waiting for the job to complete
307
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
308
+ max_wait: Maximum time to wait for the job to complete in seconds
309
+
310
+ Note:
311
+ The defaults for the config parameters can be set globally using the
312
+ [`set_config`][orca_sdk.Job.set_config] method.
313
+
314
+ This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
315
+ if you just want to wait for the job to complete without raising errors on failure.
316
+
317
+ Returns:
318
+ The result value of the job
319
+
320
+ Raises:
321
+ RuntimeError: If the job fails or times out
322
+ """
323
+ if self.value is not None:
324
+ return self.value
325
+ self.wait(show_progress, refresh_interval, max_wait)
326
+ if self.status != Status.COMPLETED:
327
+ raise RuntimeError(f"Job failed with exception: {self.exception}")
328
+ assert self.value is not None
329
+ return self.value
330
+
331
+
332
+ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
333
+ """
334
+ Abort the job
335
+
336
+ Params:
337
+ show_progress: Optionally show a progress bar while waiting for the job to abort
338
+ refresh_interval: Polling interval in seconds while waiting for the job to abort
339
+ max_wait: Maximum time to wait for the job to abort in seconds
340
+ """
341
+ client = OrcaClient._resolve_client()
342
+ client.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
343
+ self.wait(show_progress, refresh_interval, max_wait)
orca_sdk/job_test.py ADDED
@@ -0,0 +1,108 @@
1
+ import time
2
+
3
+ import pytest
4
+ from datasets import Dataset
5
+
6
+ from .classification_model import ClassificationModel
7
+ from .datasource import Datasource
8
+ from .job import Job, Status
9
+
10
+
11
+ @pytest.fixture(scope="session")
12
+ def datasource_without_nones(hf_dataset: Dataset):
13
+ return Datasource.from_hf_dataset(
14
+ "test_datasource_without_nones", hf_dataset.filter(lambda x: x["label"] is not None)
15
+ )
16
+
17
+
18
+ def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0.2):
19
+ """
20
+ Wait until all jobs reach one of the expected statuses or timeout is reached.
21
+ """
22
+ start = time.time()
23
+ while time.time() - start < timeout:
24
+ jobs = [Job(job_id) for job_id in job_ids]
25
+ if all(job.status in expected_statuses for job in jobs):
26
+ return
27
+ time.sleep(poll_interval)
28
+ raise TimeoutError(f"Jobs did not reach statuses {expected_statuses} within {timeout} seconds")
29
+
30
+
31
+ def test_job_creation(classification_model: ClassificationModel, datasource_without_nones: Datasource):
32
+ job = classification_model.evaluate(datasource_without_nones, background=True)
33
+ assert job.id is not None
34
+ assert job.type == "EVALUATE_MODEL"
35
+ assert job.status in [Status.DISPATCHED, Status.PROCESSING]
36
+ assert job.created_at is not None
37
+ assert job.updated_at is not None
38
+ assert job.refreshed_at is not None
39
+ assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
40
+
41
+
42
+ def test_job_result(classification_model: ClassificationModel, datasource_without_nones: Datasource):
43
+ job = classification_model.evaluate(datasource_without_nones, background=True)
44
+ result = job.result(show_progress=False)
45
+ assert result is not None
46
+ assert job.status == Status.COMPLETED
47
+ assert job.steps_completed is not None
48
+ assert job.steps_completed == job.steps_total
49
+
50
+
51
+ def test_job_wait(classification_model: ClassificationModel, datasource_without_nones: Datasource):
52
+ job = classification_model.evaluate(datasource_without_nones, background=True)
53
+ job.wait(show_progress=False)
54
+ assert job.status == Status.COMPLETED
55
+ assert job.steps_completed is not None
56
+ assert job.steps_completed == job.steps_total
57
+ assert job.value is not None
58
+
59
+
60
+ def test_job_refresh(classification_model: ClassificationModel, datasource_without_nones: Datasource):
61
+ job = classification_model.evaluate(datasource_without_nones, background=True)
62
+ last_refreshed_at = job.refreshed_at
63
+ # accessing the status attribute should refresh the job after the refresh interval
64
+ Job.set_config(refresh_interval=1)
65
+ time.sleep(1)
66
+ job.status
67
+ assert job.refreshed_at > last_refreshed_at
68
+ last_refreshed_at = job.refreshed_at
69
+ # calling refresh() should immediately refresh the job
70
+ job.refresh()
71
+ assert job.refreshed_at > last_refreshed_at
72
+
73
+
74
+ def test_job_query_pagination(classification_model: ClassificationModel, datasource_without_nones: Datasource):
75
+ """Test pagination with Job.query() method"""
76
+ # Create multiple jobs to test pagination
77
+ jobs_created = []
78
+ for i in range(3):
79
+ job = classification_model.evaluate(datasource_without_nones, background=True)
80
+ jobs_created.append(job.id)
81
+
82
+ # Wait for jobs to be at least PROCESSING or COMPLETED
83
+ wait_for_jobs_status(jobs_created, expected_statuses=[Status.PROCESSING, Status.COMPLETED])
84
+
85
+ # Test basic pagination with limit
86
+ jobs_page1 = Job.query(type="EVALUATE_MODEL", limit=2)
87
+ assert len(jobs_page1) == 2
88
+
89
+ # Test pagination with offset
90
+ jobs_page2 = Job.query(type="EVALUATE_MODEL", limit=2, offset=1)
91
+ assert len(jobs_page2) == 2
92
+
93
+ # Verify different pages contain different jobs (allowing for some overlap due to timing)
94
+ page1_ids = {job.id for job in jobs_page1}
95
+ page2_ids = {job.id for job in jobs_page2}
96
+
97
+ # At least one job should be different between pages
98
+ assert len(page1_ids.symmetric_difference(page2_ids)) > 0
99
+
100
+ # Test filtering by status
101
+ jobs_by_status = Job.query(status=Status.PROCESSING, limit=10)
102
+ for job in jobs_by_status:
103
+ assert job.status == Status.PROCESSING
104
+
105
+ # Test filtering by multiple statuses
106
+ multi_status_jobs = Job.query(status=[Status.PROCESSING, Status.COMPLETED], limit=10)
107
+ for job in multi_status_jobs:
108
+ assert job.status in [Status.PROCESSING, Status.COMPLETED]