torchrl-nightly 2025.6.19__cp39-cp39-win_amd64.whl → 2025.6.21__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
  2. torchrl/collectors/collectors.py +49 -24
  3. torchrl/collectors/llm/base.py +13 -6
  4. torchrl/collectors/llm/ray_collector.py +3 -0
  5. torchrl/data/__init__.py +2 -0
  6. torchrl/data/datasets/minari_data.py +1 -1
  7. torchrl/data/llm/__init__.py +2 -0
  8. torchrl/data/llm/chat.py +59 -9
  9. torchrl/data/llm/topk.py +186 -0
  10. torchrl/data/replay_buffers/ray_buffer.py +15 -1
  11. torchrl/data/replay_buffers/replay_buffers.py +50 -11
  12. torchrl/data/replay_buffers/samplers.py +98 -21
  13. torchrl/data/replay_buffers/storages.py +29 -2
  14. torchrl/envs/llm/__init__.py +2 -0
  15. torchrl/envs/llm/chat.py +4 -1
  16. torchrl/envs/llm/reward/gsm8k.py +15 -8
  17. torchrl/envs/llm/transforms/__init__.py +2 -1
  18. torchrl/envs/llm/transforms/kl.py +240 -4
  19. torchrl/envs/transforms/transforms.py +11 -27
  20. torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
  21. torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
  22. torchrl/objectives/llm/__init__.py +2 -1
  23. torchrl/objectives/llm/sft.py +465 -0
  24. torchrl/objectives/ppo.py +35 -12
  25. torchrl/version.py +2 -2
  26. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
  27. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
  28. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
  29. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
  30. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,25 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
+ import contextlib
8
+ import gc
9
+
7
10
  from copy import copy
8
11
 
9
12
  import torch
10
- from tensordict import NestedKey, TensorDictBase, unravel_key
13
+ from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
11
14
  from tensordict.nn import ProbabilisticTensorDictModule
12
- from tensordict.utils import is_seq_of_nested_key
15
+ from tensordict.utils import _zip_strict, is_seq_of_nested_key
13
16
  from torchrl.data import Composite, Unbounded
17
+ from torchrl.data.llm.chat import History
14
18
  from torchrl.envs import EnvBase, Transform
15
19
  from torchrl.envs.transforms.utils import _set_missing_tolerance
20
+ from torchrl.modules.llm.policies.common import CategoricalSequential
21
+
22
+ try:
23
+ import transformers
24
+ except ImportError:
25
+ transformers = None
16
26
 
17
27
 
18
28
  class KLRewardTransform(Transform):
@@ -141,8 +151,8 @@ class KLRewardTransform(Transform):
141
151
  f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
142
152
  f"or pass the action_key argument directly to {type(self).__name__} constructor."
143
153
  )
144
- action = tensordict.get(action_key, None)
145
- if action is None:
154
+ response_txt = tensordict.get(action_key, None)
155
+ if response_txt is None:
146
156
  if not self.missing_tolerance:
