smplfitter 0.3.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. smplfitter/__init__.py +20 -0
  2. smplfitter/_version.py +34 -0
  3. smplfitter/common.py +250 -0
  4. smplfitter/decimation/__init__.py +1 -0
  5. smplfitter/decimation/decimate_body_models.py +55 -0
  6. smplfitter/decimation/make_post_lbs_joint_regressors.py +290 -0
  7. smplfitter/decimation/make_post_lbs_joint_regressors_tf.py +272 -0
  8. smplfitter/download.py +412 -0
  9. smplfitter/jax/__init__.py +25 -0
  10. smplfitter/jax/bodyconverter.py +156 -0
  11. smplfitter/jax/bodyfitter.py +1004 -0
  12. smplfitter/jax/bodymodel.py +189 -0
  13. smplfitter/jax/lstsq.py +119 -0
  14. smplfitter/jax/rotation.py +98 -0
  15. smplfitter/nb/__init__.py +25 -0
  16. smplfitter/nb/bodyconverter.py +153 -0
  17. smplfitter/nb/bodyfitter.py +1321 -0
  18. smplfitter/nb/bodymodel.py +443 -0
  19. smplfitter/nb/lstsq.py +315 -0
  20. smplfitter/nb/precompile.py +109 -0
  21. smplfitter/nb/rotation.py +120 -0
  22. smplfitter/nb/util.py +58 -0
  23. smplfitter/np/__init__.py +24 -0
  24. smplfitter/np/bodyconverter.py +153 -0
  25. smplfitter/np/bodyfitter.py +966 -0
  26. smplfitter/np/bodymodel.py +307 -0
  27. smplfitter/np/lstsq.py +69 -0
  28. smplfitter/np/rotation.py +78 -0
  29. smplfitter/np/util.py +18 -0
  30. smplfitter/pt/__init__.py +158 -0
  31. smplfitter/pt/bodyconverter.py +170 -0
  32. smplfitter/pt/bodyfitter.py +1093 -0
  33. smplfitter/pt/bodyfitter_opt.py +255 -0
  34. smplfitter/pt/bodyflipper.py +169 -0
  35. smplfitter/pt/bodyflipper_opt.py +181 -0
  36. smplfitter/pt/bodymodel.py +404 -0
  37. smplfitter/pt/handreplacer.py +79 -0
  38. smplfitter/pt/lstsq.py +90 -0
  39. smplfitter/pt/rotation.py +73 -0
  40. smplfitter/tf/__init__.py +188 -0
  41. smplfitter/tf/bodyconverter.py +168 -0
  42. smplfitter/tf/bodyfitter.py +1074 -0
  43. smplfitter/tf/bodymodel.py +329 -0
  44. smplfitter/tf/lstsq.py +91 -0
  45. smplfitter/tf/rotation.py +66 -0
  46. smplfitter/tf/util.py +12 -0
  47. smplfitter-0.3.2.dev0.dist-info/METADATA +287 -0
  48. smplfitter-0.3.2.dev0.dist-info/RECORD +51 -0
  49. smplfitter-0.3.2.dev0.dist-info/WHEEL +5 -0
  50. smplfitter-0.3.2.dev0.dist-info/licenses/LICENSE +21 -0
  51. smplfitter-0.3.2.dev0.dist-info/top_level.txt +1 -0
