dmpo 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dmpo/__init__.py +1 -0
- dmpo/dmpo.py +1 -0
- dmpo/mpo.py +1 -0
- dmpo/tpo.py +399 -0
- dmpo/vmpo.py +1 -0
- dmpo-0.0.2.dist-info/METADATA +133 -0
- dmpo-0.0.2.dist-info/RECORD +9 -0
- dmpo-0.0.2.dist-info/WHEEL +4 -0
- dmpo-0.0.2.dist-info/licenses/LICENSE +21 -0
dmpo/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from dmpo.tpo import TPO
|
dmpo/dmpo.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
dmpo/mpo.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
dmpo/tpo.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
from collections import deque, namedtuple
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import Module
|
|
6
|
+
from torch.optim import Adam
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from einops import rearrange, einsum, reduce
|
|
10
|
+
from torch_einops_utils import masked_mean, lens_to_mask
|
|
11
|
+
|
|
12
|
+
from accelerate import Accelerator
|
|
13
|
+
from memmap_replay_buffer import ReplayBuffer
|
|
14
|
+
from discrete_continuous_embed_readout import ParameterlessReadout
|
|
15
|
+
|
|
16
|
+
# helpers
|
|
17
|
+
|
|
18
|
+
def exists(val):
|
|
19
|
+
return val is not None
|
|
20
|
+
|
|
21
|
+
def default(val, d):
|
|
22
|
+
return val if exists(val) else d
|
|
23
|
+
|
|
24
|
+
def z_score(t, eps = 1e-8):
|
|
25
|
+
return (t - t.mean()) / (t.std(unbiased = False) + eps)
|
|
26
|
+
|
|
27
|
+
# tpo loss functions
|
|
28
|
+
|
|
29
|
+
LogScoreReturn = namedtuple('LogScoreReturn', ['log_scores', 'logits'])
|
|
30
|
+
|
|
31
|
+
def tpo_target(log_scores, u, eta = 1.0):
|
|
32
|
+
return F.log_softmax(log_scores + u / eta, dim = -1)
|
|
33
|
+
|
|
34
|
+
def tpo_forward_kl_loss(log_p, log_q):
|
|
35
|
+
q = log_q.exp()
|
|
36
|
+
return -einsum(q, log_p, '... k, ... k -> ...').mean()
|
|
37
|
+
|
|
38
|
+
def tpo_reverse_kl_loss(log_p, log_q):
|
|
39
|
+
p = log_p.exp()
|
|
40
|
+
return einsum(p, log_p - log_q, '... k, ... k -> ...').mean()
|
|
41
|
+
|
|
42
|
+
def tpo_js_loss(log_p, log_q, weight = 0.5, eps = 1e-10):
|
|
43
|
+
p = log_p.exp()
|
|
44
|
+
q = log_q.exp()
|
|
45
|
+
|
|
46
|
+
m = q.lerp(p, weight)
|
|
47
|
+
log_m = m.clamp(min = eps).log()
|
|
48
|
+
|
|
49
|
+
kl_p_m = einsum(p, log_p - log_m, '... k, ... k -> ...').mean()
|
|
50
|
+
kl_q_m = einsum(q, log_q - log_m, '... k, ... k -> ...').mean()
|
|
51
|
+
|
|
52
|
+
return kl_q_m.lerp(kl_p_m, weight)
|
|
53
|
+
|
|
54
|
+
TPO_LOSS_FNS = dict(
|
|
55
|
+
forward_kl = tpo_forward_kl_loss,
|
|
56
|
+
reverse_kl = tpo_reverse_kl_loss,
|
|
57
|
+
js = tpo_js_loss
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# environments
|
|
61
|
+
|
|
62
|
+
class GymEnvironment(Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
env,
|
|
66
|
+
readout,
|
|
67
|
+
maybe_reshape_logits,
|
|
68
|
+
action_fields,
|
|
69
|
+
is_discrete,
|
|
70
|
+
is_continuous,
|
|
71
|
+
num_continuous = None,
|
|
72
|
+
num_discrete_categories = None,
|
|
73
|
+
num_discrete_logits = None,
|
|
74
|
+
group_size = 64,
|
|
75
|
+
max_timesteps = None,
|
|
76
|
+
buffer_folder = './tpo_buffer',
|
|
77
|
+
overwrite_buffer_on_start = True
|
|
78
|
+
):
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.env = env
|
|
81
|
+
self.readout = readout
|
|
82
|
+
|
|
83
|
+
self.is_discrete = is_discrete
|
|
84
|
+
self.is_continuous = is_continuous
|
|
85
|
+
self.num_continuous = num_continuous
|
|
86
|
+
self.maybe_reshape_logits = maybe_reshape_logits
|
|
87
|
+
|
|
88
|
+
self.num_discrete_categories = num_discrete_categories
|
|
89
|
+
|
|
90
|
+
if exists(num_discrete_categories):
|
|
91
|
+
categories = torch.tensor(num_discrete_categories)
|
|
92
|
+
self.register_buffer('categories', categories)
|
|
93
|
+
self.register_buffer('divisors', torch.cat((torch.tensor([1]), categories.cumprod(dim = 0)[:-1])))
|
|
94
|
+
|
|
95
|
+
self.group_size = group_size
|
|
96
|
+
|
|
97
|
+
obs_dim = int(env.observation_space.shape[0])
|
|
98
|
+
max_timesteps = default(max_timesteps, group_size * 1000)
|
|
99
|
+
|
|
100
|
+
self.buffer = ReplayBuffer(
|
|
101
|
+
folder = buffer_folder,
|
|
102
|
+
max_episodes = group_size,
|
|
103
|
+
max_timesteps = max_timesteps,
|
|
104
|
+
fields = dict(
|
|
105
|
+
state = ('float', (obs_dim,)),
|
|
106
|
+
**action_fields
|
|
107
|
+
),
|
|
108
|
+
meta_fields = dict(
|
|
109
|
+
cum_reward = 'float'
|
|
110
|
+
),
|
|
111
|
+
circular = False,
|
|
112
|
+
overwrite = overwrite_buffer_on_start
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def is_multi_discrete(self):
|
|
117
|
+
return exists(self.num_discrete_categories)
|
|
118
|
+
|
|
119
|
+
def get_discrete_env_action(self, discrete_tensor):
|
|
120
|
+
if not self.is_multi_discrete:
|
|
121
|
+
return discrete_tensor.item()
|
|
122
|
+
return ((discrete_tensor // self.divisors.to(discrete_tensor.device)) % self.categories.to(discrete_tensor.device)).cpu().numpy()
|
|
123
|
+
|
|
124
|
+
def action_to_env(self, action_tensor):
|
|
125
|
+
if self.is_continuous and not self.is_discrete:
|
|
126
|
+
return action_tensor.cpu().numpy()
|
|
127
|
+
|
|
128
|
+
if self.is_discrete and not self.is_continuous:
|
|
129
|
+
return self.get_discrete_env_action(action_tensor)
|
|
130
|
+
|
|
131
|
+
discrete_tensor, continuous_tensor = action_tensor
|
|
132
|
+
return (self.get_discrete_env_action(discrete_tensor), continuous_tensor.cpu().numpy())
|
|
133
|
+
|
|
134
|
+
def forward(self, actor):
|
|
135
|
+
device = next(actor.parameters()).device
|
|
136
|
+
self.buffer.clear()
|
|
137
|
+
|
|
138
|
+
for k in range(self.group_size):
|
|
139
|
+
state, _ = self.env.reset()
|
|
140
|
+
episode_reward = 0.
|
|
141
|
+
done = False
|
|
142
|
+
|
|
143
|
+
while not done:
|
|
144
|
+
state_t = torch.tensor(state, dtype = torch.float32, device = device)
|
|
145
|
+
|
|
146
|
+
with torch.no_grad():
|
|
147
|
+
logits = self.maybe_reshape_logits(actor(state_t))
|
|
148
|
+
action_tensor = self.readout.sample(logits)
|
|
149
|
+
|
|
150
|
+
action = self.action_to_env(action_tensor)
|
|
151
|
+
|
|
152
|
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
|
153
|
+
done = terminated or truncated
|
|
154
|
+
|
|
155
|
+
store_kwargs = dict(state = state)
|
|
156
|
+
|
|
157
|
+
if self.is_discrete:
|
|
158
|
+
t = action_tensor[0] if self.is_continuous else action_tensor
|
|
159
|
+
store_kwargs['action_discrete'] = t.item()
|
|
160
|
+
|
|
161
|
+
if self.is_continuous:
|
|
162
|
+
t = action_tensor[1] if self.is_discrete else action_tensor
|
|
163
|
+
store_kwargs['action_continuous'] = t.cpu().numpy()
|
|
164
|
+
|
|
165
|
+
self.buffer.store(**store_kwargs)
|
|
166
|
+
|
|
167
|
+
episode_reward += reward
|
|
168
|
+
state = next_state
|
|
169
|
+
|
|
170
|
+
self.buffer.store_meta_datapoint(k, 'cum_reward', episode_reward)
|
|
171
|
+
self.buffer.advance_episode()
|
|
172
|
+
|
|
173
|
+
return self.buffer.get_all_data(device = device)
|
|
174
|
+
|
|
175
|
+
# main class
|
|
176
|
+
|
|
177
|
+
class TPO(Module):
|
|
178
|
+
def __init__(
|
|
179
|
+
self,
|
|
180
|
+
actor,
|
|
181
|
+
environment,
|
|
182
|
+
*,
|
|
183
|
+
action_num_discrete = None,
|
|
184
|
+
action_num_continuous = None,
|
|
185
|
+
buffer_folder = './tpo_buffer',
|
|
186
|
+
overwrite_buffer_on_start = True,
|
|
187
|
+
max_timesteps = None,
|
|
188
|
+
epochs = 4,
|
|
189
|
+
group_size = 64,
|
|
190
|
+
optim = None,
|
|
191
|
+
optim_kwargs = dict(),
|
|
192
|
+
lr = 3e-4,
|
|
193
|
+
max_grad_norm = None,
|
|
194
|
+
eta = 1.0,
|
|
195
|
+
min_rewards_std = 1e-4,
|
|
196
|
+
entropy_coef = 0.01,
|
|
197
|
+
divergence = 'forward_kl',
|
|
198
|
+
reward_moving_average_len = 20,
|
|
199
|
+
cpu = False,
|
|
200
|
+
on_result = None,
|
|
201
|
+
**readout_kwargs
|
|
202
|
+
):
|
|
203
|
+
super().__init__()
|
|
204
|
+
|
|
205
|
+
self.has_discrete = exists(action_num_discrete)
|
|
206
|
+
self.has_continuous = exists(action_num_continuous)
|
|
207
|
+
|
|
208
|
+
assert self.has_discrete or self.has_continuous, 'must specify at least one of action_num_discrete or action_num_continuous'
|
|
209
|
+
|
|
210
|
+
# readout
|
|
211
|
+
|
|
212
|
+
readout_params = dict(**readout_kwargs)
|
|
213
|
+
|
|
214
|
+
if self.has_discrete:
|
|
215
|
+
readout_params['num_discrete'] = action_num_discrete
|
|
216
|
+
|
|
217
|
+
if self.has_continuous:
|
|
218
|
+
readout_params['num_continuous'] = action_num_continuous
|
|
219
|
+
|
|
220
|
+
self.readout = ParameterlessReadout(**readout_params)
|
|
221
|
+
|
|
222
|
+
# derive buffer field and action conversion from config
|
|
223
|
+
|
|
224
|
+
action_fields = dict()
|
|
225
|
+
num_discrete_categories = None
|
|
226
|
+
self.num_discrete_logits = None
|
|
227
|
+
|
|
228
|
+
if self.has_discrete:
|
|
229
|
+
is_multi = isinstance(action_num_discrete, (tuple, list))
|
|
230
|
+
action_fields['action_discrete'] = 'int'
|
|
231
|
+
num_discrete_categories = tuple(action_num_discrete) if is_multi else None
|
|
232
|
+
self.num_discrete_logits = sum(action_num_discrete) if is_multi else action_num_discrete
|
|
233
|
+
|
|
234
|
+
if self.has_continuous:
|
|
235
|
+
action_fields['action_continuous'] = ('float', (action_num_continuous,))
|
|
236
|
+
|
|
237
|
+
# setup environment
|
|
238
|
+
|
|
239
|
+
if not callable(environment):
|
|
240
|
+
self.environment = GymEnvironment(
|
|
241
|
+
environment,
|
|
242
|
+
readout = self.readout,
|
|
243
|
+
maybe_reshape_logits = self.maybe_reshape_logits,
|
|
244
|
+
action_fields = action_fields,
|
|
245
|
+
is_discrete = self.has_discrete,
|
|
246
|
+
is_continuous = self.has_continuous,
|
|
247
|
+
num_continuous = action_num_continuous,
|
|
248
|
+
num_discrete_categories = num_discrete_categories,
|
|
249
|
+
num_discrete_logits = self.num_discrete_logits,
|
|
250
|
+
group_size = group_size,
|
|
251
|
+
max_timesteps = max_timesteps,
|
|
252
|
+
buffer_folder = buffer_folder,
|
|
253
|
+
overwrite_buffer_on_start = overwrite_buffer_on_start
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
self.environment = environment
|
|
257
|
+
|
|
258
|
+
# store refs
|
|
259
|
+
|
|
260
|
+
self.num_continuous = action_num_continuous
|
|
261
|
+
|
|
262
|
+
self.actor = actor
|
|
263
|
+
|
|
264
|
+
self.accelerator = Accelerator(cpu = cpu)
|
|
265
|
+
self.device = self.accelerator.device
|
|
266
|
+
|
|
267
|
+
if exists(optim):
|
|
268
|
+
self.optimizer = optim
|
|
269
|
+
else:
|
|
270
|
+
self.optimizer = Adam(self.actor.parameters(), lr = lr, **optim_kwargs)
|
|
271
|
+
|
|
272
|
+
self.actor, self.readout, self.optimizer = self.accelerator.prepare(
|
|
273
|
+
self.actor, self.readout, self.optimizer
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
self.epochs = epochs
|
|
277
|
+
self.eta = eta
|
|
278
|
+
self.min_rewards_std = min_rewards_std
|
|
279
|
+
self.max_grad_norm = max_grad_norm
|
|
280
|
+
self.entropy_coef = entropy_coef
|
|
281
|
+
self.reward_moving_average_len = reward_moving_average_len
|
|
282
|
+
|
|
283
|
+
assert divergence in TPO_LOSS_FNS, f'divergence must be one of {list(TPO_LOSS_FNS.keys())}'
|
|
284
|
+
self.tpo_loss_fn = TPO_LOSS_FNS[divergence]
|
|
285
|
+
|
|
286
|
+
self.on_result = on_result
|
|
287
|
+
|
|
288
|
+
def maybe_reshape_logits(self, logits):
|
|
289
|
+
if self.has_discrete and not self.has_continuous:
|
|
290
|
+
return logits
|
|
291
|
+
|
|
292
|
+
if self.has_continuous and not self.has_discrete:
|
|
293
|
+
return rearrange(logits, '... (c d) -> ... c d', c = self.num_continuous)
|
|
294
|
+
|
|
295
|
+
discrete_logits, continuous_logits = logits.split([self.num_discrete_logits, self.num_continuous * 2], dim = -1)
|
|
296
|
+
continuous_params = rearrange(continuous_logits, '... (c d) -> ... c d', c = self.num_continuous)
|
|
297
|
+
|
|
298
|
+
return (discrete_logits, continuous_params)
|
|
299
|
+
|
|
300
|
+
def calculate_log_scores(self, states, actions, mask, episode_lens_float):
|
|
301
|
+
logits = self.maybe_reshape_logits(self.actor(states))
|
|
302
|
+
|
|
303
|
+
neg_log_probs = self.readout.calculate_loss(
|
|
304
|
+
logits,
|
|
305
|
+
targets = actions,
|
|
306
|
+
mask = mask,
|
|
307
|
+
return_unreduced_loss = True
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
log_scores = reduce(-neg_log_probs, 'b ... -> b', 'sum')
|
|
311
|
+
log_scores = log_scores / episode_lens_float
|
|
312
|
+
|
|
313
|
+
return LogScoreReturn(log_scores, logits)
|
|
314
|
+
|
|
315
|
+
def forward(
|
|
316
|
+
self,
|
|
317
|
+
num_iterations = 2000
|
|
318
|
+
):
|
|
319
|
+
device = self.device
|
|
320
|
+
recent_rewards = deque(maxlen = self.reward_moving_average_len)
|
|
321
|
+
pbar = tqdm(range(num_iterations), desc = 'tpo training')
|
|
322
|
+
|
|
323
|
+
for it in pbar:
|
|
324
|
+
|
|
325
|
+
# get rollout
|
|
326
|
+
|
|
327
|
+
data = self.environment(self.actor)
|
|
328
|
+
|
|
329
|
+
# unpack data
|
|
330
|
+
|
|
331
|
+
states = data['state']
|
|
332
|
+
|
|
333
|
+
if 'action' in data:
|
|
334
|
+
actions = data['action']
|
|
335
|
+
elif self.has_discrete and self.has_continuous:
|
|
336
|
+
actions = (data['action_discrete'], data['action_continuous'])
|
|
337
|
+
elif self.has_discrete:
|
|
338
|
+
actions = data['action_discrete']
|
|
339
|
+
else:
|
|
340
|
+
actions = data['action_continuous']
|
|
341
|
+
|
|
342
|
+
rewards = data.get('cum_reward', data.get('reward'))
|
|
343
|
+
episode_lens = data.get('episode_lens')
|
|
344
|
+
|
|
345
|
+
# log reward
|
|
346
|
+
|
|
347
|
+
recent_rewards.extend(rewards.tolist())
|
|
348
|
+
|
|
349
|
+
avg_reward = sum(recent_rewards) / max(1, len(recent_rewards))
|
|
350
|
+
|
|
351
|
+
if exists(self.on_result):
|
|
352
|
+
self.on_result(avg_reward, pbar)
|
|
353
|
+
else:
|
|
354
|
+
pbar.set_postfix(avg_reward = f'{avg_reward:.2f}')
|
|
355
|
+
|
|
356
|
+
# calculate baseline and mask
|
|
357
|
+
|
|
358
|
+
if rewards.std(unbiased = False) < self.min_rewards_std:
|
|
359
|
+
u = torch.zeros_like(rewards)
|
|
360
|
+
else:
|
|
361
|
+
u = z_score(rewards)
|
|
362
|
+
|
|
363
|
+
mask = data.get('mask')
|
|
364
|
+
|
|
365
|
+
if not exists(mask):
|
|
366
|
+
assert exists(episode_lens), 'episode_lens must be returned by environment if mask is not provided'
|
|
367
|
+
mask = lens_to_mask(episode_lens, max_len = states.shape[1])
|
|
368
|
+
|
|
369
|
+
mask = mask.to(device)
|
|
370
|
+
|
|
371
|
+
episode_lens_float = mask.sum(dim = 1).clamp(min = 1.).float()
|
|
372
|
+
|
|
373
|
+
# target distribution
|
|
374
|
+
|
|
375
|
+
with torch.no_grad():
|
|
376
|
+
out = self.calculate_log_scores(states, actions, mask, episode_lens_float)
|
|
377
|
+
log_q = tpo_target(out.log_scores, u, self.eta)
|
|
378
|
+
|
|
379
|
+
# train policy
|
|
380
|
+
|
|
381
|
+
for epoch in range(self.epochs):
|
|
382
|
+
self.optimizer.zero_grad()
|
|
383
|
+
|
|
384
|
+
out = self.calculate_log_scores(states, actions, mask, episode_lens_float)
|
|
385
|
+
|
|
386
|
+
log_p = F.log_softmax(out.log_scores, dim = -1)
|
|
387
|
+
|
|
388
|
+
entropy = self.readout.entropy(out.logits)
|
|
389
|
+
entropy = masked_mean(entropy, mask)
|
|
390
|
+
|
|
391
|
+
loss = self.tpo_loss_fn(log_p, log_q)
|
|
392
|
+
loss = loss - self.entropy_coef * entropy
|
|
393
|
+
|
|
394
|
+
self.accelerator.backward(loss)
|
|
395
|
+
|
|
396
|
+
if exists(self.max_grad_norm):
|
|
397
|
+
self.accelerator.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
|
398
|
+
|
|
399
|
+
self.optimizer.step()
|
dmpo/vmpo.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dmpo
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Maximum a Posteriori Policy Optimization and Related Algorithms
|
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/dmpo/
|
|
6
|
+
Project-URL: Repository, https://codeberg.org/lucidrains/dmpo
|
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
|
8
|
+
License: MIT License
|
|
9
|
+
|
|
10
|
+
Copyright (c) 2026 Phil Wang
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
|
29
|
+
License-File: LICENSE
|
|
30
|
+
Keywords: artificial intelligence,deep learning,mpo,reinforcement learning,tpo
|
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
|
32
|
+
Classifier: Intended Audience :: Developers
|
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
34
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
|
+
Requires-Python: >=3.10
|
|
37
|
+
Requires-Dist: accelerate
|
|
38
|
+
Requires-Dist: discrete-continuous-embed-readout
|
|
39
|
+
Requires-Dist: einops>=0.8.1
|
|
40
|
+
Requires-Dist: memmap-replay-buffer>=0.1.4
|
|
41
|
+
Requires-Dist: torch-einops-utils>=0.1.2
|
|
42
|
+
Requires-Dist: torch>=2.5
|
|
43
|
+
Requires-Dist: tqdm
|
|
44
|
+
Requires-Dist: x-mlps-pytorch
|
|
45
|
+
Requires-Dist: x-transformers
|
|
46
|
+
Provides-Extra: examples
|
|
47
|
+
Provides-Extra: test
|
|
48
|
+
Requires-Dist: pytest; extra == 'test'
|
|
49
|
+
Description-Content-Type: text/markdown
|
|
50
|
+
|
|
51
|
+
## DMPO (wip)
|
|
52
|
+
|
|
53
|
+
Implementation and explorations into [MPO](https://arxiv.org/abs/1806.06920) / DMPO
|
|
54
|
+
|
|
55
|
+
## Citations
|
|
56
|
+
|
|
57
|
+
```bibtex
|
|
58
|
+
@article{Haarnoja_2024,
|
|
59
|
+
title = {Learning agile soccer skills for a bipedal robot with deep reinforcement learning},
|
|
60
|
+
volume = {9},
|
|
61
|
+
ISSN = {2470-9476},
|
|
62
|
+
url = {http://dx.doi.org/10.1126/scirobotics.adi8022},
|
|
63
|
+
DOI = {10.1126/scirobotics.adi8022},
|
|
64
|
+
number = {89},
|
|
65
|
+
journal = {Science Robotics},
|
|
66
|
+
publisher = {American Association for the Advancement of Science (AAAS)},
|
|
67
|
+
author = {Haarnoja, Tuomas and Moran, Ben and Lever, Guy and Huang, Sandy H. and Tirumala, Dhruva and Humplik, Jan and Wulfmeier, Markus and Tunyasuvunakool, Saran and Siegel, Noah Y. and Hafner, Roland and Bloesch, Michael and Hartikainen, Kristian and Byravan, Arunkumar and Hasenclever, Leonard and Tassa, Yuval and Sadeghi, Fereshteh and Batchelor, Nathan and Casarini, Federico and Saliceti, Stefano and Game, Charles and Sreendra, Neil and Patel, Kushal and Gwira, Marlon and Huber, Andrea and Hurley, Nicole and Nori, Francesco and Hadsell, Raia and Heess, Nicolas},
|
|
68
|
+
year = {2024},
|
|
69
|
+
month = {Apr}
|
|
70
|
+
}
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
```bibtex
|
|
74
|
+
@misc{abdolmaleki2018maximumposterioripolicyoptimisation,
|
|
75
|
+
title = {Maximum a Posteriori Policy Optimisation},
|
|
76
|
+
author = {Abbas Abdolmaleki and Jost Tobias Springenberg and Yuval Tassa and Remi Munos and Nicolas Heess and Martin Riedmiller},
|
|
77
|
+
year = {2018},
|
|
78
|
+
eprint = {1806.06920},
|
|
79
|
+
archivePrefix = {arXiv},
|
|
80
|
+
primaryClass = {cs.LG},
|
|
81
|
+
url = {https://arxiv.org/abs/1806.06920}
|
|
82
|
+
}
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
```bibtex
|
|
86
|
+
@misc{song2019vmpoonpolicymaximumposteriori,
|
|
87
|
+
title = {V-MPO: On-Policy Maximum a Posteriori Policy Optimization for Discrete and Continuous Control},
|
|
88
|
+
author = {H. Francis Song and Abbas Abdolmaleki and Jost Tobias Springenberg and Aidan Clark and Hubert Soyer and Jack W. Rae and Seb Noury and Arun Ahuja and Siqi Liu and Dhruva Tirumala and Nicolas Heess and Dan Belov and Martin Riedmiller and Matthew M. Botvinick},
|
|
89
|
+
year = {2019},
|
|
90
|
+
eprint = {1909.12238},
|
|
91
|
+
archivePrefix = {arXiv},
|
|
92
|
+
primaryClass = {cs.AI},
|
|
93
|
+
url = {https://arxiv.org/abs/1909.12238}
|
|
94
|
+
}
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
```bibtex
|
|
98
|
+
@InProceedings{pmlr-v235-li24z,
|
|
99
|
+
title = {Value-Evolutionary-Based Reinforcement Learning},
|
|
100
|
+
author = {Li, Pengyi and Hao, Jianye and Tang, Hongyao and Zheng, Yan and Barez, Fazl},
|
|
101
|
+
booktitle = {Proceedings of the 41st International Conference on Machine Learning},
|
|
102
|
+
pages = {27875--27889},
|
|
103
|
+
year = {2024},
|
|
104
|
+
editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
|
|
105
|
+
volume = {235},
|
|
106
|
+
series = {Proceedings of Machine Learning Research},
|
|
107
|
+
month = {21--27 Jul},
|
|
108
|
+
publisher = {PMLR},
|
|
109
|
+
pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/li24z/li24z.pdf},
|
|
110
|
+
url = {https://proceedings.mlr.press/v235/li24z.html}
|
|
111
|
+
}
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
```bibtex
|
|
115
|
+
@article{kaddour2026target,
|
|
116
|
+
title = {Target Policy Optimization},
|
|
117
|
+
author = {Kaddour, Jean},
|
|
118
|
+
journal = {arXiv preprint arXiv:2604.06159},
|
|
119
|
+
year = {2026}
|
|
120
|
+
}
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
```bibtex
|
|
124
|
+
@misc{qu2026listwisepolicyoptimizationgroupbased,
|
|
125
|
+
title = {Listwise Policy Optimization: Group-based RLVR as Target-Projection on the LLM Response Simplex},
|
|
126
|
+
author = {Yun Qu and Qi Wang and Yixiu Mao and Heming Zou and Yuhang Jiang and Yingyue Li and Wutong Xu and Lizhou Cai and Weijie Liu and Clive Bai and Kai Yang and Yangkun Chen and Saiyong Yang and Xiangyang Ji},
|
|
127
|
+
year = {2026},
|
|
128
|
+
eprint = {2605.06139},
|
|
129
|
+
archivePrefix = {arXiv},
|
|
130
|
+
primaryClass = {cs.LG},
|
|
131
|
+
url = {https://arxiv.org/abs/2605.06139},
|
|
132
|
+
}
|
|
133
|
+
```
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
dmpo/__init__.py,sha256=xr_TuLuVsdqzvKD60YjGfkm-5M-2b4QNokgCv56GzAU,25
|
|
2
|
+
dmpo/dmpo.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
3
|
+
dmpo/mpo.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
4
|
+
dmpo/tpo.py,sha256=VOdr11o6TfG9sFaiPrwhw22oK0IDB9GgosgdmmzsLrw,13085
|
|
5
|
+
dmpo/vmpo.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
6
|
+
dmpo-0.0.2.dist-info/METADATA,sha256=VGKdkbOJaX969X7JLnM1841Y1WB0GtMRN6zmbACJwgE,6154
|
|
7
|
+
dmpo-0.0.2.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
8
|
+
dmpo-0.0.2.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
|
|
9
|
+
dmpo-0.0.2.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Phil Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|