adaptive-sdk 0.1.14__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_sdk/graphql_client/__init__.py +704 -12
- adaptive_sdk/graphql_client/custom_fields.py +1640 -960
- adaptive_sdk/graphql_client/custom_mutations.py +332 -177
- adaptive_sdk/graphql_client/custom_queries.py +278 -133
- adaptive_sdk/graphql_client/custom_typing_fields.py +74 -0
- adaptive_sdk/graphql_client/enums.py +38 -0
- adaptive_sdk/graphql_client/input_types.py +586 -213
- adaptive_sdk/resources/abtests.py +2 -0
- adaptive_sdk/resources/base_resource.py +22 -0
- adaptive_sdk/resources/chat.py +16 -7
- adaptive_sdk/resources/compute_pools.py +16 -4
- adaptive_sdk/resources/embeddings.py +11 -7
- adaptive_sdk/resources/feedback.py +22 -33
- adaptive_sdk/resources/interactions.py +6 -3
- adaptive_sdk/resources/jobs.py +90 -6
- adaptive_sdk/resources/models.py +20 -3
- adaptive_sdk/resources/permissions.py +12 -1
- adaptive_sdk/resources/recipes.py +177 -31
- adaptive_sdk/resources/roles.py +15 -8
- adaptive_sdk/resources/teams.py +30 -1
- adaptive_sdk/resources/use_cases.py +9 -6
- adaptive_sdk/resources/users.py +37 -22
- {adaptive_sdk-0.1.14.dist-info → adaptive_sdk-0.12.1.dist-info}/METADATA +1 -1
- {adaptive_sdk-0.1.14.dist-info → adaptive_sdk-0.12.1.dist-info}/RECORD +25 -25
- {adaptive_sdk-0.1.14.dist-info → adaptive_sdk-0.12.1.dist-info}/WHEEL +0 -0
|
@@ -91,6 +91,7 @@ class ABTests(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
91
91
|
Args:
|
|
92
92
|
active: Filter on active or inactive AB tests.
|
|
93
93
|
status: Filter on one of the possible AB test status.
|
|
94
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
94
95
|
"""
|
|
95
96
|
if status:
|
|
96
97
|
status_input = AbcampaignStatus(status.upper())
|
|
@@ -187,6 +188,7 @@ class AsyncABTests(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
187
188
|
Args:
|
|
188
189
|
active: Filter on active or inactive AB tests.
|
|
189
190
|
status: Filter on one of the possible AB test status.
|
|
191
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
190
192
|
"""
|
|
191
193
|
if status:
|
|
192
194
|
status_input = AbcampaignStatus(status.upper())
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
if TYPE_CHECKING:
|
|
@@ -20,10 +21,23 @@ class AsyncAPIResource:
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class UseCaseResource:
|
|
24
|
+
"""Mixin class for resources that operate within a use case context."""
|
|
25
|
+
|
|
23
26
|
def __init__(self, client: UseCaseClient) -> None:
|
|
24
27
|
self._client = client
|
|
25
28
|
|
|
26
29
|
def use_case_key(self, use_case: str | None) -> str:
|
|
30
|
+
"""Get the use case key, falling back to the client's default.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
use_case: Optional explicit use case key.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
The resolved use case key.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If no use case is provided and no default is set.
|
|
40
|
+
"""
|
|
27
41
|
target_use_case = use_case or self._client.default_use_case
|
|
28
42
|
if target_use_case is None:
|
|
29
43
|
raise ValueError(
|
|
@@ -36,6 +50,14 @@ or explicitly pass `use_case` as an input parameter."""
|
|
|
36
50
|
return target_use_case
|
|
37
51
|
|
|
38
52
|
def optional_use_case_key(self, use_case: str | None) -> str | None:
|
|
53
|
+
"""Get the use case key if available, or None.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
use_case: Optional explicit use case key.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The use case key if provided or default is set, otherwise None.
|
|
60
|
+
"""
|
|
39
61
|
if use_case:
|
|
40
62
|
return use_case
|
|
41
63
|
elif self._client.default_use_case:
|
adaptive_sdk/resources/chat.py
CHANGED
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import json
|
|
3
|
-
from loguru import logger
|
|
4
|
-
from uuid import UUID
|
|
5
4
|
from typing import (
|
|
6
|
-
|
|
7
|
-
List,
|
|
5
|
+
TYPE_CHECKING,
|
|
8
6
|
AsyncGenerator,
|
|
7
|
+
Dict,
|
|
9
8
|
Generator,
|
|
9
|
+
List,
|
|
10
10
|
Literal,
|
|
11
11
|
overload,
|
|
12
|
-
TYPE_CHECKING,
|
|
13
12
|
)
|
|
13
|
+
from uuid import UUID
|
|
14
|
+
|
|
15
|
+
from loguru import logger
|
|
14
16
|
from typing_extensions import override
|
|
17
|
+
|
|
15
18
|
from adaptive_sdk import input_types
|
|
16
19
|
from adaptive_sdk.error_handling import rest_error_handler
|
|
17
20
|
from adaptive_sdk.rest import rest_types
|
|
18
21
|
from adaptive_sdk.utils import convert_optional_UUID, get_full_model_path
|
|
19
22
|
|
|
20
|
-
from .base_resource import
|
|
23
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
21
24
|
|
|
22
25
|
if TYPE_CHECKING:
|
|
23
26
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
@@ -102,11 +105,14 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
102
105
|
top_p: Threshold for top-p sampling.
|
|
103
106
|
stream_include_usage: If set, an additional chunk will be streamed with the token usage statistics for
|
|
104
107
|
the entire request.
|
|
108
|
+
session_id: Session ID to group related interactions.
|
|
109
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
105
110
|
user: ID of user making request. If not `None`, will be logged as metadata for the request.
|
|
106
111
|
ab_campaign: AB test key. If set, request will be guaranteed to count towards AB test results,
|
|
107
112
|
no matter the configured `traffic_split`.
|
|
108
113
|
n: Number of chat completions to generate for each input messages.
|
|
109
114
|
labels: Key-value pairs of interaction labels.
|
|
115
|
+
store: Whether to store the interaction for future reference. Stores by default.
|
|
110
116
|
|
|
111
117
|
Examples:
|
|
112
118
|
```
|
|
@@ -237,13 +243,16 @@ class AsyncChat(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
237
243
|
max_tokens: Maximum # of tokens allowed to generate.
|
|
238
244
|
temperature: Sampling temperature.
|
|
239
245
|
top_p: Threshold for top-p sampling.
|
|
240
|
-
stream_include_usage: If set, an additional chunk will be streamed with the token
|
|
246
|
+
stream_include_usage: If set, an additional chunk will be streamed with the token usage statistics for
|
|
241
247
|
the entire request.
|
|
248
|
+
session_id: Session ID to group related interactions.
|
|
249
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
242
250
|
user: ID of user making request. If not `None`, will be logged as metadata for the request.
|
|
243
251
|
ab_campaign: AB test key. If set, request will be guaranteed to count towards AB test results,
|
|
244
252
|
no matter the configured `traffic_split`.
|
|
245
253
|
n: Number of chat completions to generate for each input messages.
|
|
246
254
|
labels: Key-value pairs of interaction labels.
|
|
255
|
+
store: Whether to store the interaction for future reference. Stores by default.
|
|
247
256
|
|
|
248
257
|
Examples:
|
|
249
258
|
```
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
# type: ignore
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
+
|
|
4
5
|
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
5
7
|
from pydantic import BaseModel
|
|
6
8
|
|
|
7
|
-
from adaptive_sdk.graphql_client import
|
|
9
|
+
from adaptive_sdk.graphql_client import HarmonyStatus, ResizePartitionInput
|
|
8
10
|
|
|
9
|
-
from .base_resource import
|
|
11
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
14
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
@@ -28,6 +30,11 @@ class ComputePools(SyncAPIResource, UseCaseResource):
|
|
|
28
30
|
UseCaseResource.__init__(self, client)
|
|
29
31
|
|
|
30
32
|
def list(self):
|
|
33
|
+
"""List all compute pools available in the system.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A list of compute pool objects.
|
|
37
|
+
"""
|
|
31
38
|
return self._gql_client.list_compute_pools().compute_pools
|
|
32
39
|
|
|
33
40
|
def resize_inference_partition(self, compute_pool_key: str, size: int) -> list[ResizeResult]:
|
|
@@ -52,7 +59,7 @@ class ComputePools(SyncAPIResource, UseCaseResource):
|
|
|
52
59
|
_ = self._gql_client.resize_inference_partition(input)
|
|
53
60
|
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=True))
|
|
54
61
|
except Exception as e:
|
|
55
|
-
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False,
|
|
62
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False, error=str(e)))
|
|
56
63
|
|
|
57
64
|
return resize_results
|
|
58
65
|
|
|
@@ -67,6 +74,11 @@ class AsyncComputePools(AsyncAPIResource, UseCaseResource):
|
|
|
67
74
|
UseCaseResource.__init__(self, client)
|
|
68
75
|
|
|
69
76
|
async def list(self):
|
|
77
|
+
"""List all compute pools available in the system.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A list of compute pool objects.
|
|
81
|
+
"""
|
|
70
82
|
return (await self._gql_client.list_compute_pools()).compute_pools
|
|
71
83
|
|
|
72
84
|
async def resize_inference_partition(self, compute_pool_key: str, size: int) -> list[ResizeResult]:
|
|
@@ -91,6 +103,6 @@ class AsyncComputePools(AsyncAPIResource, UseCaseResource):
|
|
|
91
103
|
_ = await self._gql_client.resize_inference_partition(input)
|
|
92
104
|
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=True))
|
|
93
105
|
except Exception as e:
|
|
94
|
-
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False,
|
|
106
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False, error=str(e)))
|
|
95
107
|
|
|
96
108
|
return resize_results
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
3
4
|
from uuid import UUID
|
|
4
|
-
|
|
5
|
+
|
|
5
6
|
from adaptive_sdk.error_handling import rest_error_handler
|
|
6
7
|
from adaptive_sdk.rest import rest_types
|
|
7
8
|
from adaptive_sdk.utils import convert_optional_UUID, get_full_model_path
|
|
8
9
|
|
|
9
|
-
from .base_resource import
|
|
10
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
@@ -51,7 +52,8 @@ class Embeddings(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
51
52
|
user=convert_optional_UUID(user),
|
|
52
53
|
)
|
|
53
54
|
r = self._rest_client.post(
|
|
54
|
-
ROUTE,
|
|
55
|
+
ROUTE,
|
|
56
|
+
json=emb_input.model_dump_json(exclude_none=True),
|
|
55
57
|
)
|
|
56
58
|
rest_error_handler(r)
|
|
57
59
|
return rest_types.EmbeddingsResponseList.model_validate(r.json())
|
|
@@ -69,7 +71,7 @@ class AsyncEmbeddings(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
69
71
|
async def create(
|
|
70
72
|
self,
|
|
71
73
|
input: str,
|
|
72
|
-
|
|
74
|
+
model: str | None = None,
|
|
73
75
|
encoding_format: Literal["Float", "Base64"] = "Float",
|
|
74
76
|
use_case: str | None = None,
|
|
75
77
|
user: str | UUID | None = None,
|
|
@@ -82,18 +84,20 @@ class AsyncEmbeddings(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
82
84
|
model: Target model key for inference. If `None`, the requests will be routed to the use case's default model.
|
|
83
85
|
Request will error if default model is not an embedding model.
|
|
84
86
|
encoding_format: Encoding format of response.
|
|
87
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
85
88
|
user: ID of user making the requests. If not `None`, will be logged as metadata for the request.
|
|
86
89
|
"""
|
|
87
90
|
encoding_format_enum = rest_types.EmbeddingsEncodingFormat(encoding_format)
|
|
88
91
|
emb_input = rest_types.GenerateEmbeddingsInput(
|
|
89
92
|
input=input,
|
|
90
|
-
model=get_full_model_path(self.use_case_key(use_case),
|
|
93
|
+
model=get_full_model_path(self.use_case_key(use_case), model),
|
|
91
94
|
encoding_format=encoding_format_enum,
|
|
92
95
|
dimensions=None,
|
|
93
96
|
user=convert_optional_UUID(user),
|
|
94
97
|
)
|
|
95
98
|
r = await self._rest_client.post(
|
|
96
|
-
ROUTE,
|
|
99
|
+
ROUTE,
|
|
100
|
+
json=emb_input.model_dump_json(exclude_none=True),
|
|
97
101
|
)
|
|
98
102
|
rest_error_handler(r)
|
|
99
103
|
return rest_types.EmbeddingsResponseList.model_validate(r.json())
|
|
@@ -1,25 +1,29 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Sequence
|
|
3
4
|
from uuid import UUID
|
|
5
|
+
|
|
4
6
|
from typing_extensions import override
|
|
7
|
+
|
|
5
8
|
from adaptive_sdk import input_types
|
|
6
9
|
from adaptive_sdk.error_handling import rest_error_handler
|
|
7
10
|
from adaptive_sdk.graphql_client import (
|
|
11
|
+
MetricCreate,
|
|
8
12
|
MetricData,
|
|
9
13
|
MetricDataAdmin,
|
|
10
|
-
MetricCreate,
|
|
11
14
|
MetricKind,
|
|
12
|
-
MetricScoringType,
|
|
13
|
-
MetricWithContextData,
|
|
14
15
|
MetricLink,
|
|
16
|
+
MetricScoringType,
|
|
15
17
|
MetricUnlink,
|
|
18
|
+
MetricWithContextData,
|
|
16
19
|
)
|
|
17
20
|
from adaptive_sdk.rest import rest_types
|
|
18
21
|
from adaptive_sdk.utils import (
|
|
19
22
|
convert_optional_UUID,
|
|
20
23
|
validate_comparison_completion,
|
|
21
24
|
)
|
|
22
|
-
|
|
25
|
+
|
|
26
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
23
27
|
|
|
24
28
|
if TYPE_CHECKING:
|
|
25
29
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
@@ -41,10 +45,8 @@ class Feedback(SyncAPIResource, UseCaseResource): # type: ignore
|
|
|
41
45
|
def register_key(
|
|
42
46
|
self,
|
|
43
47
|
key: str,
|
|
44
|
-
kind: Literal["scalar", "bool"]
|
|
45
|
-
scoring_type: Literal[
|
|
46
|
-
"higher_is_better", "lower_is_better"
|
|
47
|
-
] = "higher_is_better",
|
|
48
|
+
kind: Literal["scalar", "bool"],
|
|
49
|
+
scoring_type: Literal["higher_is_better", "lower_is_better"] = "higher_is_better",
|
|
48
50
|
name: str | None = None,
|
|
49
51
|
description: str | None = None,
|
|
50
52
|
) -> MetricData:
|
|
@@ -58,7 +60,7 @@ class Feedback(SyncAPIResource, UseCaseResource): # type: ignore
|
|
|
58
60
|
If `"scalar"`, you can log any integer or float value.
|
|
59
61
|
scoring_type: Indication of what good means for this feeback key; a higher numeric value (or `True`)
|
|
60
62
|
, or a lower numeric value (or `False`).
|
|
61
|
-
name Human-readable feedback name that will render in the UI. If `None`, will be the same as `key`.
|
|
63
|
+
name: Human-readable feedback name that will render in the UI. If `None`, will be the same as `key`.
|
|
62
64
|
description: Description of intended purpose or nuances of feedback. Will render in the UI.
|
|
63
65
|
"""
|
|
64
66
|
input = MetricCreate(
|
|
@@ -140,9 +142,7 @@ class Feedback(SyncAPIResource, UseCaseResource): # type: ignore
|
|
|
140
142
|
user_id=convert_optional_UUID(user),
|
|
141
143
|
details=details,
|
|
142
144
|
)
|
|
143
|
-
r = self._rest_client.post(
|
|
144
|
-
FEEDBACK_ROUTE, json=input.model_dump(exclude_none=True)
|
|
145
|
-
)
|
|
145
|
+
r = self._rest_client.post(FEEDBACK_ROUTE, json=input.model_dump(exclude_none=True))
|
|
146
146
|
rest_error_handler(r)
|
|
147
147
|
return rest_types.FeedbackOutput.model_validate(r.json())
|
|
148
148
|
|
|
@@ -171,9 +171,7 @@ class Feedback(SyncAPIResource, UseCaseResource): # type: ignore
|
|
|
171
171
|
tied: Indicator if both completions tied as equally bad or equally good.
|
|
172
172
|
"""
|
|
173
173
|
|
|
174
|
-
clean_preffered_completion = validate_comparison_completion(
|
|
175
|
-
preferred_completion
|
|
176
|
-
)
|
|
174
|
+
clean_preffered_completion = validate_comparison_completion(preferred_completion)
|
|
177
175
|
clean_other_completion = validate_comparison_completion(other_completion)
|
|
178
176
|
input_messages = [rest_types.ChatMessage(**m) for m in messages] if messages else None # type: ignore
|
|
179
177
|
|
|
@@ -186,9 +184,7 @@ class Feedback(SyncAPIResource, UseCaseResource): # type: ignore
|
|
|
186
184
|
tied=rest_types.ComparisonTie(tied) if tied else None,
|
|
187
185
|
use_case=self.use_case_key(use_case),
|
|
188
186
|
)
|
|
189
|
-
r = self._rest_client.post(
|
|
190
|
-
PREFERENCE_ROUTE, json=input.model_dump(exclude_none=True)
|
|
191
|
-
)
|
|
187
|
+
r = self._rest_client.post(PREFERENCE_ROUTE, json=input.model_dump(exclude_none=True))
|
|
192
188
|
rest_error_handler(r)
|
|
193
189
|
return rest_types.ComparisonOutput.model_validate(r.json())
|
|
194
190
|
|
|
@@ -206,9 +202,7 @@ class AsyncFeedback(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
206
202
|
self,
|
|
207
203
|
key: str,
|
|
208
204
|
kind: Literal["scalar", "bool"],
|
|
209
|
-
scoring_type: Literal[
|
|
210
|
-
"higher_is_better", "lower_is_better"
|
|
211
|
-
] = "higher_is_better",
|
|
205
|
+
scoring_type: Literal["higher_is_better", "lower_is_better"] = "higher_is_better",
|
|
212
206
|
name: str | None = None,
|
|
213
207
|
description: str | None = None,
|
|
214
208
|
) -> MetricData:
|
|
@@ -295,7 +289,7 @@ class AsyncFeedback(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
295
289
|
value: The feedback values.
|
|
296
290
|
completion_id: The completion_id to attach the feedback to.
|
|
297
291
|
feedback_key: The feedback key to log against.
|
|
298
|
-
|
|
292
|
+
user_id: ID of user submitting feedback. If not `None`, will be logged as metadata for the request.
|
|
299
293
|
details: Textual details for the feedback. Can be used to provide further context on the feedback `value`.
|
|
300
294
|
"""
|
|
301
295
|
input = rest_types.AddFeedbackRequest(
|
|
@@ -305,9 +299,7 @@ class AsyncFeedback(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
305
299
|
user_id=convert_optional_UUID(user_id),
|
|
306
300
|
details=details,
|
|
307
301
|
)
|
|
308
|
-
r = await self._rest_client.post(
|
|
309
|
-
FEEDBACK_ROUTE, json=input.model_dump(exclude_none=True)
|
|
310
|
-
)
|
|
302
|
+
r = await self._rest_client.post(FEEDBACK_ROUTE, json=input.model_dump(exclude_none=True))
|
|
311
303
|
rest_error_handler(r)
|
|
312
304
|
return rest_types.FeedbackOutput.model_validate(r.json())
|
|
313
305
|
|
|
@@ -330,15 +322,14 @@ class AsyncFeedback(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
330
322
|
corresponding the a valid model key and its attributed completion.
|
|
331
323
|
other_completion: Can be a completion_id or a dict with keys `model` and `text`,
|
|
332
324
|
corresponding the a valid model key and its attributed completion.
|
|
333
|
-
|
|
325
|
+
user_id: ID of user submitting feedback.
|
|
334
326
|
messages: Input chat messages, each dict with keys `role` and `content`.
|
|
335
327
|
Ignored if `preferred_` and `other_completion` are completion_ids.
|
|
336
328
|
tied: Indicator if both completions tied as equally bad or equally good.
|
|
329
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
337
330
|
|
|
338
331
|
"""
|
|
339
|
-
clean_preffered_completion = validate_comparison_completion(
|
|
340
|
-
preferred_completion
|
|
341
|
-
)
|
|
332
|
+
clean_preffered_completion = validate_comparison_completion(preferred_completion)
|
|
342
333
|
clean_other_completion = validate_comparison_completion(other_completion)
|
|
343
334
|
input_messages = [rest_types.ChatMessage(**m) for m in messages] if messages else None # type: ignore
|
|
344
335
|
|
|
@@ -351,8 +342,6 @@ class AsyncFeedback(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
351
342
|
tied=rest_types.ComparisonTie(tied) if tied else None,
|
|
352
343
|
use_case=self.use_case_key(use_case),
|
|
353
344
|
)
|
|
354
|
-
r = await self._rest_client.post(
|
|
355
|
-
PREFERENCE_ROUTE, json=input.model_dump(exclude_none=True)
|
|
356
|
-
)
|
|
345
|
+
r = await self._rest_client.post(PREFERENCE_ROUTE, json=input.model_dump(exclude_none=True))
|
|
357
346
|
rest_error_handler(r)
|
|
358
347
|
return rest_types.ComparisonOutput.model_validate(r.json())
|
|
@@ -76,11 +76,13 @@ class Interactions(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
76
76
|
Create/log an interaction.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
|
-
model: Model key.
|
|
80
79
|
messages: Input chat messages, each dict should have keys `role` and `content`.
|
|
81
80
|
completion: Model completion.
|
|
81
|
+
model: Model key.
|
|
82
82
|
feedbacks: List of feedbacks, each dict should with keys `feedback_key`, `value` and optional(`details`).
|
|
83
83
|
user: ID of user making the request. If not `None`, will be logged as metadata for the interaction.
|
|
84
|
+
session_id: Session ID to group related interactions.
|
|
85
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
84
86
|
ab_campaign: AB test key. If set, provided `feedbacks` will count towards AB test results.
|
|
85
87
|
labels: Key-value pairs of interaction labels.
|
|
86
88
|
created_at: Timestamp of interaction creation or ingestion.
|
|
@@ -187,14 +189,15 @@ class AsyncInteractions(AsyncAPIResource, UseCaseResource): # type: ignore[misc
|
|
|
187
189
|
Create/log an interaction.
|
|
188
190
|
|
|
189
191
|
Args:
|
|
190
|
-
model: Model key.
|
|
191
192
|
messages: Input chat messages, each dict should have keys `role` and `content`.
|
|
192
193
|
completion: Model completion.
|
|
194
|
+
model: Model key.
|
|
193
195
|
feedbacks: List of feedbacks, each dict should with keys `feedback_key`, `value` and optional(`details`).
|
|
194
196
|
user: ID of user making the request. If not `None`, will be logged as metadata for the interaction.
|
|
197
|
+
session_id: Session ID to group related interactions.
|
|
198
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
195
199
|
ab_campaign: AB test key. If set, provided `feedbacks` will count towards AB test results.
|
|
196
200
|
labels: Key-value pairs of interaction labels.
|
|
197
|
-
created_at: Timestamp of interaction creation or ingestion.
|
|
198
201
|
"""
|
|
199
202
|
input_messages, input_feedbacks = _prepare_add_interactions_inputs(messages, feedbacks)
|
|
200
203
|
|
adaptive_sdk/resources/jobs.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
3
4
|
|
|
4
5
|
from adaptive_sdk.graphql_client import (
|
|
5
6
|
CursorPageInput,
|
|
6
7
|
JobInput,
|
|
7
|
-
ListJobsJobs,
|
|
8
|
-
ListJobsFilterInput,
|
|
9
8
|
JobKind,
|
|
9
|
+
ListJobsFilterInput,
|
|
10
|
+
ListJobsJobs,
|
|
10
11
|
ListJobsJobsNodes,
|
|
11
12
|
)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from .base_resource import SyncAPIResource, AsyncAPIResource, UseCaseResource
|
|
15
13
|
from adaptive_sdk.output_types import JobDataPlus
|
|
16
14
|
|
|
15
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
16
|
+
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
19
19
|
|
|
@@ -28,6 +28,14 @@ class Jobs(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
28
28
|
UseCaseResource.__init__(self, client)
|
|
29
29
|
|
|
30
30
|
def get(self, job_id: str) -> JobDataPlus | None:
|
|
31
|
+
"""Get the details of a specific job.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
job_id: The ID of the job to retrieve.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The job data if found, otherwise None.
|
|
38
|
+
"""
|
|
31
39
|
job_data = self._gql_client.describe_job(id=job_id).job
|
|
32
40
|
return JobDataPlus.from_job_data(job_data) if job_data else None
|
|
33
41
|
|
|
@@ -51,6 +59,19 @@ class Jobs(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
51
59
|
) = None,
|
|
52
60
|
use_case: str | None = None,
|
|
53
61
|
) -> ListJobsJobs:
|
|
62
|
+
"""List jobs with pagination and filtering options.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
first: Number of jobs to return from the beginning.
|
|
66
|
+
last: Number of jobs to return from the end.
|
|
67
|
+
after: Cursor for forward pagination.
|
|
68
|
+
before: Cursor for backward pagination.
|
|
69
|
+
kind: Filter by job types.
|
|
70
|
+
use_case: Filter by use case key.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A paginated list of jobs.
|
|
74
|
+
"""
|
|
54
75
|
use_case = self.optional_use_case_key(use_case)
|
|
55
76
|
page = CursorPageInput(first=first, last=last, after=after, before=before)
|
|
56
77
|
validated_filter = ListJobsFilterInput(
|
|
@@ -74,6 +95,19 @@ class Jobs(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
74
95
|
use_case: str | None = None,
|
|
75
96
|
compute_pool: str | None = None,
|
|
76
97
|
) -> JobDataPlus:
|
|
98
|
+
"""Run a job using a specified recipe.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
recipe_key: The key of the recipe to run.
|
|
102
|
+
num_gpus: Number of GPUs to allocate for the job.
|
|
103
|
+
args: Optional arguments to pass to the recipe; must match the recipe schema.
|
|
104
|
+
name: Optional human-readable name for the job.
|
|
105
|
+
use_case: Use case key for the job.
|
|
106
|
+
compute_pool: Optional compute pool key to run the job on.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
The created job data.
|
|
110
|
+
"""
|
|
77
111
|
args = args or {}
|
|
78
112
|
job_data = self._gql_client.create_job(
|
|
79
113
|
input=JobInput(
|
|
@@ -88,6 +122,14 @@ class Jobs(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
88
122
|
return JobDataPlus.from_job_data(job_data)
|
|
89
123
|
|
|
90
124
|
def cancel(self, job_id: str) -> JobDataPlus:
|
|
125
|
+
"""Cancel a running job.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
job_id: The ID of the job to cancel.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The updated job data after cancellation.
|
|
132
|
+
"""
|
|
91
133
|
job_data = self._gql_client.cancel_job(job_id=job_id).cancel_job
|
|
92
134
|
return JobDataPlus.from_job_data(job_data)
|
|
93
135
|
|
|
@@ -102,6 +144,14 @@ class AsyncJobs(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
102
144
|
UseCaseResource.__init__(self, client)
|
|
103
145
|
|
|
104
146
|
async def get(self, job_id: str) -> JobDataPlus | None:
|
|
147
|
+
"""Get the details of a specific job.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
job_id: The ID of the job to retrieve.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The job data if found, otherwise None.
|
|
154
|
+
"""
|
|
105
155
|
job_data = (await self._gql_client.describe_job(id=job_id)).job
|
|
106
156
|
return JobDataPlus.from_job_data(job_data) if job_data else None
|
|
107
157
|
|
|
@@ -125,6 +175,19 @@ class AsyncJobs(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
125
175
|
) = None,
|
|
126
176
|
use_case: str | None = None,
|
|
127
177
|
) -> ListJobsJobs:
|
|
178
|
+
"""List jobs with pagination and filtering options.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
first: Number of jobs to return from the beginning.
|
|
182
|
+
last: Number of jobs to return from the end.
|
|
183
|
+
after: Cursor for forward pagination.
|
|
184
|
+
before: Cursor for backward pagination.
|
|
185
|
+
kind: Filter by job types.
|
|
186
|
+
use_case: Filter by use case key.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
A paginated list of jobs.
|
|
190
|
+
"""
|
|
128
191
|
page = CursorPageInput(first=first, last=last, after=after, before=before)
|
|
129
192
|
validated_filter = ListJobsFilterInput(
|
|
130
193
|
useCase=self.use_case_key(use_case),
|
|
@@ -146,6 +209,19 @@ class AsyncJobs(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
146
209
|
use_case: str | None = None,
|
|
147
210
|
compute_pool: str | None = None,
|
|
148
211
|
) -> JobDataPlus:
|
|
212
|
+
"""Run a job using a specified recipe.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
recipe_key: The key of the recipe to run.
|
|
216
|
+
num_gpus: Number of GPUs to allocate for the job.
|
|
217
|
+
args: Optional arguments to pass to the recipe.
|
|
218
|
+
name: Optional human-readable name for the job.
|
|
219
|
+
use_case: Use case key for the job.
|
|
220
|
+
compute_pool: Optional compute pool key to run the job on.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The created job data.
|
|
224
|
+
"""
|
|
149
225
|
args = args or {}
|
|
150
226
|
job_data = (
|
|
151
227
|
await self._gql_client.create_job(
|
|
@@ -162,5 +238,13 @@ class AsyncJobs(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
162
238
|
return JobDataPlus.from_job_data(job_data)
|
|
163
239
|
|
|
164
240
|
async def cancel(self, job_id: str) -> JobDataPlus:
|
|
241
|
+
"""Cancel a running job.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
job_id: The ID of the job to cancel.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The updated job data after cancellation.
|
|
248
|
+
"""
|
|
165
249
|
job_data = (await self._gql_client.cancel_job(job_id=job_id)).cancel_job
|
|
166
250
|
return JobDataPlus.from_job_data(job_data)
|
adaptive_sdk/resources/models.py
CHANGED
|
@@ -224,9 +224,10 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
224
224
|
|
|
225
225
|
Args:
|
|
226
226
|
model: Model key.
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
227
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
True if the model was successfully attached.
|
|
230
231
|
"""
|
|
231
232
|
|
|
232
233
|
input = AddModelToUseCaseInput(
|
|
@@ -243,6 +244,18 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
243
244
|
use_case: str | None = None,
|
|
244
245
|
placement: input_types.ModelPlacementInput | None = None,
|
|
245
246
|
) -> ModelServiceData:
|
|
247
|
+
"""Deploy a model for inference in the specified use case.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
model: Model key.
|
|
251
|
+
wait: If `True`, block until the model is online.
|
|
252
|
+
make_default: Make the model the use case's default after deployment.
|
|
253
|
+
use_case: Use case key.
|
|
254
|
+
placement: Optional placement configuration for the model.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
The model service data after deployment.
|
|
258
|
+
"""
|
|
246
259
|
input = DeployModelInput(
|
|
247
260
|
model=model,
|
|
248
261
|
useCase=self.use_case_key(use_case),
|
|
@@ -424,6 +437,10 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
424
437
|
|
|
425
438
|
Args:
|
|
426
439
|
model: Model key.
|
|
440
|
+
use_case: Use case key. Falls back to client's default if not provided.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
The updated model service data.
|
|
427
444
|
"""
|
|
428
445
|
return await self.update(model=model, use_case=use_case)
|
|
429
446
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from typing import TYPE_CHECKING, List
|
|
3
4
|
|
|
4
5
|
from adaptive_sdk.graphql_client import ListPermissions
|
|
5
6
|
|
|
6
|
-
from .base_resource import
|
|
7
|
+
from .base_resource import AsyncAPIResource, SyncAPIResource, UseCaseResource
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
9
10
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
@@ -19,6 +20,11 @@ class Permissions(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
19
20
|
UseCaseResource.__init__(self, client)
|
|
20
21
|
|
|
21
22
|
def list(self) -> List[str]:
|
|
23
|
+
"""List all available permissions in the system.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A list of permission identifiers.
|
|
27
|
+
"""
|
|
22
28
|
return self._gql_client.list_permissions().permissions
|
|
23
29
|
|
|
24
30
|
|
|
@@ -32,4 +38,9 @@ class AsyncPermissions(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
32
38
|
UseCaseResource.__init__(self, client)
|
|
33
39
|
|
|
34
40
|
async def list(self) -> List[str]:
|
|
41
|
+
"""List all available permissions in the system.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
A list of permission identifiers.
|
|
45
|
+
"""
|
|
35
46
|
return (await self._gql_client.list_permissions()).permissions
|