torchrl-nightly 2025.6.20__cp39-cp39-win_amd64.whl → 2025.6.22__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,6 +65,10 @@ class TransformersWrapper(CategoricalSequential):
65
65
  operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
66
66
  created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
67
67
  conserve type, batch-size, and device). Defaults to `True`.
68
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when
69
+ applying the chat template to the history. Defaults to `None`.
70
+ chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
71
+ Defaults to `None`.
68
72
 
69
73
  .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
70
74
  required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
@@ -131,6 +135,8 @@ class TransformersWrapper(CategoricalSequential):
131
135
  tokenizer_kwargs: dict | None = None,
132
136
  pad_output: bool = True,
133
137
  inplace: Literal[True, False, "empty"] | None = True,
138
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
139
+ chat_template: str | None = None,
134
140
  ):
135
141
  super().__init__()
136
142
 
@@ -143,6 +149,8 @@ class TransformersWrapper(CategoricalSequential):
143
149
  self.inplace = inplace
144
150
  self.pad_output = pad_output
145
151
  padding_value = None
152
+ self.chat_template_name = chat_template_name
153
+ self.chat_template = chat_template
146
154
 
147
155
  if not tokenizer_kwargs:
148
156
  tokenizer_kwargs = {}
@@ -300,7 +308,17 @@ class TransformersWrapper(CategoricalSequential):
300
308
  raise ValueError(
301
309
  "No text or history provided to the TransformersWrapper."
302
310
  )
303
- text = history.apply_chat_template(self.tokenizer)
311
+ tokenizer_kwargs = {}
312
+ if self.chat_template_name is not None:
313
+ tokenizer_kwargs.setdefault(
314
+ "chat_template_name", self.chat_template_name
315
+ )
316
+ if self.chat_template is not None:
317
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
318
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
319
+ text = history.apply_chat_template(
320
+ tokenizer=self.tokenizer, **tokenizer_kwargs
321
+ )
304
322
  if not isinstance(text, (list, str)):
305
323
  text = text.tolist()
306
324
  tokens_in = self.tokenizer(text, **self.tokenizer_kwargs)
@@ -325,7 +343,7 @@ class TransformersWrapper(CategoricalSequential):
325
343
  logits = torch.stack(list(tokens_out["logits"]), 1)
326
344
  logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
327
345
  log_probs, logits = self._log_probs_generate(
328
- sequences, logits, pad_val=pad_val
346
+ sequences, logits, pad_val=-100
329
347
  )
330
348
  response_text = self.tokenizer.batch_decode(
331
349
  sequences, skip_special_tokens=False
@@ -407,17 +425,36 @@ class TransformersWrapper(CategoricalSequential):
407
425
  pad_val = self.tokenizer.pad_token_id
408
426
 
409
427
  prompt_txt = td.get(self.text_key)
410
- if prompt_txt is None:
428
+ response_txt = td.get(self.text_response_key)
429
+ if prompt_txt is None or response_txt is None:
430
+ if prompt_txt is not None and response_txt is not None:
431
+ raise ValueError(
432
+ "No text or history provided to the TransformersWrapper. Either both are provided or none of them."
433
+ )
411
434
  # Fallback on history parsing
412
435
  history = td.get(self.history_key)
413
436
  if history is None:
414
437
  raise ValueError(
415
438
  "No text or history provided to the TransformersWrapper."
416
439
  )
417
- prompt_txt = history.apply_chat_template(self.tokenizer)
440
+ tokenizer_kwargs = {}
441
+ if self.chat_template_name is not None:
442
+ tokenizer_kwargs.setdefault(
443
+ "chat_template_name", self.chat_template_name
444
+ )
445
+ if self.chat_template is not None:
446
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
447
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
448
+ response_txt = history.apply_chat_template(
449
+ tokenizer=self.tokenizer, **tokenizer_kwargs
450
+ )
451
+ if isinstance(response_txt, list):
452
+ prompt_txt = ["" for _ in response_txt]
453
+ else:
454
+ prompt_txt = ""
455
+
418
456
  if not isinstance(prompt_txt, (list, str)):
419
457
  prompt_txt = prompt_txt.tolist()
420
- response_txt = td.get(self.text_response_key)
421
458
  if not isinstance(response_txt, (list, str)):
422
459
  response_txt = response_txt.tolist()
423
460
  total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)]