147
157
  raise RuntimeError(
148
158
  f"Action with key {action_key} not found data {tensordict}"
@@ -269,3 +279,229 @@ class KLRewardTransform(Transform):
269
279
  observation_spec[self.out_keys[1]] = reward_spec.clone()
270
280
 
271
281
  return output_spec
282
+
283
+
284
+ class RetrieveLogProb(Transform):
285
+ """A transform to retrieve the log-probs of a text given a reference model.
286
+
287
+ Args:
288
+ actor (CategoricalSequential): the reference model.
289
+
290
+ Keyword Args:
291
+ history_key (NestedKey): the key where the history is stored. Defaults to `"history"`.
292
+ log_prob_key (NestedKey): the key where the log-probs are stored. Defaults to `"ref_log_prob"`.
293
+ assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
294
+ where the role is `"assistant"`). Defaults to `False`.
295
+
296
+ .. note:: The template must accommodate the `return_assistant_tokens_mask` keyword argument.
297
+ This may not be the case for all templates. In this case, you can pass a custom template to the `apply_chat_template` method
298
+ via the `tokenizer_kwargs` argument: `tokenizer_kwargs = {"chat_template_name": "qwen"}` or `tokenizer_kwargs = {"chat_template": my_template}.
299
+
300
+ tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
301
+ To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
302
+ Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
303
+ tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
304
+ detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
305
+ device (torch.device): the device to use for tensor creation. Defaults to `None`.
306
+
307
+ Examples:
308
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
309
+ >>> from torchrl.modules.llm import TransformersWrapper
310
+ >>> from torchrl.objectives.llm.sft import SFTLoss
311
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
312
+ >>> from tensordict import TensorDict, lazy_stack, set_list_to_stack
313
+ >>> import torch
314
+ >>>
315
+ >>> set_list_to_stack(True).set()
316
+ >>>
317
+ >>> # Create chat data
318
+ >>> chats = [
319
+ ... [
320
+ ... {"role": "system", "content": "You are a helpful assistant."},
321
+ ... {"role": "user", "content": "Hello, how are you?"},
322
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
323
+ ... ],
324
+ ... [
325
+ ... {"role": "system", "content": "You are a helpful assistant."},
326
+ ... {"role": "user", "content": "What's the weather like?"},
327
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
328
+ ... ],
329
+ ... ]
330
+ >>> history = History.from_chats(chats)
331
+ >>> print(f"Created history with shape: {history.shape}")
332
+ Created history with shape: torch.Size([2, 3])
333
+ >>>
334
+ >>> # Setup tokenizer and model
335
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
336
+ >>> tokenizer.pad_token = tokenizer.eos_token
337
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
338
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
339
+ >>>
340
+ >>> # Create training and reference policies
341
+ >>> policy_train = TransformersWrapper(
342
+ ... model,
343
+ ... tokenizer=tokenizer,
344
+ ... generate=False,
345
+ ... from_text=True,
346
+ ... chat_template_name="qwen",
347
+ ... )
348
+ >>> policy_ref = TransformersWrapper(
349
+ ... model,
350
+ ... tokenizer=tokenizer,
351
+ ... generate=False,
352
+ ... from_text=True,
353
+ ... return_log_probs=True,
354
+ ... chat_template_name="qwen",
355
+ ... )
356
+ >>>
357
+ >>> # Create the RetrieveLogProb transform
358
+ >>> transform = RetrieveLogProb(
359
+ ... policy_ref,
360
+ ... assistant_only=True,
361
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
362
+ ... tokenizer=tokenizer,
363
+ ... )
364
+ >>>
365
+ >>> # Prepare data
366
+ >>> text = history[:, :-1].apply_chat_template(
367
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
368
+ ... )
369
+ >>> text_response = history.apply_chat_template(
370
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
371
+ ... )
372
+ >>> text_response = [
373
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
374
+ ... ]
375
+ >>> td = TensorDict(
376
+ ... text=text,
377
+ ... text_response=text_response,
378
+ ... history=history,
379
+ ... next=TensorDict(
380
+ ... reward=torch.randn(2, 1),
381
+ ... done=torch.zeros(2, dtype=torch.bool),
382
+ ... history=history,
383
+ ... ),
384
+ ... batch_size=(2,),
385
+ ... )
386
+ >>> data = lazy_stack(list(td.unbind(0)))
387
+ >>>
388
+ >>> # Apply the transform to get reference log probabilities
389
+ >>> data = transform(data)
390
+ >>> # You can get a padded tensor for batching:
391
+ >>> ref_log_probs = data.get(("next", "ref_log_prob"), as_padded_tensor=True)
392
+ >>> print(f"Type: {type(ref_log_probs)}, Length: {len(ref_log_probs)}")
393
+ Type: <class 'torch.Tensor'>, Length: 2
394
+ >>> print(f"Example shapes: {[x.shape for x in ref_log_probs]}")
395
+ Example shapes: [torch.Size([35]), torch.Size([35])]
396
+ >>> print(ref_log_probs.shape) # (batch, max_seq_len)
397
+ torch.Size([2, 35])
398
+ >>>
399
+ >>> # Use with SFTLoss for KL regularization
400
+ >>> loss = SFTLoss(
401
+ ... actor_network=policy_train,
402
+ ... tokenizer=tokenizer,
403
+ ... reduction="mean",
404
+ ... normalize_by_seq_length=True,
405
+ ... kl_to_ref_coeff=0.1,
406
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
407
+ ... )
408
+ >>> loss_vals = loss(data)
409
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
410
+ SFT Loss: 10.7856
411
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
412
+ KL to Reference Loss: 0.0000
413
+ >>> print(f"Total Loss: {loss_vals.sum(reduce=True).item():.4f}")
414
+ Total Loss: 10.7856
415
+
416
+ Note:
417
+ By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
418
+ Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
419
+ The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
420
+
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ actor: CategoricalSequential,
426
+ *,
427
+ history_key: NestedKey | None = None,
428
+ log_prob_key: NestedKey = "ref_log_prob",
429
+ assistant_only: bool = False,
430
+ tokenizer_kwargs: dict | None = None,
431
+ detach: bool = True,
432
+ device: torch.device | None = None,
433
+ tokenizer: transformers.AutoTokenizer | None = None,
434
+ ):
435
+ if history_key is None:
436
+ history_key = "history"
437
+ self.history_key = history_key
438
+ self.log_prob_key = log_prob_key
439
+ super().__init__(in_keys=[history_key], out_keys=[log_prob_key])
440
+ self.actor = actor
441
+ if not getattr(actor, "return_log_probs", True):
442
+ raise ValueError(
443
+ "The actor must have `return_log_probs=True` to use the `AssistantLogProb` transform."
444
+ )
445
+ if getattr(actor, "generate", True):
446
+ raise ValueError(
447
+ "The actor must have `generate=False` to use the `AssistantLogProb` transform."
448
+ )
449
+ if not getattr(actor, "from_text", False):
450
+ raise ValueError(
451
+ "The actor must have `from_text=True` to use the `AssistantLogProb` transform. If `from_text=False` is required, please file an issue on GitHub."
452
+ )
453
+ # if getattr(self.actor, "tokenizer_kwargs", {}).get("add_generation_prompt", True):
454
+ # raise ValueError("The actor must have `tokenizer_kwargs['add_generation_prompt']=False` to use the `AssistantLogProb` transform.")
455
+ self.assistant_only = assistant_only
456
+ if tokenizer_kwargs is None:
457
+ tokenizer_kwargs = {}
458
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
459
+ tokenizer_kwargs.setdefault("tokenize", True)
460
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
461
+ tokenizer_kwargs.setdefault("padding", False)
462
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
463
+ self.tokenizer_kwargs = tokenizer_kwargs
464
+ self.tokenizer = tokenizer
465
+ self.detach = detach
466
+ self.device = device
467
+
468
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
469
+ next_td = self._step(tensordict, tensordict.get("next"))
470
+ return tensordict.set("next", next_td)
471
+
472
+ @set_list_to_stack(True)
473
+ def _step(
474
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
475
+ ) -> TensorDictBase:
476
+ td = next_tensordict.select(self.history_key)
477
+ with torch.device(
478
+ self.device
479
+ ) if self.device is not None else contextlib.nullcontext(), torch.no_grad() if self.detach else contextlib.nullcontext():
480
+ result = self.actor(td.select(self.history_key))
481
+ td.update(result.select(getattr(self.actor, "log_prob_key", "log_probs")))
482
+ td.rename_key_(
483
+ getattr(self.actor, "log_prob_key", "log_probs"), self.log_prob_key
484
+ )
485
+ if torch.cuda.is_available():
486
+ gc.collect()
487
+ torch.cuda.empty_cache()
488
+ if self.assistant_only:
489
+ with torch.device(
490
+ self.device
491
+ ) if self.device is not None else contextlib.nullcontext():
492
+ # Get assistant mask
493
+ history: History = td.get(self.history_key)
494
+ proc = history.apply_chat_template(
495
+ tokenizer=self.actor.tokenizer
496
+ if self.tokenizer is None
497
+ else self.tokenizer,
498
+ **self.tokenizer_kwargs,
499
+ )
500
+ assistant_masks = proc.get("assistant_masks", as_list=True)
501
+ log_probs = td.get(self.log_prob_key, as_list=True)
502
+ log_probs = [
503
+ lp[mask.bool()]
504
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
505
+ ]
506
+ td = td.set(self.log_prob_key, log_probs)
507
+ return next_tensordict.update(td)
@@ -8726,8 +8726,10 @@ class Reward2GoTransform(Transform):
8726
8726
  class ActionMask(Transform):
8727
8727
  """An adaptive action masker.
8728
8728
 
8729
- This transform reads the mask from the input tensordict after the step is executed,
8730
- and adapts the mask of the one-hot / categorical action spec.
8729
+ This transform is useful to ensure that randomly generated actions
8730
+ respect legal actions, by masking the action specs.
8731
+ It reads the mask from the input tensordict after the step is executed,
8732
+ and adapts the mask of the finite action spec.
8731
8733
 
8732
8734
  .. note:: This transform will fail when used without an environment.
8733
8735
 
@@ -8773,8 +8775,6 @@ class ActionMask(Transform):
8773
8775
  >>> base_env = MaskedEnv()
8774
8776
  >>> env = TransformedEnv(base_env, ActionMask())
8775
8777
  >>> r = env.rollout(10)
8776
- >>> env = TransformedEnv(base_env, ActionMask())
8777
- >>> r = env.rollout(10)
8778
8778
  >>> r["action_mask"]
8779
8779
  tensor([[ True, True, True, True],
8780
8780
  [ True, True, False, True],
@@ -8810,15 +8810,8 @@ class ActionMask(Transform):
8810
8810
  raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
8811
8811
 
8812
8812
  @property
8813
- def action_spec(self):
8814
- action_spec = self.container.full_action_spec
8815
- keys = self.container.action_keys
8816
- if len(keys) == 1:
8817
- action_spec = action_spec[keys[0]]
8818
- else:
8819
- raise ValueError(
8820
- f"Too many action keys for {self.__class__.__name__}: {keys=}"
8821
- )
8813
+ def action_spec(self) -> TensorSpec:
8814
+ action_spec = self.container.full_action_spec[self.in_keys[0]]
8822
8815
  if not isinstance(action_spec, self.ACCEPTED_SPECS):
8823
8816
  raise ValueError(
8824
8817
  self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
@@ -8826,29 +8819,20 @@ class ActionMask(Transform):
8826
8819
  return action_spec
8827
8820
 
8828
8821
  def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
8829
- parent = self.parent
8830
- if parent is None:
8822
+ if self.parent is None:
8831
8823
  raise RuntimeError(
8832
8824
  f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment."
8833
8825
  )
8826
+
8834
8827
  mask = next_tensordict.get(self.in_keys[1])
8835
- action_spec = self.action_spec
8836
- action_spec.update_mask(mask.to(action_spec.device))
8828
+ self.action_spec.update_mask(mask.to(self.action_spec.device))
8829
+
8837
8830
  return next_tensordict
8838
8831
 
8839
8832
  def _reset(
8840
8833
  self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
8841
8834
  ) -> TensorDictBase:
8842
- action_spec = self.action_spec
8843
- mask = tensordict.get(self.in_keys[1], None)
8844
- if mask is not None:
8845
- mask = mask.to(action_spec.device)
8846
- action_spec.update_mask(mask)
8847
-
8848
- # TODO: Check that this makes sense
8849
- with _set_missing_tolerance(self, True):
8850
- tensordict_reset = self._call(tensordict_reset)
8851
- return tensordict_reset
8835
+ return self._call(tensordict_reset)
8852
8836
 
8853
8837
 
8854
8838
  class VecGymEnvTransform(Transform):
@@ -65,6 +65,10 @@ class TransformersWrapper(CategoricalSequential):
65
65
  operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
66
66
  created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
67
67
  conserve type, batch-size, and device). Defaults to `True`.
68
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when
69
+ applying the chat template to the history. Defaults to `None`.
70
+ chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
71
+ Defaults to `None`.
68
72
 
69
73
  .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
70
74
  required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
@@ -131,6 +135,8 @@ class TransformersWrapper(CategoricalSequential):
131
135
  tokenizer_kwargs: dict | None = None,
132
136
  pad_output: bool = True,
133
137
  inplace: Literal[True, False, "empty"] | None = True,
138
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
139
+ chat_template: str | None = None,
134
140
  ):
135
141
  super().__init__()
136
142
 
@@ -143,6 +149,8 @@ class TransformersWrapper(CategoricalSequential):
143
149
  self.inplace = inplace
144
150
  self.pad_output = pad_output
145
151
  padding_value = None
152
+ self.chat_template_name = chat_template_name
153
+ self.chat_template = chat_template
146
154
 
147
155
  if not tokenizer_kwargs:
148
156
  tokenizer_kwargs = {}
@@ -300,7 +308,17 @@ class TransformersWrapper(CategoricalSequential):
300
308
  raise ValueError(
301
309
  "No text or history provided to the TransformersWrapper."
302
310
  )
303
- text = history.apply_chat_template(self.tokenizer)
311
+ tokenizer_kwargs = {}
312
+ if self.chat_template_name is not None:
313
+ tokenizer_kwargs.setdefault(
314
+ "chat_template_name", self.chat_template_name
315
+ )
316
+ if self.chat_template is not None:
317
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
318
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
319
+ text = history.apply_chat_template(
320
+ tokenizer=self.tokenizer, **tokenizer_kwargs
321
+ )
304
322
  if not isinstance(text, (list, str)):
305
323
  text = text.tolist()
306
324
  tokens_in = self.tokenizer(text, **self.tokenizer_kwargs)
@@ -325,7 +343,7 @@ class TransformersWrapper(CategoricalSequential):
325
343
  logits = torch.stack(list(tokens_out["logits"]), 1)
326
344
  logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
327
345
  log_probs, logits = self._log_probs_generate(
328
- sequences, logits, pad_val=pad_val
346
+ sequences, logits, pad_val=-100
329
347
  )
330
348
  response_text = self.tokenizer.batch_decode(
331
349
  sequences, skip_special_tokens=False
@@ -407,17 +425,36 @@ class TransformersWrapper(CategoricalSequential):
407
425
  pad_val = self.tokenizer.pad_token_id
408
426
 
409
427
  prompt_txt = td.get(self.text_key)
410
- if prompt_txt is None:
428
+ response_txt = td.get(self.text_response_key)
429
+ if prompt_txt is None or response_txt is None:
430
+ if prompt_txt is not None and response_txt is not None:
431
+ raise ValueError(
432
+ "No text or history provided to the TransformersWrapper. Either both are provided or none of them."
433
+ )
411
434
  # Fallback on history parsing
412
435
  history = td.get(self.history_key)
413
436
  if history is None:
414
437
  raise ValueError(
415
438
  "No text or history provided to the TransformersWrapper."
416
439
  )
417
- prompt_txt = history.apply_chat_template(self.tokenizer)
440
+ tokenizer_kwargs = {}
441
+ if self.chat_template_name is not None:
442
+ tokenizer_kwargs.setdefault(
443
+ "chat_template_name", self.chat_template_name
444
+ )
445
+ if self.chat_template is not None:
446
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
447
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
448
+ response_txt = history.apply_chat_template(
449
+ tokenizer=self.tokenizer, **tokenizer_kwargs
450
+ )
451
+ if isinstance(response_txt, list):
452
+ prompt_txt = ["" for _ in response_txt]
453
+ else:
454
+ prompt_txt = ""
455
+
418
456
  if not isinstance(prompt_txt, (list, str)):
419
457
  prompt_txt = prompt_txt.tolist()
420
- response_txt = td.get(self.text_response_key)
421
458
  if not isinstance(response_txt, (list, str)):
422
459
  response_txt = response_txt.tolist()
423
460
  total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)]
@@ -450,6 +487,8 @@ class TransformersWrapper(CategoricalSequential):
450
487
  )
