torchrl-nightly 2025.6.19__cp39-cp39-win_amd64.whl → 2025.6.21__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
- torchrl/collectors/collectors.py +49 -24
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/datasets/minari_data.py +1 -1
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -9
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/envs/transforms/transforms.py +11 -27
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/objectives/ppo.py +35 -12
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,25 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
import contextlib
|
8
|
+
import gc
|
9
|
+
|
7
10
|
from copy import copy
|
8
11
|
|
9
12
|
import torch
|
10
|
-
from tensordict import NestedKey, TensorDictBase, unravel_key
|
13
|
+
from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
|
11
14
|
from tensordict.nn import ProbabilisticTensorDictModule
|
12
|
-
from tensordict.utils import is_seq_of_nested_key
|
15
|
+
from tensordict.utils import _zip_strict, is_seq_of_nested_key
|
13
16
|
from torchrl.data import Composite, Unbounded
|
17
|
+
from torchrl.data.llm.chat import History
|
14
18
|
from torchrl.envs import EnvBase, Transform
|
15
19
|
from torchrl.envs.transforms.utils import _set_missing_tolerance
|
20
|
+
from torchrl.modules.llm.policies.common import CategoricalSequential
|
21
|
+
|
22
|
+
try:
|
23
|
+
import transformers
|
24
|
+
except ImportError:
|
25
|
+
transformers = None
|
16
26
|
|
17
27
|
|
18
28
|
class KLRewardTransform(Transform):
|
@@ -141,8 +151,8 @@ class KLRewardTransform(Transform):
|
|
141
151
|
f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
|
142
152
|
f"or pass the action_key argument directly to {type(self).__name__} constructor."
|
143
153
|
)
|
144
|
-
|
145
|
-
if
|
154
|
+
response_txt = tensordict.get(action_key, None)
|
155
|
+
if response_txt is None:
|
146
156
|
if not self.missing_tolerance:
|
147
157
|
raise RuntimeError(
|
148
158
|
f"Action with key {action_key} not found data {tensordict}"
|
@@ -269,3 +279,229 @@ class KLRewardTransform(Transform):
|
|
269
279
|
observation_spec[self.out_keys[1]] = reward_spec.clone()
|
270
280
|
|
271
281
|
return output_spec
|
282
|
+
|
283
|
+
|
284
|
+
class RetrieveLogProb(Transform):
|
285
|
+
"""A transform to retrieve the log-probs of a text given a reference model.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
actor (CategoricalSequential): the reference model.
|
289
|
+
|
290
|
+
Keyword Args:
|
291
|
+
history_key (NestedKey): the key where the history is stored. Defaults to `"history"`.
|
292
|
+
log_prob_key (NestedKey): the key where the log-probs are stored. Defaults to `"ref_log_prob"`.
|
293
|
+
assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
|
294
|
+
where the role is `"assistant"`). Defaults to `False`.
|
295
|
+
|
296
|
+
.. note:: The template must accommodate the `return_assistant_tokens_mask` keyword argument.
|
297
|
+
This may not be the case for all templates. In this case, you can pass a custom template to the `apply_chat_template` method
|
298
|
+
via the `tokenizer_kwargs` argument: `tokenizer_kwargs = {"chat_template_name": "qwen"}` or `tokenizer_kwargs = {"chat_template": my_template}.
|
299
|
+
|
300
|
+
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
|
301
|
+
To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
|
302
|
+
Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
|
303
|
+
tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
|
304
|
+
detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
|
305
|
+
device (torch.device): the device to use for tensor creation. Defaults to `None`.
|
306
|
+
|
307
|
+
Examples:
|
308
|
+
>>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
|
309
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
310
|
+
>>> from torchrl.objectives.llm.sft import SFTLoss
|
311
|
+
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
|
312
|
+
>>> from tensordict import TensorDict, lazy_stack, set_list_to_stack
|
313
|
+
>>> import torch
|
314
|
+
>>>
|
315
|
+
>>> set_list_to_stack(True).set()
|
316
|
+
>>>
|
317
|
+
>>> # Create chat data
|
318
|
+
>>> chats = [
|
319
|
+
... [
|
320
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
321
|
+
... {"role": "user", "content": "Hello, how are you?"},
|
322
|
+
... {"role": "assistant", "content": "I'm doing well, thank you!"},
|
323
|
+
... ],
|
324
|
+
... [
|
325
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
326
|
+
... {"role": "user", "content": "What's the weather like?"},
|
327
|
+
... {"role": "assistant", "content": "I can't check the weather for you."},
|
328
|
+
... ],
|
329
|
+
... ]
|
330
|
+
>>> history = History.from_chats(chats)
|
331
|
+
>>> print(f"Created history with shape: {history.shape}")
|
332
|
+
Created history with shape: torch.Size([2, 3])
|
333
|
+
>>>
|
334
|
+
>>> # Setup tokenizer and model
|
335
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
336
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
337
|
+
>>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
|
338
|
+
>>> model = OPTForCausalLM(OPTConfig()).eval()
|
339
|
+
>>>
|
340
|
+
>>> # Create training and reference policies
|
341
|
+
>>> policy_train = TransformersWrapper(
|
342
|
+
... model,
|
343
|
+
... tokenizer=tokenizer,
|
344
|
+
... generate=False,
|
345
|
+
... from_text=True,
|
346
|
+
... chat_template_name="qwen",
|
347
|
+
... )
|
348
|
+
>>> policy_ref = TransformersWrapper(
|
349
|
+
... model,
|
350
|
+
... tokenizer=tokenizer,
|
351
|
+
... generate=False,
|
352
|
+
... from_text=True,
|
353
|
+
... return_log_probs=True,
|
354
|
+
... chat_template_name="qwen",
|
355
|
+
... )
|
356
|
+
>>>
|
357
|
+
>>> # Create the RetrieveLogProb transform
|
358
|
+
>>> transform = RetrieveLogProb(
|
359
|
+
... policy_ref,
|
360
|
+
... assistant_only=True,
|
361
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
362
|
+
... tokenizer=tokenizer,
|
363
|
+
... )
|
364
|
+
>>>
|
365
|
+
>>> # Prepare data
|
366
|
+
>>> text = history[:, :-1].apply_chat_template(
|
367
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
|
368
|
+
... )
|
369
|
+
>>> text_response = history.apply_chat_template(
|
370
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
|
371
|
+
... )
|
372
|
+
>>> text_response = [
|
373
|
+
... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
|
374
|
+
... ]
|
375
|
+
>>> td = TensorDict(
|
376
|
+
... text=text,
|
377
|
+
... text_response=text_response,
|
378
|
+
... history=history,
|
379
|
+
... next=TensorDict(
|
380
|
+
... reward=torch.randn(2, 1),
|
381
|
+
... done=torch.zeros(2, dtype=torch.bool),
|
382
|
+
... history=history,
|
383
|
+
... ),
|
384
|
+
... batch_size=(2,),
|
385
|
+
... )
|
386
|
+
>>> data = lazy_stack(list(td.unbind(0)))
|
387
|
+
>>>
|
388
|
+
>>> # Apply the transform to get reference log probabilities
|
389
|
+
>>> data = transform(data)
|
390
|
+
>>> # You can get a padded tensor for batching:
|
391
|
+
>>> ref_log_probs = data.get(("next", "ref_log_prob"), as_padded_tensor=True)
|
392
|
+
>>> print(f"Type: {type(ref_log_probs)}, Length: {len(ref_log_probs)}")
|
393
|
+
Type: <class 'torch.Tensor'>, Length: 2
|
394
|
+
>>> print(f"Example shapes: {[x.shape for x in ref_log_probs]}")
|
395
|
+
Example shapes: [torch.Size([35]), torch.Size([35])]
|
396
|
+
>>> print(ref_log_probs.shape) # (batch, max_seq_len)
|
397
|
+
torch.Size([2, 35])
|
398
|
+
>>>
|
399
|
+
>>> # Use with SFTLoss for KL regularization
|
400
|
+
>>> loss = SFTLoss(
|
401
|
+
... actor_network=policy_train,
|
402
|
+
... tokenizer=tokenizer,
|
403
|
+
... reduction="mean",
|
404
|
+
... normalize_by_seq_length=True,
|
405
|
+
... kl_to_ref_coeff=0.1,
|
406
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
407
|
+
... )
|
408
|
+
>>> loss_vals = loss(data)
|
409
|
+
>>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
|
410
|
+
SFT Loss: 10.7856
|
411
|
+
>>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
|
412
|
+
KL to Reference Loss: 0.0000
|
413
|
+
>>> print(f"Total Loss: {loss_vals.sum(reduce=True).item():.4f}")
|
414
|
+
Total Loss: 10.7856
|
415
|
+
|
416
|
+
Note:
|
417
|
+
By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
|
418
|
+
Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
|
419
|
+
The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
|
420
|
+
|
421
|
+
"""
|
422
|
+
|
423
|
+
def __init__(
|
424
|
+
self,
|
425
|
+
actor: CategoricalSequential,
|
426
|
+
*,
|
427
|
+
history_key: NestedKey | None = None,
|
428
|
+
log_prob_key: NestedKey = "ref_log_prob",
|
429
|
+
assistant_only: bool = False,
|
430
|
+
tokenizer_kwargs: dict | None = None,
|
431
|
+
detach: bool = True,
|
432
|
+
device: torch.device | None = None,
|
433
|
+
tokenizer: transformers.AutoTokenizer | None = None,
|
434
|
+
):
|
435
|
+
if history_key is None:
|
436
|
+
history_key = "history"
|
437
|
+
self.history_key = history_key
|
438
|
+
self.log_prob_key = log_prob_key
|
439
|
+
super().__init__(in_keys=[history_key], out_keys=[log_prob_key])
|
440
|
+
self.actor = actor
|
441
|
+
if not getattr(actor, "return_log_probs", True):
|
442
|
+
raise ValueError(
|
443
|
+
"The actor must have `return_log_probs=True` to use the `AssistantLogProb` transform."
|
444
|
+
)
|
445
|
+
if getattr(actor, "generate", True):
|
446
|
+
raise ValueError(
|
447
|
+
"The actor must have `generate=False` to use the `AssistantLogProb` transform."
|
448
|
+
)
|
449
|
+
if not getattr(actor, "from_text", False):
|
450
|
+
raise ValueError(
|
451
|
+
"The actor must have `from_text=True` to use the `AssistantLogProb` transform. If `from_text=False` is required, please file an issue on GitHub."
|
452
|
+
)
|
453
|
+
# if getattr(self.actor, "tokenizer_kwargs", {}).get("add_generation_prompt", True):
|
454
|
+
# raise ValueError("The actor must have `tokenizer_kwargs['add_generation_prompt']=False` to use the `AssistantLogProb` transform.")
|
455
|
+
self.assistant_only = assistant_only
|
456
|
+
if tokenizer_kwargs is None:
|
457
|
+
tokenizer_kwargs = {}
|
458
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
459
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
460
|
+
tokenizer_kwargs.setdefault("return_tensors", "pt")
|
461
|
+
tokenizer_kwargs.setdefault("padding", False)
|
462
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
463
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
464
|
+
self.tokenizer = tokenizer
|
465
|
+
self.detach = detach
|
466
|
+
self.device = device
|
467
|
+
|
468
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
469
|
+
next_td = self._step(tensordict, tensordict.get("next"))
|
470
|
+
return tensordict.set("next", next_td)
|
471
|
+
|
472
|
+
@set_list_to_stack(True)
|
473
|
+
def _step(
|
474
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
475
|
+
) -> TensorDictBase:
|
476
|
+
td = next_tensordict.select(self.history_key)
|
477
|
+
with torch.device(
|
478
|
+
self.device
|
479
|
+
) if self.device is not None else contextlib.nullcontext(), torch.no_grad() if self.detach else contextlib.nullcontext():
|
480
|
+
result = self.actor(td.select(self.history_key))
|
481
|
+
td.update(result.select(getattr(self.actor, "log_prob_key", "log_probs")))
|
482
|
+
td.rename_key_(
|
483
|
+
getattr(self.actor, "log_prob_key", "log_probs"), self.log_prob_key
|
484
|
+
)
|
485
|
+
if torch.cuda.is_available():
|
486
|
+
gc.collect()
|
487
|
+
torch.cuda.empty_cache()
|
488
|
+
if self.assistant_only:
|
489
|
+
with torch.device(
|
490
|
+
self.device
|
491
|
+
) if self.device is not None else contextlib.nullcontext():
|
492
|
+
# Get assistant mask
|
493
|
+
history: History = td.get(self.history_key)
|
494
|
+
proc = history.apply_chat_template(
|
495
|
+
tokenizer=self.actor.tokenizer
|
496
|
+
if self.tokenizer is None
|
497
|
+
else self.tokenizer,
|
498
|
+
**self.tokenizer_kwargs,
|
499
|
+
)
|
500
|
+
assistant_masks = proc.get("assistant_masks", as_list=True)
|
501
|
+
log_probs = td.get(self.log_prob_key, as_list=True)
|
502
|
+
log_probs = [
|
503
|
+
lp[mask.bool()]
|
504
|
+
for lp, mask in _zip_strict(log_probs, assistant_masks)
|
505
|
+
]
|
506
|
+
td = td.set(self.log_prob_key, log_probs)
|
507
|
+
return next_tensordict.update(td)
|
@@ -8726,8 +8726,10 @@ class Reward2GoTransform(Transform):
|
|
8726
8726
|
class ActionMask(Transform):
|
8727
8727
|
"""An adaptive action masker.
|
8728
8728
|
|
8729
|
-
This transform
|
8730
|
-
|
8729
|
+
This transform is useful to ensure that randomly generated actions
|
8730
|
+
respect legal actions, by masking the action specs.
|
8731
|
+
It reads the mask from the input tensordict after the step is executed,
|
8732
|
+
and adapts the mask of the finite action spec.
|
8731
8733
|
|
8732
8734
|
.. note:: This transform will fail when used without an environment.
|
8733
8735
|
|
@@ -8773,8 +8775,6 @@ class ActionMask(Transform):
|
|
8773
8775
|
>>> base_env = MaskedEnv()
|
8774
8776
|
>>> env = TransformedEnv(base_env, ActionMask())
|
8775
8777
|
>>> r = env.rollout(10)
|
8776
|
-
>>> env = TransformedEnv(base_env, ActionMask())
|
8777
|
-
>>> r = env.rollout(10)
|
8778
8778
|
>>> r["action_mask"]
|
8779
8779
|
tensor([[ True, True, True, True],
|
8780
8780
|
[ True, True, False, True],
|
@@ -8810,15 +8810,8 @@ class ActionMask(Transform):
|
|
8810
8810
|
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
|
8811
8811
|
|
8812
8812
|
@property
|
8813
|
-
def action_spec(self):
|
8814
|
-
action_spec = self.container.full_action_spec
|
8815
|
-
keys = self.container.action_keys
|
8816
|
-
if len(keys) == 1:
|
8817
|
-
action_spec = action_spec[keys[0]]
|
8818
|
-
else:
|
8819
|
-
raise ValueError(
|
8820
|
-
f"Too many action keys for {self.__class__.__name__}: {keys=}"
|
8821
|
-
)
|
8813
|
+
def action_spec(self) -> TensorSpec:
|
8814
|
+
action_spec = self.container.full_action_spec[self.in_keys[0]]
|
8822
8815
|
if not isinstance(action_spec, self.ACCEPTED_SPECS):
|
8823
8816
|
raise ValueError(
|
8824
8817
|
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
|
@@ -8826,29 +8819,20 @@ class ActionMask(Transform):
|
|
8826
8819
|
return action_spec
|
8827
8820
|
|
8828
8821
|
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
8829
|
-
|
8830
|
-
if parent is None:
|
8822
|
+
if self.parent is None:
|
8831
8823
|
raise RuntimeError(
|
8832
8824
|
f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment."
|
8833
8825
|
)
|
8826
|
+
|
8834
8827
|
mask = next_tensordict.get(self.in_keys[1])
|
8835
|
-
action_spec
|
8836
|
-
|
8828
|
+
self.action_spec.update_mask(mask.to(self.action_spec.device))
|
8829
|
+
|
8837
8830
|
return next_tensordict
|
8838
8831
|
|
8839
8832
|
def _reset(
|
8840
8833
|
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
8841
8834
|
) -> TensorDictBase:
|
8842
|
-
|
8843
|
-
mask = tensordict.get(self.in_keys[1], None)
|
8844
|
-
if mask is not None:
|
8845
|
-
mask = mask.to(action_spec.device)
|
8846
|
-
action_spec.update_mask(mask)
|
8847
|
-
|
8848
|
-
# TODO: Check that this makes sense
|
8849
|
-
with _set_missing_tolerance(self, True):
|
8850
|
-
tensordict_reset = self._call(tensordict_reset)
|
8851
|
-
return tensordict_reset
|
8835
|
+
return self._call(tensordict_reset)
|
8852
8836
|
|
8853
8837
|
|
8854
8838
|
class VecGymEnvTransform(Transform):
|
@@ -65,6 +65,10 @@ class TransformersWrapper(CategoricalSequential):
|
|
65
65
|
operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
|
66
66
|
created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
|
67
67
|
conserve type, batch-size, and device). Defaults to `True`.
|
68
|
+
chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when
|
69
|
+
applying the chat template to the history. Defaults to `None`.
|
70
|
+
chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
|
71
|
+
Defaults to `None`.
|
68
72
|
|
69
73
|
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
|
70
74
|
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
|
@@ -131,6 +135,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
131
135
|
tokenizer_kwargs: dict | None = None,
|
132
136
|
pad_output: bool = True,
|
133
137
|
inplace: Literal[True, False, "empty"] | None = True,
|
138
|
+
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
139
|
+
chat_template: str | None = None,
|
134
140
|
):
|
135
141
|
super().__init__()
|
136
142
|
|
@@ -143,6 +149,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
143
149
|
self.inplace = inplace
|
144
150
|
self.pad_output = pad_output
|
145
151
|
padding_value = None
|
152
|
+
self.chat_template_name = chat_template_name
|
153
|
+
self.chat_template = chat_template
|
146
154
|
|
147
155
|
if not tokenizer_kwargs:
|
148
156
|
tokenizer_kwargs = {}
|
@@ -300,7 +308,17 @@ class TransformersWrapper(CategoricalSequential):
|
|
300
308
|
raise ValueError(
|
301
309
|
"No text or history provided to the TransformersWrapper."
|
302
310
|
)
|
303
|
-
|
311
|
+
tokenizer_kwargs = {}
|
312
|
+
if self.chat_template_name is not None:
|
313
|
+
tokenizer_kwargs.setdefault(
|
314
|
+
"chat_template_name", self.chat_template_name
|
315
|
+
)
|
316
|
+
if self.chat_template is not None:
|
317
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
318
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
319
|
+
text = history.apply_chat_template(
|
320
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
321
|
+
)
|
304
322
|
if not isinstance(text, (list, str)):
|
305
323
|
text = text.tolist()
|
306
324
|
tokens_in = self.tokenizer(text, **self.tokenizer_kwargs)
|
@@ -325,7 +343,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
325
343
|
logits = torch.stack(list(tokens_out["logits"]), 1)
|
326
344
|
logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
|
327
345
|
log_probs, logits = self._log_probs_generate(
|
328
|
-
sequences, logits, pad_val
|
346
|
+
sequences, logits, pad_val=-100
|
329
347
|
)
|
330
348
|
response_text = self.tokenizer.batch_decode(
|
331
349
|
sequences, skip_special_tokens=False
|
@@ -407,17 +425,36 @@ class TransformersWrapper(CategoricalSequential):
|
|
407
425
|
pad_val = self.tokenizer.pad_token_id
|
408
426
|
|
409
427
|
prompt_txt = td.get(self.text_key)
|
410
|
-
|
428
|
+
response_txt = td.get(self.text_response_key)
|
429
|
+
if prompt_txt is None or response_txt is None:
|
430
|
+
if prompt_txt is not None and response_txt is not None:
|
431
|
+
raise ValueError(
|
432
|
+
"No text or history provided to the TransformersWrapper. Either both are provided or none of them."
|
433
|
+
)
|
411
434
|
# Fallback on history parsing
|
412
435
|
history = td.get(self.history_key)
|
413
436
|
if history is None:
|
414
437
|
raise ValueError(
|
415
438
|
"No text or history provided to the TransformersWrapper."
|
416
439
|
)
|
417
|
-
|
440
|
+
tokenizer_kwargs = {}
|
441
|
+
if self.chat_template_name is not None:
|
442
|
+
tokenizer_kwargs.setdefault(
|
443
|
+
"chat_template_name", self.chat_template_name
|
444
|
+
)
|
445
|
+
if self.chat_template is not None:
|
446
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
447
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
448
|
+
response_txt = history.apply_chat_template(
|
449
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
450
|
+
)
|
451
|
+
if isinstance(response_txt, list):
|
452
|
+
prompt_txt = ["" for _ in response_txt]
|
453
|
+
else:
|
454
|
+
prompt_txt = ""
|
455
|
+
|
418
456
|
if not isinstance(prompt_txt, (list, str)):
|
419
457
|
prompt_txt = prompt_txt.tolist()
|
420
|
-
response_txt = td.get(self.text_response_key)
|
421
458
|
if not isinstance(response_txt, (list, str)):
|
422
459
|
response_txt = response_txt.tolist()
|
423
460
|
total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)]
|
@@ -450,6 +487,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
450
487
|
)
|
451
488
|
sequences = [
|
452
489
|
_total_input_ids[_prompt_input_ids.shape[-1] :]
|
490
|
+
if _prompt_input_ids.shape[-1] > 0
|
491
|
+
else _total_input_ids
|
453
492
|
for _total_input_ids, _prompt_input_ids in zip(
|
454
493
|
total_input_ids, prompt_input_ids
|
455
494
|
)
|
@@ -484,7 +523,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
484
523
|
|
485
524
|
total_input_ids = [
|
486
525
|
torch.cat([_prompt_input_ids, _response_input_ids], -1)
|
487
|
-
for _prompt_input_ids, _response_input_ids in
|
526
|
+
for _prompt_input_ids, _response_input_ids in _zip_strict(
|
488
527
|
prompt_input_ids, response_input_ids
|
489
528
|
)
|
490
529
|
]
|
@@ -512,7 +551,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
512
551
|
total_input_ids, attention_mask=total_attention_mask, **kwargs
|
513
552
|
)
|
514
553
|
log_probs, logits = self._log_probs_from_logits(
|
515
|
-
total_tokens_out, response_input_ids, pad_val
|
554
|
+
total_tokens_out, response_input_ids, pad_val=-100
|
516
555
|
)
|
517
556
|
# for i in range(log_probs.size(0)):
|
518
557
|
# assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1]
|
@@ -522,7 +561,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
522
561
|
return out
|
523
562
|
|
524
563
|
@classmethod
|
525
|
-
def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val):
|
564
|
+
def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val=-100):
|
526
565
|
response_input_ids = pad_sequence(
|
527
566
|
response_input_ids,
|
528
567
|
padding_value=pad_val,
|
@@ -532,10 +571,21 @@ class TransformersWrapper(CategoricalSequential):
|
|
532
571
|
pad_mask = response_input_ids != pad_val
|
533
572
|
|
534
573
|
logits = total_tokens_out["logits"]
|
535
|
-
logits = logits.log_softmax(dim=-1)
|
536
|
-
|
537
|
-
|
538
|
-
|
574
|
+
# logits = logits.log_softmax(dim=-1)
|
575
|
+
if logits.shape[-2] != response_input_ids.shape[-1]:
|
576
|
+
logits = logits[..., -response_input_ids.shape[-1] - 1 : -1, :]
|
577
|
+
|
578
|
+
td = TensorDict(
|
579
|
+
logits=logits, response_input_ids=response_input_ids
|
580
|
+
).auto_batch_size_()
|
581
|
+
with td.flatten() as tdflat:
|
582
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
583
|
+
tdflat["logits"],
|
584
|
+
tdflat["response_input_ids"],
|
585
|
+
reduce=False,
|
586
|
+
ignore_index=pad_val,
|
587
|
+
)
|
588
|
+
log_probs = td["log_probs"]
|
539
589
|
|
540
590
|
# Recover the list
|
541
591
|
log_probs = _unpad_tensors(log_probs, pad_mask)
|
@@ -543,7 +593,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
543
593
|
return log_probs, logits
|
544
594
|
|
545
595
|
@classmethod
|
546
|
-
def _log_probs_generate(cls, sequences, logits, pad_val):
|
596
|
+
def _log_probs_generate(cls, sequences, logits, pad_val=-100):
|
547
597
|
tokens = pad_sequence(
|
548
598
|
sequences,
|
549
599
|
padding_value=pad_val,
|
@@ -557,6 +607,12 @@ class TransformersWrapper(CategoricalSequential):
|
|
557
607
|
padding_side="left",
|
558
608
|
)
|
559
609
|
|
560
|
-
logits = logits.log_softmax(dim=-1)
|
561
|
-
log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
|
610
|
+
# logits = logits.log_softmax(dim=-1)
|
611
|
+
# log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
|
612
|
+
td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
|
613
|
+
with td.flatten() as tdflat:
|
614
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
615
|
+
tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
|
616
|
+
)
|
617
|
+
log_probs = td["log_probs"]
|
562
618
|
return log_probs, logits
|
@@ -74,6 +74,9 @@ class vLLMWrapper(CategoricalSequential):
|
|
74
74
|
conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
|
75
75
|
otherwise.
|
76
76
|
|
77
|
+
chat_template_name (str | None, optional): The name of the chat template to use for the history. Defaults to `None`.
|
78
|
+
chat_template (str | None, optional): The chat template to use for the history. Defaults to `None`.
|
79
|
+
|
77
80
|
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
|
78
81
|
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
|
79
82
|
tokenization and padding.
|
@@ -120,6 +123,7 @@ class vLLMWrapper(CategoricalSequential):
|
|
120
123
|
token_response_key: NestedKey = ("tokens_response",)
|
121
124
|
text_response_key: NestedKey = ("text_response",)
|
122
125
|
attention_mask_key: NestedKey = ("attention_mask",)
|
126
|
+
history_key: NestedKey = ("history",)
|
123
127
|
|
124
128
|
def __init__(
|
125
129
|
self,
|
@@ -137,6 +141,8 @@ class vLLMWrapper(CategoricalSequential):
|
|
137
141
|
tokenizer_kwargs: dict | None = None,
|
138
142
|
pad_output: bool = False,
|
139
143
|
inplace: Literal[True, False, "empty"] | None = None,
|
144
|
+
chat_template_name: str | None = None,
|
145
|
+
chat_template: str | None = None,
|
140
146
|
):
|
141
147
|
super().__init__()
|
142
148
|
|
@@ -149,6 +155,8 @@ class vLLMWrapper(CategoricalSequential):
|
|
149
155
|
self._device = device
|
150
156
|
self.generate = generate
|
151
157
|
self.pad_output = pad_output
|
158
|
+
self.chat_template_name = chat_template_name
|
159
|
+
self.chat_template = chat_template
|
152
160
|
padding_value = None
|
153
161
|
|
154
162
|
if not tokenizer_kwargs:
|
@@ -329,7 +337,12 @@ class vLLMWrapper(CategoricalSequential):
|
|
329
337
|
history = td.get(self.history_key)
|
330
338
|
if history is None:
|
331
339
|
raise ValueError("No text or history provided to the vLLMWrapper.")
|
332
|
-
|
340
|
+
tokenizer_kwargs = {}
|
341
|
+
if self.chat_template_name is not None:
|
342
|
+
tokenizer_kwargs["chat_template_name"] = self.chat_template_name
|
343
|
+
if self.chat_template is not None:
|
344
|
+
tokenizer_kwargs["chat_template"] = self.chat_template
|
345
|
+
text = history.apply_chat_template(self.tokenizer, **tokenizer_kwargs)
|
333
346
|
if self.pad_output:
|
334
347
|
tokenizer_kwargs = self.tokenizer_kwargs
|
335
348
|
if not isinstance(text, (list, str)):
|
@@ -385,15 +398,35 @@ class vLLMWrapper(CategoricalSequential):
|
|
385
398
|
|
386
399
|
def _from_vllm_logprobs_text(self, td, sampling_params, out):
|
387
400
|
text_prompt = td.get(self.text_key)
|
388
|
-
|
401
|
+
text_response = td.get(self.text_response_key)
|
402
|
+
if text_response is None or text_prompt is None:
|
403
|
+
if text_response is not None and text_prompt is not None:
|
404
|
+
raise ValueError(
|
405
|
+
"No text or history provided to the vLLMWrapper. Either both are provided or none of them."
|
406
|
+
)
|
389
407
|
# Fallback on history parsing
|
390
408
|
history = td.get(self.history_key)
|
391
409
|
if history is None:
|
392
|
-
raise ValueError(
|
393
|
-
|
410
|
+
raise ValueError(
|
411
|
+
"No text or history provided to the TransformersWrapper."
|
412
|
+
)
|
413
|
+
tokenizer_kwargs = {}
|
414
|
+
if self.chat_template_name is not None:
|
415
|
+
tokenizer_kwargs.setdefault(
|
416
|
+
"chat_template_name", self.chat_template_name
|
417
|
+
)
|
418
|
+
if self.chat_template is not None:
|
419
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
420
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
421
|
+
text_response = history.apply_chat_template(
|
422
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
423
|
+
)
|
424
|
+
if isinstance(text_response, list):
|
425
|
+
text_prompt = ["" for _ in text_response]
|
426
|
+
else:
|
427
|
+
text_prompt = ""
|
394
428
|
if not isinstance(text_prompt, list):
|
395
429
|
text_prompt = text_prompt.tolist()
|
396
|
-
text_response = td.get(self.text_response_key)
|
397
430
|
if not isinstance(text_response, list):
|
398
431
|
text_response = text_response.tolist()
|
399
432
|
text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)]
|
@@ -5,5 +5,6 @@
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
7
|
from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
|
8
|
+
from .sft import SFTLoss, SFTLossOutput
|
8
9
|
|
9
|
-
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"]
|
10
|
+
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]
|