adaptive-sdk 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_sdk/client.py +2 -0
- adaptive_sdk/graphql_client/__init__.py +6 -4
- adaptive_sdk/graphql_client/add_model_to_use_case.py +6 -0
- adaptive_sdk/graphql_client/async_client.py +35 -19
- adaptive_sdk/graphql_client/client.py +35 -19
- adaptive_sdk/graphql_client/custom_fields.py +20 -4
- adaptive_sdk/graphql_client/custom_mutations.py +29 -8
- adaptive_sdk/graphql_client/deploy_model.py +12 -0
- adaptive_sdk/graphql_client/enums.py +3 -20
- adaptive_sdk/graphql_client/fragments.py +4 -4
- adaptive_sdk/graphql_client/input_types.py +157 -18
- adaptive_sdk/graphql_client/remove_model_from_use_case.py +6 -0
- adaptive_sdk/input_types/typed_dicts.py +14 -15
- adaptive_sdk/resources/__init__.py +3 -0
- adaptive_sdk/resources/artifacts.py +61 -0
- adaptive_sdk/resources/chat.py +11 -9
- adaptive_sdk/resources/interactions.py +57 -25
- adaptive_sdk/resources/models.py +86 -95
- adaptive_sdk/rest/rest_types.py +2 -1
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/METADATA +4 -1
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/RECORD +22 -19
- adaptive_sdk/graphql_client/attach_model_to_use_case.py +0 -12
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/WHEEL +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, List, Optional
|
|
2
2
|
from pydantic import Field
|
|
3
3
|
from .base_model import BaseModel
|
|
4
|
-
from .enums import AbcampaignStatus, CompletionGroupBy, CompletionSource, DatasetKind, DatasetSource, DateBucketUnit, ExternalModelProviderName, FeedbackType, GraderTypeEnum, JobArtifactKind, JobKind, JobStatus, MetricAggregation, MetricKind, MetricScoringType, ModelCapabilityFilter, ModelOnline,
|
|
4
|
+
from .enums import AbcampaignStatus, CompletionGroupBy, CompletionSource, DatasetKind, DatasetSource, DateBucketUnit, ExternalModelProviderName, FeedbackType, GraderTypeEnum, JobArtifactKind, JobKind, JobStatus, MetricAggregation, MetricKind, MetricScoringType, ModelCapabilityFilter, ModelOnline, PrebuiltCriteriaKey, Protocol, SelectionTypeInput, SortDirection, TimeseriesInterval, UnitPosition
|
|
5
5
|
|
|
6
6
|
class AbCampaignFilter(BaseModel):
|
|
7
7
|
"""@private"""
|
|
@@ -42,10 +42,16 @@ class AddModelInput(BaseModel):
|
|
|
42
42
|
name: str
|
|
43
43
|
key: Optional[str] = None
|
|
44
44
|
|
|
45
|
+
class AddModelToUseCaseInput(BaseModel):
|
|
46
|
+
"""@private"""
|
|
47
|
+
use_case: str = Field(alias='useCase')
|
|
48
|
+
model: str
|
|
49
|
+
|
|
45
50
|
class AnthropicProviderDataInput(BaseModel):
|
|
46
51
|
"""@private"""
|
|
47
52
|
api_key: str = Field(alias='apiKey')
|
|
48
53
|
external_model_id: str = Field(alias='externalModelId')
|
|
54
|
+
endpoint: Optional[str] = None
|
|
49
55
|
|
|
50
56
|
class ApiKeyCreate(BaseModel):
|
|
51
57
|
"""@private"""
|
|
@@ -58,15 +64,6 @@ class ArtifactFilter(BaseModel):
|
|
|
58
64
|
kinds: Optional[List[JobArtifactKind]] = None
|
|
59
65
|
job_id: Optional[Any] = Field(alias='jobId', default=None)
|
|
60
66
|
|
|
61
|
-
class AttachModel(BaseModel):
|
|
62
|
-
"""@private"""
|
|
63
|
-
use_case: str = Field(alias='useCase')
|
|
64
|
-
model: str
|
|
65
|
-
attached: bool = True
|
|
66
|
-
placement: Optional['ModelPlacementInput'] = None
|
|
67
|
-
wait: bool = False
|
|
68
|
-
'Wait for the model to be deployed or not'
|
|
69
|
-
|
|
70
67
|
class CancelAllocationInput(BaseModel):
|
|
71
68
|
"""@private"""
|
|
72
69
|
harmony_group: str = Field(alias='harmonyGroup')
|
|
@@ -93,6 +90,38 @@ class CompletionFeedbackFilterInput(BaseModel):
|
|
|
93
90
|
reasons: Optional[List[str]] = None
|
|
94
91
|
user: Optional[Any] = None
|
|
95
92
|
|
|
93
|
+
class CompletionFilterExpression(BaseModel):
|
|
94
|
+
"""@private
|
|
95
|
+
Advanced filter expression supporting AND/OR/NOT logic"""
|
|
96
|
+
and_: Optional[List['CompletionFilterExpression']] = Field(alias='and', default=None)
|
|
97
|
+
'Combine multiple conditions with AND (all must match)'
|
|
98
|
+
or_: Optional[List['CompletionFilterExpression']] = Field(alias='or', default=None)
|
|
99
|
+
'Combine multiple conditions with OR (at least one must match)'
|
|
100
|
+
not_: Optional['CompletionFilterExpression'] = Field(alias='not', default=None)
|
|
101
|
+
'Negate a condition'
|
|
102
|
+
timerange: Optional['TimeRange'] = None
|
|
103
|
+
'Filter by time'
|
|
104
|
+
model: Optional['IdOrKeyCondition'] = None
|
|
105
|
+
'Filter by model'
|
|
106
|
+
label: Optional['LabelCondition'] = None
|
|
107
|
+
'Filter by label key-value pairs'
|
|
108
|
+
feedbacks: Optional['FeedbackCondition'] = None
|
|
109
|
+
'Filter by feedback/metric values'
|
|
110
|
+
source: Optional[CompletionSource] = None
|
|
111
|
+
'Filter by completion source'
|
|
112
|
+
prompt_hash: Optional['StringCondition'] = Field(alias='promptHash', default=None)
|
|
113
|
+
'Filter by prompt hash'
|
|
114
|
+
session_id: Optional[Any] = Field(alias='sessionId', default=None)
|
|
115
|
+
'Filter by session ID'
|
|
116
|
+
user_id: Optional[Any] = Field(alias='userId', default=None)
|
|
117
|
+
'Filter by user ID'
|
|
118
|
+
completion_id: Optional[Any] = Field(alias='completionId', default=None)
|
|
119
|
+
'Filter by completion ID'
|
|
120
|
+
completion: Optional['TextCondition'] = None
|
|
121
|
+
'Filter by completion content'
|
|
122
|
+
prompt: Optional['TextCondition'] = None
|
|
123
|
+
'Filter by prompt content'
|
|
124
|
+
|
|
96
125
|
class CompletionLabelValue(BaseModel):
|
|
97
126
|
"""@private"""
|
|
98
127
|
key: str
|
|
@@ -174,6 +203,18 @@ class DatasetUploadProcessingStatusInput(BaseModel):
|
|
|
174
203
|
use_case: str = Field(alias='useCase')
|
|
175
204
|
dataset_id: Any = Field(alias='datasetId')
|
|
176
205
|
|
|
206
|
+
class DeleteModelInput(BaseModel):
|
|
207
|
+
"""@private"""
|
|
208
|
+
model: str
|
|
209
|
+
|
|
210
|
+
class DeployModelInput(BaseModel):
|
|
211
|
+
"""@private"""
|
|
212
|
+
use_case: str = Field(alias='useCase')
|
|
213
|
+
model: str
|
|
214
|
+
placement: Optional['ModelPlacementInput'] = None
|
|
215
|
+
wait: bool = False
|
|
216
|
+
'Wait for the model to be deployed or not'
|
|
217
|
+
|
|
177
218
|
class EmojiInput(BaseModel):
|
|
178
219
|
"""@private"""
|
|
179
220
|
native: str
|
|
@@ -185,6 +226,18 @@ class FeedbackAddInput(BaseModel):
|
|
|
185
226
|
reason: Optional[str] = None
|
|
186
227
|
user_id: Optional[Any] = Field(alias='userId', default=None)
|
|
187
228
|
|
|
229
|
+
class FeedbackCondition(BaseModel):
|
|
230
|
+
"""@private
|
|
231
|
+
Feedback/metric filter condition with numeric comparisons"""
|
|
232
|
+
metric: str
|
|
233
|
+
'Metric to filter by'
|
|
234
|
+
value: Optional['FloatNumericCondition'] = None
|
|
235
|
+
'Numeric value condition'
|
|
236
|
+
reasons: Optional[List[str]] = None
|
|
237
|
+
'Filter by feedback reasons'
|
|
238
|
+
user: Optional[Any] = None
|
|
239
|
+
'Filter by user who gave the feedback'
|
|
240
|
+
|
|
188
241
|
class FeedbackFilterInput(BaseModel):
|
|
189
242
|
"""@private"""
|
|
190
243
|
labels: Optional[List['LabelFilter']] = None
|
|
@@ -194,6 +247,22 @@ class FeedbackUpdateInput(BaseModel):
|
|
|
194
247
|
value: Optional[Any] = None
|
|
195
248
|
details: Optional[str] = None
|
|
196
249
|
|
|
250
|
+
class FloatNumericCondition(BaseModel):
|
|
251
|
+
"""@private
|
|
252
|
+
Numeric matching condition for filter expressions, parameterized by the numeric type"""
|
|
253
|
+
eq: Optional[float] = None
|
|
254
|
+
'Equal to value'
|
|
255
|
+
neq: Optional[float] = None
|
|
256
|
+
'Not Equal to value'
|
|
257
|
+
gt: Optional[float] = None
|
|
258
|
+
'Greater than value'
|
|
259
|
+
gte: Optional[float] = None
|
|
260
|
+
'Greater than or equal to value'
|
|
261
|
+
lt: Optional[float] = None
|
|
262
|
+
'Less than value'
|
|
263
|
+
lte: Optional[float] = None
|
|
264
|
+
'Less than or equal to value'
|
|
265
|
+
|
|
197
266
|
class FromGroupsQuery(BaseModel):
|
|
198
267
|
"""@private"""
|
|
199
268
|
filters: 'ListCompletionsFilterInput'
|
|
@@ -211,6 +280,7 @@ class GoogleProviderDataInput(BaseModel):
|
|
|
211
280
|
"""@private"""
|
|
212
281
|
api_key: str = Field(alias='apiKey')
|
|
213
282
|
external_model_id: str = Field(alias='externalModelId')
|
|
283
|
+
endpoint: Optional[str] = None
|
|
214
284
|
|
|
215
285
|
class GraderConfigInput(BaseModel):
|
|
216
286
|
"""@private"""
|
|
@@ -243,6 +313,32 @@ class GroupSelectionQuery(BaseModel):
|
|
|
243
313
|
group_id: str = Field(alias='groupId')
|
|
244
314
|
selection: 'GroupSelection'
|
|
245
315
|
|
|
316
|
+
class IdOrKeyCondition(BaseModel):
|
|
317
|
+
"""@private
|
|
318
|
+
String matching condition for filter expressions"""
|
|
319
|
+
eq: Optional[str] = None
|
|
320
|
+
'Exact match'
|
|
321
|
+
in_: Optional[List[str]] = Field(alias='in', default=None)
|
|
322
|
+
'Match any of the provided values (OR)'
|
|
323
|
+
neq: Optional[str] = None
|
|
324
|
+
'Does not equal'
|
|
325
|
+
|
|
326
|
+
class IntegerNumericCondition(BaseModel):
|
|
327
|
+
"""@private
|
|
328
|
+
Numeric matching condition for filter expressions, parameterized by the numeric type"""
|
|
329
|
+
eq: Optional[int] = None
|
|
330
|
+
'Equal to value'
|
|
331
|
+
neq: Optional[int] = None
|
|
332
|
+
'Not Equal to value'
|
|
333
|
+
gt: Optional[int] = None
|
|
334
|
+
'Greater than value'
|
|
335
|
+
gte: Optional[int] = None
|
|
336
|
+
'Greater than or equal to value'
|
|
337
|
+
lt: Optional[int] = None
|
|
338
|
+
'Less than value'
|
|
339
|
+
lte: Optional[int] = None
|
|
340
|
+
'Less than or equal to value'
|
|
341
|
+
|
|
246
342
|
class JobArtifactFilter(BaseModel):
|
|
247
343
|
"""@private"""
|
|
248
344
|
kinds: List[JobArtifactKind]
|
|
@@ -293,6 +389,14 @@ class JudgeUpdate(BaseModel):
|
|
|
293
389
|
examples: Optional[List['JudgeExampleInput']] = None
|
|
294
390
|
model: Optional[str] = None
|
|
295
391
|
|
|
392
|
+
class LabelCondition(BaseModel):
|
|
393
|
+
"""@private
|
|
394
|
+
Label-specific filter condition"""
|
|
395
|
+
key: str
|
|
396
|
+
'Label key'
|
|
397
|
+
value: Optional['StringCondition'] = None
|
|
398
|
+
'Label value condition (optional - if not set, just checks for key existence)'
|
|
399
|
+
|
|
296
400
|
class LabelFilter(BaseModel):
|
|
297
401
|
"""@private"""
|
|
298
402
|
key: str
|
|
@@ -316,6 +420,10 @@ class ListCompletionsFilterInput(BaseModel):
|
|
|
316
420
|
prompt_hash: Optional[str] = Field(alias='promptHash', default=None)
|
|
317
421
|
completion_id: Optional[Any] = Field(alias='completionId', default=None)
|
|
318
422
|
source: Optional[List[CompletionSource]] = None
|
|
423
|
+
completion: Optional[str] = None
|
|
424
|
+
prompt: Optional[str] = None
|
|
425
|
+
advanced_filter: Optional['CompletionFilterExpression'] = Field(alias='advancedFilter', default=None)
|
|
426
|
+
'Advanced filter supporting AND/OR/NOT logic\nWhen set, this takes precedence over the simple filter fields above\n(except use_case which is always required)'
|
|
319
427
|
|
|
320
428
|
class ListJobsFilterInput(BaseModel):
|
|
321
429
|
"""@private"""
|
|
@@ -369,6 +477,8 @@ class ModelFilter(BaseModel):
|
|
|
369
477
|
capabilities: Optional['CapabilityFilter'] = None
|
|
370
478
|
view_all: Optional[bool] = Field(alias='viewAll', default=None)
|
|
371
479
|
online: Optional[List[ModelOnline]] = None
|
|
480
|
+
published: Optional[bool] = None
|
|
481
|
+
size: Optional['IntegerNumericCondition'] = None
|
|
372
482
|
|
|
373
483
|
class ModelPlacementInput(BaseModel):
|
|
374
484
|
"""@private"""
|
|
@@ -378,23 +488,22 @@ class ModelPlacementInput(BaseModel):
|
|
|
378
488
|
class ModelProviderDataInput(BaseModel):
|
|
379
489
|
"""@private"""
|
|
380
490
|
open_ai: Optional['OpenAIProviderDataInput'] = Field(alias='openAI', default=None)
|
|
491
|
+
legacy_open_ai: Optional['OpenAIProviderDataInput'] = Field(alias='legacyOpenAI', default=None)
|
|
381
492
|
google: Optional['GoogleProviderDataInput'] = None
|
|
382
493
|
anthropic: Optional['AnthropicProviderDataInput'] = None
|
|
383
494
|
|
|
384
|
-
class ModelServiceDisconnect(BaseModel):
|
|
385
|
-
"""@private"""
|
|
386
|
-
use_case: str = Field(alias='useCase')
|
|
387
|
-
model_service: str = Field(alias='modelService')
|
|
388
|
-
|
|
389
495
|
class ModelServiceFilter(BaseModel):
|
|
390
496
|
"""@private"""
|
|
391
497
|
model: Optional[str] = None
|
|
392
498
|
capabilities: Optional['CapabilityFilter'] = None
|
|
499
|
+
active_only: bool = Field(alias='activeOnly', default=True)
|
|
500
|
+
'If true (default), only return model services whose model has a binding.\nIf false, return all model services regardless of binding status.'
|
|
393
501
|
|
|
394
502
|
class OpenAIProviderDataInput(BaseModel):
|
|
395
503
|
"""@private"""
|
|
396
504
|
api_key: str = Field(alias='apiKey')
|
|
397
|
-
external_model_id:
|
|
505
|
+
external_model_id: str = Field(alias='externalModelId')
|
|
506
|
+
endpoint: Optional[str] = None
|
|
398
507
|
|
|
399
508
|
class OrderPair(BaseModel):
|
|
400
509
|
"""@private"""
|
|
@@ -424,6 +533,11 @@ class RemoteEnvCreate(BaseModel):
|
|
|
424
533
|
name: Optional[str] = None
|
|
425
534
|
description: Optional[str] = None
|
|
426
535
|
|
|
536
|
+
class RemoveModelFromUseCaseInput(BaseModel):
|
|
537
|
+
"""@private"""
|
|
538
|
+
use_case: str = Field(alias='useCase')
|
|
539
|
+
model: str
|
|
540
|
+
|
|
427
541
|
class ResizePartitionInput(BaseModel):
|
|
428
542
|
"""@private"""
|
|
429
543
|
harmony_group: str = Field(alias='harmonyGroup')
|
|
@@ -444,6 +558,16 @@ class SearchInput(BaseModel):
|
|
|
444
558
|
"""@private"""
|
|
445
559
|
query: str
|
|
446
560
|
|
|
561
|
+
class StringCondition(BaseModel):
|
|
562
|
+
"""@private
|
|
563
|
+
String matching condition for filter expressions"""
|
|
564
|
+
eq: Optional[str] = None
|
|
565
|
+
'Exact match'
|
|
566
|
+
in_: Optional[List[str]] = Field(alias='in', default=None)
|
|
567
|
+
'Match any of the provided values (OR)'
|
|
568
|
+
neq: Optional[str] = None
|
|
569
|
+
'Does not equal'
|
|
570
|
+
|
|
447
571
|
class SystemPromptTemplateCreate(BaseModel):
|
|
448
572
|
"""@private"""
|
|
449
573
|
name: str
|
|
@@ -472,6 +596,13 @@ class TeamMemberSet(BaseModel):
|
|
|
472
596
|
team: str
|
|
473
597
|
role: str
|
|
474
598
|
|
|
599
|
+
class TextCondition(BaseModel):
|
|
600
|
+
"""@private"""
|
|
601
|
+
eq: Optional[str] = None
|
|
602
|
+
'Exact match'
|
|
603
|
+
contains: Optional[str] = None
|
|
604
|
+
'Text contains this substring (case insensitive)'
|
|
605
|
+
|
|
475
606
|
class TimeRange(BaseModel):
|
|
476
607
|
"""@private"""
|
|
477
608
|
from_: int | str = Field(alias='from')
|
|
@@ -502,12 +633,17 @@ class UpdateCompletion(BaseModel):
|
|
|
502
633
|
metadata: Optional[Any] = None
|
|
503
634
|
'set metadata associated with this prompt for use with external reward servers'
|
|
504
635
|
|
|
636
|
+
class UpdateModelInput(BaseModel):
|
|
637
|
+
"""@private"""
|
|
638
|
+
model: str
|
|
639
|
+
published: Optional[bool] = None
|
|
640
|
+
stable: Optional[bool] = None
|
|
641
|
+
|
|
505
642
|
class UpdateModelService(BaseModel):
|
|
506
643
|
"""@private"""
|
|
507
644
|
use_case: str = Field(alias='useCase')
|
|
508
645
|
model_service: str = Field(alias='modelService')
|
|
509
646
|
is_default: Optional[bool] = Field(alias='isDefault', default=None)
|
|
510
|
-
attached: Optional[bool] = None
|
|
511
647
|
desired_online: Optional[bool] = Field(alias='desiredOnline', default=None)
|
|
512
648
|
name: Optional[str] = None
|
|
513
649
|
system_prompt_template: Optional[Any] = Field(alias='systemPromptTemplate', default=None)
|
|
@@ -597,12 +733,14 @@ class WidgetInput(BaseModel):
|
|
|
597
733
|
aggregation: MetricAggregation
|
|
598
734
|
unit: 'UnitConfigInput'
|
|
599
735
|
AddExternalModelInput.model_rebuild()
|
|
600
|
-
|
|
736
|
+
CompletionFilterExpression.model_rebuild()
|
|
601
737
|
CompletionsByFilters.model_rebuild()
|
|
602
738
|
CreateRecipeInput.model_rebuild()
|
|
603
739
|
CustomRecipeFilterInput.model_rebuild()
|
|
604
740
|
DatasetCompletionQuery.model_rebuild()
|
|
605
741
|
DatasetCreateFromFilters.model_rebuild()
|
|
742
|
+
DeployModelInput.model_rebuild()
|
|
743
|
+
FeedbackCondition.model_rebuild()
|
|
606
744
|
FeedbackFilterInput.model_rebuild()
|
|
607
745
|
FromGroupsQuery.model_rebuild()
|
|
608
746
|
GlobalUsageFilterInput.model_rebuild()
|
|
@@ -614,6 +752,7 @@ JudgeConfigInput.model_rebuild()
|
|
|
614
752
|
JudgeCreate.model_rebuild()
|
|
615
753
|
JudgeExampleInput.model_rebuild()
|
|
616
754
|
JudgeUpdate.model_rebuild()
|
|
755
|
+
LabelCondition.model_rebuild()
|
|
617
756
|
ListCompletionsFilterInput.model_rebuild()
|
|
618
757
|
ListJobsFilterInput.model_rebuild()
|
|
619
758
|
MetricGetOrCreate.model_rebuild()
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, TypeAlias, TypedDict
|
|
4
|
+
|
|
5
|
+
from typing_extensions import NotRequired, Required
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class ChatMessage(TypedDict, total=True):
|
|
@@ -24,28 +26,26 @@ class CompletionComparisonFilterInput(TypedDict, total=True):
|
|
|
24
26
|
metric: Required[str]
|
|
25
27
|
|
|
26
28
|
|
|
29
|
+
class NumericCondition(TypedDict, total=False):
|
|
30
|
+
eq: NotRequired[float]
|
|
31
|
+
neq: NotRequired[float]
|
|
32
|
+
gt: NotRequired[float]
|
|
33
|
+
gte: NotRequired[float]
|
|
34
|
+
lt: NotRequired[float]
|
|
35
|
+
lte: NotRequired[float]
|
|
36
|
+
|
|
37
|
+
|
|
27
38
|
class CompletionFeedbackFilterInput(TypedDict, total=False):
|
|
28
39
|
"""
|
|
29
40
|
Filter for completion metric feedbacks.
|
|
30
41
|
|
|
31
42
|
Args:
|
|
32
43
|
metric: Feedback key logged against.
|
|
33
|
-
gt: >
|
|
34
|
-
gte: >=
|
|
35
|
-
eq: ==
|
|
36
|
-
neq: !=
|
|
37
|
-
lt: <
|
|
38
|
-
lte: <=
|
|
39
44
|
user: Feedbacks logged by `user` id.
|
|
40
45
|
"""
|
|
41
46
|
|
|
42
47
|
metric: Required[str]
|
|
43
|
-
|
|
44
|
-
gte: NotRequired[float]
|
|
45
|
-
eq: NotRequired[float]
|
|
46
|
-
neq: NotRequired[float]
|
|
47
|
-
lt: NotRequired[float]
|
|
48
|
-
lte: NotRequired[float]
|
|
48
|
+
value: NotRequired[NumericCondition]
|
|
49
49
|
reasons: NotRequired[List[str]]
|
|
50
50
|
user: NotRequired[Any]
|
|
51
51
|
|
|
@@ -135,7 +135,6 @@ class ListCompletionsFilterInput(TypedDict, total=False):
|
|
|
135
135
|
labels: NotRequired[List["CompletionLabelFilter"]]
|
|
136
136
|
prompt_hash: NotRequired[str]
|
|
137
137
|
completion_id: NotRequired[Any]
|
|
138
|
-
tags: NotRequired[List[str]]
|
|
139
138
|
source: NotRequired[List[CompletionSource]]
|
|
140
139
|
|
|
141
140
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from .abtests import ABTests, AsyncABTests
|
|
2
|
+
from .artifacts import Artifacts, AsyncArtifacts
|
|
2
3
|
from .chat import Chat, AsyncChat
|
|
3
4
|
from .compute_pools import ComputePools, AsyncComputePools # type: ignore[attr-defined]
|
|
4
5
|
from .recipes import Recipes, AsyncRecipes
|
|
@@ -17,6 +18,7 @@ from .users import Users, AsyncUsers
|
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"ABTests",
|
|
21
|
+
"Artifacts",
|
|
20
22
|
"Chat",
|
|
21
23
|
"ComputePools",
|
|
22
24
|
"Recipes",
|
|
@@ -33,6 +35,7 @@ __all__ = [
|
|
|
33
35
|
"UseCase",
|
|
34
36
|
"Users",
|
|
35
37
|
"AsyncABTests",
|
|
38
|
+
"AsyncArtifacts",
|
|
36
39
|
"AsyncChat",
|
|
37
40
|
"AsyncComputePools",
|
|
38
41
|
"AsyncRecipes",
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from .base_resource import SyncAPIResource, AsyncAPIResource
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Artifacts(SyncAPIResource): # type: ignore[misc]
|
|
11
|
+
"""
|
|
12
|
+
Resource to interact with job artifacts.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, client: Adaptive) -> None:
|
|
16
|
+
SyncAPIResource.__init__(self, client)
|
|
17
|
+
|
|
18
|
+
def download(self, artifact_id: str, destination_path: str) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Download an artifact file to a local path.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
artifact_id: The UUID of the artifact to download.
|
|
24
|
+
destination_path: Local file path where the artifact will be saved.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
HTTPError: If the download fails or the artifact is not found.
|
|
28
|
+
"""
|
|
29
|
+
download_url = f"/artifacts/{artifact_id}/download"
|
|
30
|
+
response = self._rest_client.get(download_url)
|
|
31
|
+
response.raise_for_status()
|
|
32
|
+
|
|
33
|
+
with open(destination_path, "wb") as f:
|
|
34
|
+
f.write(response.content)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AsyncArtifacts(AsyncAPIResource): # type: ignore[misc]
|
|
38
|
+
"""
|
|
39
|
+
Async resource to interact with job artifacts.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, client: AsyncAdaptive) -> None:
|
|
43
|
+
AsyncAPIResource.__init__(self, client)
|
|
44
|
+
|
|
45
|
+
async def download(self, artifact_id: str, destination_path: str) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Download an artifact file to a local path.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
artifact_id: The UUID of the artifact to download.
|
|
51
|
+
destination_path: Local file path where the artifact will be saved.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
HTTPError: If the download fails or the artifact is not found.
|
|
55
|
+
"""
|
|
56
|
+
download_url = f"/artifacts/{artifact_id}/download"
|
|
57
|
+
response = await self._rest_client.get(download_url)
|
|
58
|
+
response.raise_for_status()
|
|
59
|
+
|
|
60
|
+
with open(destination_path, "wb") as f:
|
|
61
|
+
f.write(response.content)
|
adaptive_sdk/resources/chat.py
CHANGED
|
@@ -48,6 +48,7 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
48
48
|
ab_campaign: str | None = None,
|
|
49
49
|
n: int | None = None,
|
|
50
50
|
labels: Dict[str, str] | None = None,
|
|
51
|
+
store: bool | None = None,
|
|
51
52
|
) -> rest_types.ChatResponse: ...
|
|
52
53
|
|
|
53
54
|
@overload
|
|
@@ -67,6 +68,7 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
67
68
|
ab_campaign: str | None = None,
|
|
68
69
|
n: int | None = None,
|
|
69
70
|
labels: Dict[str, str] | None = None,
|
|
71
|
+
store: bool | None = None,
|
|
70
72
|
) -> Generator[rest_types.ChatResponseChunk, None, None]: ...
|
|
71
73
|
|
|
72
74
|
def create(
|
|
@@ -85,6 +87,7 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
85
87
|
ab_campaign: str | None = None,
|
|
86
88
|
n: int | None = None,
|
|
87
89
|
labels: Dict[str, str] | None = None,
|
|
90
|
+
store: bool | None = None,
|
|
88
91
|
) -> rest_types.ChatResponse | Generator[rest_types.ChatResponseChunk, None, None]:
|
|
89
92
|
"""
|
|
90
93
|
Create a chat completion.
|
|
@@ -135,6 +138,7 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
135
138
|
ab_campaign=ab_campaign,
|
|
136
139
|
n=n,
|
|
137
140
|
labels=labels,
|
|
141
|
+
store=store,
|
|
138
142
|
)
|
|
139
143
|
if input.stream:
|
|
140
144
|
return self._stream(input)
|
|
@@ -142,9 +146,7 @@ class Chat(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
142
146
|
rest_error_handler(r)
|
|
143
147
|
return rest_types.ChatResponse.model_validate(r.json())
|
|
144
148
|
|
|
145
|
-
def _stream(
|
|
146
|
-
self, input: rest_types.ChatInput
|
|
147
|
-
) -> Generator[rest_types.ChatResponseChunk, None, None]:
|
|
149
|
+
def _stream(self, input: rest_types.ChatInput) -> Generator[rest_types.ChatResponseChunk, None, None]:
|
|
148
150
|
import httpx_sse
|
|
149
151
|
|
|
150
152
|
with httpx_sse.connect_sse(
|
|
@@ -183,6 +185,7 @@ class AsyncChat(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
183
185
|
ab_campaign: str | None = None,
|
|
184
186
|
n: int | None = None,
|
|
185
187
|
labels: Dict[str, str] | None = None,
|
|
188
|
+
store: bool | None = None,
|
|
186
189
|
) -> rest_types.ChatResponse: ...
|
|
187
190
|
|
|
188
191
|
@overload # type: ignore[no-redef, misc]
|
|
@@ -202,6 +205,7 @@ class AsyncChat(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
202
205
|
ab_campaign: str | None = None,
|
|
203
206
|
n: int | None = None,
|
|
204
207
|
labels: Dict[str, str] | None = None,
|
|
208
|
+
store: bool | None = None,
|
|
205
209
|
) -> AsyncGenerator[rest_types.ChatResponseChunk, None]: ...
|
|
206
210
|
|
|
207
211
|
async def create( # type: ignore
|
|
@@ -220,6 +224,7 @@ class AsyncChat(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
220
224
|
ab_campaign: str | None = None,
|
|
221
225
|
n: int | None = None,
|
|
222
226
|
labels: Dict[str, str] | None = None,
|
|
227
|
+
store: bool | None = None,
|
|
223
228
|
) -> rest_types.ChatResponse | AsyncGenerator[rest_types.ChatResponseChunk, None]:
|
|
224
229
|
"""
|
|
225
230
|
Create a chat completion.
|
|
@@ -270,18 +275,15 @@ class AsyncChat(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
270
275
|
ab_campaign=ab_campaign,
|
|
271
276
|
n=n,
|
|
272
277
|
labels=labels,
|
|
278
|
+
store=store,
|
|
273
279
|
)
|
|
274
280
|
if input.stream:
|
|
275
281
|
return self._stream(input)
|
|
276
|
-
r = await self._rest_client.post(
|
|
277
|
-
ROUTE, json=input.model_dump(exclude_none=True)
|
|
278
|
-
)
|
|
282
|
+
r = await self._rest_client.post(ROUTE, json=input.model_dump(exclude_none=True))
|
|
279
283
|
rest_error_handler(r)
|
|
280
284
|
return rest_types.ChatResponse.model_validate(r.json())
|
|
281
285
|
|
|
282
|
-
async def _stream(
|
|
283
|
-
self, input: rest_types.ChatInput
|
|
284
|
-
) -> AsyncGenerator[rest_types.ChatResponseChunk, None]:
|
|
286
|
+
async def _stream(self, input: rest_types.ChatInput) -> AsyncGenerator[rest_types.ChatResponseChunk, None]:
|
|
285
287
|
import httpx_sse
|
|
286
288
|
|
|
287
289
|
async with httpx_sse.aconnect_sse(
|