seekrai 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. seekrai/__init__.py +0 -15
  2. seekrai/abstract/api_requestor.py +121 -297
  3. seekrai/client.py +10 -11
  4. seekrai/constants.py +36 -16
  5. seekrai/error.py +1 -8
  6. seekrai/filemanager.py +40 -79
  7. seekrai/resources/chat/completions.py +13 -13
  8. seekrai/resources/completions.py +4 -4
  9. seekrai/resources/embeddings.py +4 -2
  10. seekrai/resources/files.py +17 -9
  11. seekrai/resources/finetune.py +57 -82
  12. seekrai/resources/images.py +2 -2
  13. seekrai/resources/models.py +115 -15
  14. seekrai/types/__init__.py +5 -4
  15. seekrai/types/common.py +1 -2
  16. seekrai/types/files.py +23 -19
  17. seekrai/types/finetune.py +20 -26
  18. seekrai/types/models.py +26 -20
  19. seekrai/utils/_log.py +3 -3
  20. seekrai/utils/api_helpers.py +2 -2
  21. seekrai/utils/tools.py +1 -1
  22. seekrai-0.1.1.dist-info/METADATA +165 -0
  23. seekrai-0.1.1.dist-info/RECORD +39 -0
  24. seekrai/cli/__init__.py +0 -0
  25. seekrai/cli/api/__init__.py +0 -0
  26. seekrai/cli/api/chat.py +0 -245
  27. seekrai/cli/api/completions.py +0 -107
  28. seekrai/cli/api/files.py +0 -125
  29. seekrai/cli/api/finetune.py +0 -175
  30. seekrai/cli/api/images.py +0 -82
  31. seekrai/cli/api/models.py +0 -42
  32. seekrai/cli/cli.py +0 -77
  33. seekrai/legacy/__init__.py +0 -0
  34. seekrai/legacy/base.py +0 -27
  35. seekrai/legacy/complete.py +0 -91
  36. seekrai/legacy/embeddings.py +0 -25
  37. seekrai/legacy/files.py +0 -140
  38. seekrai/legacy/finetune.py +0 -173
  39. seekrai/legacy/images.py +0 -25
  40. seekrai/legacy/models.py +0 -44
  41. seekrai-0.0.1.dist-info/METADATA +0 -401
  42. seekrai-0.0.1.dist-info/RECORD +0 -56
  43. {seekrai-0.0.1.dist-info → seekrai-0.1.1.dist-info}/LICENSE +0 -0
  44. {seekrai-0.0.1.dist-info → seekrai-0.1.1.dist-info}/WHEEL +0 -0
  45. {seekrai-0.0.1.dist-info → seekrai-0.1.1.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from pathlib import Path
4
4
 
5
5
  from seekrai.abstract import api_requestor
6
- from seekrai.filemanager import DownloadManager
7
6
  from seekrai.seekrflow_response import SeekrFlowResponse
8
7
  from seekrai.types import (
9
8
  FinetuneDownloadResult,
@@ -11,12 +10,11 @@ from seekrai.types import (
11
10
  FinetuneListEvents,
12
11
  FinetuneRequest,
13
12
  FinetuneResponse,
13
+ InfrastructureConfig,
14
14
  SeekrFlowClient,
15
15
  SeekrFlowRequest,
16
- TrainingConfig,
17
- InfrastructureConfig
16
+ TrainingConfig,
18
17
  )
19
- from seekrai.utils import normalize_key
20
18
 
21
19
 
22
20
  class FineTuning:
@@ -27,25 +25,13 @@ class FineTuning:
27
25
  self,
28
26
  *,
29
27
  training_config: TrainingConfig,
30
- infrastructure_config: InfrastructureConfig
28
+ infrastructure_config: InfrastructureConfig,
31
29
  # wandb_api_key: str | None = None,
32
30
  ) -> FinetuneResponse:
33
31
  """
34
32
  Method to initiate a fine-tuning job
35
33
 
36
34
  Args:
37
- training_file (str): File-ID of a file uploaded to the SeekrFlow API
38
- model (str): Name of the base model to run fine-tune job on
39
- n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
40
- n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
41
- Defaults to 1.
42
- batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
43
- learning_rate (float, optional): Learning rate multiplier to use for training
44
- Defaults to 0.00001.
45
- suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
46
- Defaults to None.
47
- wandb_api_key (str, optional): API key for Weights & Biases integration.
48
- Defaults to None.
49
35
 
50
36
  Returns:
51
37
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -56,14 +42,13 @@ class FineTuning:
56
42
  )
57
43
 
58
44
  parameter_payload = FinetuneRequest(
59
- training_config=training_config,
60
- infrastructure_config=infrastructure_config
45
+ training_config=training_config, infrastructure_config=infrastructure_config
61
46
  ).model_dump()
62
47
 
63
48
  response, _, _ = requestor.request(
64
49
  options=SeekrFlowRequest(
65
50
  method="POST",
66
- url="fine-tune",
51
+ url="flow/fine-tune",
67
52
  params=parameter_payload,
68
53
  ),
69
54
  stream=False,
@@ -88,7 +73,7 @@ class FineTuning:
88
73
  response, _, _ = requestor.request(
89
74
  options=SeekrFlowRequest(
90
75
  method="GET",
91
- url="fine-tunes",
76
+ url="flow/fine-tunes",
92
77
  ),
93
78
  stream=False,
94
79
  )
@@ -115,7 +100,7 @@ class FineTuning:
115
100
  response, _, _ = requestor.request(
116
101
  options=SeekrFlowRequest(
117
102
  method="GET",
118
- url=f"fine-tunes/{id}",
103
+ url=f"flow/fine-tunes/{id}",
119
104
  ),
120
105
  stream=False,
121
106
  )
@@ -142,7 +127,7 @@ class FineTuning:
142
127
  response, _, _ = requestor.request(
143
128
  options=SeekrFlowRequest(
144
129
  method="POST",
145
- url=f"fine-tunes/{id}/cancel",
130
+ url=f"flow/fine-tunes/{id}/cancel",
146
131
  ),
147
132
  stream=False,
148
133
  )
@@ -169,7 +154,7 @@ class FineTuning:
169
154
  response, _, _ = requestor.request(
170
155
  options=SeekrFlowRequest(
171
156
  method="GET",
172
- url=f"fine-tunes/{id}/events",
157
+ url=f"flow/fine-tunes/{id}/events",
173
158
  ),
174
159
  stream=False,
175
160
  )
@@ -196,33 +181,50 @@ class FineTuning:
196
181
  Returns:
197
182
  FinetuneDownloadResult: Object containing downloaded model metadata
198
183
  """
199
-
200
- url = f"finetune/download?ft_id={id}"
201
-
202
- if checkpoint_step > 0:
203
- url += f"&checkpoint_step={checkpoint_step}"
204
-
205
- remote_name = self.retrieve(id).output_name
206
-
207
- download_manager = DownloadManager(self._client)
208
-
209
- if isinstance(output, str):
210
- output = Path(output)
211
-
212
- downloaded_filename, file_size = download_manager.download(
213
- url, output, normalize_key(remote_name or id), fetch_metadata=True
184
+ raise NotImplementedError("Function not yet implemented")
185
+ # url = f"finetune/download?ft_id={id}"
186
+ #
187
+ # if checkpoint_step > 0:
188
+ # url += f"&checkpoint_step={checkpoint_step}"
189
+ #
190
+ # remote_name = self.retrieve(id).output_name
191
+ #
192
+ # download_manager = DownloadManager(self._client)
193
+ #
194
+ # if isinstance(output, str):
195
+ # output = Path(output)
196
+ #
197
+ # downloaded_filename, file_size = download_manager.download(
198
+ # url, output, normalize_key(remote_name or id), fetch_metadata=True
199
+ # )
200
+ #
201
+ # return FinetuneDownloadResult(
202
+ # object="local",
203
+ # id=id,
204
+ # checkpoint_step=checkpoint_step,
205
+ # filename=downloaded_filename,
206
+ # size=file_size,
207
+ # )
208
+
209
+ def promote(self, id: str) -> FinetuneListEvents:
210
+ requestor = api_requestor.APIRequestor(
211
+ client=self._client,
214
212
  )
215
213
 
216
- return FinetuneDownloadResult(
217
- object="local",
218
- id=id,
219
- checkpoint_step=checkpoint_step,
220
- filename=downloaded_filename,
221
- size=file_size,
214
+ response, _, _ = requestor.request(
215
+ options=SeekrFlowRequest(
216
+ method="GET",
217
+ url=f"flow/fine-tunes/{id}/promote-model",
218
+ params={"fine_tune_id": id},
219
+ ),
220
+ stream=False,
222
221
  )
223
222
 
223
+ assert isinstance(response, SeekrFlowResponse)
224
+
225
+ return FinetuneListEvents(**response.data)
224
226
 
225
- def promote(self, id: str):
227
+ def demote(self, id: str) -> FinetuneListEvents:
226
228
  requestor = api_requestor.APIRequestor(
227
229
  client=self._client,
228
230
  )
@@ -230,13 +232,12 @@ class FineTuning:
230
232
  response, _, _ = requestor.request(
231
233
  options=SeekrFlowRequest(
232
234
  method="GET",
233
- url=f"fine-tunes/{id}/promote-model",
235
+ url=f"flow/fine-tunes/{id}/demote-model",
234
236
  params={"fine_tune_id": id},
235
237
  ),
236
238
  stream=False,
237
239
  )
238
240
 
239
-
240
241
  assert isinstance(response, SeekrFlowResponse)
241
242
 
242
243
  return FinetuneListEvents(**response.data)
@@ -249,32 +250,13 @@ class AsyncFineTuning:
249
250
  async def create(
250
251
  self,
251
252
  *,
252
- training_file: str,
253
- model: str,
254
- n_epochs: int = 1,
255
- n_checkpoints: int | None = 1,
256
- batch_size: int | None = 32,
257
- learning_rate: float = 0.00001,
258
- suffix: str | None = None,
259
- wandb_api_key: str | None = None,
253
+ training_config: TrainingConfig,
254
+ infrastructure_config: InfrastructureConfig,
260
255
  ) -> FinetuneResponse:
261
256
  """
262
257
  Async method to initiate a fine-tuning job
263
258
 
264
259
  Args:
265
- training_file (str): File-ID of a file uploaded to the SeekrFlow API
266
- model (str): Name of the base model to run fine-tune job on
267
- n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
268
- n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
269
- Defaults to 1.
270
- batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
271
- learning_rate (float, optional): Learning rate multiplier to use for training
272
- Defaults to 0.00001.
273
- suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
274
- Defaults to None.
275
- wandb_api_key (str, optional): API key for Weights & Biases integration.
276
- Defaults to None.
277
-
278
260
  Returns:
279
261
  FinetuneResponse: Object containing information about fine-tuning job.
280
262
  """
@@ -284,20 +266,13 @@ class AsyncFineTuning:
284
266
  )
285
267
 
286
268
  parameter_payload = FinetuneRequest(
287
- model=model,
288
- training_file=training_file,
289
- n_epochs=n_epochs,
290
- n_checkpoints=n_checkpoints,
291
- batch_size=batch_size,
292
- learning_rate=learning_rate,
293
- suffix=suffix,
294
- wandb_key=wandb_api_key,
269
+ training_config=training_config, infrastructure_config=infrastructure_config
295
270
  ).model_dump()
296
271
 
297
272
  response, _, _ = await requestor.arequest(
298
273
  options=SeekrFlowRequest(
299
274
  method="POST",
300
- url="fine-tunes",
275
+ url="flow/fine-tunes",
301
276
  params=parameter_payload,
302
277
  ),
303
278
  stream=False,
@@ -322,7 +297,7 @@ class AsyncFineTuning:
322
297
  response, _, _ = await requestor.arequest(
323
298
  options=SeekrFlowRequest(
324
299
  method="GET",
325
- url="fine-tunes",
300
+ url="flow/fine-tunes",
326
301
  ),
327
302
  stream=False,
328
303
  )
@@ -349,7 +324,7 @@ class AsyncFineTuning:
349
324
  response, _, _ = await requestor.arequest(
350
325
  options=SeekrFlowRequest(
351
326
  method="GET",
352
- url=f"fine-tunes/{id}",
327
+ url=f"flow/fine-tunes/{id}",
353
328
  ),
354
329
  stream=False,
355
330
  )
@@ -376,7 +351,7 @@ class AsyncFineTuning:
376
351
  response, _, _ = await requestor.arequest(
377
352
  options=SeekrFlowRequest(
378
353
  method="POST",
379
- url=f"fine-tunes/{id}/cancel",
354
+ url=f"flow/fine-tunes/{id}/cancel",
380
355
  ),
381
356
  stream=False,
382
357
  )
@@ -403,7 +378,7 @@ class AsyncFineTuning:
403
378
  response, _, _ = await requestor.arequest(
404
379
  options=SeekrFlowRequest(
405
380
  method="GET",
406
- url=f"fine-tunes/{id}/events",
381
+ url=f"flow/fine-tunes/{id}/events",
407
382
  ),
408
383
  stream=False,
409
384
  )
@@ -72,7 +72,7 @@ class Images:
72
72
  response, _, _ = requestor.request(
73
73
  options=SeekrFlowRequest(
74
74
  method="POST",
75
- url="images/generations",
75
+ url="inference/images/generations",
76
76
  params=parameter_payload,
77
77
  ),
78
78
  stream=False,
@@ -145,7 +145,7 @@ class AsyncImages:
145
145
  response, _, _ = await requestor.arequest(
146
146
  options=SeekrFlowRequest(
147
147
  method="POST",
148
- url="images/generations",
148
+ url="inference/images/generations",
149
149
  params=parameter_payload,
150
150
  ),
151
151
  stream=False,
@@ -1,28 +1,62 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import List
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, List
6
+
7
+ from tqdm import tqdm
4
8
 
5
9
  from seekrai.abstract import api_requestor
10
+ from seekrai.constants import DISABLE_TQDM
6
11
  from seekrai.seekrflow_response import SeekrFlowResponse
7
- from seekrai.types import (
8
- ModelObject,
9
- SeekrFlowClient,
10
- SeekrFlowRequest,
11
- )
12
+ from seekrai.types import ModelList, ModelResponse, SeekrFlowClient, SeekrFlowRequest
13
+ from seekrai.types.models import ModelType
12
14
 
13
15
 
14
16
  class Models:
15
17
  def __init__(self, client: SeekrFlowClient) -> None:
16
18
  self._client = client
17
19
 
20
+ def upload(
21
+ self,
22
+ file: Path | str,
23
+ *,
24
+ model_type: ModelType | str = ModelType.OBJECT_DETECTION,
25
+ ) -> ModelResponse:
26
+ if isinstance(file, str):
27
+ file = Path(file)
28
+
29
+ requestor = api_requestor.APIRequestor(
30
+ client=self._client,
31
+ )
32
+ file_size = os.stat(file.as_posix()).st_size
33
+
34
+ with tqdm(
35
+ total=file_size,
36
+ unit="B",
37
+ unit_scale=True,
38
+ desc=f"Uploading file {file.name}",
39
+ disable=bool(DISABLE_TQDM),
40
+ ):
41
+ with file.open("rb") as f:
42
+ response, _, _ = requestor.request(
43
+ options=SeekrFlowRequest(
44
+ method="PUT",
45
+ url="flow/pt-models",
46
+ files={"files": f, "filename": file.name},
47
+ params={"purpose": model_type},
48
+ ),
49
+ )
50
+ return ModelResponse(**response.data)
51
+
18
52
  def list(
19
53
  self,
20
- ) -> List[ModelObject]:
54
+ ) -> ModelList:
21
55
  """
22
56
  Method to return list of models on the API
23
57
 
24
58
  Returns:
25
- List[ModelObject]: List of model objects
59
+ List[ModelResponse]: List of model objects
26
60
  """
27
61
 
28
62
  requestor = api_requestor.APIRequestor(
@@ -32,15 +66,81 @@ class Models:
32
66
  response, _, _ = requestor.request(
33
67
  options=SeekrFlowRequest(
34
68
  method="GET",
35
- url="models",
69
+ url="flow/pt-models",
36
70
  ),
37
71
  stream=False,
38
72
  )
39
73
 
40
74
  assert isinstance(response, SeekrFlowResponse)
41
- assert isinstance(response.data, list)
75
+ return ModelList(**response.data)
76
+
77
+ def promote(self, id: str) -> ModelResponse:
78
+ requestor = api_requestor.APIRequestor(
79
+ client=self._client,
80
+ )
81
+
82
+ response, _, _ = requestor.request(
83
+ options=SeekrFlowRequest(
84
+ method="GET",
85
+ url=f"flow/pt-models/{id}/promote-model",
86
+ params={"model_id": id},
87
+ ),
88
+ stream=False,
89
+ )
90
+
91
+ assert isinstance(response, SeekrFlowResponse)
92
+
93
+ return ModelResponse(**response.data)
94
+
95
+ def demote(self, id: str) -> ModelResponse:
96
+ requestor = api_requestor.APIRequestor(
97
+ client=self._client,
98
+ )
99
+
100
+ response, _, _ = requestor.request(
101
+ options=SeekrFlowRequest(
102
+ method="GET",
103
+ url=f"flow/pt-models/{id}/demote-model",
104
+ params={"model_id": id},
105
+ ),
106
+ stream=False,
107
+ )
108
+
109
+ assert isinstance(response, SeekrFlowResponse)
110
+
111
+ return ModelResponse(**response.data)
112
+
113
+ def predict(self, id: str, file: Path | str) -> Any:
114
+ requestor = api_requestor.APIRequestor(
115
+ client=self._client,
116
+ )
117
+
118
+ if isinstance(file, str):
119
+ file = Path(file)
120
+
121
+ file_size = os.stat(file.as_posix()).st_size
122
+
123
+ with tqdm(
124
+ total=file_size,
125
+ unit="B",
126
+ unit_scale=True,
127
+ desc=f"Uploading file {file.name}",
128
+ disable=bool(DISABLE_TQDM),
129
+ ):
130
+ with file.open("rb") as f:
131
+ response, _, _ = requestor.request(
132
+ options=SeekrFlowRequest(
133
+ method="POST",
134
+ url="flow/pt-models/predict",
135
+ files={"files": f, "filename": file.name},
136
+ params={"model_id": id},
137
+ ),
138
+ stream=False,
139
+ )
140
+
141
+ assert isinstance(response, SeekrFlowResponse)
42
142
 
43
- return [ModelObject(**model) for model in response.data]
143
+ return response.data
44
144
 
45
145
 
46
146
  class AsyncModels:
@@ -49,12 +149,12 @@ class AsyncModels:
49
149
 
50
150
  async def list(
51
151
  self,
52
- ) -> List[ModelObject]:
152
+ ) -> List[ModelResponse]:
53
153
  """
54
154
  Async method to return list of models on API
55
155
 
56
156
  Returns:
57
- List[ModelObject]: List of model objects
157
+ List[ModelResponse]: List of model objects
58
158
  """
59
159
 
60
160
  requestor = api_requestor.APIRequestor(
@@ -64,7 +164,7 @@ class AsyncModels:
64
164
  response, _, _ = await requestor.arequest(
65
165
  options=SeekrFlowRequest(
66
166
  method="GET",
67
- url="models",
167
+ url="flow/models",
68
168
  ),
69
169
  stream=False,
70
170
  )
@@ -72,4 +172,4 @@ class AsyncModels:
72
172
  assert isinstance(response, SeekrFlowResponse)
73
173
  assert isinstance(response.data, list)
74
174
 
75
- return [ModelObject(**model) for model in response.data]
175
+ return [ModelResponse(**model) for model in response.data]
seekrai/types/__init__.py CHANGED
@@ -26,14 +26,14 @@ from seekrai.types.finetune import (
26
26
  FinetuneListEvents,
27
27
  FinetuneRequest,
28
28
  FinetuneResponse,
29
- InfrastructureConfig,
30
- TrainingConfig,
29
+ InfrastructureConfig,
30
+ TrainingConfig,
31
31
  )
32
32
  from seekrai.types.images import (
33
33
  ImageRequest,
34
34
  ImageResponse,
35
35
  )
36
- from seekrai.types.models import ModelObject
36
+ from seekrai.types.models import ModelList, ModelResponse
37
37
 
38
38
 
39
39
  __all__ = [
@@ -63,5 +63,6 @@ __all__ = [
63
63
  "FileType",
64
64
  "ImageRequest",
65
65
  "ImageResponse",
66
- "ModelObject",
66
+ "ModelResponse",
67
+ "ModelList",
67
68
  ]
seekrai/types/common.py CHANGED
@@ -4,7 +4,6 @@ from enum import Enum
4
4
  from typing import Any, Dict, List
5
5
 
6
6
  from pydantic import ConfigDict
7
- from tqdm.utils import CallbackIOWrapper
8
7
 
9
8
  from seekrai.types.abstract import BaseModel
10
9
 
@@ -58,7 +57,7 @@ class SeekrFlowRequest(BaseModel):
58
57
  method: str
59
58
  url: str
60
59
  headers: Dict[str, str] | None = None
61
- params: Dict[str, Any] | CallbackIOWrapper | None = None
60
+ params: Dict[str, Any] | None = None
62
61
  files: Dict[str, Any] | None = None
63
62
  allow_redirects: bool = True
64
63
  override_headers: bool = False
seekrai/types/files.py CHANGED
@@ -1,23 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from datetime import datetime
3
4
  from enum import Enum
4
5
  from typing import List, Literal
5
6
 
6
- from pydantic import Field
7
-
8
7
  from seekrai.types.abstract import BaseModel
9
8
  from seekrai.types.common import (
10
9
  ObjectType,
11
10
  )
12
- from datetime import datetime
11
+
13
12
 
14
13
  class FilePurpose(str, Enum):
15
14
  FineTune = "fine-tune"
15
+ PreTrain = "pre-train"
16
16
 
17
17
 
18
18
  class FileType(str, Enum):
19
19
  jsonl = "jsonl"
20
20
  parquet = "parquet"
21
+ pytorch = "pt"
21
22
 
22
23
 
23
24
  class FileRequest(BaseModel):
@@ -25,22 +26,25 @@ class FileRequest(BaseModel):
25
26
  Files request type
26
27
  """
27
28
 
28
- # training file ID
29
- training_file: str
30
- # base model string
31
- model: str
32
- # number of epochs to train for
33
- n_epochs: int
34
- # training learning rate
35
- learning_rate: float
36
- # number of checkpoints to save
37
- n_checkpoints: int | None = None
38
- # training batch size
39
- batch_size: int | None = None
40
- # up to 40 character suffix for output model name
41
- suffix: str | None = None
42
- # weights & biases api key
43
- wandb_api_key: str | None = None
29
+ # # training file ID
30
+ # training_file: str
31
+ # # base model string
32
+ # model: str
33
+ # # number of epochs to train for
34
+ # n_epochs: int
35
+ # # training learning rate
36
+ # learning_rate: float
37
+ # # number of checkpoints to save
38
+ # n_checkpoints: int | None = None
39
+ # # training batch size
40
+ # batch_size: int | None = None
41
+ # # up to 40 character suffix for output model name
42
+ # suffix: str | None = None
43
+ # # weights & biases api key
44
+ # wandb_api_key: str | None = None
45
+ purpose: FilePurpose
46
+ filetype: FileType
47
+ filename: str
44
48
 
45
49
 
46
50
  class FileResponse(BaseModel):
seekrai/types/finetune.py CHANGED
@@ -1,16 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from datetime import datetime
3
4
  from enum import Enum
4
5
  from typing import List, Literal
5
6
 
6
- from pydantic import Field
7
-
8
7
  from seekrai.types.abstract import BaseModel
9
8
  from seekrai.types.common import (
10
9
  ObjectType,
11
10
  )
12
- from datetime import datetime
13
-
14
11
 
15
12
 
16
13
  class FinetuneJobStatus(str, Enum):
@@ -83,22 +80,10 @@ class FinetuneEvent(BaseModel):
83
80
  # object type
84
81
  object: Literal[ObjectType.FinetuneEvent]
85
82
  # created at datetime stamp
86
- created_at: str | None = None
87
- # event log level
88
- level: FinetuneEventLevels | None = None
89
- # event message string
90
- message: str | None = None
91
- # event type
92
- type: FinetuneEventType | None = None
93
- # optional: model parameter count
94
- param_count: int | None = None
95
- # optional: dataset token count
96
- token_count: int | None = None
97
- # optional: weights & biases url
98
- wandb_url: str | None = None
99
- # event hash
100
- hash: str | None = None
101
-
83
+ created_at: datetime | None = None
84
+ # metrics that we expose
85
+ loss: float | None = None
86
+ epoch: float | None = None
102
87
 
103
88
 
104
89
  class TrainingConfig(BaseModel):
@@ -118,22 +103,31 @@ class TrainingConfig(BaseModel):
118
103
  experiment_name: str | None = None
119
104
  # # weights & biases api key
120
105
  # wandb_key: str | None = None
106
+ # IFT by default
107
+ pre_train: bool = False
108
+
109
+
110
+ class AcceleratorType(str, Enum):
111
+ GAUDI2 = "GAUDI2"
112
+ GAUDI3 = "GAUDI3"
113
+ A100 = "A100"
114
+ H100 = "H100"
115
+
121
116
 
122
117
  class InfrastructureConfig(BaseModel):
123
- n_cpu: int
124
- n_gpu: int
118
+ accel_type: AcceleratorType
119
+ n_accel: int
120
+
125
121
 
126
122
  class FinetuneRequest(BaseModel):
127
123
  """
128
124
  Fine-tune request type
129
125
  """
126
+
130
127
  training_config: TrainingConfig
131
128
  infrastructure_config: InfrastructureConfig
132
129
 
133
130
 
134
-
135
-
136
-
137
131
  class FinetuneResponse(BaseModel):
138
132
  """
139
133
  Fine-tune API response type
@@ -170,7 +164,7 @@ class FinetuneResponse(BaseModel):
170
164
 
171
165
  # list of fine-tune events
172
166
  events: List[FinetuneEvent] | None = None
173
- inference_available: bool = False
167
+ inference_available: bool = False
174
168
  # dataset token count
175
169
  # TODO
176
170
  # token_count: int | None = None