@@ -450,6 +487,8 @@ class TransformersWrapper(CategoricalSequential):
450
487
  )
451
488
  sequences = [
452
489
  _total_input_ids[_prompt_input_ids.shape[-1] :]
490
+ if _prompt_input_ids.shape[-1] > 0
491
+ else _total_input_ids
453
492
  for _total_input_ids, _prompt_input_ids in zip(
454
493
  total_input_ids, prompt_input_ids
455
494
  )
@@ -484,7 +523,7 @@ class TransformersWrapper(CategoricalSequential):
484
523
 
485
524
  total_input_ids = [
486
525
  torch.cat([_prompt_input_ids, _response_input_ids], -1)
487
- for _prompt_input_ids, _response_input_ids in zip(
526
+ for _prompt_input_ids, _response_input_ids in _zip_strict(
488
527
  prompt_input_ids, response_input_ids
489
528
  )
490
529
  ]
@@ -512,7 +551,7 @@ class TransformersWrapper(CategoricalSequential):
512
551
  total_input_ids, attention_mask=total_attention_mask, **kwargs
513
552
  )
514
553
  log_probs, logits = self._log_probs_from_logits(
515
- total_tokens_out, response_input_ids, pad_val=pad_val
554
+ total_tokens_out, response_input_ids, pad_val=-100
516
555
  )
517
556
  # for i in range(log_probs.size(0)):
518
557
  # assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1]
@@ -522,7 +561,7 @@ class TransformersWrapper(CategoricalSequential):
522
561
  return out
523
562
 
524
563
  @classmethod
525
- def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val):
564
+ def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val=-100):
526
565
  response_input_ids = pad_sequence(
527
566
  response_input_ids,
528
567
  padding_value=pad_val,
@@ -532,10 +571,21 @@ class TransformersWrapper(CategoricalSequential):
532
571
  pad_mask = response_input_ids != pad_val
533
572
 
534
573
  logits = total_tokens_out["logits"]
535
- logits = logits.log_softmax(dim=-1)
536
- logits = logits[:, -response_input_ids.shape[-1] - 1 : -1, :]
537
-
538
- log_probs = logits.gather(-1, response_input_ids.unsqueeze(-1)).squeeze(-1)
574
+ # logits = logits.log_softmax(dim=-1)
575
+ if logits.shape[-2] != response_input_ids.shape[-1]:
576
+ logits = logits[..., -response_input_ids.shape[-1] - 1 : -1, :]
577
+
578
+ td = TensorDict(
579
+ logits=logits, response_input_ids=response_input_ids
580
+ ).auto_batch_size_()
581
+ with td.flatten() as tdflat:
582
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
583
+ tdflat["logits"],
584
+ tdflat["response_input_ids"],
585
+ reduce=False,
586
+ ignore_index=pad_val,
587
+ )
588
+ log_probs = td["log_probs"]
539
589
 
540
590
  # Recover the list
541
591
  log_probs = _unpad_tensors(log_probs, pad_mask)
@@ -543,7 +593,7 @@ class TransformersWrapper(CategoricalSequential):
543
593
  return log_probs, logits
544
594
 
545
595
  @classmethod
546
- def _log_probs_generate(cls, sequences, logits, pad_val):
596
+ def _log_probs_generate(cls, sequences, logits, pad_val=-100):
547
597
  tokens = pad_sequence(
548
598
  sequences,
549
599
  padding_value=pad_val,
@@ -557,6 +607,12 @@ class TransformersWrapper(CategoricalSequential):
557
607
  padding_side="left",
558
608
  )
559
609
 
560
- logits = logits.log_softmax(dim=-1)
561
- log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
610
+ # logits = logits.log_softmax(dim=-1)
611
+ # log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
612
+ td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
613
+ with td.flatten() as tdflat:
614
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
615
+ tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
616
+ )
617
+ log_probs = td["log_probs"]
562
618
  return log_probs, logits
@@ -74,6 +74,9 @@ class vLLMWrapper(CategoricalSequential):
74
74
  conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
75
75
  otherwise.
76
76
 