451
488
  sequences = [
452
489
  _total_input_ids[_prompt_input_ids.shape[-1] :]
490
+ if _prompt_input_ids.shape[-1] > 0
491
+ else _total_input_ids
453
492
  for _total_input_ids, _prompt_input_ids in zip(
454
493
  total_input_ids, prompt_input_ids
455
494
  )
@@ -484,7 +523,7 @@ class TransformersWrapper(CategoricalSequential):
484
523
 
485
524
  total_input_ids = [
486
525
  torch.cat([_prompt_input_ids, _response_input_ids], -1)
487
- for _prompt_input_ids, _response_input_ids in zip(
526
+ for _prompt_input_ids, _response_input_ids in _zip_strict(
488
527
  prompt_input_ids, response_input_ids
489
528
  )
490
529
  ]
@@ -512,7 +551,7 @@ class TransformersWrapper(CategoricalSequential):
512
551
  total_input_ids, attention_mask=total_attention_mask, **kwargs
513
552
  )
514
553
  log_probs, logits = self._log_probs_from_logits(
515
- total_tokens_out, response_input_ids, pad_val=pad_val
554
+ total_tokens_out, response_input_ids, pad_val=-100
516
555
  )
517
556
  # for i in range(log_probs.size(0)):
518
557
  # assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1]
@@ -522,7 +561,7 @@ class TransformersWrapper(CategoricalSequential):
522
561
  return out
523
562
 
524
563
  @classmethod
525
- def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val):
564
+ def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val=-100):
526
565
  response_input_ids = pad_sequence(
527
566
  response_input_ids,
528
567
  padding_value=pad_val,
@@ -532,10 +571,21 @@ class TransformersWrapper(CategoricalSequential):
532
571
  pad_mask = response_input_ids != pad_val
533
572
 
534
573
  logits = total_tokens_out["logits"]
535
- logits = logits.log_softmax(dim=-1)
536
- logits = logits[:, -response_input_ids.shape[-1] - 1 : -1, :]
537
-
538
- log_probs = logits.gather(-1, response_input_ids.unsqueeze(-1)).squeeze(-1)
574
+ # logits = logits.log_softmax(dim=-1)
575
+ if logits.shape[-2] != response_input_ids.shape[-1]:
576
+ logits = logits[..., -response_input_ids.shape[-1] - 1 : -1, :]
577
+
578
+ td = TensorDict(
579
+ logits=logits, response_input_ids=response_input_ids
580
+ ).auto_batch_size_()
581
+ with td.flatten() as tdflat:
582
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
583
+ tdflat["logits"],
584
+ tdflat["response_input_ids"],
585
+ reduce=False,
586
+ ignore_index=pad_val,
587
+ )
588
+ log_probs = td["log_probs"]
539
589
 
