orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,14 +1,17 @@
1
+ import logging
2
+ from typing import get_args
1
3
  from uuid import uuid4
2
4
 
3
5
  import pytest
4
6
 
5
7
  from .datasource import Datasource
6
8
  from .embedding_model import (
9
+ ClassificationMetrics,
7
10
  FinetunedEmbeddingModel,
8
11
  PretrainedEmbeddingModel,
9
12
  PretrainedEmbeddingModelName,
10
- TaskStatus,
11
13
  )
14
+ from .job import Status
12
15
  from .memoryset import LabeledMemoryset
13
16
 
14
17
 
@@ -29,13 +32,16 @@ def test_open_pretrained_model_unauthenticated(unauthenticated):
29
32
 
30
33
  def test_open_pretrained_model_not_found():
31
34
  with pytest.raises(LookupError):
32
- PretrainedEmbeddingModel._get("INVALID_MODEL")
35
+ PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
33
36
 
34
37
 
35
38
  def test_all_pretrained_models():
36
39
  models = PretrainedEmbeddingModel.all()
37
- assert len(models) == len(PretrainedEmbeddingModelName)
38
- assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
40
+ assert len(models) > 1
41
+ if len(models) != len(get_args(PretrainedEmbeddingModelName)):
42
+ logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
43
+ model_names = [m.name for m in models]
44
+ assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
39
45
 
40
46
 
41
47
  def test_embed_text():
@@ -51,6 +57,13 @@ def test_embed_text_unauthenticated(unauthenticated):
51
57
  PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
52
58
 
53
59
 
60
+ def test_evaluate_pretrained_model(datasource: Datasource):
61
+ metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
62
+ assert metrics is not None
63
+ assert isinstance(metrics, ClassificationMetrics)
64
+ assert metrics.accuracy > 0.5
65
+
66
+
54
67
  @pytest.fixture(scope="session")
55
68
  def finetuned_model(datasource) -> FinetunedEmbeddingModel:
56
69
  return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
@@ -62,26 +75,25 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
62
75
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
63
76
  assert finetuned_model.embedding_dim == 768
64
77
  assert finetuned_model.max_seq_length == 512
65
- assert finetuned_model._status == TaskStatus.COMPLETED
78
+ assert finetuned_model._status == Status.COMPLETED
66
79
 
67
80
 
68
- def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
- finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
81
+ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
82
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
83
+ "test_finetuned_model_from_memoryset", readonly_memoryset
84
+ )
70
85
  assert finetuned_model is not None
71
86
  assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
87
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
73
88
  assert finetuned_model.embedding_dim == 768
74
89
  assert finetuned_model.max_seq_length == 512
75
- assert finetuned_model._status == TaskStatus.COMPLETED
90
+ assert finetuned_model._status == Status.COMPLETED
76
91
 
77
92
 
78
93
  def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
79
94
  with pytest.raises(ValueError):
80
95
  PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
81
96
 
82
- with pytest.raises(ValueError):
83
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="error")
84
-
85
97
 
86
98
  def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
87
99
  with pytest.raises(ValueError):
@@ -93,7 +105,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
93
105
  assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
94
106
  assert new_model.embedding_dim == 768
95
107
  assert new_model.max_seq_length == 512
96
- assert new_model._status == TaskStatus.COMPLETED
108
+ assert new_model._status == Status.COMPLETED
97
109
 
98
110
 
99
111
  def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
@@ -106,7 +118,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
106
118
  "test_memoryset_finetuned_model",
107
119
  datasource,
108
120
  embedding_model=finetuned_model,
109
- value_column="text",
110
121
  )
111
122
  assert memoryset is not None
112
123
  assert memoryset.name == "test_memoryset_finetuned_model"
@@ -171,3 +182,18 @@ def test_drop_finetuned_model_not_found():
171
182
  def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
172
183
  with pytest.raises(LookupError):
173
184
  FinetunedEmbeddingModel.drop(finetuned_model.id)
