torchrl-nightly 2025.8.10__cp313-cp313-manylinux1_x86_64.whl → 2025.8.11__cp313-cp313-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/modules/llm/policies/common.py +47 -40
- torchrl/modules/llm/policies/transformers_wrapper.py +22 -4
- torchrl/modules/llm/policies/vllm_wrapper.py +21 -3
- {torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/RECORD +8 -8
- {torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/top_level.txt +0 -0
@@ -1300,11 +1300,11 @@ def _extract_responses_from_full_histories(
|
|
1300
1300
|
def _batching(func):
|
1301
1301
|
@wraps(func)
|
1302
1302
|
def _batched_func(self, td_input: TensorDictBase, **kwargs):
|
1303
|
-
# -- 0.
|
1303
|
+
# -- 0. Bypass if batching disabled
|
1304
1304
|
if not self.batching:
|
1305
1305
|
return func(self, td_input, **kwargs)
|
1306
1306
|
|
1307
|
-
#
|
1307
|
+
# -- 1. Normalise --------------------------------------------------------
|
1308
1308
|
if td_input.batch_dims > 1:
|
1309
1309
|
raise ValueError(
|
1310
1310
|
f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
|
@@ -1313,52 +1313,59 @@ def _batching(func):
|
|
1313
1313
|
single = td_input.batch_dims == 0
|
1314
1314
|
inputs = [td_input] if single else list(td_input.unbind(0))
|
1315
1315
|
futures = [Future() for _ in inputs]
|
1316
|
+
pending = set(futures) # ← track our own Futures
|
1316
1317
|
|
1317
|
-
#
|
1318
|
+
# -- 2. Enqueue ----------------------------------------------------------
|
1318
1319
|
self._batch_queue.extend(inputs)
|
1319
1320
|
self._futures.extend(futures)
|
1320
1321
|
|
1321
1322
|
min_bs = getattr(self, "_min_batch_size", 1)
|
1322
1323
|
max_bs = getattr(self, "_max_batch_size", None)
|
1323
1324
|
|
1325
|
+
# -- 3. Drain while holding the lock ------------------------------------
|
1324
1326
|
with self._batching_lock:
|
1325
|
-
|
1326
|
-
|
1327
|
-
#
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1327
|
+
if all(f.done() for f in futures):
|
1328
|
+
# Our items were already processed by another thread.
|
1329
|
+
# Skip draining; other workers will handle the rest of the queue.
|
1330
|
+
pass
|
1331
|
+
else:
|
1332
|
+
while len(self._batch_queue) >= min_bs:
|
1333
|
+
slice_size = (
|
1334
|
+
len(self._batch_queue)
|
1335
|
+
if max_bs is None
|
1336
|
+
else min(max_bs, len(self._batch_queue))
|
1337
|
+
)
|
1338
|
+
batch = self._batch_queue[:slice_size]
|
1339
|
+
fut_slice = self._futures[:slice_size]
|
1340
|
+
|
1341
|
+
try:
|
1342
|
+
results = func(self, lazy_stack(batch), **kwargs).unbind(0)
|
1343
|
+
if len(results) != slice_size:
|
1344
|
+
raise RuntimeError(
|
1345
|
+
f"Expected {slice_size} results, got {len(results)}"
|
1346
|
+
)
|
1347
|
+
for fut, res in zip(fut_slice, results):
|
1348
|
+
fut.set_result(res)
|
1349
|
+
pending.discard(fut) # ← mark as done
|
1350
|
+
except Exception as exc:
|
1351
|
+
for fut in fut_slice:
|
1352
|
+
fut.set_exception(exc)
|
1353
|
+
pending.discard(fut)
|
1354
|
+
raise
|
1335
1355
|
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
raise
|
1351
|
-
|
1352
|
-
# Pop processed work
|
1353
|
-
del self._batch_queue[:slice_size]
|
1354
|
-
del self._futures[:slice_size]
|
1355
|
-
|
1356
|
-
# ── 3. Outside the lock: wait only for OUR futures (they may already be done) ──
|
1357
|
-
wait(
|
1358
|
-
futures
|
1359
|
-
) # no timeout → immediate return if set_result()/set_exception() already called
|
1360
|
-
result = [f.result() for f in futures]
|
1361
|
-
|
1362
|
-
return result[0] if single else lazy_stack(result)
|
1356
|
+
# Pop processed work
|
1357
|
+
del self._batch_queue[:slice_size]
|
1358
|
+
del self._futures[:slice_size]
|
1359
|
+
|
1360
|
+
# ---- Early-exit: all *our* Futures are done -------------------
|
1361
|
+
if not pending:
|
1362
|
+
break
|
1363
|
+
|
1364
|
+
# -- 4. Outside the lock: wait only on remaining (rare) -----------------
|
1365
|
+
if pending: # usually empty; safety for min_bs > queue size
|
1366
|
+
wait(pending)
|
1367
|
+
results = [f.result() for f in futures]
|
1368
|
+
|
1369
|
+
return results[0] if single else lazy_stack(results)
|
1363
1370
|
|
1364
1371
|
return _batched_func
|
@@ -23,7 +23,7 @@ from tensordict import (
|
|
23
23
|
from tensordict.utils import _zip_strict, NestedKey
|
24
24
|
from torch import distributions as D
|
25
25
|
from torch.nn.utils.rnn import pad_sequence
|
26
|
-
|
26
|
+
from torchrl import logger as torchrl_logger
|
27
27
|
from torchrl.modules.llm.policies.common import (
|
28
28
|
_batching,
|
29
29
|
_extract_responses_from_full_histories,
|
@@ -2443,7 +2443,12 @@ class RemoteTransformersWrapper:
|
|
2443
2443
|
"""
|
2444
2444
|
|
2445
2445
|
def __init__(
|
2446
|
-
self,
|
2446
|
+
self,
|
2447
|
+
model,
|
2448
|
+
max_concurrency: int = 16,
|
2449
|
+
validate_model: bool = True,
|
2450
|
+
actor_name: str = None,
|
2451
|
+
**kwargs,
|
2447
2452
|
):
|
2448
2453
|
import ray
|
2449
2454
|
|
@@ -2458,10 +2463,23 @@ class RemoteTransformersWrapper:
|
|
2458
2463
|
|
2459
2464
|
if not ray.is_initialized():
|
2460
2465
|
ray.init()
|
2461
|
-
|
2466
|
+
|
2467
|
+
if actor_name is not None:
|
2468
|
+
# Check if an actor with this name already exists
|
2469
|
+
try:
|
2470
|
+
existing_actor = ray.get_actor(actor_name)
|
2471
|
+
# If we can get the actor, assume it's alive and use it
|
2472
|
+
self._remote_wrapper = existing_actor
|
2473
|
+
torchrl_logger.info(f"Using existing actor {actor_name}")
|
2474
|
+
return
|
2475
|
+
except ValueError:
|
2476
|
+
# Actor doesn't exist, create a new one
|
2477
|
+
torchrl_logger.info(f"Creating new actor {actor_name}")
|
2478
|
+
|
2479
|
+
# Create the remote actor with the unique name
|
2462
2480
|
self._remote_wrapper = (
|
2463
2481
|
ray.remote(TransformersWrapper)
|
2464
|
-
.options(max_concurrency=max_concurrency)
|
2482
|
+
.options(max_concurrency=max_concurrency, name=actor_name)
|
2465
2483
|
.remote(model, **kwargs)
|
2466
2484
|
)
|
2467
2485
|
|
@@ -23,6 +23,7 @@ from tensordict.tensorclass import from_dataclass, TensorClass
|
|
23
23
|
from tensordict.utils import _zip_strict, NestedKey
|
24
24
|
from torch import distributions as D
|
25
25
|
from torch.nn.utils.rnn import pad_sequence
|
26
|
+
from torchrl import logger as torchrl_logger
|
26
27
|
|
27
28
|
from torchrl.envs.utils import _classproperty
|
28
29
|
from torchrl.modules.llm.policies.common import (
|
@@ -2101,7 +2102,12 @@ class RemotevLLMWrapper:
|
|
2101
2102
|
"""
|
2102
2103
|
|
2103
2104
|
def __init__(
|
2104
|
-
self,
|
2105
|
+
self,
|
2106
|
+
model,
|
2107
|
+
max_concurrency: int = 16,
|
2108
|
+
validate_model: bool = True,
|
2109
|
+
actor_name: str = None,
|
2110
|
+
**kwargs,
|
2105
2111
|
):
|
2106
2112
|
import ray
|
2107
2113
|
|
@@ -2141,10 +2147,22 @@ class RemotevLLMWrapper:
|
|
2141
2147
|
if not ray.is_initialized():
|
2142
2148
|
ray.init()
|
2143
2149
|
|
2144
|
-
|
2150
|
+
if actor_name is not None:
|
2151
|
+
# Check if an actor with this name already exists
|
2152
|
+
try:
|
2153
|
+
existing_actor = ray.get_actor(actor_name)
|
2154
|
+
torchrl_logger.info(f"Using existing actor {actor_name}")
|
2155
|
+
# If we can get the actor, assume it's alive and use it
|
2156
|
+
self._remote_wrapper = existing_actor
|
2157
|
+
return
|
2158
|
+
except ValueError:
|
2159
|
+
# Actor doesn't exist, create a new one
|
2160
|
+
torchrl_logger.info(f"Creating new actor {actor_name}")
|
2161
|
+
|
2162
|
+
# Create the remote actor with the unique name
|
2145
2163
|
self._remote_wrapper = (
|
2146
2164
|
ray.remote(vLLMWrapper)
|
2147
|
-
.options(max_concurrency=max_concurrency)
|
2165
|
+
.options(max_concurrency=max_concurrency, name=actor_name)
|
2148
2166
|
.remote(model, **kwargs)
|
2149
2167
|
)
|
2150
2168
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: torchrl-nightly
|
3
|
-
Version: 2025.8.
|
3
|
+
Version: 2025.8.11
|
4
4
|
Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
|
5
5
|
Author-email: torchrl contributors <vmoens@fb.com>
|
6
6
|
Maintainer-email: torchrl contributors <vmoens@fb.com>
|
@@ -245,9 +245,9 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
|
|
245
245
|
torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
|
246
246
|
torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
|
247
247
|
torchrl/modules/llm/policies/__init__.py,sha256=VYAiblw6ETlo4q1vSvaKaybuxwxuPXfC-QCFzZJk4PA,649
|
248
|
-
torchrl/modules/llm/policies/common.py,sha256=
|
249
|
-
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=
|
250
|
-
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=
|
248
|
+
torchrl/modules/llm/policies/common.py,sha256=hjDx09kt-IG81AiwD_8SRwZU7Zf530nWJMpdvECilrE,56733
|
249
|
+
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=GzaJBoEDIO7NpLXzNWySiqjRNXelOapUISAHSn4dyx8,111044
|
250
|
+
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=_cXdE5HjYtJE0cuGSxSynhFu8cShWQuu0R0rxBnD4jc,99829
|
251
251
|
torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
|
252
252
|
torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
|
253
253
|
torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
|
@@ -322,8 +322,8 @@ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Y
|
|
322
322
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
323
323
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
324
324
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
325
|
-
torchrl_nightly-2025.8.
|
326
|
-
torchrl_nightly-2025.8.
|
327
|
-
torchrl_nightly-2025.8.
|
328
|
-
torchrl_nightly-2025.8.
|
329
|
-
torchrl_nightly-2025.8.
|
325
|
+
torchrl_nightly-2025.8.11.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
326
|
+
torchrl_nightly-2025.8.11.dist-info/METADATA,sha256=yho9bHnrIfIvRsr_JI-SLCVVViu_hSb2H94L8t8HDZ4,41412
|
327
|
+
torchrl_nightly-2025.8.11.dist-info/WHEEL,sha256=lfRLw2w0c5GQxgaYG7pjqx4f74UCjBeXDAmVwz2UOJI,104
|
328
|
+
torchrl_nightly-2025.8.11.dist-info/top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
|
329
|
+
torchrl_nightly-2025.8.11.dist-info/RECORD,,
|
File without changes
|
{torchrl_nightly-2025.8.10.dist-info → torchrl_nightly-2025.8.11.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|