540
590
  # Recover the list
541
591
  log_probs = _unpad_tensors(log_probs, pad_mask)
@@ -543,7 +593,7 @@ class TransformersWrapper(CategoricalSequential):
543
593
  return log_probs, logits
544
594
 
545
595
  @classmethod
546
- def _log_probs_generate(cls, sequences, logits, pad_val):
596
+ def _log_probs_generate(cls, sequences, logits, pad_val=-100):
547
597
  tokens = pad_sequence(
548
598
  sequences,
549
599
  padding_value=pad_val,
@@ -557,6 +607,12 @@ class TransformersWrapper(CategoricalSequential):
557
607
  padding_side="left",
558
608
  )
559
609
 
560
- logits = logits.log_softmax(dim=-1)
561
- log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
610
+ # logits = logits.log_softmax(dim=-1)
611
+ # log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
612
+ td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
613
+ with td.flatten() as tdflat:
614
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
615
+ tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
616
+ )
617
+ log_probs = td["log_probs"]
562
618
  return log_probs, logits
@@ -74,6 +74,9 @@ class vLLMWrapper(CategoricalSequential):
74
74
  conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
75
75
  otherwise.
76
76
 
77
+ chat_template_name (str | None, optional): The name of the chat template to use for the history. Defaults to `None`.
78
+ chat_template (str | None, optional): The chat template to use for the history. Defaults to `None`.
79
+
77
80
  .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
