seekrai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- seekrai/__init__.py +64 -0
- seekrai/abstract/__init__.py +1 -0
- seekrai/abstract/api_requestor.py +710 -0
- seekrai/cli/__init__.py +0 -0
- seekrai/cli/api/__init__.py +0 -0
- seekrai/cli/api/chat.py +245 -0
- seekrai/cli/api/completions.py +107 -0
- seekrai/cli/api/files.py +125 -0
- seekrai/cli/api/finetune.py +175 -0
- seekrai/cli/api/images.py +82 -0
- seekrai/cli/api/models.py +42 -0
- seekrai/cli/cli.py +77 -0
- seekrai/client.py +154 -0
- seekrai/constants.py +32 -0
- seekrai/error.py +188 -0
- seekrai/filemanager.py +393 -0
- seekrai/legacy/__init__.py +0 -0
- seekrai/legacy/base.py +27 -0
- seekrai/legacy/complete.py +91 -0
- seekrai/legacy/embeddings.py +25 -0
- seekrai/legacy/files.py +140 -0
- seekrai/legacy/finetune.py +173 -0
- seekrai/legacy/images.py +25 -0
- seekrai/legacy/models.py +44 -0
- seekrai/resources/__init__.py +25 -0
- seekrai/resources/chat/__init__.py +24 -0
- seekrai/resources/chat/completions.py +241 -0
- seekrai/resources/completions.py +205 -0
- seekrai/resources/embeddings.py +100 -0
- seekrai/resources/files.py +173 -0
- seekrai/resources/finetune.py +425 -0
- seekrai/resources/images.py +156 -0
- seekrai/resources/models.py +75 -0
- seekrai/seekrflow_response.py +50 -0
- seekrai/types/__init__.py +67 -0
- seekrai/types/abstract.py +26 -0
- seekrai/types/chat_completions.py +151 -0
- seekrai/types/common.py +64 -0
- seekrai/types/completions.py +86 -0
- seekrai/types/embeddings.py +35 -0
- seekrai/types/error.py +16 -0
- seekrai/types/files.py +88 -0
- seekrai/types/finetune.py +218 -0
- seekrai/types/images.py +42 -0
- seekrai/types/models.py +43 -0
- seekrai/utils/__init__.py +28 -0
- seekrai/utils/_log.py +61 -0
- seekrai/utils/api_helpers.py +84 -0
- seekrai/utils/files.py +204 -0
- seekrai/utils/tools.py +75 -0
- seekrai/version.py +6 -0
- seekrai-0.0.1.dist-info/LICENSE +201 -0
- seekrai-0.0.1.dist-info/METADATA +401 -0
- seekrai-0.0.1.dist-info/RECORD +56 -0
- seekrai-0.0.1.dist-info/WHEEL +4 -0
- seekrai-0.0.1.dist-info/entry_points.txt +3 -0
seekrai/cli/api/chat.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import cmd
|
|
4
|
+
import json
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
from seekrai import seekrflow
|
|
10
|
+
from seekrai.types.chat_completions import (
|
|
11
|
+
ChatCompletionChoicesChunk,
|
|
12
|
+
ChatCompletionChunk,
|
|
13
|
+
ChatCompletionResponse,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChatShell(cmd.Cmd):
|
|
18
|
+
intro = "Type /exit to exit, /help, or /? to list commands.\n"
|
|
19
|
+
prompt = ">>> "
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
client: SeekrFlow,
|
|
24
|
+
model: str,
|
|
25
|
+
max_tokens: int | None = None,
|
|
26
|
+
stop: List[str] | None = None,
|
|
27
|
+
temperature: float | None = None,
|
|
28
|
+
top_p: float | None = None,
|
|
29
|
+
top_k: int | None = None,
|
|
30
|
+
repetition_penalty: float | None = None,
|
|
31
|
+
safety_model: str | None = None,
|
|
32
|
+
system_message: str | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.client = client
|
|
36
|
+
self.model = model
|
|
37
|
+
self.max_tokens = max_tokens
|
|
38
|
+
self.stop = stop
|
|
39
|
+
self.temperature = temperature
|
|
40
|
+
self.top_p = top_p
|
|
41
|
+
self.top_k = top_k
|
|
42
|
+
self.repetition_penalty = repetition_penalty
|
|
43
|
+
self.safety_model = safety_model
|
|
44
|
+
self.system_message = system_message
|
|
45
|
+
|
|
46
|
+
self.messages = (
|
|
47
|
+
[{"role": "system", "content": self.system_message}]
|
|
48
|
+
if self.system_message
|
|
49
|
+
else []
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def precmd(self, line: str) -> str:
|
|
53
|
+
if line.startswith("/"):
|
|
54
|
+
return line[1:]
|
|
55
|
+
else:
|
|
56
|
+
return "say " + line
|
|
57
|
+
|
|
58
|
+
def do_say(self, arg: str) -> None:
|
|
59
|
+
self.messages.append({"role": "user", "content": arg})
|
|
60
|
+
|
|
61
|
+
output = ""
|
|
62
|
+
|
|
63
|
+
for chunk in self.client.chat.completions.create(
|
|
64
|
+
messages=self.messages,
|
|
65
|
+
model=self.model,
|
|
66
|
+
max_tokens=self.max_tokens,
|
|
67
|
+
stop=self.stop,
|
|
68
|
+
temperature=self.temperature,
|
|
69
|
+
top_p=self.top_p,
|
|
70
|
+
top_k=self.top_k,
|
|
71
|
+
repetition_penalty=self.repetition_penalty,
|
|
72
|
+
safety_model=self.safety_model,
|
|
73
|
+
stream=True,
|
|
74
|
+
):
|
|
75
|
+
# assertions for type checking
|
|
76
|
+
assert isinstance(chunk, ChatCompletionChunk)
|
|
77
|
+
assert chunk.choices
|
|
78
|
+
assert chunk.choices[0].delta
|
|
79
|
+
assert chunk.choices[0].delta.content
|
|
80
|
+
|
|
81
|
+
token = chunk.choices[0].delta.content
|
|
82
|
+
|
|
83
|
+
click.echo(token, nl=False)
|
|
84
|
+
|
|
85
|
+
output += token
|
|
86
|
+
|
|
87
|
+
click.echo("\n")
|
|
88
|
+
|
|
89
|
+
self.messages.append({"role": "assistant", "content": output})
|
|
90
|
+
|
|
91
|
+
def do_reset(self, arg: str) -> None:
|
|
92
|
+
self.messages = (
|
|
93
|
+
[{"role": "system", "content": self.system_message}]
|
|
94
|
+
if self.system_message
|
|
95
|
+
else []
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def do_exit(self, arg: str) -> bool:
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@click.command(name="chat.interactive")
|
|
103
|
+
@click.pass_context
|
|
104
|
+
@click.option("--model", type=str, required=True, help="Model name")
|
|
105
|
+
@click.option("--max-tokens", type=int, help="Max tokens to generate")
|
|
106
|
+
@click.option(
|
|
107
|
+
"--stop", type=str, multiple=True, help="List of strings to stop generation"
|
|
108
|
+
)
|
|
109
|
+
@click.option("--temperature", type=float, help="Sampling temperature")
|
|
110
|
+
@click.option("--top-p", type=int, help="Top p sampling")
|
|
111
|
+
@click.option("--top-k", type=float, help="Top k sampling")
|
|
112
|
+
@click.option("--safety-model", type=str, help="Moderation model")
|
|
113
|
+
@click.option("--system-message", type=str, help="System message to use for the chat")
|
|
114
|
+
def interactive(
|
|
115
|
+
ctx: click.Context,
|
|
116
|
+
model: str,
|
|
117
|
+
max_tokens: int | None = None,
|
|
118
|
+
stop: List[str] | None = None,
|
|
119
|
+
temperature: float | None = None,
|
|
120
|
+
top_p: float | None = None,
|
|
121
|
+
top_k: int | None = None,
|
|
122
|
+
repetition_penalty: float | None = None,
|
|
123
|
+
safety_model: str | None = None,
|
|
124
|
+
system_message: str | None = None,
|
|
125
|
+
) -> None:
|
|
126
|
+
"""Interactive chat shell"""
|
|
127
|
+
client: SeekrFlow = ctx.obj
|
|
128
|
+
|
|
129
|
+
ChatShell(
|
|
130
|
+
client=client,
|
|
131
|
+
model=model,
|
|
132
|
+
max_tokens=max_tokens,
|
|
133
|
+
stop=stop,
|
|
134
|
+
temperature=temperature,
|
|
135
|
+
top_p=top_p,
|
|
136
|
+
top_k=top_k,
|
|
137
|
+
repetition_penalty=repetition_penalty,
|
|
138
|
+
safety_model=safety_model,
|
|
139
|
+
system_message=system_message,
|
|
140
|
+
).cmdloop()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@click.command(name="chat.completions")
|
|
144
|
+
@click.pass_context
|
|
145
|
+
@click.option(
|
|
146
|
+
"--message",
|
|
147
|
+
type=(str, str),
|
|
148
|
+
multiple=True,
|
|
149
|
+
required=True,
|
|
150
|
+
help="Message to generate chat completions from",
|
|
151
|
+
)
|
|
152
|
+
@click.option("--model", type=str, required=True, help="Model name")
|
|
153
|
+
@click.option("--max-tokens", type=int, help="Max tokens to generate")
|
|
154
|
+
@click.option(
|
|
155
|
+
"--stop", type=str, multiple=True, help="List of strings to stop generation"
|
|
156
|
+
)
|
|
157
|
+
@click.option("--temperature", type=float, help="Sampling temperature")
|
|
158
|
+
@click.option("--top-p", type=int, help="Top p sampling")
|
|
159
|
+
@click.option("--top-k", type=float, help="Top k sampling")
|
|
160
|
+
@click.option("--repetition-penalty", type=float, help="Repetition penalty")
|
|
161
|
+
@click.option("--no-stream", is_flag=True, help="Disable streaming")
|
|
162
|
+
@click.option("--logprobs", type=int, help="Return logprobs. Only works with --raw.")
|
|
163
|
+
@click.option("--echo", is_flag=True, help="Echo prompt. Only works with --raw.")
|
|
164
|
+
@click.option("--n", type=int, help="Number of output generations")
|
|
165
|
+
@click.option("--safety-model", type=str, help="Moderation model")
|
|
166
|
+
@click.option("--raw", is_flag=True, help="Output raw JSON")
|
|
167
|
+
def chat(
|
|
168
|
+
ctx: click.Context,
|
|
169
|
+
message: List[Tuple[str, str]],
|
|
170
|
+
model: str,
|
|
171
|
+
max_tokens: int | None = None,
|
|
172
|
+
stop: List[str] | None = None,
|
|
173
|
+
temperature: float | None = None,
|
|
174
|
+
top_p: float | None = None,
|
|
175
|
+
top_k: int | None = None,
|
|
176
|
+
repetition_penalty: float | None = None,
|
|
177
|
+
no_stream: bool = False,
|
|
178
|
+
logprobs: int | None = None,
|
|
179
|
+
echo: bool | None = None,
|
|
180
|
+
n: int | None = None,
|
|
181
|
+
safety_model: str | None = None,
|
|
182
|
+
raw: bool = False,
|
|
183
|
+
) -> None:
|
|
184
|
+
"""Generate chat completions from messages"""
|
|
185
|
+
client: SeekrFlow = ctx.obj
|
|
186
|
+
|
|
187
|
+
messages = [{"role": msg[0], "content": msg[1]} for msg in message]
|
|
188
|
+
|
|
189
|
+
response = client.chat.completions.create(
|
|
190
|
+
model=model,
|
|
191
|
+
messages=messages,
|
|
192
|
+
top_p=top_p,
|
|
193
|
+
top_k=top_k,
|
|
194
|
+
temperature=temperature,
|
|
195
|
+
max_tokens=max_tokens,
|
|
196
|
+
stop=stop,
|
|
197
|
+
repetition_penalty=repetition_penalty,
|
|
198
|
+
stream=not no_stream,
|
|
199
|
+
logprobs=logprobs,
|
|
200
|
+
echo=echo,
|
|
201
|
+
n=n,
|
|
202
|
+
safety_model=safety_model,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if not no_stream:
|
|
206
|
+
for chunk in response:
|
|
207
|
+
# assertions for type checking
|
|
208
|
+
assert isinstance(chunk, ChatCompletionChunk)
|
|
209
|
+
assert chunk.choices
|
|
210
|
+
|
|
211
|
+
if raw:
|
|
212
|
+
click.echo(f"{json.dumps(chunk.model_dump())}")
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
should_print_header = len(chunk.choices) > 1
|
|
216
|
+
for stream_choice in sorted(chunk.choices, key=lambda c: c.index): # type: ignore
|
|
217
|
+
assert isinstance(stream_choice, ChatCompletionChoicesChunk)
|
|
218
|
+
assert stream_choice.delta
|
|
219
|
+
|
|
220
|
+
if should_print_header:
|
|
221
|
+
click.echo(f"\n===== Completion {stream_choice.index} =====\n")
|
|
222
|
+
click.echo(f"{stream_choice.delta.content}", nl=False)
|
|
223
|
+
|
|
224
|
+
if should_print_header:
|
|
225
|
+
click.echo("\n")
|
|
226
|
+
|
|
227
|
+
# new line after stream ends
|
|
228
|
+
click.echo("\n")
|
|
229
|
+
else:
|
|
230
|
+
# assertions for type checking
|
|
231
|
+
assert isinstance(response, ChatCompletionResponse)
|
|
232
|
+
assert isinstance(response.choices, list)
|
|
233
|
+
|
|
234
|
+
if raw:
|
|
235
|
+
click.echo(f"{json.dumps(response.model_dump(), indent=4)}")
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
should_print_header = len(response.choices) > 1
|
|
239
|
+
for i, choice in enumerate(response.choices):
|
|
240
|
+
if should_print_header:
|
|
241
|
+
click.echo(f"===== Completion {i} =====")
|
|
242
|
+
click.echo(choice.message.content) # type: ignore
|
|
243
|
+
|
|
244
|
+
if should_print_header:
|
|
245
|
+
click.echo("\n")
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import click
|
|
7
|
+
|
|
8
|
+
from seekrai import seekrflow
|
|
9
|
+
from seekrai.types import CompletionChunk
|
|
10
|
+
from seekrai.types.completions import CompletionChoicesChunk, CompletionResponse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@click.command()
|
|
14
|
+
@click.pass_context
|
|
15
|
+
@click.argument("prompt", type=str, required=True)
|
|
16
|
+
@click.option("--model", type=str, required=True, help="Model name")
|
|
17
|
+
@click.option("--no-stream", is_flag=True, help="Disable streaming")
|
|
18
|
+
@click.option("--max-tokens", type=int, help="Max tokens to generate")
|
|
19
|
+
@click.option(
|
|
20
|
+
"--stop", type=str, multiple=True, help="List of strings to stop generation"
|
|
21
|
+
)
|
|
22
|
+
@click.option("--temperature", type=float, help="Sampling temperature")
|
|
23
|
+
@click.option("--top-p", type=int, help="Top p sampling")
|
|
24
|
+
@click.option("--top-k", type=float, help="Top k sampling")
|
|
25
|
+
@click.option("--logprobs", type=int, help="Return logprobs. Only works with --raw.")
|
|
26
|
+
@click.option("--echo", is_flag=True, help="Echo prompt. Only works with --raw.")
|
|
27
|
+
@click.option("--n", type=int, help="Number of output generations")
|
|
28
|
+
@click.option("--safety-model", type=str, help="Moderation model")
|
|
29
|
+
@click.option("--raw", is_flag=True, help="Return raw JSON response")
|
|
30
|
+
def completions(
|
|
31
|
+
ctx: click.Context,
|
|
32
|
+
prompt: str,
|
|
33
|
+
model: str,
|
|
34
|
+
max_tokens: int | None = 512,
|
|
35
|
+
stop: List[str] | None = None,
|
|
36
|
+
temperature: float | None = None,
|
|
37
|
+
top_p: float | None = None,
|
|
38
|
+
top_k: int | None = None,
|
|
39
|
+
repetition_penalty: float | None = None,
|
|
40
|
+
no_stream: bool = False,
|
|
41
|
+
logprobs: int | None = None,
|
|
42
|
+
echo: bool | None = None,
|
|
43
|
+
n: int | None = None,
|
|
44
|
+
safety_model: str | None = None,
|
|
45
|
+
raw: bool = False,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Generate text completions"""
|
|
48
|
+
client: SeekrFlow = ctx.obj
|
|
49
|
+
|
|
50
|
+
response = client.completions.create(
|
|
51
|
+
model=model,
|
|
52
|
+
prompt=prompt,
|
|
53
|
+
top_p=top_p,
|
|
54
|
+
top_k=top_k,
|
|
55
|
+
temperature=temperature,
|
|
56
|
+
max_tokens=max_tokens,
|
|
57
|
+
stop=stop,
|
|
58
|
+
repetition_penalty=repetition_penalty,
|
|
59
|
+
stream=not no_stream,
|
|
60
|
+
logprobs=logprobs,
|
|
61
|
+
echo=echo,
|
|
62
|
+
n=n,
|
|
63
|
+
safety_model=safety_model,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if not no_stream:
|
|
67
|
+
for chunk in response:
|
|
68
|
+
# assertions for type checking
|
|
69
|
+
assert isinstance(chunk, CompletionChunk)
|
|
70
|
+
assert chunk.choices
|
|
71
|
+
|
|
72
|
+
if raw:
|
|
73
|
+
click.echo(f"{json.dumps(chunk.model_dump())}")
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
should_print_header = len(chunk.choices) > 1
|
|
77
|
+
for stream_choice in sorted(chunk.choices, key=lambda c: c.index): # type: ignore
|
|
78
|
+
# assertions for type checking
|
|
79
|
+
assert isinstance(stream_choice, CompletionChoicesChunk)
|
|
80
|
+
assert stream_choice.delta
|
|
81
|
+
|
|
82
|
+
if should_print_header:
|
|
83
|
+
click.echo(f"\n===== Completion {stream_choice.index} =====\n")
|
|
84
|
+
click.echo(f"{stream_choice.delta.content}", nl=False)
|
|
85
|
+
|
|
86
|
+
if should_print_header:
|
|
87
|
+
click.echo("\n")
|
|
88
|
+
|
|
89
|
+
# new line after stream ends
|
|
90
|
+
click.echo("\n")
|
|
91
|
+
else:
|
|
92
|
+
# assertions for type checking
|
|
93
|
+
assert isinstance(response, CompletionResponse)
|
|
94
|
+
assert isinstance(response.choices, list)
|
|
95
|
+
|
|
96
|
+
if raw:
|
|
97
|
+
click.echo(f"{json.dumps(response.model_dump(), indent=4)}")
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
should_print_header = len(response.choices) > 1
|
|
101
|
+
for i, choice in enumerate(response.choices):
|
|
102
|
+
if should_print_header:
|
|
103
|
+
click.echo(f"===== Completion {i} =====")
|
|
104
|
+
click.echo(choice.text)
|
|
105
|
+
|
|
106
|
+
if should_print_header or not choice.text.endswith("\n"):
|
|
107
|
+
click.echo("\n")
|
seekrai/cli/api/files.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import pathlib
|
|
3
|
+
from textwrap import wrap
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
from tabulate import tabulate
|
|
7
|
+
|
|
8
|
+
from seekrai import seekrflow
|
|
9
|
+
from seekrai.types import FilePurpose
|
|
10
|
+
from seekrai.utils import check_file, convert_bytes, convert_unix_timestamp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@click.group()
|
|
14
|
+
@click.pass_context
|
|
15
|
+
def files(ctx: click.Context) -> None:
|
|
16
|
+
"""File API commands"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@files.command()
|
|
21
|
+
@click.pass_context
|
|
22
|
+
@click.argument(
|
|
23
|
+
"file",
|
|
24
|
+
type=click.Path(
|
|
25
|
+
exists=True, file_okay=True, resolve_path=True, readable=True, dir_okay=False
|
|
26
|
+
),
|
|
27
|
+
required=True,
|
|
28
|
+
)
|
|
29
|
+
@click.option(
|
|
30
|
+
"--purpose",
|
|
31
|
+
type=str,
|
|
32
|
+
default=FilePurpose.FineTune.value,
|
|
33
|
+
help="Purpose of file upload. Acceptable values in enum `seekrai.types.FilePurpose`. Defaults to `fine-tunes`.",
|
|
34
|
+
)
|
|
35
|
+
def upload(ctx: click.Context, file: pathlib.Path, purpose: str) -> None:
|
|
36
|
+
"""Upload file"""
|
|
37
|
+
|
|
38
|
+
client: SeekrFlow = ctx.obj
|
|
39
|
+
|
|
40
|
+
response = client.files.upload(file=file, purpose=purpose)
|
|
41
|
+
|
|
42
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@files.command()
|
|
46
|
+
@click.pass_context
|
|
47
|
+
def list(ctx: click.Context) -> None:
|
|
48
|
+
"""List files"""
|
|
49
|
+
client: SeekrFlow = ctx.obj
|
|
50
|
+
|
|
51
|
+
response = client.files.list()
|
|
52
|
+
|
|
53
|
+
display_list = []
|
|
54
|
+
for i in response.data or []:
|
|
55
|
+
display_list.append(
|
|
56
|
+
{
|
|
57
|
+
"File name": "\n".join(wrap(i.filename or "", width=30)),
|
|
58
|
+
"File ID": i.id,
|
|
59
|
+
"Size": convert_bytes(
|
|
60
|
+
float(str(i.bytes))
|
|
61
|
+
), # convert to string for mypy typing
|
|
62
|
+
"Created At": convert_unix_timestamp(i.created_at or 0),
|
|
63
|
+
"Line Count": i.line_count,
|
|
64
|
+
}
|
|
65
|
+
)
|
|
66
|
+
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
|
|
67
|
+
|
|
68
|
+
click.echo(table)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@files.command()
|
|
72
|
+
@click.pass_context
|
|
73
|
+
@click.argument("id", type=str, required=True)
|
|
74
|
+
def retrieve(ctx: click.Context, id: str) -> None:
|
|
75
|
+
"""Upload file"""
|
|
76
|
+
|
|
77
|
+
client: SeekrFlow = ctx.obj
|
|
78
|
+
|
|
79
|
+
response = client.files.retrieve(id=id)
|
|
80
|
+
|
|
81
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@files.command()
|
|
85
|
+
@click.pass_context
|
|
86
|
+
@click.argument("id", type=str, required=True)
|
|
87
|
+
@click.option("--output", type=str, default=None, help="Output filename")
|
|
88
|
+
def retrieve_content(ctx: click.Context, id: str, output: str) -> None:
|
|
89
|
+
"""Retrieve file content and output to file"""
|
|
90
|
+
|
|
91
|
+
client: SeekrFlow = ctx.obj
|
|
92
|
+
|
|
93
|
+
response = client.files.retrieve_content(id=id, output=output)
|
|
94
|
+
|
|
95
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@files.command()
|
|
99
|
+
@click.pass_context
|
|
100
|
+
@click.argument("id", type=str, required=True)
|
|
101
|
+
def delete(ctx: click.Context, id: str) -> None:
|
|
102
|
+
"""Delete remote file"""
|
|
103
|
+
|
|
104
|
+
client: SeekrFlow = ctx.obj
|
|
105
|
+
|
|
106
|
+
response = client.files.delete(id=id)
|
|
107
|
+
|
|
108
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@files.command()
|
|
112
|
+
@click.pass_context
|
|
113
|
+
@click.argument(
|
|
114
|
+
"file",
|
|
115
|
+
type=click.Path(
|
|
116
|
+
exists=True, file_okay=True, resolve_path=True, readable=True, dir_okay=False
|
|
117
|
+
),
|
|
118
|
+
required=True,
|
|
119
|
+
)
|
|
120
|
+
def check(ctx: click.Context, file: pathlib.Path) -> None:
|
|
121
|
+
"""Check file for issues"""
|
|
122
|
+
|
|
123
|
+
report = check_file(file)
|
|
124
|
+
|
|
125
|
+
click.echo(json.dumps(report, indent=4))
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from textwrap import wrap
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
from tabulate import tabulate
|
|
6
|
+
|
|
7
|
+
from seekrai import seekrflow
|
|
8
|
+
from seekrai.utils import finetune_price_to_dollars, parse_timestamp
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@click.group(name="fine-tuning")
|
|
12
|
+
@click.pass_context
|
|
13
|
+
def fine_tuning(ctx: click.Context) -> None:
|
|
14
|
+
"""Fine-tunes API commands"""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@fine_tuning.command()
|
|
19
|
+
@click.pass_context
|
|
20
|
+
@click.option(
|
|
21
|
+
"--training-file", type=str, required=True, help="Training file ID from Files API"
|
|
22
|
+
)
|
|
23
|
+
@click.option("--model", type=str, required=True, help="Base model name")
|
|
24
|
+
@click.option("--n-epochs", type=int, default=1, help="Number of epochs to train for")
|
|
25
|
+
@click.option(
|
|
26
|
+
"--n-checkpoints", type=int, default=1, help="Number of checkpoints to save"
|
|
27
|
+
)
|
|
28
|
+
@click.option("--batch-size", type=int, default=32, help="Train batch size")
|
|
29
|
+
@click.option("--learning-rate", type=float, default=3e-5, help="Learning rate")
|
|
30
|
+
@click.option(
|
|
31
|
+
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
|
|
32
|
+
)
|
|
33
|
+
@click.option("--wandb-api-key", type=str, default=None, help="Wandb API key")
|
|
34
|
+
def create(
|
|
35
|
+
ctx: click.Context,
|
|
36
|
+
training_file: str,
|
|
37
|
+
model: str,
|
|
38
|
+
n_epochs: int,
|
|
39
|
+
n_checkpoints: int,
|
|
40
|
+
batch_size: int,
|
|
41
|
+
learning_rate: float,
|
|
42
|
+
suffix: str,
|
|
43
|
+
wandb_api_key: str,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Start fine-tuning"""
|
|
46
|
+
client: SeekrFlow = ctx.obj
|
|
47
|
+
|
|
48
|
+
response = client.fine_tuning.create(
|
|
49
|
+
training_file=training_file,
|
|
50
|
+
model=model,
|
|
51
|
+
n_epochs=n_epochs,
|
|
52
|
+
n_checkpoints=n_checkpoints,
|
|
53
|
+
batch_size=batch_size,
|
|
54
|
+
learning_rate=learning_rate,
|
|
55
|
+
suffix=suffix,
|
|
56
|
+
wandb_api_key=wandb_api_key,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@fine_tuning.command()
|
|
63
|
+
@click.pass_context
|
|
64
|
+
def list(ctx: click.Context) -> None:
|
|
65
|
+
"""List fine-tuning jobs"""
|
|
66
|
+
client: SeekrFlow = ctx.obj
|
|
67
|
+
|
|
68
|
+
response = client.fine_tuning.list()
|
|
69
|
+
|
|
70
|
+
response.data = response.data or []
|
|
71
|
+
|
|
72
|
+
response.data.sort(key=lambda x: parse_timestamp(x.created_at or ""))
|
|
73
|
+
|
|
74
|
+
display_list = []
|
|
75
|
+
for i in response.data:
|
|
76
|
+
display_list.append(
|
|
77
|
+
{
|
|
78
|
+
"Fine-tune ID": i.id,
|
|
79
|
+
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
|
|
80
|
+
"Status": i.status,
|
|
81
|
+
"Created At": i.created_at,
|
|
82
|
+
"Price": f"""${finetune_price_to_dollars(
|
|
83
|
+
float(str(i.total_price))
|
|
84
|
+
)}""", # convert to string for mypy typing
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
|
|
88
|
+
|
|
89
|
+
click.echo(table)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@fine_tuning.command()
|
|
93
|
+
@click.pass_context
|
|
94
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
95
|
+
def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
|
|
96
|
+
"""Retrieve fine-tuning job details"""
|
|
97
|
+
client: SeekrFlow = ctx.obj
|
|
98
|
+
|
|
99
|
+
response = client.fine_tuning.retrieve(fine_tune_id)
|
|
100
|
+
|
|
101
|
+
# remove events from response for cleaner output
|
|
102
|
+
response.events = None
|
|
103
|
+
|
|
104
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@fine_tuning.command()
|
|
108
|
+
@click.pass_context
|
|
109
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
110
|
+
def cancel(ctx: click.Context, fine_tune_id: str) -> None:
|
|
111
|
+
"""Cancel fine-tuning job"""
|
|
112
|
+
client: SeekrFlow = ctx.obj
|
|
113
|
+
|
|
114
|
+
response = client.fine_tuning.cancel(fine_tune_id)
|
|
115
|
+
|
|
116
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@fine_tuning.command()
|
|
120
|
+
@click.pass_context
|
|
121
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
122
|
+
def list_events(ctx: click.Context, fine_tune_id: str) -> None:
|
|
123
|
+
"""List fine-tuning events"""
|
|
124
|
+
client: SeekrFlow = ctx.obj
|
|
125
|
+
|
|
126
|
+
response = client.fine_tuning.list_events(fine_tune_id)
|
|
127
|
+
|
|
128
|
+
response.data = response.data or []
|
|
129
|
+
|
|
130
|
+
display_list = []
|
|
131
|
+
for i in response.data:
|
|
132
|
+
display_list.append(
|
|
133
|
+
{
|
|
134
|
+
"Message": "\n".join(wrap(i.message or "", width=50)),
|
|
135
|
+
"Type": i.type,
|
|
136
|
+
"Created At": parse_timestamp(i.created_at or ""),
|
|
137
|
+
"Hash": i.hash,
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
|
|
141
|
+
|
|
142
|
+
click.echo(table)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@fine_tuning.command()
|
|
146
|
+
@click.pass_context
|
|
147
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
148
|
+
@click.option(
|
|
149
|
+
"--output_dir",
|
|
150
|
+
type=click.Path(exists=True, file_okay=False, resolve_path=True),
|
|
151
|
+
required=False,
|
|
152
|
+
default=None,
|
|
153
|
+
help="Output directory",
|
|
154
|
+
)
|
|
155
|
+
@click.option(
|
|
156
|
+
"--checkpoint-step",
|
|
157
|
+
type=int,
|
|
158
|
+
required=False,
|
|
159
|
+
default=-1,
|
|
160
|
+
help="Download fine-tuning checkpoint. Defaults to latest.",
|
|
161
|
+
)
|
|
162
|
+
def download(
|
|
163
|
+
ctx: click.Context,
|
|
164
|
+
fine_tune_id: str,
|
|
165
|
+
output_dir: str,
|
|
166
|
+
checkpoint_step: int,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Download fine-tuning checkpoint"""
|
|
169
|
+
client: SeekrFlow = ctx.obj
|
|
170
|
+
|
|
171
|
+
response = client.fine_tuning.download(
|
|
172
|
+
fine_tune_id, output=output_dir, checkpoint_step=checkpoint_step
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
click.echo(json.dumps(response.model_dump(), indent=4))
|