torchrl-nightly 2025.7.16__cp312-cp312-macosx_10_13_universal2.whl → 2025.7.18__cp312-cp312-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-312-darwin.so +0 -0
- torchrl/collectors/collectors.py +13 -3
- torchrl/data/llm/history.py +36 -0
- torchrl/modules/distributions/discrete.py +1 -1
- torchrl/modules/llm/policies/common.py +37 -15
- torchrl/modules/llm/policies/transformers_wrapper.py +90 -53
- torchrl/modules/llm/policies/vllm_wrapper.py +50 -23
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/RECORD +13 -13
- {torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/collectors/collectors.py
CHANGED
@@ -686,6 +686,10 @@ class SyncDataCollector(DataCollectorBase):
|
|
686
686
|
policy = RandomPolicy(env.full_action_spec)
|
687
687
|
elif policy_factory is not None:
|
688
688
|
raise TypeError("policy_factory cannot be used with policy argument.")
|
689
|
+
# If the underlying policy has a state_dict, we keep a reference to the policy and
|
690
|
+
# do all policy weight saving/loading through it
|
691
|
+
if hasattr(policy, "state_dict"):
|
692
|
+
self._policy_w_state_dict = policy
|
689
693
|
|
690
694
|
if trust_policy is None:
|
691
695
|
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
|
@@ -1686,8 +1690,8 @@ class SyncDataCollector(DataCollectorBase):
|
|
1686
1690
|
else:
|
1687
1691
|
env_state_dict = OrderedDict()
|
1688
1692
|
|
1689
|
-
if hasattr(self
|
1690
|
-
policy_state_dict = self.
|
1693
|
+
if hasattr(self, "_policy_w_state_dict"):
|
1694
|
+
policy_state_dict = self._policy_w_state_dict.state_dict()
|
1691
1695
|
state_dict = OrderedDict(
|
1692
1696
|
policy_state_dict=policy_state_dict,
|
1693
1697
|
env_state_dict=env_state_dict,
|
@@ -1711,7 +1715,13 @@ class SyncDataCollector(DataCollectorBase):
|
|
1711
1715
|
if strict or "env_state_dict" in state_dict:
|
1712
1716
|
self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
|
1713
1717
|
if strict or "policy_state_dict" in state_dict:
|
1714
|
-
self
|
1718
|
+
if not hasattr(self, "_policy_w_state_dict"):
|
1719
|
+
raise ValueError(
|
1720
|
+
"Underlying policy does not have state_dict to load policy_state_dict into."
|
1721
|
+
)
|
1722
|
+
self._policy_w_state_dict.load_state_dict(
|
1723
|
+
state_dict["policy_state_dict"], **kwargs
|
1724
|
+
)
|
1715
1725
|
self._frames = state_dict["frames"]
|
1716
1726
|
self._iter = state_dict["iter"]
|
1717
1727
|
|
torchrl/data/llm/history.py
CHANGED
@@ -713,6 +713,42 @@ class History(TensorClass["nocast"]):
|
|
713
713
|
| transformers.AutoProcessor # noqa: F821
|
714
714
|
| None = None,
|
715
715
|
) -> History:
|
716
|
+
r"""Inverts a chat template into a History object.
|
717
|
+
|
718
|
+
Args:
|
719
|
+
text (str | list[str]): The chat template to invert.
|
720
|
+
chat_template_name (str, optional): The name of the chat template to use.
|
721
|
+
tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
|
722
|
+
|
723
|
+
Returns:
|
724
|
+
History: The inverted History object.
|
725
|
+
|
726
|
+
Examples:
|
727
|
+
>>> from torchrl.data.llm.history import History
|
728
|
+
>>> from transformers import AutoTokenizer
|
729
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
730
|
+
>>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
|
731
|
+
>>> history = History.from_text(text, tokenizer=tokenizer)
|
732
|
+
>>> print(history)
|
733
|
+
History(
|
734
|
+
content=NonTensorStack(
|
735
|
+
['You are a helpful assistant.', 'Write a python s...,
|
736
|
+
batch_size=torch.Size([3]),
|
737
|
+
device=None),
|
738
|
+
is_complete=NonTensorStack(
|
739
|
+
[True, True, False],
|
740
|
+
batch_size=torch.Size([3]),
|
741
|
+
device=None),
|
742
|
+
role=NonTensorStack(
|
743
|
+
['system', 'user', 'assistant'],
|
744
|
+
batch_size=torch.Size([3]),
|
745
|
+
device=None),
|
746
|
+
tool_calls=None,
|
747
|
+
tool_responses=None,
|
748
|
+
batch_size=torch.Size([3]),
|
749
|
+
device=None,
|
750
|
+
is_shared=False)
|
751
|
+
"""
|
716
752
|
if chat_template_name is None:
|
717
753
|
if chat_template is not None:
|
718
754
|
# TODO: find best match given template
|
@@ -352,7 +352,7 @@ class MaskedCategorical(D.Categorical):
|
|
352
352
|
logits = self.logits
|
353
353
|
if logits.ndim > 2:
|
354
354
|
# Bring channels in 2nd dim
|
355
|
-
logits = logits.
|
355
|
+
logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
|
356
356
|
original_value_shape = None
|
357
357
|
if logits.ndim == 1 and value.ndim >= 1:
|
358
358
|
if value.ndim >= 2:
|
@@ -9,8 +9,8 @@ import weakref
|
|
9
9
|
from typing import Any, Literal, overload
|
10
10
|
|
11
11
|
import torch
|
12
|
-
from tensordict import NestedKey, TensorDictBase
|
13
|
-
from tensordict.nn import TensorDictModuleBase
|
12
|
+
from tensordict import lazy_stack, NestedKey, TensorDictBase
|
13
|
+
from tensordict.nn import TensorDictModuleBase
|
14
14
|
from tensordict.tensorclass import TensorClass
|
15
15
|
from tensordict.utils import _zip_strict
|
16
16
|
from torch import distributions as D
|
@@ -175,29 +175,35 @@ class ChatHistory(TensorClass["nocast"]):
|
|
175
175
|
def __post_init__(self):
|
176
176
|
# Check that all history objects have one more batch dimension than the ChatHistory object
|
177
177
|
if self.prompt is not None:
|
178
|
-
if self.prompt
|
178
|
+
if getattr(self.prompt, "batch_dims", None) == self.batch_dims:
|
179
179
|
warnings.warn(
|
180
180
|
"Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
181
181
|
f"got {self.prompt.batch_dims} and {self.batch_dims}. "
|
182
182
|
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
183
183
|
)
|
184
|
-
self.prompt =
|
184
|
+
self.prompt = lazy_stack(
|
185
|
+
[self.prompt], -1
|
186
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
185
187
|
if self.response is not None:
|
186
|
-
if self.response
|
188
|
+
if getattr(self.response, "batch_dims", None) == self.batch_dims:
|
187
189
|
warnings.warn(
|
188
190
|
"Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
189
191
|
f"got {self.response.batch_dims} and {self.batch_dims}. "
|
190
192
|
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
191
193
|
)
|
192
|
-
self.response =
|
194
|
+
self.response = lazy_stack(
|
195
|
+
[self.response], -1
|
196
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
193
197
|
if self.full is not None:
|
194
|
-
if self.full
|
198
|
+
if getattr(self.full, "batch_dims", None) == self.batch_dims:
|
195
199
|
warnings.warn(
|
196
200
|
"Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
197
201
|
f"got {self.full.batch_dims} and {self.batch_dims}. "
|
198
202
|
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
199
203
|
)
|
200
|
-
self.full =
|
204
|
+
self.full = lazy_stack(
|
205
|
+
[self.full], -1
|
206
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
201
207
|
|
202
208
|
|
203
209
|
class LogProbs(TensorClass["nocast"]):
|
@@ -482,7 +488,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
482
488
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
483
489
|
)
|
484
490
|
|
485
|
-
td_out = self(tensordict.copy())
|
491
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
486
492
|
|
487
493
|
# Get logits/log-probs
|
488
494
|
if as_padded_tensor is None:
|
@@ -557,7 +563,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
557
563
|
"get_dist_with_prompt_mask is not implemented for generate=True. "
|
558
564
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
559
565
|
)
|
560
|
-
td_out = self(tensordict.copy())
|
566
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
561
567
|
|
562
568
|
# Try to get prompt tokens first
|
563
569
|
if self.pad_output:
|
@@ -668,7 +674,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
668
674
|
"get_dist_with_assistant_mask is not implemented for generate=True. "
|
669
675
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
670
676
|
)
|
671
|
-
td_out = self(tensordict.copy())
|
677
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
672
678
|
# Update the tokens key to reflect the tokenized history when querying the log-probs
|
673
679
|
tensordict.update(
|
674
680
|
td_out,
|
@@ -737,7 +743,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
737
743
|
"get_dist_with_attention_mask is not implemented for generate=True. "
|
738
744
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
739
745
|
)
|
740
|
-
td_out = self(tensordict.copy())
|
746
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
741
747
|
if self.pad_output:
|
742
748
|
logits = td_out.get(logits_key)
|
743
749
|
attention_mask = td_out.get(attention_mask_key)
|
@@ -794,7 +800,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
794
800
|
"get_dist_with_custom_mask is not implemented for generate=True. "
|
795
801
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
796
802
|
)
|
797
|
-
td_out = self(tensordict.copy())
|
803
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
798
804
|
if self.pad_output:
|
799
805
|
logits = td_out.get(logits_key)
|
800
806
|
else:
|
@@ -841,8 +847,24 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
841
847
|
"""
|
842
848
|
return self._get_dist_with_attention_mask(tensordict, **kwargs)
|
843
849
|
|
844
|
-
|
845
|
-
|
850
|
+
def forward(
|
851
|
+
self,
|
852
|
+
tensordict: TensorDictBase,
|
853
|
+
*,
|
854
|
+
tensordict_out: TensorDictBase | None = None,
|
855
|
+
logits_only: bool = False,
|
856
|
+
**kwargs,
|
857
|
+
) -> TensorDictBase: # noqa: D417
|
858
|
+
"""Forward pass for the LLM policy.
|
859
|
+
|
860
|
+
Args:
|
861
|
+
tensordict (TensorDictBase): The input tensordict.
|
862
|
+
|
863
|
+
Keyword Args:
|
864
|
+
tensordict_out (TensorDictBase | None): The output tensordict.
|
865
|
+
logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
|
866
|
+
"""
|
867
|
+
raise NotImplementedError
|
846
868
|
|
847
869
|
def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
|
848
870
|
"""Check that a value is a padded tensor."""
|
@@ -13,6 +13,7 @@ from typing import Literal
|
|
13
13
|
import torch
|
14
14
|
from tensordict import (
|
15
15
|
lazy_stack,
|
16
|
+
LazyStackedTensorDict,
|
16
17
|
MetaData,
|
17
18
|
NonTensorStack,
|
18
19
|
set_list_to_stack,
|
@@ -468,19 +469,32 @@ class TransformersWrapper(LLMWrapperBase):
|
|
468
469
|
def forward(
|
469
470
|
self,
|
470
471
|
tensordict: TensorDictBase,
|
472
|
+
*,
|
471
473
|
tensordict_out: TensorDictBase | None = None,
|
474
|
+
logits_only: bool = False,
|
472
475
|
**kwargs,
|
473
476
|
) -> TensorDictBase:
|
477
|
+
tensordict_orig = tensordict
|
474
478
|
if not tensordict.ndim:
|
479
|
+
if tensordict_out is not None:
|
480
|
+
raise ValueError(
|
481
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
482
|
+
"please submit an issue on github."
|
483
|
+
)
|
475
484
|
# unsqueeze - squeeze the input
|
476
|
-
|
477
|
-
return self(lazy_stack([tensordict])).squeeze(0)
|
478
|
-
except Exception as e:
|
479
|
-
raise RuntimeError(
|
480
|
-
f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
|
481
|
-
) from e
|
485
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
482
486
|
elif tensordict.ndim > 1:
|
483
|
-
|
487
|
+
if tensordict_out is not None:
|
488
|
+
raise ValueError(
|
489
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
490
|
+
"please submit an issue on github."
|
491
|
+
)
|
492
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
493
|
+
tensordict.shape
|
494
|
+
)
|
495
|
+
|
496
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
497
|
+
tensordict = tensordict.to_lazystack(0)
|
484
498
|
|
485
499
|
_source_device = None
|
486
500
|
if self._device:
|
@@ -517,17 +531,23 @@ class TransformersWrapper(LLMWrapperBase):
|
|
517
531
|
if self.generate:
|
518
532
|
out = self._from_transformers_generate_history(tensordict, cfg, out)
|
519
533
|
else:
|
520
|
-
out = self._from_transformers_logprobs_history(
|
534
|
+
out = self._from_transformers_logprobs_history(
|
535
|
+
tensordict, cfg, out, logits_only=logits_only
|
536
|
+
)
|
521
537
|
elif self.input_mode == "text":
|
522
538
|
if self.generate:
|
523
539
|
out = self._from_transformers_generate_text(tensordict, cfg, out)
|
524
540
|
else:
|
525
|
-
out = self._from_transformers_logprobs_text(
|
541
|
+
out = self._from_transformers_logprobs_text(
|
542
|
+
tensordict, cfg, out, logits_only=logits_only
|
543
|
+
)
|
526
544
|
elif self.input_mode == "tokens":
|
527
545
|
if self.generate:
|
528
546
|
out = self._from_transformers_generate_tokens(tensordict, cfg, out)
|
529
547
|
else:
|
530
|
-
out = self._from_transformers_logprobs_tokens(
|
548
|
+
out = self._from_transformers_logprobs_tokens(
|
549
|
+
tensordict, cfg, out, logits_only=logits_only
|
550
|
+
)
|
531
551
|
|
532
552
|
if _source_device:
|
533
553
|
out = out.to(_source_device)
|
@@ -535,7 +555,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
535
555
|
if tensordict_out is None:
|
536
556
|
if self.inplace is True:
|
537
557
|
# The output is the input
|
538
|
-
tensordict_out =
|
558
|
+
tensordict_out = tensordict_orig
|
539
559
|
elif self.inplace is False:
|
540
560
|
# The output is the new structure
|
541
561
|
tensordict_out = out
|
@@ -690,7 +710,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
690
710
|
result.set(self.history_key, history_chat)
|
691
711
|
return result
|
692
712
|
|
693
|
-
def _from_transformers_logprobs_history(self, td, cfg, out):
|
713
|
+
def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
|
694
714
|
"""Compute log-probs from history input."""
|
695
715
|
from torchrl.data.llm import History
|
696
716
|
|
@@ -731,7 +751,9 @@ class TransformersWrapper(LLMWrapperBase):
|
|
731
751
|
raise ValueError(
|
732
752
|
f"Expected TensorDictBase for history input, got {type(response_tokens)}"
|
733
753
|
)
|
734
|
-
result = self._logprobs_from_history_tokens(
|
754
|
+
result = self._logprobs_from_history_tokens(
|
755
|
+
response_tokens, cfg, out, logits_only=logits_only
|
756
|
+
)
|
735
757
|
text_result = Text._from_tensordict(result.empty())
|
736
758
|
result.set(self.text_key, text_result)
|
737
759
|
result[self.text_key, "full"] = text_full
|
@@ -952,7 +974,9 @@ class TransformersWrapper(LLMWrapperBase):
|
|
952
974
|
result = result.to(cast)
|
953
975
|
return result
|
954
976
|
|
955
|
-
def _logprobs_from_history_tokens(
|
977
|
+
def _logprobs_from_history_tokens(
|
978
|
+
self, response_tokens, cfg, out, logits_only=False
|
979
|
+
):
|
956
980
|
"""Compute log-probs from history tokens."""
|
957
981
|
pad_val = self.tokenizer.pad_token_id
|
958
982
|
|
@@ -996,6 +1020,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
996
1020
|
tokens_full_padded,
|
997
1021
|
attention_mask_full_padded,
|
998
1022
|
pad_val,
|
1023
|
+
logits_only=logits_only,
|
999
1024
|
)
|
1000
1025
|
|
1001
1026
|
# Build output TensorClass objects
|
@@ -1051,19 +1076,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1051
1076
|
tokens_obj.padded = MetaData(self.pad_output)
|
1052
1077
|
out.set(self.tokens_key, tokens_obj)
|
1053
1078
|
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
if self.pad_output:
|
1058
|
-
log_probs_obj.full = log_probs_full_padded
|
1059
|
-
else:
|
1060
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1061
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1079
|
+
if not logits_only:
|
1080
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1081
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1062
1082
|
)
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1083
|
+
if self.pad_output:
|
1084
|
+
log_probs_obj.full = log_probs_full_padded
|
1085
|
+
else:
|
1086
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1087
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1088
|
+
)
|
1089
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1090
|
+
log_probs_obj.response = None
|
1091
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1092
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1067
1093
|
|
1068
1094
|
# Add logits to output if we're in a get_dist call
|
1069
1095
|
if self._in_get_dist_call:
|
@@ -1095,7 +1121,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1095
1121
|
raise ValueError(f"Expected list of text for text input, got {type(text)}")
|
1096
1122
|
return self._generate_from_text(text, cfg, out)
|
1097
1123
|
|
1098
|
-
def _from_transformers_logprobs_text(self, td, cfg, out):
|
1124
|
+
def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
|
1099
1125
|
"""Compute log-probs from text input."""
|
1100
1126
|
# Validate input
|
1101
1127
|
if self.input_key not in td:
|
@@ -1168,6 +1194,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1168
1194
|
input_ids_full_padded,
|
1169
1195
|
attention_mask_full_padded,
|
1170
1196
|
self.tokenizer.pad_token_id,
|
1197
|
+
logits_only=logits_only,
|
1171
1198
|
)
|
1172
1199
|
|
1173
1200
|
# Build output TensorClass objects
|
@@ -1212,19 +1239,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1212
1239
|
masks_obj.padded = MetaData(self.pad_output)
|
1213
1240
|
out.set(self.masks_key, masks_obj)
|
1214
1241
|
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
if self.pad_output:
|
1219
|
-
log_probs_obj.full = log_probs_full_padded
|
1220
|
-
else:
|
1221
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1222
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1242
|
+
if not logits_only:
|
1243
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1244
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1223
1245
|
)
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1246
|
+
if self.pad_output:
|
1247
|
+
log_probs_obj.full = log_probs_full_padded
|
1248
|
+
else:
|
1249
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1250
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1251
|
+
)
|
1252
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1253
|
+
log_probs_obj.response = None
|
1254
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1255
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1228
1256
|
|
1229
1257
|
# Add logits to output if we're in a get_dist call
|
1230
1258
|
if self._in_get_dist_call:
|
@@ -1416,7 +1444,11 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1416
1444
|
return out
|
1417
1445
|
|
1418
1446
|
def _from_transformers_logprobs_tokens(
|
1419
|
-
self,
|
1447
|
+
self,
|
1448
|
+
td: TensorDictBase,
|
1449
|
+
cfg: dict | None,
|
1450
|
+
out: TensorDictBase,
|
1451
|
+
logits_only=False,
|
1420
1452
|
) -> TensorDictBase:
|
1421
1453
|
"""Compute log-probs from tokens input."""
|
1422
1454
|
# Validate input
|
@@ -1470,6 +1502,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1470
1502
|
input_ids_full_padded,
|
1471
1503
|
attention_mask_full_padded,
|
1472
1504
|
self.tokenizer.pad_token_id,
|
1505
|
+
logits_only=logits_only,
|
1473
1506
|
)
|
1474
1507
|
|
1475
1508
|
# Build output TensorClass objects
|
@@ -1514,19 +1547,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1514
1547
|
masks_obj.padded = MetaData(self.pad_output)
|
1515
1548
|
out.set(self.masks_key, masks_obj)
|
1516
1549
|
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
if self.pad_output:
|
1521
|
-
log_probs_obj.full = log_probs_full_padded
|
1522
|
-
else:
|
1523
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1524
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1550
|
+
if not logits_only:
|
1551
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1552
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1525
1553
|
)
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1554
|
+
if self.pad_output:
|
1555
|
+
log_probs_obj.full = log_probs_full_padded
|
1556
|
+
else:
|
1557
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1558
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1559
|
+
)
|
1560
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1561
|
+
log_probs_obj.response = None
|
1562
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1563
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1530
1564
|
|
1531
1565
|
# Add logits to output if we're in a get_dist call
|
1532
1566
|
if self._in_get_dist_call:
|
@@ -1567,7 +1601,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1567
1601
|
return log_probs, logits
|
1568
1602
|
|
1569
1603
|
def _compute_log_probs_from_model_output(
|
1570
|
-
self, model_output, input_ids, attention_mask, pad_val
|
1604
|
+
self, model_output, input_ids, attention_mask, pad_val, logits_only=False
|
1571
1605
|
):
|
1572
1606
|
"""Compute log-probs from model output without modifying original tensors.
|
1573
1607
|
|
@@ -1576,6 +1610,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1576
1610
|
input_ids: Original input token ids
|
1577
1611
|
attention_mask: Original attention mask
|
1578
1612
|
pad_val: Padding token value to ignore in loss computation
|
1613
|
+
logits_only: Whether to return only the logits.
|
1579
1614
|
|
1580
1615
|
Returns:
|
1581
1616
|
tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
|
@@ -1600,6 +1635,8 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1600
1635
|
raise ValueError(
|
1601
1636
|
f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
|
1602
1637
|
)
|
1638
|
+
if logits_only:
|
1639
|
+
return None, shifted_logits
|
1603
1640
|
|
1604
1641
|
# Compute log-probs
|
1605
1642
|
td = TensorDict(
|
@@ -11,6 +11,7 @@ from typing import Any, Literal
|
|
11
11
|
import torch
|
12
12
|
from tensordict import (
|
13
13
|
lazy_stack,
|
14
|
+
LazyStackedTensorDict,
|
14
15
|
MetaData,
|
15
16
|
NonTensorStack,
|
16
17
|
set_list_to_stack,
|
@@ -500,19 +501,32 @@ class vLLMWrapper(LLMWrapperBase):
|
|
500
501
|
def forward(
|
501
502
|
self,
|
502
503
|
tensordict: TensorDictBase,
|
504
|
+
*,
|
503
505
|
tensordict_out: TensorDictBase | None = None,
|
506
|
+
logits_only: bool = False,
|
504
507
|
**kwargs,
|
505
508
|
) -> TensorDictBase:
|
509
|
+
tensordict_orig = tensordict
|
506
510
|
if not tensordict.ndim:
|
511
|
+
if tensordict_out is not None:
|
512
|
+
raise ValueError(
|
513
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
514
|
+
"please submit an issue on github."
|
515
|
+
)
|
507
516
|
# unsqueeze - squeeze the input
|
508
|
-
|
509
|
-
return self(lazy_stack([tensordict])).squeeze(0)
|
510
|
-
except Exception as e:
|
511
|
-
raise RuntimeError(
|
512
|
-
f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
|
513
|
-
) from e
|
517
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
514
518
|
elif tensordict.ndim > 1:
|
515
|
-
|
519
|
+
if tensordict_out is not None:
|
520
|
+
raise ValueError(
|
521
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
522
|
+
"please submit an issue on github."
|
523
|
+
)
|
524
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
525
|
+
tensordict.shape
|
526
|
+
)
|
527
|
+
|
528
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
529
|
+
tensordict = tensordict.to_lazystack(0)
|
516
530
|
|
517
531
|
_source_device = None
|
518
532
|
if self._device:
|
@@ -567,7 +581,7 @@ class vLLMWrapper(LLMWrapperBase):
|
|
567
581
|
if tensordict_out is None:
|
568
582
|
if self.inplace is True:
|
569
583
|
# The output is the input
|
570
|
-
tensordict_out =
|
584
|
+
tensordict_out = tensordict_orig
|
571
585
|
elif self.inplace is False:
|
572
586
|
# The output is the new structure
|
573
587
|
tensordict_out = out
|
@@ -1242,12 +1256,14 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1242
1256
|
|
1243
1257
|
generate_kwargs = {"sampling_params": sampling_params}
|
1244
1258
|
args = ()
|
1259
|
+
empirical_attention_mask = None
|
1245
1260
|
|
1246
1261
|
if tokens_prompt_unpadded is None:
|
1247
1262
|
# TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
|
1248
1263
|
# the user passed an unpadded tensor in the first place.
|
1264
|
+
empirical_attention_mask = tokens_prompt_padded != self.padding_value
|
1249
1265
|
tokens_prompt_list = self._to_list(
|
1250
|
-
tokens_prompt_padded,
|
1266
|
+
tokens_prompt_padded, empirical_attention_mask
|
1251
1267
|
)
|
1252
1268
|
else:
|
1253
1269
|
tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
|
@@ -1365,6 +1381,22 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1365
1381
|
padding_value=self.padding_value,
|
1366
1382
|
padding_side="right",
|
1367
1383
|
)
|
1384
|
+
if (
|
1385
|
+
prompt_logprobs_padded.shape[-1]
|
1386
|
+
!= tokens_prompt_padded.shape[-1]
|
1387
|
+
):
|
1388
|
+
tshape = tokens_prompt_padded.shape
|
1389
|
+
oshape = prompt_logprobs_padded.shape
|
1390
|
+
# it could be that the input was padded already - padding again then
|
1391
|
+
prompt_logprobs_padded = torch.cat(
|
1392
|
+
[
|
1393
|
+
prompt_logprobs_padded.new_zeros(
|
1394
|
+
tshape[:-1] + (tshape[-1] - oshape[-1],)
|
1395
|
+
),
|
1396
|
+
prompt_logprobs_padded,
|
1397
|
+
],
|
1398
|
+
-1,
|
1399
|
+
)
|
1368
1400
|
else:
|
1369
1401
|
prompt_logprobs_list = request_output_tc.get(
|
1370
1402
|
"prompt_logprobs",
|
@@ -1490,26 +1522,21 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1490
1522
|
|
1491
1523
|
request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
|
1492
1524
|
|
1525
|
+
# For unpadded case, extract from each sequence
|
1526
|
+
log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
|
1527
|
+
|
1493
1528
|
# Extract log-probs from prompt_logprobs
|
1494
1529
|
if self.pad_output:
|
1495
1530
|
# For padded case, use all prompt_logprobs
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
padding_side="left",
|
1531
|
+
if attention_mask_full_padded is not None:
|
1532
|
+
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
1533
|
+
log_probs_full_padded = torch.zeros_like(
|
1534
|
+
tokens_full_padded, dtype=torch.get_default_dtype()
|
1501
1535
|
)
|
1502
|
-
|
1503
|
-
|
1504
|
-
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
1505
|
-
log_probs_full_padded = torch.where(
|
1506
|
-
attention_mask_full_padded, log_probs_full_padded, 0.0
|
1536
|
+
log_probs_full_padded[attention_mask_full_padded] = torch.cat(
|
1537
|
+
log_probs_full_unpadded, -1
|
1507
1538
|
)
|
1508
1539
|
else:
|
1509
|
-
# For unpadded case, extract from each sequence
|
1510
|
-
log_probs_full_unpadded = request_output_tc.get(
|
1511
|
-
"prompt_logprobs", as_list=True
|
1512
|
-
)
|
1513
1540
|
self._check_not_padded(log_probs_full_unpadded)
|
1514
1541
|
|
1515
1542
|
assistant_mask_full_padded = None
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.7.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.7.18'
|
2
|
+
git_version = '4001d9cb73cea4498b0fdfe420effc58a5a336be'
|
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
|
4
4
|
torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
|
5
5
|
torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
|
6
|
-
torchrl/_torchrl.cpython-312-darwin.so,sha256=
|
6
|
+
torchrl/_torchrl.cpython-312-darwin.so,sha256=k_jisocYRQ0Z1X52Pu2ym9-02BXMnLrD7q0fOpElam8,1692224
|
7
7
|
torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=MHs4CxNjQupYI_f84bY7dOAAfPSU9yN6TOyxiS7tS8c,83
|
9
9
|
torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
|
10
|
-
torchrl/collectors/collectors.py,sha256=
|
10
|
+
torchrl/collectors/collectors.py,sha256=HpaW-y0bQOaOql8_7VyEPJ084CulrVwn6iBpGYoHyH4,178287
|
11
11
|
torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
|
12
12
|
torchrl/collectors/weight_update.py,sha256=nSIfs8ALsfggLoC2ylg1oOAqdGku1tt4e-50JCZJBww,21073
|
13
13
|
torchrl/collectors/distributed/__init__.py,sha256=_24P0ALFunLhL-ls7EsssGUhJkZ_m3nw7krfMTwPqS0,705
|
@@ -41,7 +41,7 @@ torchrl/data/datasets/utils.py,sha256=nAFDTlBIPyEoPoJC-Hc_fcOhzE7UZQE4BwKxq15Vhv
|
|
41
41
|
torchrl/data/datasets/vd4rl.py,sha256=z90MqrxKzod8TPGK0uzkC6vw5wQIE4cgrDAC4e72jyk,18262
|
42
42
|
torchrl/data/llm/__init__.py,sha256=B4Ekok-w5PMiWcfmAGXaseaN6hWdNOr4WebeLrHfBVQ,975
|
43
43
|
torchrl/data/llm/dataset.py,sha256=t-41hAzQcjrdoKwpHIMbcrT7pRcQ7DHl2a1-lr6E7W4,20703
|
44
|
-
torchrl/data/llm/history.py,sha256=
|
44
|
+
torchrl/data/llm/history.py,sha256=l9JSxIO5eLUFwHH5IZkANSrByYa8BGmtxMlNXYf2fbs,59640
|
45
45
|
torchrl/data/llm/prompt.py,sha256=bg5LzJfwOq5Ns72KQMciIprMWAmDDinzdopwdopU04c,8380
|
46
46
|
torchrl/data/llm/reward.py,sha256=FbPchNXG3smJV9NCbB5Yk4grsCa2Se4KZ_tojVLKWQM,8404
|
47
47
|
torchrl/data/llm/topk.py,sha256=mYXCgJS4TuEVLZfTNccQd6kmC858AAh2Ygy0q_K1hlY,8365
|
@@ -139,7 +139,7 @@ torchrl/envs/transforms/vip.py,sha256=kmygbenw75rEYsKRq4X1hzEH_CRe1406NZZ8Hg2R_V
|
|
139
139
|
torchrl/modules/__init__.py,sha256=XlAO0hulhDQNcKhbu3cFi8KJOHXNiAgmXeTfny0WBqE,4157
|
140
140
|
torchrl/modules/distributions/__init__.py,sha256=RDFoYD9IX1FhwXk5R4M8khq42gdTOcVnUnKHfWCTZBQ,1597
|
141
141
|
torchrl/modules/distributions/continuous.py,sha256=VPBugDuavJmyZ-RzemyLIFA02UCMLsm-rzBQrKcTlIA,25667
|
142
|
-
torchrl/modules/distributions/discrete.py,sha256=
|
142
|
+
torchrl/modules/distributions/discrete.py,sha256=7UE6X8LeTZkaTRFvKNcFSOoug_tOcD_u-FOh-39ZSC4,35581
|
143
143
|
torchrl/modules/distributions/truncated_normal.py,sha256=-qM8vwxTzv3VsWphZwcueDQpHQ67IRnkDFKlTDkQQnY,5937
|
144
144
|
torchrl/modules/distributions/utils.py,sha256=kXRvNHeKUePIgKgn7DnKqbhQ6ImFGgkFVRxITX2dwNU,7567
|
145
145
|
torchrl/modules/llm/__init__.py,sha256=BTkn-8QKp_8sW_NTKP02yoWSJUsX0XL6L9chTJl6epc,737
|
@@ -147,9 +147,9 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
|
|
147
147
|
torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
|
148
148
|
torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
|
149
149
|
torchrl/modules/llm/policies/__init__.py,sha256=nfZ2mcVuucxnY3WCuzIQrTLIf1yEd36k8-AlvwnSa8Y,545
|
150
|
-
torchrl/modules/llm/policies/common.py,sha256=
|
151
|
-
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=
|
152
|
-
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=
|
150
|
+
torchrl/modules/llm/policies/common.py,sha256=Kvn1cJQbp1EZtxWpAQ50TzZkwVtLAmryqiBHH2nK_wM,39112
|
151
|
+
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=oi-2KALM0pkH-u-Kd6WlnxfH9eGV2GzBqM410ANpPeM,75777
|
152
|
+
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=ReBvi2M9IAiwwBAR7GpDLSQhX0aC-dXPnHYb082Q0To,79632
|
153
153
|
torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
|
154
154
|
torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
|
155
155
|
torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
|
@@ -223,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Y
|
|
223
223
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
224
224
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
225
225
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
226
|
-
torchrl_nightly-2025.7.
|
227
|
-
torchrl_nightly-2025.7.
|
228
|
-
torchrl_nightly-2025.7.
|
229
|
-
torchrl_nightly-2025.7.
|
230
|
-
torchrl_nightly-2025.7.
|
226
|
+
torchrl_nightly-2025.7.18.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
227
|
+
torchrl_nightly-2025.7.18.dist-info/METADATA,sha256=K_Nmn84sw1xeD28lqIPqdhLjdaFchSMXuG2vjAajTn0,42990
|
228
|
+
torchrl_nightly-2025.7.18.dist-info/WHEEL,sha256=9_3tTSxMJq-dgdzMiScNvtT5eTBVd3l6RgHS7HwTzpA,115
|
229
|
+
torchrl_nightly-2025.7.18.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
230
|
+
torchrl_nightly-2025.7.18.dist-info/RECORD,,
|
File without changes
|
{torchrl_nightly-2025.7.16.dist-info → torchrl_nightly-2025.7.18.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|