seekrai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- seekrai/__init__.py +64 -0
- seekrai/abstract/__init__.py +1 -0
- seekrai/abstract/api_requestor.py +710 -0
- seekrai/cli/__init__.py +0 -0
- seekrai/cli/api/__init__.py +0 -0
- seekrai/cli/api/chat.py +245 -0
- seekrai/cli/api/completions.py +107 -0
- seekrai/cli/api/files.py +125 -0
- seekrai/cli/api/finetune.py +175 -0
- seekrai/cli/api/images.py +82 -0
- seekrai/cli/api/models.py +42 -0
- seekrai/cli/cli.py +77 -0
- seekrai/client.py +154 -0
- seekrai/constants.py +32 -0
- seekrai/error.py +188 -0
- seekrai/filemanager.py +393 -0
- seekrai/legacy/__init__.py +0 -0
- seekrai/legacy/base.py +27 -0
- seekrai/legacy/complete.py +91 -0
- seekrai/legacy/embeddings.py +25 -0
- seekrai/legacy/files.py +140 -0
- seekrai/legacy/finetune.py +173 -0
- seekrai/legacy/images.py +25 -0
- seekrai/legacy/models.py +44 -0
- seekrai/resources/__init__.py +25 -0
- seekrai/resources/chat/__init__.py +24 -0
- seekrai/resources/chat/completions.py +241 -0
- seekrai/resources/completions.py +205 -0
- seekrai/resources/embeddings.py +100 -0
- seekrai/resources/files.py +173 -0
- seekrai/resources/finetune.py +425 -0
- seekrai/resources/images.py +156 -0
- seekrai/resources/models.py +75 -0
- seekrai/seekrflow_response.py +50 -0
- seekrai/types/__init__.py +67 -0
- seekrai/types/abstract.py +26 -0
- seekrai/types/chat_completions.py +151 -0
- seekrai/types/common.py +64 -0
- seekrai/types/completions.py +86 -0
- seekrai/types/embeddings.py +35 -0
- seekrai/types/error.py +16 -0
- seekrai/types/files.py +88 -0
- seekrai/types/finetune.py +218 -0
- seekrai/types/images.py +42 -0
- seekrai/types/models.py +43 -0
- seekrai/utils/__init__.py +28 -0
- seekrai/utils/_log.py +61 -0
- seekrai/utils/api_helpers.py +84 -0
- seekrai/utils/files.py +204 -0
- seekrai/utils/tools.py +75 -0
- seekrai/version.py +6 -0
- seekrai-0.0.1.dist-info/LICENSE +201 -0
- seekrai-0.0.1.dist-info/METADATA +401 -0
- seekrai-0.0.1.dist-info/RECORD +56 -0
- seekrai-0.0.1.dist-info/WHEEL +4 -0
- seekrai-0.0.1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import seekrai
|
|
7
|
+
from seekrai.legacy.base import API_KEY_WARNING, deprecated
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Finetune:
|
|
11
|
+
@classmethod
|
|
12
|
+
@deprecated # type: ignore
|
|
13
|
+
def create(
|
|
14
|
+
cls,
|
|
15
|
+
training_file: str, # training file_id
|
|
16
|
+
model: str,
|
|
17
|
+
n_epochs: int = 1,
|
|
18
|
+
n_checkpoints: int | None = 1,
|
|
19
|
+
batch_size: int | None = 32,
|
|
20
|
+
learning_rate: float = 0.00001,
|
|
21
|
+
suffix: (
|
|
22
|
+
str | None
|
|
23
|
+
) = None, # resulting finetuned model name will include the suffix
|
|
24
|
+
estimate_price: bool = False,
|
|
25
|
+
wandb_api_key: str | None = None,
|
|
26
|
+
confirm_inputs: bool = False,
|
|
27
|
+
):
|
|
28
|
+
api_key = None
|
|
29
|
+
if seekrai.api_key:
|
|
30
|
+
warnings.warn(API_KEY_WARNING)
|
|
31
|
+
api_key = seekrai.api_key
|
|
32
|
+
|
|
33
|
+
if estimate_price:
|
|
34
|
+
raise ValueError("Price estimation is not supported in version >= 1.0.1")
|
|
35
|
+
|
|
36
|
+
if confirm_inputs:
|
|
37
|
+
raise ValueError("Input confirmation is not supported in version >= 1.0.1")
|
|
38
|
+
|
|
39
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
40
|
+
|
|
41
|
+
return client.fine_tuning.create(
|
|
42
|
+
training_file=training_file,
|
|
43
|
+
model=model,
|
|
44
|
+
n_epochs=n_epochs,
|
|
45
|
+
n_checkpoints=n_checkpoints,
|
|
46
|
+
batch_size=batch_size,
|
|
47
|
+
learning_rate=learning_rate,
|
|
48
|
+
suffix=suffix,
|
|
49
|
+
wandb_api_key=wandb_api_key,
|
|
50
|
+
).model_dump()
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
@deprecated # type: ignore
|
|
54
|
+
def list(
|
|
55
|
+
cls,
|
|
56
|
+
) -> Dict[str, Any]:
|
|
57
|
+
"""Legacy finetuning list function."""
|
|
58
|
+
|
|
59
|
+
api_key = None
|
|
60
|
+
if seekrai.api_key:
|
|
61
|
+
warnings.warn(API_KEY_WARNING)
|
|
62
|
+
api_key = seekrai.api_key
|
|
63
|
+
|
|
64
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
65
|
+
|
|
66
|
+
return client.fine_tuning.list().model_dump()
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
@deprecated # type: ignore
|
|
70
|
+
def retrieve(
|
|
71
|
+
cls,
|
|
72
|
+
fine_tune_id: str,
|
|
73
|
+
) -> Dict[str, Any]:
|
|
74
|
+
"""Legacy finetuning retrieve function."""
|
|
75
|
+
|
|
76
|
+
api_key = None
|
|
77
|
+
if seekrai.api_key:
|
|
78
|
+
warnings.warn(API_KEY_WARNING)
|
|
79
|
+
api_key = seekrai.api_key
|
|
80
|
+
|
|
81
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
82
|
+
|
|
83
|
+
return client.fine_tuning.retrieve(id=fine_tune_id).model_dump()
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
@deprecated # type: ignore
|
|
87
|
+
def cancel(
|
|
88
|
+
cls,
|
|
89
|
+
fine_tune_id: str,
|
|
90
|
+
) -> Dict[str, Any]:
|
|
91
|
+
"""Legacy finetuning cancel function."""
|
|
92
|
+
|
|
93
|
+
api_key = None
|
|
94
|
+
if seekrai.api_key:
|
|
95
|
+
warnings.warn(API_KEY_WARNING)
|
|
96
|
+
api_key = seekrai.api_key
|
|
97
|
+
|
|
98
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
99
|
+
|
|
100
|
+
return client.fine_tuning.cancel(id=fine_tune_id).model_dump()
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
@deprecated # type: ignore
|
|
104
|
+
def list_events(
|
|
105
|
+
cls,
|
|
106
|
+
fine_tune_id: str,
|
|
107
|
+
) -> Dict[str, Any]:
|
|
108
|
+
"""Legacy finetuning list events function."""
|
|
109
|
+
|
|
110
|
+
api_key = None
|
|
111
|
+
if seekrai.api_key:
|
|
112
|
+
warnings.warn(API_KEY_WARNING)
|
|
113
|
+
api_key = seekrai.api_key
|
|
114
|
+
|
|
115
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
116
|
+
|
|
117
|
+
return client.fine_tuning.list_events(id=fine_tune_id).model_dump()
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
@deprecated # type: ignore
|
|
121
|
+
def get_checkpoints(
|
|
122
|
+
cls,
|
|
123
|
+
fine_tune_id: str,
|
|
124
|
+
) -> List[Any]:
|
|
125
|
+
"""Legacy finetuning get checkpoints function."""
|
|
126
|
+
|
|
127
|
+
finetune_events = list(cls.retrieve(fine_tune_id=fine_tune_id)["events"])
|
|
128
|
+
|
|
129
|
+
saved_events = [i for i in finetune_events if i["type"] in ["CHECKPOINT_SAVE"]]
|
|
130
|
+
|
|
131
|
+
return saved_events
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
@deprecated # type: ignore
|
|
135
|
+
def get_job_status(cls, fine_tune_id: str) -> str:
|
|
136
|
+
"""Legacy finetuning get job status function."""
|
|
137
|
+
return str(cls.retrieve(fine_tune_id=fine_tune_id)["status"])
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
@deprecated # type: ignore
|
|
141
|
+
def is_final_model_available(cls, fine_tune_id: str) -> bool:
|
|
142
|
+
"""Legacy finetuning is final model available function."""
|
|
143
|
+
|
|
144
|
+
finetune_events = list(cls.retrieve(fine_tune_id=fine_tune_id)["events"])
|
|
145
|
+
|
|
146
|
+
for i in finetune_events:
|
|
147
|
+
if i["type"] in ["JOB_COMPLETE", "JOB_ERROR"]:
|
|
148
|
+
if i["checkpoint_path"] != "":
|
|
149
|
+
return False
|
|
150
|
+
else:
|
|
151
|
+
return True
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
@deprecated # type: ignore
|
|
156
|
+
def download(
|
|
157
|
+
cls,
|
|
158
|
+
fine_tune_id: str,
|
|
159
|
+
output: str | None = None,
|
|
160
|
+
step: int = -1,
|
|
161
|
+
) -> Dict[str, Any]:
|
|
162
|
+
"""Legacy finetuning download function."""
|
|
163
|
+
|
|
164
|
+
api_key = None
|
|
165
|
+
if seekrai.api_key:
|
|
166
|
+
warnings.warn(API_KEY_WARNING)
|
|
167
|
+
api_key = seekrai.api_key
|
|
168
|
+
|
|
169
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
170
|
+
|
|
171
|
+
return client.fine_tuning.download(
|
|
172
|
+
id=fine_tune_id, output=output, checkpoint_step=step
|
|
173
|
+
).model_dump()
|
seekrai/legacy/images.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import seekrai
|
|
5
|
+
from seekrai.legacy.base import API_KEY_WARNING, deprecated
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Image:
|
|
9
|
+
@classmethod
|
|
10
|
+
@deprecated # type: ignore
|
|
11
|
+
def create(
|
|
12
|
+
cls,
|
|
13
|
+
prompt: str,
|
|
14
|
+
**kwargs,
|
|
15
|
+
) -> Dict[str, Any]:
|
|
16
|
+
"""Legacy image function."""
|
|
17
|
+
|
|
18
|
+
api_key = None
|
|
19
|
+
if seekrai.api_key:
|
|
20
|
+
warnings.warn(API_KEY_WARNING)
|
|
21
|
+
api_key = seekrai.api_key
|
|
22
|
+
|
|
23
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
24
|
+
|
|
25
|
+
return client.images.generate(prompt=prompt, **kwargs).model_dump()
|
seekrai/legacy/models.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Dict, List
|
|
3
|
+
|
|
4
|
+
import seekrai
|
|
5
|
+
from seekrai.legacy.base import API_KEY_WARNING, deprecated
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Models:
|
|
9
|
+
@classmethod
|
|
10
|
+
@deprecated # type: ignore
|
|
11
|
+
def list(
|
|
12
|
+
cls,
|
|
13
|
+
) -> List[Dict[str, Any]]:
|
|
14
|
+
"""Legacy model list function."""
|
|
15
|
+
|
|
16
|
+
api_key = None
|
|
17
|
+
if seekrai.api_key:
|
|
18
|
+
warnings.warn(API_KEY_WARNING)
|
|
19
|
+
api_key = seekrai.api_key
|
|
20
|
+
|
|
21
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
22
|
+
|
|
23
|
+
return [item.model_dump() for item in client.models.list()]
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
@deprecated # type: ignore
|
|
27
|
+
def info(
|
|
28
|
+
cls,
|
|
29
|
+
model: str,
|
|
30
|
+
) -> Dict[str, Any]:
|
|
31
|
+
"""Legacy model info function."""
|
|
32
|
+
|
|
33
|
+
api_key = None
|
|
34
|
+
if seekrai.api_key:
|
|
35
|
+
warnings.warn(API_KEY_WARNING)
|
|
36
|
+
api_key = seekrai.api_key
|
|
37
|
+
|
|
38
|
+
client = seekrai.SeekrFlow(api_key=api_key)
|
|
39
|
+
|
|
40
|
+
model_list = client.models.list()
|
|
41
|
+
|
|
42
|
+
for item in model_list:
|
|
43
|
+
if item.id == model:
|
|
44
|
+
return item.model_dump()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from seekrai.resources.chat import AsyncChat, Chat
|
|
2
|
+
from seekrai.resources.completions import AsyncCompletions, Completions
|
|
3
|
+
from seekrai.resources.embeddings import AsyncEmbeddings, Embeddings
|
|
4
|
+
from seekrai.resources.files import AsyncFiles, Files
|
|
5
|
+
from seekrai.resources.finetune import AsyncFineTuning, FineTuning
|
|
6
|
+
from seekrai.resources.images import AsyncImages, Images
|
|
7
|
+
from seekrai.resources.models import AsyncModels, Models
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"AsyncCompletions",
|
|
12
|
+
"Completions",
|
|
13
|
+
"AsyncChat",
|
|
14
|
+
"Chat",
|
|
15
|
+
"AsyncEmbeddings",
|
|
16
|
+
"Embeddings",
|
|
17
|
+
"AsyncFineTuning",
|
|
18
|
+
"FineTuning",
|
|
19
|
+
"AsyncFiles",
|
|
20
|
+
"Files",
|
|
21
|
+
"AsyncImages",
|
|
22
|
+
"Images",
|
|
23
|
+
"AsyncModels",
|
|
24
|
+
"Models",
|
|
25
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
|
|
3
|
+
from seekrai.resources.chat.completions import AsyncChatCompletions, ChatCompletions
|
|
4
|
+
from seekrai.types import (
|
|
5
|
+
SeekrFlowClient,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Chat:
|
|
10
|
+
def __init__(self, client: SeekrFlowClient) -> None:
|
|
11
|
+
self._client = client
|
|
12
|
+
|
|
13
|
+
@cached_property
|
|
14
|
+
def completions(self) -> ChatCompletions:
|
|
15
|
+
return ChatCompletions(self._client)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncChat:
|
|
19
|
+
def __init__(self, client: SeekrFlowClient) -> None:
|
|
20
|
+
self._client = client
|
|
21
|
+
|
|
22
|
+
@cached_property
|
|
23
|
+
def completions(self) -> AsyncChatCompletions:
|
|
24
|
+
return AsyncChatCompletions(self._client)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, AsyncGenerator, Dict, Iterator, List
|
|
4
|
+
|
|
5
|
+
from seekrai.abstract import api_requestor
|
|
6
|
+
from seekrai.seekrflow_response import SeekrFlowResponse
|
|
7
|
+
from seekrai.types import (
|
|
8
|
+
ChatCompletionChunk,
|
|
9
|
+
ChatCompletionRequest,
|
|
10
|
+
ChatCompletionResponse,
|
|
11
|
+
SeekrFlowClient,
|
|
12
|
+
SeekrFlowRequest,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChatCompletions:
|
|
17
|
+
def __init__(self, client: SeekrFlowClient) -> None:
|
|
18
|
+
self._client = client
|
|
19
|
+
|
|
20
|
+
def create(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
messages: List[Dict[str, str]],
|
|
24
|
+
model: str,
|
|
25
|
+
max_tokens: int | None = None,
|
|
26
|
+
stop: List[str] | None = None,
|
|
27
|
+
temperature: float = 0.7,
|
|
28
|
+
top_p: float = 1,
|
|
29
|
+
top_k: int = 5,
|
|
30
|
+
repetition_penalty: float = 1,
|
|
31
|
+
stream: bool = False,
|
|
32
|
+
logprobs: int = 0,
|
|
33
|
+
echo: bool = False,
|
|
34
|
+
n: int = 1,
|
|
35
|
+
safety_model: str | None = None,
|
|
36
|
+
response_format: Dict[str, str | Dict[str, Any]] | None = None,
|
|
37
|
+
tools: Dict[str, str | Dict[str, Any]] | None = None,
|
|
38
|
+
tool_choice: str | Dict[str, str | Dict[str, str]] | None = None,
|
|
39
|
+
) -> ChatCompletionResponse | Iterator[ChatCompletionChunk]:
|
|
40
|
+
"""
|
|
41
|
+
Method to generate completions based on a given prompt using a specified model.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
messages (List[Dict[str, str]]): A list of messages in the format
|
|
45
|
+
`[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
|
|
46
|
+
model (str): The name of the model to query.
|
|
47
|
+
max_tokens (int, optional): The maximum number of tokens to generate.
|
|
48
|
+
Defaults to 512.
|
|
49
|
+
stop (List[str], optional): List of strings at which to stop generation.
|
|
50
|
+
Defaults to None.
|
|
51
|
+
temperature (float, optional): A decimal number that determines the degree of randomness in the response.
|
|
52
|
+
Defaults to None.
|
|
53
|
+
top_p (float, optional): The top_p (nucleus) parameter is used to dynamically adjust the number
|
|
54
|
+
of choices for each predicted token based on the cumulative probabilities.
|
|
55
|
+
Defaults to None.
|
|
56
|
+
top_k (int, optional): The top_k parameter is used to limit the number of choices for the
|
|
57
|
+
next predicted word or token.
|
|
58
|
+
Defaults to None.
|
|
59
|
+
repetition_penalty (float, optional): A number that controls the diversity of generated text
|
|
60
|
+
by reducing the likelihood of repeated sequences. Higher values decrease repetition.
|
|
61
|
+
Defaults to None.
|
|
62
|
+
stream (bool, optional): Flag indicating whether to stream the generated completions.
|
|
63
|
+
Defaults to False.
|
|
64
|
+
logprobs (int, optional): Number of top-k logprobs to return
|
|
65
|
+
Defaults to None.
|
|
66
|
+
echo (bool, optional): Echo prompt in output. Can be used with logprobs to return prompt logprobs.
|
|
67
|
+
Defaults to None.
|
|
68
|
+
n (int, optional): Number of completions to generate. Setting to None will return a single generation.
|
|
69
|
+
Defaults to None.
|
|
70
|
+
safety_model (str, optional): A moderation model to validate tokens. Choice between available moderation
|
|
71
|
+
models found [here](https://docs.seekrflow.ai/docs/inference-models#moderation-models).
|
|
72
|
+
Defaults to None.
|
|
73
|
+
response_format (Dict[str, Any], optional): An object specifying the format that the model must output.
|
|
74
|
+
Defaults to None.
|
|
75
|
+
tools (Dict[str, str | Dict[str, str | Dict[str, Any]]], optional): A list of tools the model may call.
|
|
76
|
+
Currently, only functions are supported as a tool.
|
|
77
|
+
Use this to provide a list of functions the model may generate JSON inputs for.
|
|
78
|
+
Defaults to None
|
|
79
|
+
tool_choice: Controls which (if any) function is called by the model. auto means the model can pick
|
|
80
|
+
between generating a message or calling a function. Specifying a particular function
|
|
81
|
+
via {"type": "function", "function": {"name": "my_function"}} forces the model to call that function.
|
|
82
|
+
Sets to `auto` if None.
|
|
83
|
+
Defaults to None.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
ChatCompletionResponse | Iterator[ChatCompletionChunk]: Object containing the completions
|
|
87
|
+
or an iterator over completion chunks.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
requestor = api_requestor.APIRequestor(
|
|
91
|
+
client=self._client,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
parameter_payload = ChatCompletionRequest(
|
|
95
|
+
model=model,
|
|
96
|
+
messages=messages,
|
|
97
|
+
top_p=top_p,
|
|
98
|
+
top_k=top_k,
|
|
99
|
+
temperature=temperature,
|
|
100
|
+
max_tokens=max_tokens,
|
|
101
|
+
stop=stop,
|
|
102
|
+
repetition_penalty=repetition_penalty,
|
|
103
|
+
stream=stream,
|
|
104
|
+
logprobs=logprobs,
|
|
105
|
+
echo=echo,
|
|
106
|
+
n=n,
|
|
107
|
+
safety_model=safety_model,
|
|
108
|
+
response_format=response_format,
|
|
109
|
+
tools=tools,
|
|
110
|
+
tool_choice=tool_choice,
|
|
111
|
+
).model_dump()
|
|
112
|
+
|
|
113
|
+
response, _, _ = requestor.request(
|
|
114
|
+
options=SeekrFlowRequest(
|
|
115
|
+
method="POST",
|
|
116
|
+
url="chat/completions",
|
|
117
|
+
params=parameter_payload,
|
|
118
|
+
),
|
|
119
|
+
stream=stream,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if stream:
|
|
123
|
+
# must be an iterator
|
|
124
|
+
assert not isinstance(response, SeekrFlowResponse)
|
|
125
|
+
return (ChatCompletionChunk(**line.data) for line in response)
|
|
126
|
+
assert isinstance(response, SeekrFlowResponse)
|
|
127
|
+
return ChatCompletionResponse(**response.data)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class AsyncChatCompletions:
|
|
131
|
+
def __init__(self, client: SeekrFlowClient) -> None:
|
|
132
|
+
self._client = client
|
|
133
|
+
|
|
134
|
+
async def create(
|
|
135
|
+
self,
|
|
136
|
+
*,
|
|
137
|
+
messages: List[Dict[str, str]],
|
|
138
|
+
model: str,
|
|
139
|
+
max_tokens: int | None = None,
|
|
140
|
+
stop: List[str] | None = None,
|
|
141
|
+
temperature: float | None = None,
|
|
142
|
+
top_p: float | None = None,
|
|
143
|
+
top_k: int | None = None,
|
|
144
|
+
repetition_penalty: float | None = None,
|
|
145
|
+
stream: bool = False,
|
|
146
|
+
logprobs: int | None = None,
|
|
147
|
+
echo: bool | None = None,
|
|
148
|
+
n: int | None = None,
|
|
149
|
+
safety_model: str | None = None,
|
|
150
|
+
response_format: Dict[str, Any] | None = None,
|
|
151
|
+
# tools: Dict[str, str | Dict[str, str | Dict[str, Any]]] | None = None,
|
|
152
|
+
# tool_choice: str | Dict[str, str | Dict[str, str]] | None = None,
|
|
153
|
+
) -> AsyncGenerator[ChatCompletionChunk, None] | ChatCompletionResponse:
|
|
154
|
+
"""
|
|
155
|
+
Async method to generate completions based on a given prompt using a specified model.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
messages (List[Dict[str, str]]): A list of messages in the format
|
|
159
|
+
`[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
|
|
160
|
+
model (str): The name of the model to query.
|
|
161
|
+
max_tokens (int, optional): The maximum number of tokens to generate.
|
|
162
|
+
Defaults to 512.
|
|
163
|
+
stop (List[str], optional): List of strings at which to stop generation.
|
|
164
|
+
Defaults to None.
|
|
165
|
+
temperature (float, optional): A decimal number that determines the degree of randomness in the response.
|
|
166
|
+
Defaults to None.
|
|
167
|
+
top_p (float, optional): The top_p (nucleus) parameter is used to dynamically adjust the number
|
|
168
|
+
of choices for each predicted token based on the cumulative probabilities.
|
|
169
|
+
Defaults to None.
|
|
170
|
+
top_k (int, optional): The top_k parameter is used to limit the number of choices for the
|
|
171
|
+
next predicted word or token.
|
|
172
|
+
Defaults to None.
|
|
173
|
+
repetition_penalty (float, optional): A number that controls the diversity of generated text
|
|
174
|
+
by reducing the likelihood of repeated sequences. Higher values decrease repetition.
|
|
175
|
+
Defaults to None.
|
|
176
|
+
stream (bool, optional): Flag indicating whether to stream the generated completions.
|
|
177
|
+
Defaults to False.
|
|
178
|
+
logprobs (int, optional): Number of top-k logprobs to return
|
|
179
|
+
Defaults to None.
|
|
180
|
+
echo (bool, optional): Echo prompt in output. Can be used with logprobs to return prompt logprobs.
|
|
181
|
+
Defaults to None.
|
|
182
|
+
n (int, optional): Number of completions to generate. Setting to None will return a single generation.
|
|
183
|
+
Defaults to None.
|
|
184
|
+
safety_model (str, optional): A moderation model to validate tokens. Choice between available moderation
|
|
185
|
+
models found [here](https://docs.seekrflow.ai/docs/inference-models#moderation-models).
|
|
186
|
+
Defaults to None.
|
|
187
|
+
response_format (Dict[str, Any], optional): An object specifying the format that the model must output.
|
|
188
|
+
Defaults to None.
|
|
189
|
+
tools (Dict[str, str | Dict[str, str | Dict[str, Any]]], optional): A list of tools the model may call.
|
|
190
|
+
Currently, only functions are supported as a tool.
|
|
191
|
+
Use this to provide a list of functions the model may generate JSON inputs for.
|
|
192
|
+
Defaults to None
|
|
193
|
+
tool_choice: Controls which (if any) function is called by the model. auto means the model can pick
|
|
194
|
+
between generating a message or calling a function. Specifying a particular function
|
|
195
|
+
via {"type": "function", "function": {"name": "my_function"}} forces the model to call that function.
|
|
196
|
+
Sets to `auto` if None.
|
|
197
|
+
Defaults to None.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
AsyncGenerator[ChatCompletionChunk, None] | ChatCompletionResponse: Object containing the completions
|
|
201
|
+
or an iterator over completion chunks.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
requestor = api_requestor.APIRequestor(
|
|
205
|
+
client=self._client,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
parameter_payload = ChatCompletionRequest(
|
|
209
|
+
model=model,
|
|
210
|
+
messages=messages,
|
|
211
|
+
top_p=top_p,
|
|
212
|
+
top_k=top_k,
|
|
213
|
+
temperature=temperature,
|
|
214
|
+
max_tokens=max_tokens,
|
|
215
|
+
stop=stop,
|
|
216
|
+
repetition_penalty=repetition_penalty,
|
|
217
|
+
stream=stream,
|
|
218
|
+
logprobs=logprobs,
|
|
219
|
+
echo=echo,
|
|
220
|
+
n=n,
|
|
221
|
+
safety_model=safety_model,
|
|
222
|
+
response_format=response_format,
|
|
223
|
+
# tools=tools,
|
|
224
|
+
# tool_choice=tool_choice,
|
|
225
|
+
).model_dump()
|
|
226
|
+
|
|
227
|
+
response, _, _ = await requestor.arequest(
|
|
228
|
+
options=SeekrFlowRequest(
|
|
229
|
+
method="POST",
|
|
230
|
+
url="chat/completions",
|
|
231
|
+
params=parameter_payload,
|
|
232
|
+
),
|
|
233
|
+
stream=stream,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if stream:
|
|
237
|
+
# must be an iterator
|
|
238
|
+
assert not isinstance(response, SeekrFlowResponse)
|
|
239
|
+
return (ChatCompletionChunk(**line.data) async for line in response)
|
|
240
|
+
assert isinstance(response, SeekrFlowResponse)
|
|
241
|
+
return ChatCompletionResponse(**response.data)
|