185
+
186
+
187
+ def test_supports_instructions():
188
+ model = PretrainedEmbeddingModel.GTE_BASE
189
+ assert not model.supports_instructions
190
+
191
+ instruction_model = PretrainedEmbeddingModel.BGE_BASE
192
+ assert instruction_model.supports_instructions
193
+
194
+
195
+ def test_use_explicit_instruction_prompt():
196
+ model = PretrainedEmbeddingModel.BGE_BASE
197
+ assert model.supports_instructions
198
+ input = "Hello world"
199
+ assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
orca_sdk/job.py ADDED
@@ -0,0 +1,337 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from datetime import datetime, timedelta
5
+ from enum import Enum
6
+ from typing import Callable, Generic, TypedDict, TypeVar, cast
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ from .client import orca_api
11
+
12
+
13
+ class JobConfig(TypedDict):
14
+ refresh_interval: int
15
+ show_progress: bool
16
+ max_wait: int
17
+
18
+
19
+ class Status(Enum):
20
+ """Status of a cloud job in the task queue"""
21
+
22
+ # the INITIALIZED state should never be returned by the API
23
+ INITIALIZED = "INITIALIZED"
24
+ """The job has been initialized"""
25
+
26
+ DISPATCHED = "DISPATCHED"
27
+ """The job has been queued and is waiting to be processed"""
28
+
29
+ WAITING = "WAITING"
30
+ """The job is waiting for dependencies to complete"""
31
+
32
+ PROCESSING = "PROCESSING"
33
+ """The job is being processed"""
34
+
35
+ COMPLETED = "COMPLETED"
36
+ """The job has been completed successfully"""
37
+
38
+ FAILED = "FAILED"
39
+ """The job has failed"""
40
+
41
+ ABORTING = "ABORTING"
42
+ """The job is being aborted"""
43
+
44
+ ABORTED = "ABORTED"
45
+ """The job has been aborted"""
46
+
47
+
48
+ TResult = TypeVar("TResult")
49
+
50
+
51
+ class Job(Generic[TResult]):
52
+ """
53
+ Handle to a job that is run in the OrcaCloud
54
+
55
+ Attributes:
56
+ id: Unique identifier for the job
57
+ type: Type of the job
58
+ status: Current status of the job
59
+ steps_total: Total number of steps in the job, present if the job started processing
60
+ steps_completed: Number of steps completed in the job, present if the job started processing
61
+ completion: Percentage of the job that has been completed, present if the job started processing
62
+ exception: Exception that occurred during the job, present if the status is `FAILED`
63
+ value: Value of the result of the job, present if the status is `COMPLETED`
64
+ created_at: When the job was queued for processing
65
+ updated_at: When the job was last updated
66
+ refreshed_at: When the job status was last refreshed
67
+
68
+ Note:
69
+ Accessing status and related attributes will refresh the job status in the background.
70
+ """
71
+
72
+ id: str
73
+ type: str
74
+ status: Status
75
+ steps_total: int | None
76
+ steps_completed: int | None
77
+ exception: str | None
78
+ value: TResult | None
79
+ updated_at: datetime
80
+ created_at: datetime
81
+ refreshed_at: datetime
82
+
83
+ @property
84
+ def completion(self) -> float:
85
+ """
86
+ Percentage of the job that has been completed, present if the job started processing
87
+ """
88
+ return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
89
+
90
+ # Global configuration for all jobs
91
+ config: JobConfig = {
92
+ "refresh_interval": 3,
93
+ "show_progress": True,
94
+ "max_wait": 60 * 60,
95
+ }
96
+
97
+ def __repr__(self) -> str:
98
+ return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
99
+
100
+ @classmethod
101
+ def set_config(
102
+ cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
103
+ ):
104
+ """
105
+ Set global configuration for running jobs
106
+
107
+ Args:
108
+ refresh_interval: Time to wait between polling the job status in seconds, default is 3
109
+ show_progress: Whether to show a progress bar when calling the wait method, default is True
110
+ max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
111
+ """
112
+ if refresh_interval is not None:
113
+ cls.config["refresh_interval"] = refresh_interval
114
+ if show_progress is not None:
115
+ cls.config["show_progress"] = show_progress
116
+ if max_wait is not None:
117
+ cls.config["max_wait"] = max_wait
118
+
119
+ @classmethod
120
+ def query(
121
+ cls,
122
+ status: Status | list[Status] | None = None,
123
+ type: str | list[str] | None = None,
124
+ limit: int = 100,
125
+ offset: int = 0,
126
+ start: datetime | None = None,
127
+ end: datetime | None = None,
128
+ ) -> list[Job]:
129
+ """
130
+ Query the job queue for jobs matching the given filters
131
+
132
+ Args:
133
+ status: Optional status or list of statuses to filter by
134
+ type: Optional type or list of types to filter by
135
+ limit: Maximum number of jobs to return
136
+ offset: Offset into the list of jobs to return
137
+ start: Optional minimum creation time of the jobs to query for
138
+ end: Optional maximum creation time of the jobs to query for
139
+
140
+ Returns:
141
+ List of jobs matching the given filters
142
+ """
143
+ paginated_tasks = orca_api.GET(
144
+ "/task",
145
+ params={
146
+ "status": (
147
+ [s.value for s in status]
148
+ if isinstance(status, list)
149
+ else status.value if isinstance(status, Status) else None
150
+ ),
151
+ "type": type,
152
+ "limit": limit,
153
+ "offset": offset,
154
+ "start_timestamp": start.isoformat() if start is not None else None,
155
+ "end_timestamp": end.isoformat() if end is not None else None,
156
+ },
157
+ )
158
+
159
+ # can't use constructor because it makes an API call, so we construct the objects manually
160
+ return [
161
+ (
162
+ lambda t: (
163
+ obj := cls.__new__(cls),
164
+ setattr(obj, "id", t["id"]),
165
+ setattr(obj, "type", t["type"]),
166
+ setattr(obj, "status", Status(t["status"])),
167
+ setattr(obj, "steps_total", t["steps_total"]),
168
+ setattr(obj, "steps_completed", t["steps_completed"]),
169
+ setattr(obj, "exception", t["exception"]),
170
+ setattr(obj, "value", cast(TResult, t["result"]) if t["result"] is not None else None),
171
+ setattr(obj, "updated_at", datetime.fromisoformat(t["updated_at"])),
172
+ setattr(obj, "created_at", datetime.fromisoformat(t["created_at"])),
173
+ setattr(obj, "refreshed_at", datetime.now()),
174
+ obj,
175
+ )[-1]
176
+ )(t)
177
+ for t in paginated_tasks["items"]
178
+ ]
179
+
180
+ def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
181
+ """
182
+ Create a handle to a job in the job queue
183
+
184
+ Args:
185
+ id: Unique identifier for the job
186
+ get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
187
+ """
188
+ self.id = id
189
+ task = orca_api.GET("/task/{task_id}", params={"task_id": id})
190
+
191
+ self._get_value = get_value or (
192
+ lambda: cast(TResult | None, orca_api.GET("/task/{task_id}", params={"task_id": id})["result"])
193
+ )
194
+ self.type = task["type"]
195
+ self.status = Status(task["status"])
196
+ self.steps_total = task["steps_total"]
197
+ self.steps_completed = task["steps_completed"]
198
+ self.exception = task["exception"]
199
+ self.value = (
200
+ None
201
+ if task["status"] != "COMPLETED"
202
+ else (
203
+ get_value()
204
+ if get_value is not None
205
+ else cast(TResult, task["result"]) if task["result"] is not None else None
206
+ )
207
+ )
208
+ self.updated_at = datetime.fromisoformat(task["updated_at"])
209
+ self.created_at = datetime.fromisoformat(task["created_at"])
210
+ self.refreshed_at = datetime.now()
211
+
212
+ def refresh(self, throttle: float = 0):
213
+ """
214
+ Refresh the status and progress of the job
215
+
216
+ Params:
217
+ throttle: Minimum time in seconds between refreshes
218
+ """
219
+ current_time = datetime.now()
220
+ # Skip refresh if last refresh was too recent
221
+ if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
222
+ return
223
+ self.refreshed_at = current_time
224
+
225
+ status_info = orca_api.GET("/task/{task_id}/status", params={"task_id": self.id})
226
+ self.status = Status(status_info["status"])
227
+ if status_info["steps_total"] is not None:
228
+ self.steps_total = status_info["steps_total"]
229
+ if status_info["steps_completed"] is not None:
230
+ self.steps_completed = status_info["steps_completed"]
231
+
232
+ self.exception = status_info["exception"]
233
+ self.updated_at = datetime.fromisoformat(status_info["updated_at"])
234
+
235
+ if status_info["status"] == "COMPLETED":
236
+ self.value = self._get_value()
237
+
238
+ def __getattribute__(self, name: str):
239
+ # if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
240
+ if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
241
+ self.refresh(self.config["refresh_interval"])
242
+ return super().__getattribute__(name)
243
+
244
+ def wait(
245
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
246
+ ) -> None:
247
+ """
248
+ Block until the job is complete
249
+
250
+ Params:
251
+ show_progress: Show a progress bar while waiting for the job to complete
252
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
253
+ max_wait: Maximum time to wait for the job to complete in seconds
254
+
255
+ Note:
256
+ The defaults for the config parameters can be set globally using the
257
+ [`set_config`][orca_sdk.Job.set_config] method.
258
+
259
+ This method will not return the result or raise an exception if the job fails. Call
260
+ [`result`][orca_sdk.Job.result] instead if you want to get the result.
261
+
262
+ Raises:
263
+ RuntimeError: If the job times out
264
+ """
265
+ start_time = time.time()
266
+ show_progress = show_progress if show_progress is not None else self.config["show_progress"]
267
+ refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
268
+ max_wait = max_wait if max_wait is not None else self.config["max_wait"]
269
+ pbar = None
270
+ while True:
271
+ # setup progress bar if steps total is known
272
+ if not pbar and self.steps_total is not None and show_progress:
273
+ desc = " ".join(self.type.split("_")).lower()
274
+ pbar = tqdm(total=self.steps_total, desc=desc)
275
+
276
+ # return if job is complete
277
+ if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
278
+ if pbar:
279
+ pbar.update(self.steps_total - pbar.n)
280
+ pbar.close()
281
+ return
282
+
283
+ # raise error if job timed out
284
+ if (time.time() - start_time) > max_wait:
285
+ raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
286
+
287
+ # update progress bar
288
+ if pbar and self.steps_completed is not None:
289
+ pbar.update(self.steps_completed - pbar.n)
290
+
291
+ # sleep before retrying
292
+ time.sleep(refresh_interval)
293
+
294
+ def result(
295
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
296
+ ) -> TResult:
297
+ """
298
+ Block until the job is complete and return the result value
299
+
300
+ Params:
301
+ show_progress: Show a progress bar while waiting for the job to complete
302
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
303
+ max_wait: Maximum time to wait for the job to complete in seconds
304
+
305
+ Note:
306
+ The defaults for the config parameters can be set globally using the
307
+ [`set_config`][orca_sdk.Job.set_config] method.
308
+
309
+ This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
310
+ if you just want to wait for the job to complete without raising errors on failure.
311
+
312
+ Returns:
313
+ The result value of the job
314
+
315
+ Raises:
316
+ RuntimeError: If the job fails or times out
317
+ """
318
+ if self.value is not None:
319
+ return self.value
320
+ self.wait(show_progress, refresh_interval, max_wait)
321
+ if self.status != Status.COMPLETED:
322
+ raise RuntimeError(f"Job failed with exception: {self.exception}")
323
+ assert self.value is not None
324
+ return self.value
325
+
326
+
327
+ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
328
+ """
329
+ Abort the job
330
+
331
+ Params:
332
+ show_progress: Optionally show a progress bar while waiting for the job to abort
333
+ refresh_interval: Polling interval in seconds while waiting for the job to abort
334
+ max_wait: Maximum time to wait for the job to abort in seconds
335
+ """
336
+ orca_api.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
337
+ self.wait(show_progress, refresh_interval, max_wait)
orca_sdk/job_test.py ADDED
@@ -0,0 +1,108 @@
1
+ import time
2
+
3
+ import pytest
4
+ from datasets import Dataset
5
+
6
+ from .classification_model import ClassificationModel
7
+ from .datasource import Datasource
8
+ from .job import Job, Status
9
+
10
+
11
+ @pytest.fixture(scope="session")
12
+ def datasource_without_nones(hf_dataset: Dataset):
13
+ return Datasource.from_hf_dataset(
14
+ "test_datasource_without_nones", hf_dataset.filter(lambda x: x["label"] is not None)
15
+ )
16
+
17
+
18
+ def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0.2):
19
+ """
20
+ Wait until all jobs reach one of the expected statuses or timeout is reached.
21
+ """
22
+ start = time.time()
23
+ while time.time() - start < timeout:
24
+ jobs = [Job(job_id) for job_id in job_ids]
25
+ if all(job.status in expected_statuses for job in jobs):
26
+ return
27
+ time.sleep(poll_interval)
28
+ raise TimeoutError(f"Jobs did not reach statuses {expected_statuses} within {timeout} seconds")
29
+
30
+
31
+ def test_job_creation(classification_model: ClassificationModel, datasource_without_nones: Datasource):
32
+ job = classification_model.evaluate(datasource_without_nones, background=True)
33
+ assert job.id is not None
34
+ assert job.type == "EVALUATE_MODEL"
35
+ assert job.status in [Status.DISPATCHED, Status.PROCESSING]
36
+ assert job.created_at is not None
37
+ assert job.updated_at is not None
38
+ assert job.refreshed_at is not None
39
+ assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
40
+
41
+
42
+ def test_job_result(classification_model: ClassificationModel, datasource_without_nones: Datasource):
43
+ job = classification_model.evaluate(datasource_without_nones, background=True)
44
+ result = job.result(show_progress=False)
45
+ assert result is not None
46
+ assert job.status == Status.COMPLETED
47
+ assert job.steps_completed is not None
48
+ assert job.steps_completed == job.steps_total
49
+
50
+
51
+ def test_job_wait(classification_model: ClassificationModel, datasource_without_nones: Datasource):
52
+ job = classification_model.evaluate(datasource_without_nones, background=True)
53
+ job.wait(show_progress=False)
54
+ assert job.status == Status.COMPLETED
55
+ assert job.steps_completed is not None
56
+ assert job.steps_completed == job.steps_total
57
+ assert job.value is not None
58
+
59
+
60
+ def test_job_refresh(classification_model: ClassificationModel, datasource_without_nones: Datasource):
61
+ job = classification_model.evaluate(datasource_without_nones, background=True)
62
+ last_refreshed_at = job.refreshed_at
63
+ # accessing the status attribute should refresh the job after the refresh interval
64
+ Job.set_config(refresh_interval=1)
65
+ time.sleep(1)
66
+ job.status
67
+ assert job.refreshed_at > last_refreshed_at
68
+ last_refreshed_at = job.refreshed_at
69
+ # calling refresh() should immediately refresh the job
70
+ job.refresh()
71
+ assert job.refreshed_at > last_refreshed_at
72
+
73
+
74
+ def test_job_query_pagination(classification_model: ClassificationModel, datasource_without_nones: Datasource):
75
+ """Test pagination with Job.query() method"""
76
+ # Create multiple jobs to test pagination
77
+ jobs_created = []
78
+ for i in range(3):
79
+ job = classification_model.evaluate(datasource_without_nones, background=True)
80
+ jobs_created.append(job.id)
81
+
82
+ # Wait for jobs to be at least PROCESSING or COMPLETED
83
+ wait_for_jobs_status(jobs_created, expected_statuses=[Status.PROCESSING, Status.COMPLETED])
84
+
85
+ # Test basic pagination with limit
86
+ jobs_page1 = Job.query(type="EVALUATE_MODEL", limit=2)
87
+ assert len(jobs_page1) == 2
88
+
89
+ # Test pagination with offset
90
+ jobs_page2 = Job.query(type="EVALUATE_MODEL", limit=2, offset=1)
91
+ assert len(jobs_page2) == 2
92
+
93
+ # Verify different pages contain different jobs (allowing for some overlap due to timing)
94
+ page1_ids = {job.id for job in jobs_page1}
95
+ page2_ids = {job.id for job in jobs_page2}
96
+
97
+ # At least one job should be different between pages
98
+ assert len(page1_ids.symmetric_difference(page2_ids)) > 0
99
+
100
+ # Test filtering by status
101
+ jobs_by_status = Job.query(status=Status.PROCESSING, limit=10)
102
+ for job in jobs_by_status:
103
+ assert job.status == Status.PROCESSING
104
+
105
+ # Test filtering by multiple statuses
106
+ multi_status_jobs = Job.query(status=[Status.PROCESSING, Status.COMPLETED], limit=10)
107
+ for job in multi_status_jobs:
108
+ assert job.status in [Status.PROCESSING, Status.COMPLETED]