torchrl-nightly 2025.6.20__cp313-cp313-win_amd64.whl → 2025.6.22__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp313-win_amd64.pyd +0 -0
- torchrl/collectors/collectors.py +8 -5
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -8
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/RECORD +27 -25
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/top_level.txt +0 -0
@@ -65,6 +65,10 @@ class TransformersWrapper(CategoricalSequential):
|
|
65
65
|
operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
|
66
66
|
created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
|
67
67
|
conserve type, batch-size, and device). Defaults to `True`.
|
68
|
+
chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when
|
69
|
+
applying the chat template to the history. Defaults to `None`.
|
70
|
+
chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
|
71
|
+
Defaults to `None`.
|
68
72
|
|
69
73
|
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
|
70
74
|
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
|
@@ -131,6 +135,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
131
135
|
tokenizer_kwargs: dict | None = None,
|
132
136
|
pad_output: bool = True,
|
133
137
|
inplace: Literal[True, False, "empty"] | None = True,
|
138
|
+
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
139
|
+
chat_template: str | None = None,
|
134
140
|
):
|
135
141
|
super().__init__()
|
136
142
|
|
@@ -143,6 +149,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
143
149
|
self.inplace = inplace
|
144
150
|
self.pad_output = pad_output
|
145
151
|
padding_value = None
|
152
|
+
self.chat_template_name = chat_template_name
|
153
|
+
self.chat_template = chat_template
|
146
154
|
|
147
155
|
if not tokenizer_kwargs:
|
148
156
|
tokenizer_kwargs = {}
|
@@ -300,7 +308,17 @@ class TransformersWrapper(CategoricalSequential):
|
|
300
308
|
raise ValueError(
|
301
309
|
"No text or history provided to the TransformersWrapper."
|
302
310
|
)
|
303
|
-
|
311
|
+
tokenizer_kwargs = {}
|
312
|
+
if self.chat_template_name is not None:
|
313
|
+
tokenizer_kwargs.setdefault(
|
314
|
+
"chat_template_name", self.chat_template_name
|
315
|
+
)
|
316
|
+
if self.chat_template is not None:
|
317
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
318
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
319
|
+
text = history.apply_chat_template(
|
320
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
321
|
+
)
|
304
322
|
if not isinstance(text, (list, str)):
|
305
323
|
text = text.tolist()
|
306
324
|
tokens_in = self.tokenizer(text, **self.tokenizer_kwargs)
|
@@ -325,7 +343,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
325
343
|
logits = torch.stack(list(tokens_out["logits"]), 1)
|
326
344
|
logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
|
327
345
|
log_probs, logits = self._log_probs_generate(
|
328
|
-
sequences, logits, pad_val
|
346
|
+
sequences, logits, pad_val=-100
|
329
347
|
)
|
330
348
|
response_text = self.tokenizer.batch_decode(
|
331
349
|
sequences, skip_special_tokens=False
|
@@ -407,17 +425,36 @@ class TransformersWrapper(CategoricalSequential):
|
|
407
425
|
pad_val = self.tokenizer.pad_token_id
|
408
426
|
|
409
427
|
prompt_txt = td.get(self.text_key)
|
410
|
-
|
428
|
+
response_txt = td.get(self.text_response_key)
|
429
|
+
if prompt_txt is None or response_txt is None:
|
430
|
+
if prompt_txt is not None and response_txt is not None:
|
431
|
+
raise ValueError(
|
432
|
+
"No text or history provided to the TransformersWrapper. Either both are provided or none of them."
|
433
|
+
)
|
411
434
|
# Fallback on history parsing
|
412
435
|
history = td.get(self.history_key)
|
413
436
|
if history is None:
|
414
437
|
raise ValueError(
|
415
438
|
"No text or history provided to the TransformersWrapper."
|
416
439
|
)
|
417
|
-
|
440
|
+
tokenizer_kwargs = {}
|
441
|
+
if self.chat_template_name is not None:
|
442
|
+
tokenizer_kwargs.setdefault(
|
443
|
+
"chat_template_name", self.chat_template_name
|
444
|
+
)
|
445
|
+
if self.chat_template is not None:
|
446
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
447
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
448
|
+
response_txt = history.apply_chat_template(
|
449
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
450
|
+
)
|
451
|
+
if isinstance(response_txt, list):
|
452
|
+
prompt_txt = ["" for _ in response_txt]
|
453
|
+
else:
|
454
|
+
prompt_txt = ""
|
455
|
+
|
418
456
|
if not isinstance(prompt_txt, (list, str)):
|
419
457
|
prompt_txt = prompt_txt.tolist()
|
420
|
-
response_txt = td.get(self.text_response_key)
|
421
458
|
if not isinstance(response_txt, (list, str)):
|
422
459
|
response_txt = response_txt.tolist()
|
423
460
|
total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)]
|
@@ -450,6 +487,8 @@ class TransformersWrapper(CategoricalSequential):
|
|
450
487
|
)
|
451
488
|
sequences = [
|
452
489
|
_total_input_ids[_prompt_input_ids.shape[-1] :]
|
490
|
+
if _prompt_input_ids.shape[-1] > 0
|
491
|
+
else _total_input_ids
|
453
492
|
for _total_input_ids, _prompt_input_ids in zip(
|
454
493
|
total_input_ids, prompt_input_ids
|
455
494
|
)
|
@@ -484,7 +523,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
484
523
|
|
485
524
|
total_input_ids = [
|
486
525
|
torch.cat([_prompt_input_ids, _response_input_ids], -1)
|
487
|
-
for _prompt_input_ids, _response_input_ids in
|
526
|
+
for _prompt_input_ids, _response_input_ids in _zip_strict(
|
488
527
|
prompt_input_ids, response_input_ids
|
489
528
|
)
|
490
529
|
]
|
@@ -512,7 +551,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
512
551
|
total_input_ids, attention_mask=total_attention_mask, **kwargs
|
513
552
|
)
|
514
553
|
log_probs, logits = self._log_probs_from_logits(
|
515
|
-
total_tokens_out, response_input_ids, pad_val
|
554
|
+
total_tokens_out, response_input_ids, pad_val=-100
|
516
555
|
)
|
517
556
|
# for i in range(log_probs.size(0)):
|
518
557
|
# assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1]
|
@@ -522,7 +561,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
522
561
|
return out
|
523
562
|
|
524
563
|
@classmethod
|
525
|
-
def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val):
|
564
|
+
def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val=-100):
|
526
565
|
response_input_ids = pad_sequence(
|
527
566
|
response_input_ids,
|
528
567
|
padding_value=pad_val,
|
@@ -532,10 +571,21 @@ class TransformersWrapper(CategoricalSequential):
|
|
532
571
|
pad_mask = response_input_ids != pad_val
|
533
572
|
|
534
573
|
logits = total_tokens_out["logits"]
|
535
|
-
logits = logits.log_softmax(dim=-1)
|
536
|
-
|
537
|
-
|
538
|
-
|
574
|
+
# logits = logits.log_softmax(dim=-1)
|
575
|
+
if logits.shape[-2] != response_input_ids.shape[-1]:
|
576
|
+
logits = logits[..., -response_input_ids.shape[-1] - 1 : -1, :]
|
577
|
+
|
578
|
+
td = TensorDict(
|
579
|
+
logits=logits, response_input_ids=response_input_ids
|
580
|
+
).auto_batch_size_()
|
581
|
+
with td.flatten() as tdflat:
|
582
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
583
|
+
tdflat["logits"],
|
584
|
+
tdflat["response_input_ids"],
|
585
|
+
reduce=False,
|
586
|
+
ignore_index=pad_val,
|
587
|
+
)
|
588
|
+
log_probs = td["log_probs"]
|
539
589
|
|
540
590
|
# Recover the list
|
541
591
|
log_probs = _unpad_tensors(log_probs, pad_mask)
|
@@ -543,7 +593,7 @@ class TransformersWrapper(CategoricalSequential):
|
|
543
593
|
return log_probs, logits
|
544
594
|
|
545
595
|
@classmethod
|
546
|
-
def _log_probs_generate(cls, sequences, logits, pad_val):
|
596
|
+
def _log_probs_generate(cls, sequences, logits, pad_val=-100):
|
547
597
|
tokens = pad_sequence(
|
548
598
|
sequences,
|
549
599
|
padding_value=pad_val,
|
@@ -557,6 +607,12 @@ class TransformersWrapper(CategoricalSequential):
|
|
557
607
|
padding_side="left",
|
558
608
|
)
|
559
609
|
|
560
|
-
logits = logits.log_softmax(dim=-1)
|
561
|
-
log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
|
610
|
+
# logits = logits.log_softmax(dim=-1)
|
611
|
+
# log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
|
612
|
+
td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
|
613
|
+
with td.flatten() as tdflat:
|
614
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
615
|
+
tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
|
616
|
+
)
|
617
|
+
log_probs = td["log_probs"]
|
562
618
|
return log_probs, logits
|
@@ -74,6 +74,9 @@ class vLLMWrapper(CategoricalSequential):
|
|
74
74
|
conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
|
75
75
|
otherwise.
|
76
76
|
|
77
|
+
chat_template_name (str | None, optional): The name of the chat template to use for the history. Defaults to `None`.
|
78
|
+
chat_template (str | None, optional): The chat template to use for the history. Defaults to `None`.
|
79
|
+
|
77
80
|
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
|
78
81
|
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
|
79
82
|
tokenization and padding.
|
@@ -120,6 +123,7 @@ class vLLMWrapper(CategoricalSequential):
|
|
120
123
|
token_response_key: NestedKey = ("tokens_response",)
|
121
124
|
text_response_key: NestedKey = ("text_response",)
|
122
125
|
attention_mask_key: NestedKey = ("attention_mask",)
|
126
|
+
history_key: NestedKey = ("history",)
|
123
127
|
|
124
128
|
def __init__(
|
125
129
|
self,
|
@@ -137,6 +141,8 @@ class vLLMWrapper(CategoricalSequential):
|
|
137
141
|
tokenizer_kwargs: dict | None = None,
|
138
142
|
pad_output: bool = False,
|
139
143
|
inplace: Literal[True, False, "empty"] | None = None,
|
144
|
+
chat_template_name: str | None = None,
|
145
|
+
chat_template: str | None = None,
|
140
146
|
):
|
141
147
|
super().__init__()
|
142
148
|
|
@@ -149,6 +155,8 @@ class vLLMWrapper(CategoricalSequential):
|
|
149
155
|
self._device = device
|
150
156
|
self.generate = generate
|
151
157
|
self.pad_output = pad_output
|
158
|
+
self.chat_template_name = chat_template_name
|
159
|
+
self.chat_template = chat_template
|
152
160
|
padding_value = None
|
153
161
|
|
154
162
|
if not tokenizer_kwargs:
|
@@ -329,7 +337,12 @@ class vLLMWrapper(CategoricalSequential):
|
|
329
337
|
history = td.get(self.history_key)
|
330
338
|
if history is None:
|
331
339
|
raise ValueError("No text or history provided to the vLLMWrapper.")
|
332
|
-
|
340
|
+
tokenizer_kwargs = {}
|
341
|
+
if self.chat_template_name is not None:
|
342
|
+
tokenizer_kwargs["chat_template_name"] = self.chat_template_name
|
343
|
+
if self.chat_template is not None:
|
344
|
+
tokenizer_kwargs["chat_template"] = self.chat_template
|
345
|
+
text = history.apply_chat_template(self.tokenizer, **tokenizer_kwargs)
|
333
346
|
if self.pad_output:
|
334
347
|
tokenizer_kwargs = self.tokenizer_kwargs
|
335
348
|
if not isinstance(text, (list, str)):
|
@@ -385,15 +398,35 @@ class vLLMWrapper(CategoricalSequential):
|
|
385
398
|
|
386
399
|
def _from_vllm_logprobs_text(self, td, sampling_params, out):
|
387
400
|
text_prompt = td.get(self.text_key)
|
388
|
-
|
401
|
+
text_response = td.get(self.text_response_key)
|
402
|
+
if text_response is None or text_prompt is None:
|
403
|
+
if text_response is not None and text_prompt is not None:
|
404
|
+
raise ValueError(
|
405
|
+
"No text or history provided to the vLLMWrapper. Either both are provided or none of them."
|
406
|
+
)
|
389
407
|
# Fallback on history parsing
|
390
408
|
history = td.get(self.history_key)
|
391
409
|
if history is None:
|
392
|
-
raise ValueError(
|
393
|
-
|
410
|
+
raise ValueError(
|
411
|
+
"No text or history provided to the TransformersWrapper."
|
412
|
+
)
|
413
|
+
tokenizer_kwargs = {}
|
414
|
+
if self.chat_template_name is not None:
|
415
|
+
tokenizer_kwargs.setdefault(
|
416
|
+
"chat_template_name", self.chat_template_name
|
417
|
+
)
|
418
|
+
if self.chat_template is not None:
|
419
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
420
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
421
|
+
text_response = history.apply_chat_template(
|
422
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
423
|
+
)
|
424
|
+
if isinstance(text_response, list):
|
425
|
+
text_prompt = ["" for _ in text_response]
|
426
|
+
else:
|
427
|
+
text_prompt = ""
|
394
428
|
if not isinstance(text_prompt, list):
|
395
429
|
text_prompt = text_prompt.tolist()
|
396
|
-
text_response = td.get(self.text_response_key)
|
397
430
|
if not isinstance(text_response, list):
|
398
431
|
text_response = text_response.tolist()
|
399
432
|
text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)]
|
@@ -5,5 +5,6 @@
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
7
|
from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
|
8
|
+
from .sft import SFTLoss, SFTLossOutput
|
8
9
|
|
9
|
-
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"]
|
10
|
+
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]
|