smplfitter/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """SMPLFitter provides forward and inverse kinematics for SMPL-family body models.
2
+
3
+ Main submodules:
4
+ - :mod:`smplfitter.np` - NumPy backend
5
+ - :mod:`smplfitter.pt` - PyTorch backend
6
+ - :mod:`smplfitter.tf` - TensorFlow backend
7
+ - :mod:`smplfitter.jax` - JAX backend
8
+ - :mod:`smplfitter.nb` - Numba backend
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from .common import ModelData, initialize
14
+
15
+ try:
16
+ from ._version import version as __version__
17
+ except ImportError:
18
+ __version__ = '0.0.0'
19
+
20
+ __all__ = ['ModelData', 'initialize', '__version__']
smplfitter/_version.py ADDED
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.3.2.dev0'
32
+ __version_tuple__ = version_tuple = (0, 3, 2, 'dev0')
33
+
34
+ __commit_id__ = commit_id = None
smplfitter/common.py ADDED
@@ -0,0 +1,250 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import os
5
+ import os.path as osp
6
+ import pickle
7
+ import sys
8
+ import warnings
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+
13
+
14
+ @dataclass
15
+ class ModelData:
16
+ """Data loaded from a SMPL-family body model file.
17
+
18
+ This dataclass holds all arrays and metadata needed to instantiate a body model
19
+ in any backend (NumPy, PyTorch, TensorFlow, JAX, Numba).
20
+ """
21
+
22
+ # Tensor data (numpy arrays, to be converted to framework-specific tensors)
23
+ v_template: np.ndarray
24
+ """Vertex template in T-pose, shape (num_vertices, 3)."""
25
+
26
+ shapedirs: np.ndarray
27
+ """Shape blend shapes, shape (num_vertices, 3, num_betas)."""
28
+
29
+ posedirs: np.ndarray
30
+ """Pose blend shapes, shape (num_vertices, 3, (num_joints-1)*9)."""
31
+
32
+ J_regressor_post_lbs: np.ndarray
33
+ """Joint regressor for post-LBS joint locations, shape (num_joints, num_vertices)."""
34
+
35
+ J_template: np.ndarray
36
+ """Joint template positions, shape (num_joints, 3)."""
37
+
38
+ J_shapedirs: np.ndarray
39
+ """Joint shape directions, shape (num_joints, 3, num_betas)."""
40
+
41
+ kid_shapedir: np.ndarray
42
+ """Kid shape blend shape for vertices, shape (num_vertices, 3)."""
43
+
44
+ kid_J_shapedir: np.ndarray
45
+ """Kid shape blend shape for joints, shape (num_joints, 3)."""
46
+
47
+ weights: np.ndarray
48
+ """Skinning weights, shape (num_vertices, num_joints)."""
49
+
50
+ # Non-tensor data (metadata)
51
+ kintree_parents: list[int]
52
+ """Parent joint indices for kinematic tree."""
53
+
54
+ faces: np.ndarray
55
+ """Face indices, shape (num_faces, 3)."""
56
+
57
+ num_joints: int
58
+ """Number of joints in the body model."""
59
+
60
+ num_vertices: int
61
+ """Number of vertices in the body model mesh."""
62
+
63
+ vertex_subset: np.ndarray
64
+ """Indices of vertices used (for partial models)."""
65
+
66
+
67
+ def initialize(
68
+ model_name,
69
+ gender,
70
+ model_root=None,
71
+ num_betas=None,
72
+ vertex_subset_size=None,
73
+ vertex_subset=None,
74
+ faces=None,
75
+ joint_regressor_post_lbs=None,
76
+ ):
77
+ if model_root is None:
78
+ body_models_dir = os.getenv('SMPLFITTER_BODY_MODELS')
79
+ if body_models_dir is None:
80
+ data_root = os.getenv('DATA_ROOT', '.')
81
+ body_models_dir = f'{data_root}/body_models'
82
+ model_root = f'{body_models_dir}/{model_name}'
83
+
84
+ with monkey_patched_for_chumpy():
85
+ if model_name == 'smpl':
86
+ gender_str = dict(f='f', m='m', n='neutral')[gender[0]]
87
+ filename = f'basicmodel_{gender_str}_lbs_10_207_0_v1.1.0.pkl'
88
+ elif model_name in ('smplx', 'smplxlh', 'smplxmoyo'):
89
+ gender_str = dict(f='FEMALE', m='MALE', n='NEUTRAL')[gender[0]]
90
+ filename = f'SMPLX_{gender_str}.npz'
91
+ elif model_name == 'smplh':
92
+ gender_str = dict(f='female', m='male')[gender[0]]
93
+ filename = f'SMPLH_{gender_str}.pkl'
94
+ elif model_name == 'smplh16':
95
+ gender_str = dict(f='female', m='male', n='neutral')[gender[0]]
96
+ filename = osp.join(gender_str, 'model.npz')
97
+ elif model_name == 'mano':
98
+ filename = 'MANO_RIGHT.pkl'
99
+ else:
100
+ raise ValueError(f'Unknown model name: {model_name}')
101
+
102
+ filepath = osp.join(model_root, filename)
103
+ try:
104
+ if filename.endswith('.npz'):
105
+ smpl_data = np.load(filepath)
106
+ else:
107
+ with open(filepath, 'rb') as f:
108
+ smpl_data = pickle.load(f, encoding='latin1')
109
+ except FileNotFoundError:
110
+ raise FileNotFoundError(
111
+ f'Body model file not found: {filepath}\n\n'
112
+ f'Set the body model location using one of:\n'
113
+ f" 1. BodyModel('{model_name}', '{gender}', "
114
+ f"model_root='/your/path/body_models/{model_name}')\n"
115
+ f' 2. export SMPLFITTER_BODY_MODELS=/your/path/body_models\n'
116
+ f' 3. export DATA_ROOT=/your/path '
117
+ f'(looks for $DATA_ROOT/body_models/)\n\n'
118
+ f'Download models: python -m smplfitter.download\n'
119
+ f'Register first at https://smpl.is.tue.mpg.de/'
120
+ ) from None
121
+
122
+ res = {}
123
+ res['shapedirs'] = np.array(smpl_data['shapedirs'], dtype=np.float64)
124
+ res['posedirs'] = np.array(smpl_data['posedirs'], dtype=np.float64)
125
+ res['v_template'] = np.array(smpl_data['v_template'], dtype=np.float64)
126
+
127
+ if not isinstance(smpl_data['J_regressor'], np.ndarray):
128
+ res['J_regressor'] = np.array(smpl_data['J_regressor'].toarray(), dtype=np.float64)
129
+ else:
130
+ res['J_regressor'] = smpl_data['J_regressor'].astype(np.float64)
131
+
132
+ res['weights'] = np.array(smpl_data['weights'])
133
+ res['faces'] = np.array(smpl_data['f'].astype(np.int32))
134
+ res['kintree_parents'] = smpl_data['kintree_table'][0].tolist()
135
+ res['num_joints'] = len(res['kintree_parents'])
136
+ res['num_vertices'] = len(res['v_template'])
137
+
138
+ # Kid model has an additional shape parameter which pulls the mesh towards the SMIL mean
139
+ # template
140
+ if model_name.lower().startswith('smpl'):
141
+ kid_path = os.path.join(model_root, 'kid_template.npy')
142
+ try:
143
+ v_template_smil = np.load(kid_path).astype(np.float64)
144
+ except FileNotFoundError:
145
+ raise FileNotFoundError(
146
+ f'Kid template not found: {kid_path}\n'
147
+ f'Download it: python -m smplfitter.download'
148
+ ) from None
149
+ res['kid_shapedir'] = (
150
+ v_template_smil - np.mean(v_template_smil, axis=0) - res['v_template']
151
+ )
152
+ res['kid_J_shapedir'] = res['J_regressor'] @ res['kid_shapedir']
153
+ else:
154
+ res['kid_shapedir'] = np.zeros_like(res['v_template'])
155
+ res['kid_J_shapedir'] = np.zeros((res['num_joints'], 3))
156
+
157
+ if 'J_shapedirs' in smpl_data:
158
+ res['J_shapedirs'] = np.array(smpl_data['J_shapedirs'], dtype=np.float64)
159
+ else:
160
+ res['J_shapedirs'] = np.einsum('jv,vcs->jcs', res['J_regressor'], res['shapedirs'])
161
+
162
+ if 'J_template' in smpl_data:
163
+ res['J_template'] = np.array(smpl_data['J_template'], dtype=np.float64)
164
+ else:
165
+ res['J_template'] = res['J_regressor'] @ res['v_template']
166
+
167
+ res['v_template'] = res['v_template'] - np.einsum(
168
+ 'vcx,x->vc',
169
+ res['posedirs'],
170
+ np.reshape(np.tile(np.eye(3, dtype=np.float64), [res['num_joints'] - 1, 1]), [-1]),
171
+ )
172
+
173
+ if vertex_subset_size is not None:
174
+ vertex_subset_dict = np.load(f'{model_root}/vertex_subset_{vertex_subset_size}.npz')
175
+ vertex_subset = vertex_subset_dict['i_verts']
176
+ faces = vertex_subset_dict['faces']
177
+ joint_regressor_post_lbs = np.load(
178
+ f'{model_root}/vertex_subset_joint_regr_post_lbs_{vertex_subset_size}.npy'
179
+ )
180
+
181
+ if vertex_subset is None:
182
+ vertex_subset = np.arange(res['num_vertices'], dtype=np.int64)
183
+ else:
184
+ vertex_subset = np.array(vertex_subset, dtype=np.int64)
185
+
186
+ if faces is None:
187
+ faces = res['faces']
188
+
189
+ if joint_regressor_post_lbs is None:
190
+ joint_regressor_post_lbs = res['J_regressor']
191
+
192
+ return ModelData(
193
+ v_template=res['v_template'][vertex_subset],
194
+ shapedirs=res['shapedirs'][vertex_subset, :, :num_betas],
195
+ posedirs=res['posedirs'][vertex_subset],
196
+ J_regressor_post_lbs=joint_regressor_post_lbs,
197
+ J_template=res['J_template'],
198
+ J_shapedirs=res['J_shapedirs'][:, :, :num_betas],
199
+ kid_shapedir=res['kid_shapedir'][vertex_subset],
200
+ kid_J_shapedir=res['kid_J_shapedir'],
201
+ weights=res['weights'][vertex_subset],
202
+ kintree_parents=res['kintree_parents'],
203
+ faces=faces,
204
+ num_joints=res['num_joints'],
205
+ num_vertices=len(vertex_subset),
206
+ vertex_subset=vertex_subset,
207
+ )
208
+
209
+
210
+ @contextlib.contextmanager
211
+ def monkey_patched_for_chumpy():
212
+ """The pickle file of SMPLH imports chumpy and it tries to import np.bool etc which are
213
+ not available anymore.
214
+ """
215
+ added = []
216
+ for name in ['bool', 'int', 'object', 'str']:
217
+ if name not in dir(np):
218
+ try:
219
+ sys.modules[f'numpy.{name}'] = getattr(np, name + '_')
220
+ added.append(name)
221
+ except AttributeError:
222
+ pass
223
+
224
+ sys.modules['numpy.float'] = float
225
+ sys.modules['numpy.complex'] = np.complex128
226
+ sys.modules['numpy.NINF'] = -np.inf
227
+ np.NINF = -np.inf # type: ignore[misc]
228
+ np.complex = np.complex128 # type: ignore[misc]
229
+ np.float = float # type: ignore[misc]
230
+
231
+ if 'unicode' not in dir(np):
232
+ sys.modules['numpy.unicode'] = np.str_
233
+ added.append('unicode')
234
+
235
+ import inspect
236
+
237
+ added_getargspec = False
238
+ if not hasattr(inspect, 'getargspec'):
239
+ inspect.getargspec = inspect.getfullargspec
240
+ added_getargspec = True
241
+
242
+ with warnings.catch_warnings():
243
+ warnings.simplefilter('ignore', FutureWarning)
244
+ yield
245
+
246
+ for name in added:
247
+ del sys.modules[f'numpy.{name}']
248
+
249
+ if added_getargspec:
250
+ del inspect.getargspec
@@ -0,0 +1 @@
1
+ from __future__ import annotations
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+ import os
3
+
4
+ import numpy as np
5
+ from .. import np as smplfitter_np
6
+ import trimesh
7
+ from scipy.optimize import linear_sum_assignment
8
+ from scipy.spatial.distance import cdist
9
+
10
+
11
+ def main():
12
+ DATA_ROOT = os.getenv('DATA_ROOT', default='.')
13
+ for model_name in ['smpl', 'smplx']:
14
+ model_root = f'{DATA_ROOT}/body_models/{model_name}'
15
+ bm = smplfitter_np.get_cached_body_model(model_name)
16
+ verts = bm.single()['vertices']
17
+
18
+ for n in [32, 64, 128, 256, 512, 1024]:
19
+ i_verts, faces = decimate(verts, bm.faces, n)
20
+ if len(i_verts) != n:
21
+ print(f'Failed to decimate to {n} vertices')
22
+ continue
23
+ else:
24
+ print(f'Decimated to {n} vertices')
25
+ np.savez(f'{model_root}/vertex_subset_{n}.npz', i_verts=i_verts, faces=faces)
26
+
27
+
28
+ def _decimate(verts, faces, n_verts_out=128):
29
+ n_faces = 2 * n_verts_out - 4
30
+ decimated_mesh = trimesh.Trimesh(verts, faces).simplify_quadric_decimation(face_count=n_faces)
31
+ row_ind, col_ind = linear_sum_assignment(cdist(verts, decimated_mesh.vertices))
32
+ i_verts = row_ind[np.argsort(col_ind)]
33
+ return i_verts, decimated_mesh.faces
34
+
35
+
36
+ def decimate(verts, faces, n_verts_out=128, n_trials=100):
37
+ # We may have to try it multiple times to get the desired number of vertices
38
+ # since the decimation algorithm may not always return the exact number of vertices we ask.
39
+ # The documentation explicitly states this.
40
+
41
+ n_verts_out_arg = n_verts_out
42
+ i_verts, new_faces = _decimate(verts, faces, n_verts_out_arg)
43
+
44
+ for _ in range(n_trials):
45
+ if i_verts.shape[0] == n_verts_out:
46
+ break
47
+
48
+ n_verts_out_arg += n_verts_out - i_verts.shape[0]
49
+ i_verts, new_faces = _decimate(verts, faces, n_verts_out_arg)
50
+
51
+ return i_verts, new_faces
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
@@ -0,0 +1,290 @@
1
+ from __future__ import annotations
2
+ # This module finds joint regressors for subsets of body model vertices.
3
+ # That is, it finds a linear regressor that maps a subset of vertices to the joints.
4
+ # And this is done for the post-LBS case, i.e. the vertices and joints are posed with linear
5
+ # blend skinning.
6
+ # This is useful in body model fitting in case the user does not provide the joint positions.
7
+ # In that case we must run a regressor on the (already posed) vertices.
8
+ #
9
+ # The joint regressors need to create convex combinations, i.e. the weights must be non-negative
10
+ # and sum to 1 for each joint. They should also be sparse and spatially compact.
11
+ # That is, each joint should only depend on a few vertices that are spatially close to each other.
12
+ # This is achieved by adding a regularization term that encourages the weights to be
13
+ # sparse (L1/2 norm), and a term that computes the weighted variance of the template vertices.
14
+
15
+ import os
16
+
17
+ import numpy as np
18
+ from .. import pt as smplfitter_pt
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torch.utils.data import DataLoader, IterableDataset
23
+ from tqdm import tqdm
24
+
25
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ DATA_ROOT = os.getenv('DATA_ROOT', default='.')
27
+ torch.backends.cudnn.benchmark = True
28
+
29
+
30
+ def main():
31
+ for body_model_name in ['smpl', 'smplx']:
32
+ canonical_verts = torch.from_numpy(
33
+ np.load(f'{DATA_ROOT}/nlf/canonical_vertices_{body_model_name}.npy')
34
+ ).float()
35
+ body_model = smplfitter_pt.BodyModel(body_model_name).to(DEVICE)
36
+ dataset = RandomBodyParamDataset(num_betas=16, num_joints=body_model.num_joints)
37
+ dataloader = DataLoader(dataset, batch_size=None, num_workers=4, pin_memory=True)
38
+
39
+ for n_verts_subset in reversed(
40
+ [32, 64, 128, 256, 512, 1024, 2048, 4096, body_model.num_vertices]
41
+ ):
42
+ print(f'Fitting joint regressor for {n_verts_subset} vertices')
43
+ out_path = (
44
+ f'{DATA_ROOT}/body_models/{body_model_name}/'
45
+ f'vertex_subset_joint_regr_post_lbs_{n_verts_subset}.npy'
46
+ )
47
+ # if os.path.exists(out_path):
48
+ # continue
49
+
50
+ if n_verts_subset == body_model.num_vertices:
51
+ i_verts = torch.arange(body_model.num_vertices)
52
+ else:
53
+ subset = np.load(
54
+ f'{DATA_ROOT}/body_models/{body_model_name}/vertex_subset_{n_verts_subset}.npz'
55
+ )
56
+ i_verts = torch.from_numpy(subset['i_verts'])
57
+
58
+ model = ConvexCombiningRegressor(len(i_verts), body_model.num_joints).to(DEVICE)
59
+ trainer = ConvexCombiningRegressorTrainer(
60
+ model=model,
61
+ body_model=body_model,
62
+ template_verts=canonical_verts[i_verts],
63
+ vertex_subset=i_verts,
64
+ regul_lambda=6e-5,
65
+ )
66
+ optimizer = optim.Adam(model.parameters(), lr=1e0)
67
+ scheduler = optim.lr_scheduler.LambdaLR(
68
+ optimizer, lr_lambda=lambda step: 1.0 if step < int(37500 * 0.9) else 1e-3
69
+ )
70
+ trainer.train_loop(
71
+ dataloader, total_steps=37500, optimizer=optimizer, scheduler=scheduler
72
+ )
73
+ model.threshold_for_sparsity(1e-3)
74
+
75
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 30)
76
+ trainer.train_loop(
77
+ dataloader, total_steps=37500 + 12500, optimizer=optimizer, scheduler=scheduler
78
+ )
79
+ model.threshold_for_sparsity(1e-3)
80
+
81
+ J_subset = model.get_w().cpu().detach().numpy().T
82
+ print(f'Sparsity ratio: {sparsity_ratio(J_subset)}')
83
+ np.save(out_path, J_subset)
84
+
85
+
86
+ class ConvexCombiningRegressor(nn.Module):
87
+ def __init__(self, n_in_points, n_out_points):
88
+ super().__init__()
89
+ self.n_in = n_in_points
90
+ self.n_out = n_out_points
91
+
92
+ self.weight_mask = nn.Parameter(torch.ones(n_in_points, n_out_points), requires_grad=False)
93
+ self.weights = nn.Parameter(torch.empty(n_in_points, n_out_points))
94
+ nn.init.uniform_(self.weights, -1, 1)
95
+
96
+ def forward(self, x):
97
+ w = self.get_w()
98
+ return torch.einsum('bjc,jJ->bJc', x, w)
99
+
100
+ def get_w(self):
101
+ w = torch.nn.functional.softplus(self.weights) * self.weight_mask
102
+ return w / w.sum(dim=0, keepdim=True)
103
+
104
+ def threshold_for_sparsity(self, threshold=1e-3):
105
+ with torch.no_grad():
106
+ self.weight_mask.data = (torch.abs(self.get_w()) > threshold).float()
107
+
108
+
109
+ class ConvexCombiningRegressorTrainer:
110
+ def __init__(self, model, body_model, template_verts, vertex_subset, regul_lambda=3e-5):
111
+ self.model = model
112
+ self.body_model = body_model
113
+ self.template_verts = template_verts.to(DEVICE)
114
+ self.vertex_subset = vertex_subset
115
+ self.regul_lambda = regul_lambda
116
+ self.current_step = 0
117
+
118
+ def train_loop(self, dataloader, total_steps, optimizer, scheduler=None):
119
+ progress_bar = tqdm(
120
+ total=total_steps,
121
+ desc='Training',
122
+ unit='step',
123
+ dynamic_ncols=True,
124
+ initial=self.current_step,
125
+ )
126
+
127
+ for batch in dataloader:
128
+ batch = EasyDict({k: v.to(DEVICE) for k, v in batch.items()})
129
+ need_metrics = self.current_step % 100 == 0
130
+ losses, metrics = self.train_step(
131
+ batch, optimizer, scheduler, need_metrics=need_metrics
132
+ )
133
+ self.current_step += 1
134
+ progress_bar.update(1)
135
+
136
+ if need_metrics:
137
+ losses_and_metrics_str = ', '.join(
138
+ f'{k}: {v.item():.4f}' for k, v in (losses | metrics).items()
139
+ )
140
+ progress_bar.set_postfix_str(losses_and_metrics_str)
141
+
142
+ if self.current_step >= total_steps:
143
+ break
144
+
145
+ def forward_train(self, inps):
146
+ preds = EasyDict()
147
+ r = self.body_model(inps.pose, inps.shape)
148
+ preds.postjoints = self.body_model.J_regressor_post_lbs @ r['vertices']
149
+ preds.pose3d = self.model(r['vertices'][:, self.vertex_subset])
150
+ preds.pose3d_gt = r['joints']
151
+ return preds
152
+
153
+ def train_step(self, inps, optimizer, scheduler=None, need_metrics=True):
154
+ self.model.train()
155
+ optimizer.zero_grad()
156
+ preds = self.forward_train(inps)
157
+ losses = self.compute_losses(inps, preds)
158
+ losses.loss.backward()
159
+ optimizer.step()
160
+ if scheduler is not None:
161
+ scheduler.step()
162
+
163
+ if need_metrics:
164
+ with torch.no_grad():
165
+ metrics = self.compute_metrics(inps, preds)
166
+ else:
167
+ metrics = None
168
+
169
+ return losses, metrics
170
+
171
+ def compute_losses(self, inps, preds):
172
+ losses = EasyDict()
173
+ losses.main_loss = torch.mean(torch.abs(preds.pose3d_gt - preds.pose3d))
174
+
175
+ w = self.model.get_w()
176
+ losses.regul = torch.sum(soft_sqrt(torch.abs(w), 1e-5)) / w.shape[1]
177
+ # losses.supp = mean_spatial_support(self.template_verts, w)
178
+ losses.loss = (
179
+ losses.main_loss + self.regul_lambda * losses.regul # + self.supp_lambda * losses.supp
180
+ )
181
+ return losses
182
+
183
+ def compute_metrics(self, inps, preds):
184
+ m = EasyDict()
185
+ dist = torch.norm(preds.pose3d_gt - preds.pose3d, dim=-1)
186
+ m.pck1 = pck(dist, 0.01)
187
+ m.pck2 = pck(dist, 0.02)
188
+ m.pck3 = pck(dist, 0.03)
189
+ m.pck7 = pck(dist, 0.07)
190
+ m.euclidean = torch.mean(dist)
191
+ m.l1 = torch.mean(torch.abs(preds.pose3d_gt - preds.pose3d))
192
+ m.max_supp = torch.max(
193
+ torch.sqrt(spatial_support(self.template_verts, self.model.get_w()))
194
+ )
195
+
196
+ dist = torch.norm(preds.pose3d_gt - preds.postjoints, dim=-1)
197
+ m.pck1_post = pck(dist, 0.01)
198
+ return m
199
+
200
+
201
+ class RandomBodyParamDataset(IterableDataset):
202
+ def __init__(self, batch_size=72, num_joints=24, num_betas=16):
203
+ self.batch_size = batch_size
204
+ self.num_joints = num_joints
205
+ self.num_betas = num_betas
206
+ self.rng = torch.Generator()
207
+
208
+ def __iter__(self):
209
+ with torch.no_grad():
210
+ while True:
211
+ pose = self.random_smpl_pose() / 3
212
+ shape = torch.empty((self.batch_size, self.num_betas))
213
+ nn.init.trunc_normal_(shape, mean=0.0, std=1.0, a=-2.0, b=2.0, generator=self.rng)
214
+ yield {'pose': pose, 'shape': shape * 10}
215
+
216
+ def random_rotvec(self, batch_size):
217
+ rand = torch.rand(3, batch_size, generator=self.rng)
218
+ r1 = torch.sqrt(1 - rand[0])
219
+ r2 = torch.sqrt(rand[0])
220
+ t1 = 2 * np.pi * rand[1]
221
+ t2 = 2 * np.pi * rand[2]
222
+ cost2 = torch.cos(t2)
223
+ xyz = torch.stack([torch.sin(t1) * r1, torch.cos(t1) * r1, torch.sin(t2) * r2], dim=0)
224
+ return (xyz / torch.sqrt(1 - cost2**2 * rand[0]) * 2 * torch.acos(cost2 * r2)).T
225
+
226
+ def random_smpl_pose(self):
227
+ return self.random_rotvec(self.batch_size * self.num_joints).reshape(
228
+ self.batch_size, self.num_joints * 3
229
+ )
230
+
231
+
232
+ def sparsity_ratio(J):
233
+ return np.count_nonzero(np.abs(J) > 1e-4) / J.shape[0]
234
+
235
+
236
+ def soft_sqrt(x, eps):
237
+ return x / torch.sqrt(x + eps)
238
+
239
+
240
+ def spatial_support(template, weights):
241
+ weighted_mean = torch.matmul(weights.t(), template)
242
+ sq_diff = (template[np.newaxis, :] - weighted_mean[:, np.newaxis]) ** 2
243
+ sq_dists = torch.sum(sq_diff, dim=-1)
244
+ return torch.einsum('Jj,jJ->J', sq_dists, torch.abs(weights))
245
+
246
+
247
+ def pck(x, t):
248
+ return (x <= t).float().mean()
249
+
250
+
251
+ class EasyDict(dict):
252
+ def __init__(self, d=None, **kwargs):
253
+ super().__init__()
254
+ if d is None:
255
+ d = {}
256
+ else:
257
+ d = dict(d)
258
+ if kwargs:
259
+ d.update(**kwargs)
260
+ for k, v in d.items():
261
+ setattr(self, k, v)
262
+ # Class attributes
263
+ for k in self.__class__.__dict__.keys():
264
+ if not (k.startswith('__') and k.endswith('__')) and k not in ('update', 'pop'):
265
+ setattr(self, k, getattr(self, k))
266
+
267
+ def __setattr__(self, name, value):
268
+ super().__setitem__(name, value)
269
+
270
+ def __getattr__(self, name):
271
+ try:
272
+ return super().__getitem__(name)
273
+ except KeyError:
274
+ raise AttributeError(name)
275
+
276
+ __setitem__ = __setattr__
277
+
278
+ def update(self, e=None, **f):
279
+ d = e or dict()
280
+ d.update(f)
281
+ for k, v in d.items():
282
+ setattr(self, k, v)
283
+
284
+ def pop(self, k, d=None):
285
+ delattr(self, k)
286
+ return super().pop(k, d)
287
+
288
+
289
+ if __name__ == '__main__':
290
+ main()