78
81
  required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
79
82
  tokenization and padding.
@@ -120,6 +123,7 @@ class vLLMWrapper(CategoricalSequential):
120
123
  token_response_key: NestedKey = ("tokens_response",)
121
124
  text_response_key: NestedKey = ("text_response",)
122
125
  attention_mask_key: NestedKey = ("attention_mask",)
126
+ history_key: NestedKey = ("history",)
123
127
 
124
128
  def __init__(
125
129
  self,
@@ -137,6 +141,8 @@ class vLLMWrapper(CategoricalSequential):
137
141
  tokenizer_kwargs: dict | None = None,
138
142
  pad_output: bool = False,
139
143
  inplace: Literal[True, False, "empty"] | None = None,
144
+ chat_template_name: str | None = None,
145
+ chat_template: str | None = None,
140
146
  ):
141
147
  super().__init__()
142
148
 
@@ -149,6 +155,8 @@ class vLLMWrapper(CategoricalSequential):
149
155
  self._device = device
150
156
  self.generate = generate
151
157
  self.pad_output = pad_output
158
+ self.chat_template_name = chat_template_name
159
+ self.chat_template = chat_template
152
160
  padding_value = None
153
161
 
154
162
  if not tokenizer_kwargs:
@@ -329,7 +337,12 @@ class vLLMWrapper(CategoricalSequential):
329
337
  history = td.get(self.history_key)
