seekrai 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- seekrai/types/finetune.py +206 -3
- {seekrai-0.5.26.dist-info → seekrai-0.5.28.dist-info}/METADATA +1 -1
- {seekrai-0.5.26.dist-info → seekrai-0.5.28.dist-info}/RECORD +6 -6
- {seekrai-0.5.26.dist-info → seekrai-0.5.28.dist-info}/WHEEL +1 -1
- {seekrai-0.5.26.dist-info → seekrai-0.5.28.dist-info}/entry_points.txt +0 -0
- {seekrai-0.5.26.dist-info → seekrai-0.5.28.dist-info}/licenses/LICENSE +0 -0
seekrai/types/finetune.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from datetime import datetime
|
|
4
5
|
from enum import Enum
|
|
5
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Optional
|
|
6
7
|
|
|
7
|
-
from pydantic import Field
|
|
8
|
+
from pydantic import Field, field_validator, model_serializer, model_validator
|
|
8
9
|
|
|
9
10
|
from seekrai.types.abstract import BaseModel
|
|
10
11
|
from seekrai.types.common import (
|
|
11
12
|
ObjectType,
|
|
12
13
|
)
|
|
14
|
+
from seekrai.utils._log import log_info
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class FinetuneJobStatus(str, Enum):
|
|
@@ -77,8 +79,28 @@ class FinetuneEventType(str, Enum):
|
|
|
77
79
|
|
|
78
80
|
class FineTuneType(str, Enum):
|
|
79
81
|
STANDARD = "STANDARD"
|
|
82
|
+
GRPO = "GRPO" # deprecated
|
|
80
83
|
PREFERENCE = "PREFERENCE"
|
|
81
|
-
|
|
84
|
+
REINFORCEMENT = "REINFORCEMENT"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class GraderType(str, Enum):
|
|
88
|
+
FORMAT_CHECK = "format_check"
|
|
89
|
+
MATH_ACCURACY = "math_accuracy"
|
|
90
|
+
STRING_CHECK = "string_check"
|
|
91
|
+
TEXT_SIMILARITY = "text_similarity"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class StringOperation(str, Enum):
|
|
95
|
+
EQUALS = "equals"
|
|
96
|
+
NOT_EQUALS = "not_equals"
|
|
97
|
+
CONTAINS = "contains"
|
|
98
|
+
CASE_INSENSITIVE_CONTAINS = "case_insensitive_contains"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TextSimilarityOperation(str, Enum):
|
|
102
|
+
BLEU = "bleu"
|
|
103
|
+
ROUGE = "rouge"
|
|
82
104
|
|
|
83
105
|
|
|
84
106
|
class FinetuneEvent(BaseModel):
|
|
@@ -93,6 +115,17 @@ class FinetuneEvent(BaseModel):
|
|
|
93
115
|
# metrics that we expose
|
|
94
116
|
loss: float | None = None
|
|
95
117
|
epoch: float | None = None
|
|
118
|
+
reward: float | None = None
|
|
119
|
+
|
|
120
|
+
@model_serializer(mode="wrap")
|
|
121
|
+
def serialize_model(
|
|
122
|
+
self, handler: Callable[[Any], dict[str, Any]]
|
|
123
|
+
) -> dict[str, Any]:
|
|
124
|
+
# Remove 'reward' if it's None
|
|
125
|
+
dump_dict = handler(self)
|
|
126
|
+
if dump_dict.get("reward") is None:
|
|
127
|
+
del dump_dict["reward"]
|
|
128
|
+
return dump_dict
|
|
96
129
|
|
|
97
130
|
|
|
98
131
|
class LoRAConfig(BaseModel):
|
|
@@ -111,6 +144,126 @@ class LoRAConfig(BaseModel):
|
|
|
111
144
|
extras: Dict[str, Any] = Field(default_factory=dict)
|
|
112
145
|
|
|
113
146
|
|
|
147
|
+
class Grader(BaseModel):
|
|
148
|
+
type: GraderType
|
|
149
|
+
weight: float | None = Field(default=None, gt=0.0, le=1.0)
|
|
150
|
+
operation: StringOperation | TextSimilarityOperation | None = Field(default=None)
|
|
151
|
+
|
|
152
|
+
@model_validator(mode="before")
|
|
153
|
+
@classmethod
|
|
154
|
+
def validate_operation(cls, data: Any) -> Any:
|
|
155
|
+
if not isinstance(data, dict):
|
|
156
|
+
return data
|
|
157
|
+
|
|
158
|
+
grader_type = data.get("type")
|
|
159
|
+
operation_value = data.get("operation")
|
|
160
|
+
|
|
161
|
+
if grader_type == GraderType.STRING_CHECK:
|
|
162
|
+
if not operation_value:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"string_check grader is missing required StringOperation"
|
|
165
|
+
)
|
|
166
|
+
if isinstance(operation_value, str):
|
|
167
|
+
try:
|
|
168
|
+
# Convert to enum to validate it's a valid value
|
|
169
|
+
data["operation"] = StringOperation(operation_value.lower())
|
|
170
|
+
except ValueError:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Invalid operation for string_check grader: "
|
|
173
|
+
f"expected StringOperation, but got type '{type(operation_value).__name__}' with value '{operation_value}'"
|
|
174
|
+
)
|
|
175
|
+
elif grader_type == GraderType.TEXT_SIMILARITY:
|
|
176
|
+
if not operation_value:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"text_similarity grader is missing required TextSimilarityOperation"
|
|
179
|
+
)
|
|
180
|
+
if isinstance(operation_value, str):
|
|
181
|
+
try:
|
|
182
|
+
data["operation"] = TextSimilarityOperation(operation_value.lower())
|
|
183
|
+
except ValueError:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Invalid operation for text_similarity grader: "
|
|
186
|
+
f"expected TextSimilarityOperation, got type '{type(operation_value).__name__}' with value '{operation_value}'"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
elif grader_type in (GraderType.FORMAT_CHECK, GraderType.MATH_ACCURACY):
|
|
190
|
+
if operation_value:
|
|
191
|
+
raise ValueError(f"{grader_type} grader cannot have an operation")
|
|
192
|
+
data["operation"] = None
|
|
193
|
+
|
|
194
|
+
return data
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class RewardComponents(BaseModel):
|
|
198
|
+
format_reward_weight: float = Field(default=0.1, gt=0.0, le=1.0)
|
|
199
|
+
graders: list[Grader] = Field(min_length=1)
|
|
200
|
+
|
|
201
|
+
@model_validator(mode="after")
|
|
202
|
+
def validate_weights(self) -> "RewardComponents":
|
|
203
|
+
is_format_weight_specified = "format_reward_weight" in self.model_fields_set
|
|
204
|
+
|
|
205
|
+
grader_weights_specified = [
|
|
206
|
+
grader.weight is not None for grader in self.graders
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
all_graders_have_weights = all(grader_weights_specified)
|
|
210
|
+
some_graders_have_weights = any(grader_weights_specified) and not all(
|
|
211
|
+
grader_weights_specified
|
|
212
|
+
)
|
|
213
|
+
no_graders_have_weights = not any(grader_weights_specified)
|
|
214
|
+
|
|
215
|
+
if some_graders_have_weights:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Only some graders have weights specified. Either all graders must have weights specified, or none of them."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if all_graders_have_weights and is_format_weight_specified:
|
|
221
|
+
self._validate_weights_sum_to_one()
|
|
222
|
+
|
|
223
|
+
elif all_graders_have_weights and not is_format_weight_specified:
|
|
224
|
+
self._normalize_grader_weights()
|
|
225
|
+
|
|
226
|
+
elif no_graders_have_weights:
|
|
227
|
+
self._initialize_grader_weights()
|
|
228
|
+
self._normalize_grader_weights()
|
|
229
|
+
|
|
230
|
+
return self
|
|
231
|
+
|
|
232
|
+
def _validate_weights_sum_to_one(self) -> None:
|
|
233
|
+
"""Validate that format_reward_weight and grader weights sum to 1.0"""
|
|
234
|
+
total_weight = self.format_reward_weight + sum( # type: ignore[operator]
|
|
235
|
+
grader.weight # type: ignore[misc]
|
|
236
|
+
for grader in self.graders
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if abs(total_weight - 1.0) > 1e-10:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"When all weights are explicitly provided, they must sum to 1.0. "
|
|
242
|
+
f"Got format_reward_weight={self.format_reward_weight}, "
|
|
243
|
+
f"graders={self.graders}"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def _normalize_grader_weights(self) -> None:
|
|
247
|
+
"""Normalize only grader weights to fill (1 - format_reward_weight)"""
|
|
248
|
+
total_grader_weight = sum(grader.weight for grader in self.graders) # type: ignore[misc]
|
|
249
|
+
target_grader_total = 1.0 - self.format_reward_weight
|
|
250
|
+
|
|
251
|
+
# only normalize if weights aren't already properly normalized
|
|
252
|
+
if abs(total_grader_weight - target_grader_total) > 1e-10:
|
|
253
|
+
scale_factor = target_grader_total / total_grader_weight
|
|
254
|
+
for grader in self.graders:
|
|
255
|
+
original_weight = grader.weight
|
|
256
|
+
grader.weight *= scale_factor # type: ignore[operator]
|
|
257
|
+
log_info(
|
|
258
|
+
f"{grader.type}'s weight scaled from {original_weight} to {grader.weight:.2f}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _initialize_grader_weights(self) -> None:
|
|
262
|
+
"""Initialize all grader weights when none are provided"""
|
|
263
|
+
for grader in self.graders:
|
|
264
|
+
grader.weight = 1.0
|
|
265
|
+
|
|
266
|
+
|
|
114
267
|
class TrainingConfig(BaseModel):
|
|
115
268
|
# training file ID
|
|
116
269
|
training_files: List[str]
|
|
@@ -136,6 +289,45 @@ class TrainingConfig(BaseModel):
|
|
|
136
289
|
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
137
290
|
# LoRA config
|
|
138
291
|
lora_config: Optional[LoRAConfig] = None
|
|
292
|
+
# reward_components are REINFORCEMENT-specific
|
|
293
|
+
reward_components: Optional[RewardComponents] = None
|
|
294
|
+
|
|
295
|
+
@model_validator(mode="after")
|
|
296
|
+
def validate_reward_components(self) -> "TrainingConfig":
|
|
297
|
+
# TODO: re-enable the below and make reward_components required for REINFORCEMENT. Disabled for now for backwards-compatibility
|
|
298
|
+
# if (
|
|
299
|
+
# self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
|
|
300
|
+
# and not self.reward_components
|
|
301
|
+
# ):
|
|
302
|
+
# raise ValueError("REINFORCEMENT fine-tuning requires reward components")
|
|
303
|
+
if (
|
|
304
|
+
self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
|
|
305
|
+
and not self.reward_components
|
|
306
|
+
):
|
|
307
|
+
self.reward_components = RewardComponents(
|
|
308
|
+
format_reward_weight=0.1,
|
|
309
|
+
graders=[Grader(type=GraderType.MATH_ACCURACY, weight=0.9)],
|
|
310
|
+
)
|
|
311
|
+
if self.fine_tune_type == FineTuneType.STANDARD and self.reward_components:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
"Reward components are incompatible with standard fine-tuning"
|
|
314
|
+
)
|
|
315
|
+
if self.fine_tune_type == FineTuneType.PREFERENCE and self.reward_components:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
"Reward components are incompatible with preference fine-tuning"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return self
|
|
321
|
+
|
|
322
|
+
@field_validator("fine_tune_type")
|
|
323
|
+
def validate_fine_tune_type(cls, v: Any) -> Any:
|
|
324
|
+
if v == FineTuneType.GRPO:
|
|
325
|
+
warnings.warn(
|
|
326
|
+
"FineTuneType.GRPO is deprecated and will be removed in a future version. Use FineTuneType.REINFORCEMENT",
|
|
327
|
+
DeprecationWarning,
|
|
328
|
+
stacklevel=2,
|
|
329
|
+
)
|
|
330
|
+
return v
|
|
139
331
|
|
|
140
332
|
|
|
141
333
|
class AcceleratorType(str, Enum):
|
|
@@ -172,6 +364,7 @@ class FinetuneResponse(BaseModel):
|
|
|
172
364
|
id: str | None = None
|
|
173
365
|
# fine-tune type
|
|
174
366
|
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
367
|
+
reward_components: Optional[RewardComponents] = None
|
|
175
368
|
# training file id
|
|
176
369
|
training_files: List[str] | None = None
|
|
177
370
|
# validation file id
|
|
@@ -228,6 +421,16 @@ class FinetuneResponse(BaseModel):
|
|
|
228
421
|
# training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
229
422
|
# training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
230
423
|
|
|
424
|
+
@model_serializer(mode="wrap")
|
|
425
|
+
def serialize_model(
|
|
426
|
+
self, handler: Callable[[Any], dict[str, Any]]
|
|
427
|
+
) -> dict[str, Any]:
|
|
428
|
+
# Remove 'reward_components' if it's None
|
|
429
|
+
dump_dict = handler(self)
|
|
430
|
+
if dump_dict.get("reward_components") is None:
|
|
431
|
+
del dump_dict["reward_components"]
|
|
432
|
+
return dump_dict
|
|
433
|
+
|
|
231
434
|
|
|
232
435
|
class FinetuneList(BaseModel):
|
|
233
436
|
# object type
|
|
@@ -59,7 +59,7 @@ seekrai/types/enums.py,sha256=sQ1CW-ctbhpV2jM1cEAEy7ZUdzZa0IC85YvycjvudHE,633
|
|
|
59
59
|
seekrai/types/error.py,sha256=uTKISs9aRC4_6zwirtNkanxepN8KY-SqCq0kNbfZylQ,370
|
|
60
60
|
seekrai/types/explainability.py,sha256=Ih-8hCm5r22EMMtr83cDy8vePo7_Ik7UdUcXhsj5Zm0,835
|
|
61
61
|
seekrai/types/files.py,sha256=kOy4s8D4tlsenyWmiiEyAS0jDAdxMScBu5j1GwQCf3E,2808
|
|
62
|
-
seekrai/types/finetune.py,sha256
|
|
62
|
+
seekrai/types/finetune.py,sha256=-dRSjRqJVu2-dEfykOJYTuuzt6Ok1nx91gJzQ_WAqEU,15341
|
|
63
63
|
seekrai/types/images.py,sha256=Fusj8OhVYFsT8kz636lRGGivLbPXo_ZNgakKwmzJi3U,914
|
|
64
64
|
seekrai/types/ingestion.py,sha256=uUdKOR4xqSfAXWQOR1UOltSlOnuyAwKVA1Q2a6Yslk8,919
|
|
65
65
|
seekrai/types/models.py,sha256=9Z0nvLdlAfpF8mNRW5-IqBdDHoE-3qQ5przmIDJgwLo,1345
|
|
@@ -72,8 +72,8 @@ seekrai/utils/api_helpers.py,sha256=0Y8BblNIr9h_R12zdmhkxgTlxgoRkbq84QNi4nNWGu8,
|
|
|
72
72
|
seekrai/utils/files.py,sha256=7ixn_hgV-6pEhYqLyOp-EN0o8c1CzUwJzX9n3PQ5oqo,7164
|
|
73
73
|
seekrai/utils/tools.py,sha256=jgJTL-dOIouDbEJLdQpQfpXhqaz_poQYS52adyUtBjo,1781
|
|
74
74
|
seekrai/version.py,sha256=q6iGQVFor8zXiPP5F-3vy9TndOxKv5JXbaNJ2kdOQws,125
|
|
75
|
-
seekrai-0.5.
|
|
76
|
-
seekrai-0.5.
|
|
77
|
-
seekrai-0.5.
|
|
78
|
-
seekrai-0.5.
|
|
79
|
-
seekrai-0.5.
|
|
75
|
+
seekrai-0.5.28.dist-info/METADATA,sha256=LQLnopWdiNd6l3OJP92tv-DzgJWOa_jw_0lbXooI2JE,4788
|
|
76
|
+
seekrai-0.5.28.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
77
|
+
seekrai-0.5.28.dist-info/entry_points.txt,sha256=N49yOEGi1sK7Xr13F_rkkcOxQ88suyiMoOmRhUHTZ_U,48
|
|
78
|
+
seekrai-0.5.28.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
79
|
+
seekrai-0.5.28.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|