torchrl-nightly 2025.8.10__cp312-cp312-manylinux1_x86_64.whl → 2025.8.11__cp312-cp312-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1300,11 +1300,11 @@ def _extract_responses_from_full_histories(
1300
1300
  def _batching(func):
1301
1301
  @wraps(func)
1302
1302
  def _batched_func(self, td_input: TensorDictBase, **kwargs):
1303
- # -- 0. skip if batching is disabled
1303
+ # -- 0. Bypass if batching disabled
1304
1304
  if not self.batching:
1305
1305
  return func(self, td_input, **kwargs)
1306
1306
 
1307
- # ── 1. Normalise input ──────────────────────────────────────────────────
1307
+ # -- 1. Normalise --------------------------------------------------------
1308
1308
  if td_input.batch_dims > 1:
1309
1309
  raise ValueError(
1310
1310
  f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
@@ -1313,52 +1313,59 @@ def _batching(func):
1313
1313
  single = td_input.batch_dims == 0
1314
1314
  inputs = [td_input] if single else list(td_input.unbind(0))
1315
1315
  futures = [Future() for _ in inputs]
1316
+ pending = set(futures) # ← track our own Futures
1316
1317
 
1317
- # ── 2. Enqueue work and, if first in, do the draining ───────────────────
1318
+ # -- 2. Enqueue ----------------------------------------------------------
1318
1319
  self._batch_queue.extend(inputs)
1319
1320
  self._futures.extend(futures)
1320
1321
 
1321
1322
  min_bs = getattr(self, "_min_batch_size", 1)
1322
1323
  max_bs = getattr(self, "_max_batch_size", None)
1323
1324
 
1325
+ # -- 3. Drain while holding the lock ------------------------------------
1324
1326
  with self._batching_lock:
1325
- # Only the thread that managed to grab the lock will run the loop
1326
- while len(self._batch_queue) >= min_bs:
1327
- # Determine slice
1328
- slice_size = (
1329
- len(self._batch_queue)
1330
- if max_bs is None
1331
- else min(max_bs, len(self._batch_queue))
1332
- )
1333
- batch = self._batch_queue[:slice_size]
1334
- fut_slice = self._futures[:slice_size]
1327
+ if all(f.done() for f in futures):
1328
+ # Our items were already processed by another thread.
1329
+ # Skip draining; other workers will handle the rest of the queue.
1330
+ pass
1331
+ else:
1332
+ while len(self._batch_queue) >= min_bs:
1333
+ slice_size = (
1334
+ len(self._batch_queue)
1335
+ if max_bs is None
1336
+ else min(max_bs, len(self._batch_queue))
1337
+ )
1338
+ batch = self._batch_queue[:slice_size]
1339
+ fut_slice = self._futures[:slice_size]
1340
+
1341
+ try:
1342
+ results = func(self, lazy_stack(batch), **kwargs).unbind(0)
1343
+ if len(results) != slice_size:
1344
+ raise RuntimeError(
1345
+ f"Expected {slice_size} results, got {len(results)}"
1346
+ )
1347
+ for fut, res in zip(fut_slice, results):
1348
+ fut.set_result(res)
1349
+ pending.discard(fut) # ← mark as done
1350
+ except Exception as exc:
1351
+ for fut in fut_slice:
1352
+ fut.set_exception(exc)
1353
+ pending.discard(fut)
1354
+ raise
1335
1355
 
1336
- # Execute model
1337
- try:
1338
- results = func(self, lazy_stack(batch), **kwargs).unbind(0)
1339
- if len(results) != slice_size: # sanity
1340
- raise RuntimeError(
1341
- f"Expected {slice_size} results, got {len(results)}"
1342
- )
1343
- # Fulfil the corresponding futures
1344
- for fut, res in zip(fut_slice, results):
1345
- fut.set_result(res)
1346
- except Exception as exc:
1347
- for fut in fut_slice:
1348
- fut.set_exception(exc)
1349
- # Propagate to caller; other waiters will read the exception from their future
1350
- raise
1351
-
1352
- # Pop processed work
1353
- del self._batch_queue[:slice_size]
1354
- del self._futures[:slice_size]
1355
-
1356
- # ── 3. Outside the lock: wait only for OUR futures (they may already be done) ──
1357
- wait(
1358
- futures
1359
- ) # no timeout → immediate return if set_result()/set_exception() already called
1360
- result = [f.result() for f in futures]
1361
-
1362
- return result[0] if single else lazy_stack(result)
1356
+ # Pop processed work
1357
+ del self._batch_queue[:slice_size]
1358
+ del self._futures[:slice_size]
1359
+
1360
+ # ---- Early-exit: all *our* Futures are done -------------------
1361
+ if not pending:
1362
+ break
1363
+
1364
+ # -- 4. Outside the lock: wait only on remaining (rare) -----------------
1365
+ if pending: # usually empty; safety for min_bs > queue size
1366
+ wait(pending)
1367
+ results = [f.result() for f in futures]
1368
+
1369
+ return results[0] if single else lazy_stack(results)
1363
1370
 
1364
1371
  return _batched_func
@@ -23,7 +23,7 @@ from tensordict import (
23
23
  from tensordict.utils import _zip_strict, NestedKey
24
24
  from torch import distributions as D
25
25
  from torch.nn.utils.rnn import pad_sequence
26
-
26
+ from torchrl import logger as torchrl_logger
27
27
  from torchrl.modules.llm.policies.common import (
28
28
  _batching,
29
29
  _extract_responses_from_full_histories,
@@ -2443,7 +2443,12 @@ class RemoteTransformersWrapper:
2443
2443
  """
2444
2444
 
2445
2445
  def __init__(
2446
- self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
2446
+ self,
2447
+ model,
2448
+ max_concurrency: int = 16,
2449
+ validate_model: bool = True,
2450
+ actor_name: str = None,
2451
+ **kwargs,
2447
2452
  ):
2448
2453
  import ray
2449
2454
 
@@ -2458,10 +2463,23 @@ class RemoteTransformersWrapper:
2458
2463
 
2459
2464
  if not ray.is_initialized():
2460
2465
  ray.init()
2461
- # Create the remote actor
2466
+
2467
+ if actor_name is not None:
2468
+ # Check if an actor with this name already exists
2469
+ try:
2470
+ existing_actor = ray.get_actor(actor_name)
2471
+ # If we can get the actor, assume it's alive and use it
2472
+ self._remote_wrapper = existing_actor
2473
+ torchrl_logger.info(f"Using existing actor {actor_name}")
2474
+ return
2475
+ except ValueError:
2476
+ # Actor doesn't exist, create a new one
2477
+ torchrl_logger.info(f"Creating new actor {actor_name}")
2478
+
2479
+ # Create the remote actor with the unique name
2462
2480
  self._remote_wrapper = (
2463
2481
  ray.remote(TransformersWrapper)
2464
- .options(max_concurrency=max_concurrency)
2482
+ .options(max_concurrency=max_concurrency, name=actor_name)
2465
2483
  .remote(model, **kwargs)
2466
2484
  )
2467
2485
 
@@ -23,6 +23,7 @@ from tensordict.tensorclass import from_dataclass, TensorClass
23
23
  from tensordict.utils import _zip_strict, NestedKey
24
24
  from torch import distributions as D
25
25
  from torch.nn.utils.rnn import pad_sequence
26
+ from torchrl import logger as torchrl_logger
26
27
 
27
28
  from torchrl.envs.utils import _classproperty
28
29
  from torchrl.modules.llm.policies.common import (
@@ -2101,7 +2102,12 @@ class RemotevLLMWrapper:
2101
2102
  """
2102
2103
 
2103
2104
  def __init__(
2104
- self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
2105
+ self,
2106
+ model,
2107
+ max_concurrency: int = 16,
2108
+ validate_model: bool = True,
2109
+ actor_name: str = None,
2110
+ **kwargs,
2105
2111
  ):
2106
2112
  import ray
2107
2113
 
@@ -2141,10 +2147,22 @@ class RemotevLLMWrapper:
2141
2147
  if not ray.is_initialized():
2142
2148
  ray.init()
2143
2149
 
2144
- # Create the remote actor
2150
+ if actor_name is not None:
2151
+ # Check if an actor with this name already exists
2152
+ try:
2153
+ existing_actor = ray.get_actor(actor_name)
2154
+ torchrl_logger.info(f"Using existing actor {actor_name}")
2155
+ # If we can get the actor, assume it's alive and use it
2156
+ self._remote_wrapper = existing_actor
2157
+ return
2158
+ except ValueError:
2159
+ # Actor doesn't exist, create a new one
2160
+ torchrl_logger.info(f"Creating new actor {actor_name}")
2161
+
2162
+ # Create the remote actor with the unique name
2145
2163
  self._remote_wrapper = (
2146
2164
  ray.remote(vLLMWrapper)
2147
- .options(max_concurrency=max_concurrency)
2165
+ .options(max_concurrency=max_concurrency, name=actor_name)
2148
2166
  .remote(model, **kwargs)
2149
2167
  )
2150
2168
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchrl-nightly
3
- Version: 2025.8.10
3
+ Version: 2025.8.11
4
4
  Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
5
5
  Author-email: torchrl contributors <vmoens@fb.com>
6
6
  Maintainer-email: torchrl contributors <vmoens@fb.com>
@@ -245,9 +245,9 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
245
245
  torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
246
246
  torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
247
247
  torchrl/modules/llm/policies/__init__.py,sha256=VYAiblw6ETlo4q1vSvaKaybuxwxuPXfC-QCFzZJk4PA,649
248
- torchrl/modules/llm/policies/common.py,sha256=LSJrDy9NW2xyqTttCrcurJHIqHVmk4F4jjYamjibfZs,56489
249
- torchrl/modules/llm/policies/transformers_wrapper.py,sha256=rh8Us_95U-NL-aM_AVYXQWfxneRl-z74ovHPTwTW12M,110340
250
- torchrl/modules/llm/policies/vllm_wrapper.py,sha256=SAh2cgmDmc4qiaQi6yHoaISqM3flLbrgriIbx9zQpIs,99125
248
+ torchrl/modules/llm/policies/common.py,sha256=hjDx09kt-IG81AiwD_8SRwZU7Zf530nWJMpdvECilrE,56733
249
+ torchrl/modules/llm/policies/transformers_wrapper.py,sha256=GzaJBoEDIO7NpLXzNWySiqjRNXelOapUISAHSn4dyx8,111044
250
+ torchrl/modules/llm/policies/vllm_wrapper.py,sha256=_cXdE5HjYtJE0cuGSxSynhFu8cShWQuu0R0rxBnD4jc,99829
251
251
  torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
252
252
  torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
253
253
  torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
@@ -322,8 +322,8 @@ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Y
322
322
  torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
323
323
  torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
324
324
  torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
325
- torchrl_nightly-2025.8.10.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
326
- torchrl_nightly-2025.8.10.dist-info/METADATA,sha256=paA13xYvsRhQHdq4kbORTI9utZkTu1QAIDElif9MAqw,41412
327
- torchrl_nightly-2025.8.10.dist-info/WHEEL,sha256=ziAMZrFEBAMOBaTDNVVFwf5i-WiFj1yXRFZ4MRxHC0g,104
328
- torchrl_nightly-2025.8.10.dist-info/top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
329
- torchrl_nightly-2025.8.10.dist-info/RECORD,,
325
+ torchrl_nightly-2025.8.11.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
326
+ torchrl_nightly-2025.8.11.dist-info/METADATA,sha256=yho9bHnrIfIvRsr_JI-SLCVVViu_hSb2H94L8t8HDZ4,41412
327
+ torchrl_nightly-2025.8.11.dist-info/WHEEL,sha256=ziAMZrFEBAMOBaTDNVVFwf5i-WiFj1yXRFZ4MRxHC0g,104
328
+ torchrl_nightly-2025.8.11.dist-info/top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
329
+ torchrl_nightly-2025.8.11.dist-info/RECORD,,