330
338
  if history is None:
331
339
  raise ValueError("No text or history provided to the vLLMWrapper.")
332
- text = history.apply_chat_template(self.tokenizer)
340
+ tokenizer_kwargs = {}
341
+ if self.chat_template_name is not None:
342
+ tokenizer_kwargs["chat_template_name"] = self.chat_template_name
343
+ if self.chat_template is not None:
344
+ tokenizer_kwargs["chat_template"] = self.chat_template
345
+ text = history.apply_chat_template(self.tokenizer, **tokenizer_kwargs)
333
346
  if self.pad_output:
334
347
  tokenizer_kwargs = self.tokenizer_kwargs
335
348
  if not isinstance(text, (list, str)):
@@ -385,15 +398,35 @@ class vLLMWrapper(CategoricalSequential):
385
398
 
386
399
  def _from_vllm_logprobs_text(self, td, sampling_params, out):
387
400
  text_prompt = td.get(self.text_key)
388
- if text_prompt is None:
401
+ text_response = td.get(self.text_response_key)
402
+ if text_response is None or text_prompt is None:
403
+ if text_response is not None and text_prompt is not None:
404
+ raise ValueError(
405
+ "No text or history provided to the vLLMWrapper. Either both are provided or none of them."
406
+ )
389
407
  # Fallback on history parsing
