tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
tuft/__init__.py
CHANGED
tuft/auth.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Authentication utilities for TuFT server.
|
|
2
|
+
|
|
3
|
+
The current implementation is only for demonstration purposes and
|
|
4
|
+
should be replaced with a proper authentication system in the future.
|
|
5
|
+
Planned improvements:
|
|
6
|
+
|
|
7
|
+
1. persistent storage
|
|
8
|
+
2. API key hashing (store hashed key instead of actual keys in persistent storage)
|
|
9
|
+
3. API key format with format validation to avoid hitting db every time
|
|
10
|
+
4. API key expiry
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class User:
|
|
15
|
+
"""A simple user representation. Enhance it in the future."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, user_id: str):
|
|
18
|
+
self.user_id = user_id
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AuthenticationDB:
|
|
22
|
+
"""A simple in-memory authentication database.
|
|
23
|
+
It maps API keys to User instances.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, authorized_users: dict[str, str]):
|
|
27
|
+
"""Initialize the authentication database."""
|
|
28
|
+
self.authorized_users = authorized_users
|
|
29
|
+
|
|
30
|
+
def authenticate(self, api_key: str) -> User | None:
|
|
31
|
+
"""Authenticate a user by API key."""
|
|
32
|
+
user_id = self.authorized_users.get(api_key)
|
|
33
|
+
if user_id:
|
|
34
|
+
return User(user_id)
|
|
35
|
+
return None
|
tuft/backend.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Toy backend implementations used by the local TuFT server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Sequence
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tinker import types
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _safe_mean(values: Sequence[int]) -> float:
|
|
14
|
+
if not values:
|
|
15
|
+
return 0.0
|
|
16
|
+
return float(sum(values) / len(values))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ModelBackend:
|
|
21
|
+
"""Deterministic toy backend with lightweight gradient/optimizer tracking."""
|
|
22
|
+
|
|
23
|
+
base_model: str
|
|
24
|
+
lora_rank: int
|
|
25
|
+
seed: int = 0
|
|
26
|
+
hidden_dim: int = 16
|
|
27
|
+
|
|
28
|
+
_lock: threading.Lock = field(init=False, repr=False)
|
|
29
|
+
_weights: np.ndarray = field(init=False, repr=False)
|
|
30
|
+
_adam_m: np.ndarray = field(init=False, repr=False)
|
|
31
|
+
_adam_v: np.ndarray = field(init=False, repr=False)
|
|
32
|
+
_beta1_power: float = field(init=False, default=1.0, repr=False)
|
|
33
|
+
_beta2_power: float = field(init=False, default=1.0, repr=False)
|
|
34
|
+
_pending_grad: np.ndarray | None = field(init=False, default=None, repr=False)
|
|
35
|
+
_pending_examples: int = field(init=False, default=0, repr=False)
|
|
36
|
+
_embedding_cache: dict[int, np.ndarray] = field(init=False, default_factory=dict, repr=False)
|
|
37
|
+
step: int = field(init=False, default=0)
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
rng = np.random.default_rng(self.seed or 0)
|
|
41
|
+
self._lock = threading.Lock()
|
|
42
|
+
self._weights = rng.standard_normal(self.hidden_dim, dtype=np.float32)
|
|
43
|
+
self._adam_m = np.zeros_like(self._weights)
|
|
44
|
+
self._adam_v = np.zeros_like(self._weights)
|
|
45
|
+
|
|
46
|
+
# ------------------------------------------------------------------
|
|
47
|
+
# Forward / backward helpers
|
|
48
|
+
# ------------------------------------------------------------------
|
|
49
|
+
def forward(
|
|
50
|
+
self, data: list[types.Datum], _: types.LossFnType, __: dict[str, float] | None
|
|
51
|
+
) -> types.ForwardBackwardOutput:
|
|
52
|
+
return self._run_step(data, backward=False)
|
|
53
|
+
|
|
54
|
+
def forward_backward(
|
|
55
|
+
self,
|
|
56
|
+
data: list[types.Datum],
|
|
57
|
+
_: types.LossFnType,
|
|
58
|
+
__: dict[str, float] | None,
|
|
59
|
+
) -> types.ForwardBackwardOutput:
|
|
60
|
+
return self._run_step(data, backward=True)
|
|
61
|
+
|
|
62
|
+
def _run_step(self, data: list[types.Datum], *, backward: bool) -> types.ForwardBackwardOutput:
|
|
63
|
+
outputs: list[types.LossFnOutput] = []
|
|
64
|
+
total_loss = 0.0
|
|
65
|
+
grad_accum = np.zeros_like(self._weights)
|
|
66
|
+
for datum in data:
|
|
67
|
+
prompt_tokens = datum.model_input.to_ints()
|
|
68
|
+
target_tokens = self._target_tokens(datum)
|
|
69
|
+
prompt_vec = self._vectorize(prompt_tokens)
|
|
70
|
+
target_scalar = self._target_scalar(target_tokens)
|
|
71
|
+
prediction = float(np.dot(self._weights, prompt_vec))
|
|
72
|
+
loss = (prediction - target_scalar) ** 2
|
|
73
|
+
total_loss += loss
|
|
74
|
+
if backward:
|
|
75
|
+
grad = 2 * (prediction - target_scalar) * prompt_vec
|
|
76
|
+
grad_accum += grad
|
|
77
|
+
logprob_tensor = types.TensorData(
|
|
78
|
+
data=[float(-abs(prediction - target_scalar))] * max(len(target_tokens), 1),
|
|
79
|
+
dtype="float32",
|
|
80
|
+
shape=[max(len(target_tokens), 1)],
|
|
81
|
+
)
|
|
82
|
+
outputs.append({"logprobs": logprob_tensor})
|
|
83
|
+
|
|
84
|
+
metrics = {
|
|
85
|
+
"loss:mean": total_loss / max(len(data), 1),
|
|
86
|
+
"step:max": float(self.step),
|
|
87
|
+
}
|
|
88
|
+
if backward:
|
|
89
|
+
grad_norm = float(np.linalg.norm(grad_accum) / max(len(data), 1))
|
|
90
|
+
metrics["grad_norm:mean"] = grad_norm
|
|
91
|
+
with self._lock:
|
|
92
|
+
if self._pending_grad is None:
|
|
93
|
+
self._pending_grad = grad_accum
|
|
94
|
+
else:
|
|
95
|
+
self._pending_grad += grad_accum
|
|
96
|
+
self._pending_examples += len(data)
|
|
97
|
+
|
|
98
|
+
return types.ForwardBackwardOutput(
|
|
99
|
+
loss_fn_output_type="ToyLoss",
|
|
100
|
+
loss_fn_outputs=outputs,
|
|
101
|
+
metrics=metrics,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# ------------------------------------------------------------------
|
|
105
|
+
# Optimizer
|
|
106
|
+
# ------------------------------------------------------------------
|
|
107
|
+
def optim_step(self, adam_params: types.AdamParams) -> types.OptimStepResponse:
|
|
108
|
+
with self._lock:
|
|
109
|
+
grad = self._pending_grad
|
|
110
|
+
examples = self._pending_examples
|
|
111
|
+
self._pending_grad = None
|
|
112
|
+
self._pending_examples = 0
|
|
113
|
+
|
|
114
|
+
if grad is None or not np.any(grad):
|
|
115
|
+
return types.OptimStepResponse(
|
|
116
|
+
metrics={
|
|
117
|
+
"learning_rate:mean": adam_params.learning_rate,
|
|
118
|
+
"step:max": float(self.step),
|
|
119
|
+
}
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
grad = grad / max(examples, 1)
|
|
123
|
+
if adam_params.grad_clip_norm > 0:
|
|
124
|
+
norm = np.linalg.norm(grad)
|
|
125
|
+
if norm > adam_params.grad_clip_norm:
|
|
126
|
+
grad *= adam_params.grad_clip_norm / max(norm, 1e-12)
|
|
127
|
+
|
|
128
|
+
if adam_params.weight_decay:
|
|
129
|
+
grad += adam_params.weight_decay * self._weights
|
|
130
|
+
|
|
131
|
+
beta1 = adam_params.beta1
|
|
132
|
+
beta2 = adam_params.beta2
|
|
133
|
+
|
|
134
|
+
self._adam_m = beta1 * self._adam_m + (1 - beta1) * grad
|
|
135
|
+
self._adam_v = beta2 * self._adam_v + (1 - beta2) * (grad**2)
|
|
136
|
+
self._beta1_power *= beta1
|
|
137
|
+
self._beta2_power *= beta2
|
|
138
|
+
m_hat = self._adam_m / (1 - self._beta1_power + 1e-12)
|
|
139
|
+
v_hat = self._adam_v / (1 - self._beta2_power + 1e-12)
|
|
140
|
+
|
|
141
|
+
update = adam_params.learning_rate * m_hat / (np.sqrt(v_hat) + adam_params.eps)
|
|
142
|
+
self._weights -= update
|
|
143
|
+
self.step += 1
|
|
144
|
+
|
|
145
|
+
metrics = {
|
|
146
|
+
"learning_rate:mean": adam_params.learning_rate,
|
|
147
|
+
"step:max": float(self.step),
|
|
148
|
+
"update_norm:mean": float(np.linalg.norm(update)),
|
|
149
|
+
}
|
|
150
|
+
return types.OptimStepResponse(metrics=metrics)
|
|
151
|
+
|
|
152
|
+
# ------------------------------------------------------------------
|
|
153
|
+
# Sampling
|
|
154
|
+
# ------------------------------------------------------------------
|
|
155
|
+
def sample(
|
|
156
|
+
self,
|
|
157
|
+
prompt: types.ModelInput,
|
|
158
|
+
num_samples: int,
|
|
159
|
+
sampling_params: types.SamplingParams,
|
|
160
|
+
include_prompt_logprobs: bool,
|
|
161
|
+
topk_prompt_logprobs: int,
|
|
162
|
+
) -> types.SampleResponse:
|
|
163
|
+
prompt_tokens = prompt.to_ints()
|
|
164
|
+
max_tokens = sampling_params.max_tokens or 16
|
|
165
|
+
sequences: list[types.SampledSequence] = []
|
|
166
|
+
for _ in range(num_samples):
|
|
167
|
+
generated = self._generate_tokens(prompt_tokens, max_tokens)
|
|
168
|
+
seq = types.SampledSequence(
|
|
169
|
+
stop_reason="length",
|
|
170
|
+
tokens=generated,
|
|
171
|
+
logprobs=[-0.3 for _ in generated],
|
|
172
|
+
)
|
|
173
|
+
sequences.append(seq)
|
|
174
|
+
prompt_logprobs = None
|
|
175
|
+
topk_prompt = None
|
|
176
|
+
if include_prompt_logprobs:
|
|
177
|
+
prompt_logprobs = [-0.1 if tok is not None else None for tok in prompt_tokens]
|
|
178
|
+
if topk_prompt_logprobs > 0:
|
|
179
|
+
topk_prompt = [
|
|
180
|
+
[
|
|
181
|
+
(token, round(-0.05 - idx * 0.01, 4))
|
|
182
|
+
for idx, token in enumerate(prompt_tokens[:topk_prompt_logprobs])
|
|
183
|
+
]
|
|
184
|
+
if token is not None
|
|
185
|
+
else None
|
|
186
|
+
for token in prompt_tokens
|
|
187
|
+
]
|
|
188
|
+
return types.SampleResponse(
|
|
189
|
+
sequences=sequences,
|
|
190
|
+
prompt_logprobs=prompt_logprobs,
|
|
191
|
+
topk_prompt_logprobs=topk_prompt,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
# Checkpoint I/O
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
def snapshot_state(self) -> dict[str, Any]:
|
|
198
|
+
return {
|
|
199
|
+
"weights": self._weights.astype(float).tolist(),
|
|
200
|
+
"adam_m": self._adam_m.astype(float).tolist(),
|
|
201
|
+
"adam_v": self._adam_v.astype(float).tolist(),
|
|
202
|
+
"step": self.step,
|
|
203
|
+
"beta1_power": self._beta1_power,
|
|
204
|
+
"beta2_power": self._beta2_power,
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
def load_state(self, state: dict[str, Any]) -> None:
|
|
208
|
+
with self._lock:
|
|
209
|
+
self._weights = np.array(state.get("weights", self._weights.tolist()), dtype=np.float32)
|
|
210
|
+
self._adam_m = np.array(state.get("adam_m", self._adam_m.tolist()), dtype=np.float32)
|
|
211
|
+
self._adam_v = np.array(state.get("adam_v", self._adam_v.tolist()), dtype=np.float32)
|
|
212
|
+
self.step = int(state.get("step", self.step))
|
|
213
|
+
self._beta1_power = float(state.get("beta1_power", self._beta1_power))
|
|
214
|
+
self._beta2_power = float(state.get("beta2_power", self._beta2_power))
|
|
215
|
+
self._pending_grad = None
|
|
216
|
+
self._pending_examples = 0
|
|
217
|
+
|
|
218
|
+
# ------------------------------------------------------------------
|
|
219
|
+
# Internal helpers
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
def _target_tokens(self, datum: types.Datum) -> list[int]:
|
|
222
|
+
if not datum.loss_fn_inputs:
|
|
223
|
+
return datum.model_input.to_ints()
|
|
224
|
+
tensor = datum.loss_fn_inputs.get("target_tokens")
|
|
225
|
+
if tensor is None:
|
|
226
|
+
return datum.model_input.to_ints()
|
|
227
|
+
return [int(value) for value in tensor.data]
|
|
228
|
+
|
|
229
|
+
def _vectorize(self, tokens: Sequence[int]) -> np.ndarray:
|
|
230
|
+
if not tokens:
|
|
231
|
+
return np.zeros(self.hidden_dim, dtype=np.float32)
|
|
232
|
+
vecs = [self._token_embedding(token) for token in tokens]
|
|
233
|
+
return np.mean(vecs, axis=0)
|
|
234
|
+
|
|
235
|
+
def _token_embedding(self, token_id: int) -> np.ndarray:
|
|
236
|
+
cached = self._embedding_cache.get(token_id)
|
|
237
|
+
if cached is None:
|
|
238
|
+
rng = np.random.default_rng(self.seed + token_id)
|
|
239
|
+
cached = rng.standard_normal(self.hidden_dim, dtype=np.float32)
|
|
240
|
+
self._embedding_cache[token_id] = cached
|
|
241
|
+
return cached
|
|
242
|
+
|
|
243
|
+
def _target_scalar(self, tokens: Sequence[int]) -> float:
|
|
244
|
+
if not tokens:
|
|
245
|
+
return 0.0
|
|
246
|
+
return np.tanh(_safe_mean(tokens) / 100.0)
|
|
247
|
+
|
|
248
|
+
def _generate_tokens(self, prompt_tokens: list[int], max_tokens: int) -> list[int]:
|
|
249
|
+
start = prompt_tokens[-1] if prompt_tokens else (abs(self.seed) % 32000) + 1
|
|
250
|
+
return [(start + i) % 32000 for i in range(1, max_tokens + 1)]
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def build_backend(base_model: str, lora_rank: int, seed: int | None = None) -> ModelBackend:
|
|
254
|
+
return ModelBackend(base_model=base_model, lora_rank=lora_rank, seed=seed or 0)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .sampling_backend import BaseSamplingBackend, VLLMSamplingBackend
|
|
2
|
+
from .training_backend import BaseTrainingBackend, HFTrainingBackend
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"BaseSamplingBackend",
|
|
7
|
+
"VLLMSamplingBackend",
|
|
8
|
+
"BaseTrainingBackend",
|
|
9
|
+
"HFTrainingBackend",
|
|
10
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from tinker import types
|
|
7
|
+
|
|
8
|
+
from ..checkpoints import CheckpointRecord
|
|
9
|
+
from ..config import ModelConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseBackend(ABC):
|
|
13
|
+
"""Base class for all backends."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
16
|
+
self.base_model = config.model_name
|
|
17
|
+
self.config = config
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def async_init(self) -> None:
|
|
21
|
+
"""Asynchronous initialization if needed."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseSamplingBackend(BaseBackend):
|
|
25
|
+
"""Abstract sampling backend."""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
async def sample(
|
|
29
|
+
self,
|
|
30
|
+
prompt: types.ModelInput,
|
|
31
|
+
num_samples: int,
|
|
32
|
+
sampling_params: types.SamplingParams,
|
|
33
|
+
include_prompt_logprobs: bool = False,
|
|
34
|
+
topk_prompt_logprobs: int = 0,
|
|
35
|
+
lora_id: Optional[str] = None,
|
|
36
|
+
) -> types.SampleResponse:
|
|
37
|
+
"""Abstract method for sampling."""
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
|
|
41
|
+
"""Add LoRA adapter to the backend."""
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
45
|
+
"""Remove LoRA adapter from the backend."""
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def create_backend(cls, config: ModelConfig) -> "BaseSamplingBackend":
|
|
49
|
+
"""Factory method to create a sampling backend instance."""
|
|
50
|
+
if os.getenv("TUFT_CPU_TEST", "0") == "1":
|
|
51
|
+
from ..backends.sampling_backend import DummySamplingBackend
|
|
52
|
+
|
|
53
|
+
return DummySamplingBackend(config)
|
|
54
|
+
else:
|
|
55
|
+
from ..backends.sampling_backend import VLLMSamplingBackend
|
|
56
|
+
|
|
57
|
+
return VLLMSamplingBackend(config)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BaseTrainingBackend(BaseBackend):
|
|
61
|
+
"""Abstract training backend."""
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def forward(
|
|
65
|
+
self,
|
|
66
|
+
data: list[types.Datum],
|
|
67
|
+
lora_id: str,
|
|
68
|
+
loss_fn: types.LossFnType,
|
|
69
|
+
loss_fn_config: dict[str, float] | None,
|
|
70
|
+
backward: bool = False,
|
|
71
|
+
) -> types.ForwardBackwardOutput:
|
|
72
|
+
"""Abstract method for forward pass."""
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
|
|
76
|
+
"""Abstract method for creating LoRA adapter."""
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
80
|
+
"""Abstract method for removing LoRA adapter."""
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def optim_step(
|
|
84
|
+
self,
|
|
85
|
+
adam_params: types.AdamParams,
|
|
86
|
+
lora_id: str,
|
|
87
|
+
) -> types.OptimStepResponse:
|
|
88
|
+
"""Abstract method for optimization step."""
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
async def save_state(
|
|
92
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Abstract method for saving model state."""
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def load_state(
|
|
98
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Abstract method for loading model state."""
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def create_backend(cls, config: ModelConfig) -> "BaseTrainingBackend":
|
|
104
|
+
"""Factory method to create a training backend instance."""
|
|
105
|
+
if os.getenv("TUFT_CPU_TEST", "0") == "1":
|
|
106
|
+
from ..backends.training_backend import DummyTrainingBackend
|
|
107
|
+
|
|
108
|
+
return DummyTrainingBackend(config)
|
|
109
|
+
else:
|
|
110
|
+
from ..backends.training_backend import HFTrainingBackend
|
|
111
|
+
|
|
112
|
+
return HFTrainingBackend(config)
|