77
+ chat_template_name (str | None, optional): The name of the chat template to use for the history. Defaults to `None`.
78
+ chat_template (str | None, optional): The chat template to use for the history. Defaults to `None`.
79
+
77
80
  .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
78
81
  required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
79
82
  tokenization and padding.
@@ -120,6 +123,7 @@ class vLLMWrapper(CategoricalSequential):
120
123
  token_response_key: NestedKey = ("tokens_response",)
121
124
  text_response_key: NestedKey = ("text_response",)
122
125
  attention_mask_key: NestedKey = ("attention_mask",)
126
+ history_key: NestedKey = ("history",)
123
127
 
124
128
  def __init__(
125
129
  self,
@@ -137,6 +141,8 @@ class vLLMWrapper(CategoricalSequential):
137
141
  tokenizer_kwargs: dict | None = None,
138
142
  pad_output: bool = False,
139
143
  inplace: Literal[True, False, "empty"] | None = None,
144
+ chat_template_name: str | None = None,
145
+ chat_template: str | None = None,
140
146
  ):
141
147
  super().__init__()
142
148
 
@@ -149,6 +155,8 @@ class vLLMWrapper(CategoricalSequential):
149
155
  self._device = device
150
156
  self.generate = generate
151
157
  self.pad_output = pad_output
158
+ self.chat_template_name = chat_template_name
159
+ self.chat_template = chat_template
152
160
  padding_value = None
153
161
 
154
162
  if not tokenizer_kwargs:
@@ -329,7 +337,12 @@ class vLLMWrapper(CategoricalSequential):
329
337
  history = td.get(self.history_key)
330
338
  if history is None:
331
339
  raise ValueError("No text or history provided to the vLLMWrapper.")
332
- text = history.apply_chat_template(self.tokenizer)
340
+ tokenizer_kwargs = {}
341
+ if self.chat_template_name is not None:
342
+ tokenizer_kwargs["chat_template_name"] = self.chat_template_name
343
+ if self.chat_template is not None:
344
+ tokenizer_kwargs["chat_template"] = self.chat_template
345
+ text = history.apply_chat_template(self.tokenizer, **tokenizer_kwargs)
333
346
  if self.pad_output:
334
347
  tokenizer_kwargs = self.tokenizer_kwargs
335
348
  if not isinstance(text, (list, str)):
@@ -385,15 +398,35 @@ class vLLMWrapper(CategoricalSequential):
385
398
 
386
399
  def _from_vllm_logprobs_text(self, td, sampling_params, out):
387
400
  text_prompt = td.get(self.text_key)
388
- if text_prompt is None:
401
+ text_response = td.get(self.text_response_key)
402
+ if text_response is None or text_prompt is None:
403
+ if text_response is not None and text_prompt is not None:
404
+ raise ValueError(
405
+ "No text or history provided to the vLLMWrapper. Either both are provided or none of them."
406
+ )
389
407
  # Fallback on history parsing
390
408
  history = td.get(self.history_key)
391
409
  if history is None:
392
- raise ValueError("No text or history provided to the vLLMWrapper.")
393
- text_prompt = history.apply_chat_template(self.tokenizer)
410
+ raise ValueError(
411
+ "No text or history provided to the TransformersWrapper."
412
+ )
413
+ tokenizer_kwargs = {}
414
+ if self.chat_template_name is not None:
415
+ tokenizer_kwargs.setdefault(
416
+ "chat_template_name", self.chat_template_name
417
+ )
418
+ if self.chat_template is not None:
419
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
420
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
421
+ text_response = history.apply_chat_template(
422
+ tokenizer=self.tokenizer, **tokenizer_kwargs
423
+ )
424
+ if isinstance(text_response, list):
425
+ text_prompt = ["" for _ in text_response]
426
+ else:
427
+ text_prompt = ""
394
428
  if not isinstance(text_prompt, list):
395
429
  text_prompt = text_prompt.tolist()
396
- text_response = td.get(self.text_response_key)
397
430
  if not isinstance(text_response, list):
398
431
  text_response = text_response.tolist()
399
432
  text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)]
@@ -5,5 +5,6 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
8
+ from .sft import SFTLoss, SFTLossOutput
8
9
 
9
- __all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"]
10
+ __all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]