390
408
  history = td.get(self.history_key)
391
409
  if history is None:
392
- raise ValueError("No text or history provided to the vLLMWrapper.")
393
- text_prompt = history.apply_chat_template(self.tokenizer)
410
+ raise ValueError(
411
+ "No text or history provided to the TransformersWrapper."
412
+ )
413
+ tokenizer_kwargs = {}
414
+ if self.chat_template_name is not None:
415
+ tokenizer_kwargs.setdefault(
416
+ "chat_template_name", self.chat_template_name
417
+ )
418
+ if self.chat_template is not None:
419
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
420
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
421
+ text_response = history.apply_chat_template(
422
+ tokenizer=self.tokenizer, **tokenizer_kwargs
423
+ )
424
+ if isinstance(text_response, list):
425
+ text_prompt = ["" for _ in text_response]
426
+ else:
427
+ text_prompt = ""
394
428
  if not isinstance(text_prompt, list):
395
429
  text_prompt = text_prompt.tolist()
396
- text_response = td.get(self.text_response_key)
397
430
  if not isinstance(text_response, list):
398
431
  text_response = text_response.tolist()
399
432
  text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)]
@@ -5,5 +5,6 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
8
+ from .sft import SFTLoss, SFTLossOutput
8
9
 
9
- __all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"]
10
+ __all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]