smplfitter 0.3.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smplfitter/__init__.py +20 -0
- smplfitter/_version.py +34 -0
- smplfitter/common.py +250 -0
- smplfitter/decimation/__init__.py +1 -0
- smplfitter/decimation/decimate_body_models.py +55 -0
- smplfitter/decimation/make_post_lbs_joint_regressors.py +290 -0
- smplfitter/decimation/make_post_lbs_joint_regressors_tf.py +272 -0
- smplfitter/download.py +412 -0
- smplfitter/jax/__init__.py +25 -0
- smplfitter/jax/bodyconverter.py +156 -0
- smplfitter/jax/bodyfitter.py +1004 -0
- smplfitter/jax/bodymodel.py +189 -0
- smplfitter/jax/lstsq.py +119 -0
- smplfitter/jax/rotation.py +98 -0
- smplfitter/nb/__init__.py +25 -0
- smplfitter/nb/bodyconverter.py +153 -0
- smplfitter/nb/bodyfitter.py +1321 -0
- smplfitter/nb/bodymodel.py +443 -0
- smplfitter/nb/lstsq.py +315 -0
- smplfitter/nb/precompile.py +109 -0
- smplfitter/nb/rotation.py +120 -0
- smplfitter/nb/util.py +58 -0
- smplfitter/np/__init__.py +24 -0
- smplfitter/np/bodyconverter.py +153 -0
- smplfitter/np/bodyfitter.py +966 -0
- smplfitter/np/bodymodel.py +307 -0
- smplfitter/np/lstsq.py +69 -0
- smplfitter/np/rotation.py +78 -0
- smplfitter/np/util.py +18 -0
- smplfitter/pt/__init__.py +158 -0
- smplfitter/pt/bodyconverter.py +170 -0
- smplfitter/pt/bodyfitter.py +1093 -0
- smplfitter/pt/bodyfitter_opt.py +255 -0
- smplfitter/pt/bodyflipper.py +169 -0
- smplfitter/pt/bodyflipper_opt.py +181 -0
- smplfitter/pt/bodymodel.py +404 -0
- smplfitter/pt/handreplacer.py +79 -0
- smplfitter/pt/lstsq.py +90 -0
- smplfitter/pt/rotation.py +73 -0
- smplfitter/tf/__init__.py +188 -0
- smplfitter/tf/bodyconverter.py +168 -0
- smplfitter/tf/bodyfitter.py +1074 -0
- smplfitter/tf/bodymodel.py +329 -0
- smplfitter/tf/lstsq.py +91 -0
- smplfitter/tf/rotation.py +66 -0
- smplfitter/tf/util.py +12 -0
- smplfitter-0.3.2.dev0.dist-info/METADATA +287 -0
- smplfitter-0.3.2.dev0.dist-info/RECORD +51 -0
- smplfitter-0.3.2.dev0.dist-info/WHEEL +5 -0
- smplfitter-0.3.2.dev0.dist-info/licenses/LICENSE +21 -0
- smplfitter-0.3.2.dev0.dist-info/top_level.txt +1 -0
smplfitter/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""SMPLFitter provides forward and inverse kinematics for SMPL-family body models.
|
|
2
|
+
|
|
3
|
+
Main submodules:
|
|
4
|
+
- :mod:`smplfitter.np` - NumPy backend
|
|
5
|
+
- :mod:`smplfitter.pt` - PyTorch backend
|
|
6
|
+
- :mod:`smplfitter.tf` - TensorFlow backend
|
|
7
|
+
- :mod:`smplfitter.jax` - JAX backend
|
|
8
|
+
- :mod:`smplfitter.nb` - Numba backend
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from .common import ModelData, initialize
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from ._version import version as __version__
|
|
17
|
+
except ImportError:
|
|
18
|
+
__version__ = '0.0.0'
|
|
19
|
+
|
|
20
|
+
__all__ = ['ModelData', 'initialize', '__version__']
|
smplfitter/_version.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.3.2.dev0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 2, 'dev0')
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
smplfitter/common.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import os
|
|
5
|
+
import os.path as osp
|
|
6
|
+
import pickle
|
|
7
|
+
import sys
|
|
8
|
+
import warnings
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ModelData:
|
|
16
|
+
"""Data loaded from a SMPL-family body model file.
|
|
17
|
+
|
|
18
|
+
This dataclass holds all arrays and metadata needed to instantiate a body model
|
|
19
|
+
in any backend (NumPy, PyTorch, TensorFlow, JAX, Numba).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Tensor data (numpy arrays, to be converted to framework-specific tensors)
|
|
23
|
+
v_template: np.ndarray
|
|
24
|
+
"""Vertex template in T-pose, shape (num_vertices, 3)."""
|
|
25
|
+
|
|
26
|
+
shapedirs: np.ndarray
|
|
27
|
+
"""Shape blend shapes, shape (num_vertices, 3, num_betas)."""
|
|
28
|
+
|
|
29
|
+
posedirs: np.ndarray
|
|
30
|
+
"""Pose blend shapes, shape (num_vertices, 3, (num_joints-1)*9)."""
|
|
31
|
+
|
|
32
|
+
J_regressor_post_lbs: np.ndarray
|
|
33
|
+
"""Joint regressor for post-LBS joint locations, shape (num_joints, num_vertices)."""
|
|
34
|
+
|
|
35
|
+
J_template: np.ndarray
|
|
36
|
+
"""Joint template positions, shape (num_joints, 3)."""
|
|
37
|
+
|
|
38
|
+
J_shapedirs: np.ndarray
|
|
39
|
+
"""Joint shape directions, shape (num_joints, 3, num_betas)."""
|
|
40
|
+
|
|
41
|
+
kid_shapedir: np.ndarray
|
|
42
|
+
"""Kid shape blend shape for vertices, shape (num_vertices, 3)."""
|
|
43
|
+
|
|
44
|
+
kid_J_shapedir: np.ndarray
|
|
45
|
+
"""Kid shape blend shape for joints, shape (num_joints, 3)."""
|
|
46
|
+
|
|
47
|
+
weights: np.ndarray
|
|
48
|
+
"""Skinning weights, shape (num_vertices, num_joints)."""
|
|
49
|
+
|
|
50
|
+
# Non-tensor data (metadata)
|
|
51
|
+
kintree_parents: list[int]
|
|
52
|
+
"""Parent joint indices for kinematic tree."""
|
|
53
|
+
|
|
54
|
+
faces: np.ndarray
|
|
55
|
+
"""Face indices, shape (num_faces, 3)."""
|
|
56
|
+
|
|
57
|
+
num_joints: int
|
|
58
|
+
"""Number of joints in the body model."""
|
|
59
|
+
|
|
60
|
+
num_vertices: int
|
|
61
|
+
"""Number of vertices in the body model mesh."""
|
|
62
|
+
|
|
63
|
+
vertex_subset: np.ndarray
|
|
64
|
+
"""Indices of vertices used (for partial models)."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def initialize(
|
|
68
|
+
model_name,
|
|
69
|
+
gender,
|
|
70
|
+
model_root=None,
|
|
71
|
+
num_betas=None,
|
|
72
|
+
vertex_subset_size=None,
|
|
73
|
+
vertex_subset=None,
|
|
74
|
+
faces=None,
|
|
75
|
+
joint_regressor_post_lbs=None,
|
|
76
|
+
):
|
|
77
|
+
if model_root is None:
|
|
78
|
+
body_models_dir = os.getenv('SMPLFITTER_BODY_MODELS')
|
|
79
|
+
if body_models_dir is None:
|
|
80
|
+
data_root = os.getenv('DATA_ROOT', '.')
|
|
81
|
+
body_models_dir = f'{data_root}/body_models'
|
|
82
|
+
model_root = f'{body_models_dir}/{model_name}'
|
|
83
|
+
|
|
84
|
+
with monkey_patched_for_chumpy():
|
|
85
|
+
if model_name == 'smpl':
|
|
86
|
+
gender_str = dict(f='f', m='m', n='neutral')[gender[0]]
|
|
87
|
+
filename = f'basicmodel_{gender_str}_lbs_10_207_0_v1.1.0.pkl'
|
|
88
|
+
elif model_name in ('smplx', 'smplxlh', 'smplxmoyo'):
|
|
89
|
+
gender_str = dict(f='FEMALE', m='MALE', n='NEUTRAL')[gender[0]]
|
|
90
|
+
filename = f'SMPLX_{gender_str}.npz'
|
|
91
|
+
elif model_name == 'smplh':
|
|
92
|
+
gender_str = dict(f='female', m='male')[gender[0]]
|
|
93
|
+
filename = f'SMPLH_{gender_str}.pkl'
|
|
94
|
+
elif model_name == 'smplh16':
|
|
95
|
+
gender_str = dict(f='female', m='male', n='neutral')[gender[0]]
|
|
96
|
+
filename = osp.join(gender_str, 'model.npz')
|
|
97
|
+
elif model_name == 'mano':
|
|
98
|
+
filename = 'MANO_RIGHT.pkl'
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f'Unknown model name: {model_name}')
|
|
101
|
+
|
|
102
|
+
filepath = osp.join(model_root, filename)
|
|
103
|
+
try:
|
|
104
|
+
if filename.endswith('.npz'):
|
|
105
|
+
smpl_data = np.load(filepath)
|
|
106
|
+
else:
|
|
107
|
+
with open(filepath, 'rb') as f:
|
|
108
|
+
smpl_data = pickle.load(f, encoding='latin1')
|
|
109
|
+
except FileNotFoundError:
|
|
110
|
+
raise FileNotFoundError(
|
|
111
|
+
f'Body model file not found: {filepath}\n\n'
|
|
112
|
+
f'Set the body model location using one of:\n'
|
|
113
|
+
f" 1. BodyModel('{model_name}', '{gender}', "
|
|
114
|
+
f"model_root='/your/path/body_models/{model_name}')\n"
|
|
115
|
+
f' 2. export SMPLFITTER_BODY_MODELS=/your/path/body_models\n'
|
|
116
|
+
f' 3. export DATA_ROOT=/your/path '
|
|
117
|
+
f'(looks for $DATA_ROOT/body_models/)\n\n'
|
|
118
|
+
f'Download models: python -m smplfitter.download\n'
|
|
119
|
+
f'Register first at https://smpl.is.tue.mpg.de/'
|
|
120
|
+
) from None
|
|
121
|
+
|
|
122
|
+
res = {}
|
|
123
|
+
res['shapedirs'] = np.array(smpl_data['shapedirs'], dtype=np.float64)
|
|
124
|
+
res['posedirs'] = np.array(smpl_data['posedirs'], dtype=np.float64)
|
|
125
|
+
res['v_template'] = np.array(smpl_data['v_template'], dtype=np.float64)
|
|
126
|
+
|
|
127
|
+
if not isinstance(smpl_data['J_regressor'], np.ndarray):
|
|
128
|
+
res['J_regressor'] = np.array(smpl_data['J_regressor'].toarray(), dtype=np.float64)
|
|
129
|
+
else:
|
|
130
|
+
res['J_regressor'] = smpl_data['J_regressor'].astype(np.float64)
|
|
131
|
+
|
|
132
|
+
res['weights'] = np.array(smpl_data['weights'])
|
|
133
|
+
res['faces'] = np.array(smpl_data['f'].astype(np.int32))
|
|
134
|
+
res['kintree_parents'] = smpl_data['kintree_table'][0].tolist()
|
|
135
|
+
res['num_joints'] = len(res['kintree_parents'])
|
|
136
|
+
res['num_vertices'] = len(res['v_template'])
|
|
137
|
+
|
|
138
|
+
# Kid model has an additional shape parameter which pulls the mesh towards the SMIL mean
|
|
139
|
+
# template
|
|
140
|
+
if model_name.lower().startswith('smpl'):
|
|
141
|
+
kid_path = os.path.join(model_root, 'kid_template.npy')
|
|
142
|
+
try:
|
|
143
|
+
v_template_smil = np.load(kid_path).astype(np.float64)
|
|
144
|
+
except FileNotFoundError:
|
|
145
|
+
raise FileNotFoundError(
|
|
146
|
+
f'Kid template not found: {kid_path}\n'
|
|
147
|
+
f'Download it: python -m smplfitter.download'
|
|
148
|
+
) from None
|
|
149
|
+
res['kid_shapedir'] = (
|
|
150
|
+
v_template_smil - np.mean(v_template_smil, axis=0) - res['v_template']
|
|
151
|
+
)
|
|
152
|
+
res['kid_J_shapedir'] = res['J_regressor'] @ res['kid_shapedir']
|
|
153
|
+
else:
|
|
154
|
+
res['kid_shapedir'] = np.zeros_like(res['v_template'])
|
|
155
|
+
res['kid_J_shapedir'] = np.zeros((res['num_joints'], 3))
|
|
156
|
+
|
|
157
|
+
if 'J_shapedirs' in smpl_data:
|
|
158
|
+
res['J_shapedirs'] = np.array(smpl_data['J_shapedirs'], dtype=np.float64)
|
|
159
|
+
else:
|
|
160
|
+
res['J_shapedirs'] = np.einsum('jv,vcs->jcs', res['J_regressor'], res['shapedirs'])
|
|
161
|
+
|
|
162
|
+
if 'J_template' in smpl_data:
|
|
163
|
+
res['J_template'] = np.array(smpl_data['J_template'], dtype=np.float64)
|
|
164
|
+
else:
|
|
165
|
+
res['J_template'] = res['J_regressor'] @ res['v_template']
|
|
166
|
+
|
|
167
|
+
res['v_template'] = res['v_template'] - np.einsum(
|
|
168
|
+
'vcx,x->vc',
|
|
169
|
+
res['posedirs'],
|
|
170
|
+
np.reshape(np.tile(np.eye(3, dtype=np.float64), [res['num_joints'] - 1, 1]), [-1]),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if vertex_subset_size is not None:
|
|
174
|
+
vertex_subset_dict = np.load(f'{model_root}/vertex_subset_{vertex_subset_size}.npz')
|
|
175
|
+
vertex_subset = vertex_subset_dict['i_verts']
|
|
176
|
+
faces = vertex_subset_dict['faces']
|
|
177
|
+
joint_regressor_post_lbs = np.load(
|
|
178
|
+
f'{model_root}/vertex_subset_joint_regr_post_lbs_{vertex_subset_size}.npy'
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if vertex_subset is None:
|
|
182
|
+
vertex_subset = np.arange(res['num_vertices'], dtype=np.int64)
|
|
183
|
+
else:
|
|
184
|
+
vertex_subset = np.array(vertex_subset, dtype=np.int64)
|
|
185
|
+
|
|
186
|
+
if faces is None:
|
|
187
|
+
faces = res['faces']
|
|
188
|
+
|
|
189
|
+
if joint_regressor_post_lbs is None:
|
|
190
|
+
joint_regressor_post_lbs = res['J_regressor']
|
|
191
|
+
|
|
192
|
+
return ModelData(
|
|
193
|
+
v_template=res['v_template'][vertex_subset],
|
|
194
|
+
shapedirs=res['shapedirs'][vertex_subset, :, :num_betas],
|
|
195
|
+
posedirs=res['posedirs'][vertex_subset],
|
|
196
|
+
J_regressor_post_lbs=joint_regressor_post_lbs,
|
|
197
|
+
J_template=res['J_template'],
|
|
198
|
+
J_shapedirs=res['J_shapedirs'][:, :, :num_betas],
|
|
199
|
+
kid_shapedir=res['kid_shapedir'][vertex_subset],
|
|
200
|
+
kid_J_shapedir=res['kid_J_shapedir'],
|
|
201
|
+
weights=res['weights'][vertex_subset],
|
|
202
|
+
kintree_parents=res['kintree_parents'],
|
|
203
|
+
faces=faces,
|
|
204
|
+
num_joints=res['num_joints'],
|
|
205
|
+
num_vertices=len(vertex_subset),
|
|
206
|
+
vertex_subset=vertex_subset,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@contextlib.contextmanager
|
|
211
|
+
def monkey_patched_for_chumpy():
|
|
212
|
+
"""The pickle file of SMPLH imports chumpy and it tries to import np.bool etc which are
|
|
213
|
+
not available anymore.
|
|
214
|
+
"""
|
|
215
|
+
added = []
|
|
216
|
+
for name in ['bool', 'int', 'object', 'str']:
|
|
217
|
+
if name not in dir(np):
|
|
218
|
+
try:
|
|
219
|
+
sys.modules[f'numpy.{name}'] = getattr(np, name + '_')
|
|
220
|
+
added.append(name)
|
|
221
|
+
except AttributeError:
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
sys.modules['numpy.float'] = float
|
|
225
|
+
sys.modules['numpy.complex'] = np.complex128
|
|
226
|
+
sys.modules['numpy.NINF'] = -np.inf
|
|
227
|
+
np.NINF = -np.inf # type: ignore[misc]
|
|
228
|
+
np.complex = np.complex128 # type: ignore[misc]
|
|
229
|
+
np.float = float # type: ignore[misc]
|
|
230
|
+
|
|
231
|
+
if 'unicode' not in dir(np):
|
|
232
|
+
sys.modules['numpy.unicode'] = np.str_
|
|
233
|
+
added.append('unicode')
|
|
234
|
+
|
|
235
|
+
import inspect
|
|
236
|
+
|
|
237
|
+
added_getargspec = False
|
|
238
|
+
if not hasattr(inspect, 'getargspec'):
|
|
239
|
+
inspect.getargspec = inspect.getfullargspec
|
|
240
|
+
added_getargspec = True
|
|
241
|
+
|
|
242
|
+
with warnings.catch_warnings():
|
|
243
|
+
warnings.simplefilter('ignore', FutureWarning)
|
|
244
|
+
yield
|
|
245
|
+
|
|
246
|
+
for name in added:
|
|
247
|
+
del sys.modules[f'numpy.{name}']
|
|
248
|
+
|
|
249
|
+
if added_getargspec:
|
|
250
|
+
del inspect.getargspec
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from .. import np as smplfitter_np
|
|
6
|
+
import trimesh
|
|
7
|
+
from scipy.optimize import linear_sum_assignment
|
|
8
|
+
from scipy.spatial.distance import cdist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
|
|
13
|
+
for model_name in ['smpl', 'smplx']:
|
|
14
|
+
model_root = f'{DATA_ROOT}/body_models/{model_name}'
|
|
15
|
+
bm = smplfitter_np.get_cached_body_model(model_name)
|
|
16
|
+
verts = bm.single()['vertices']
|
|
17
|
+
|
|
18
|
+
for n in [32, 64, 128, 256, 512, 1024]:
|
|
19
|
+
i_verts, faces = decimate(verts, bm.faces, n)
|
|
20
|
+
if len(i_verts) != n:
|
|
21
|
+
print(f'Failed to decimate to {n} vertices')
|
|
22
|
+
continue
|
|
23
|
+
else:
|
|
24
|
+
print(f'Decimated to {n} vertices')
|
|
25
|
+
np.savez(f'{model_root}/vertex_subset_{n}.npz', i_verts=i_verts, faces=faces)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _decimate(verts, faces, n_verts_out=128):
|
|
29
|
+
n_faces = 2 * n_verts_out - 4
|
|
30
|
+
decimated_mesh = trimesh.Trimesh(verts, faces).simplify_quadric_decimation(face_count=n_faces)
|
|
31
|
+
row_ind, col_ind = linear_sum_assignment(cdist(verts, decimated_mesh.vertices))
|
|
32
|
+
i_verts = row_ind[np.argsort(col_ind)]
|
|
33
|
+
return i_verts, decimated_mesh.faces
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def decimate(verts, faces, n_verts_out=128, n_trials=100):
|
|
37
|
+
# We may have to try it multiple times to get the desired number of vertices
|
|
38
|
+
# since the decimation algorithm may not always return the exact number of vertices we ask.
|
|
39
|
+
# The documentation explicitly states this.
|
|
40
|
+
|
|
41
|
+
n_verts_out_arg = n_verts_out
|
|
42
|
+
i_verts, new_faces = _decimate(verts, faces, n_verts_out_arg)
|
|
43
|
+
|
|
44
|
+
for _ in range(n_trials):
|
|
45
|
+
if i_verts.shape[0] == n_verts_out:
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
n_verts_out_arg += n_verts_out - i_verts.shape[0]
|
|
49
|
+
i_verts, new_faces = _decimate(verts, faces, n_verts_out_arg)
|
|
50
|
+
|
|
51
|
+
return i_verts, new_faces
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == '__main__':
|
|
55
|
+
main()
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
# This module finds joint regressors for subsets of body model vertices.
|
|
3
|
+
# That is, it finds a linear regressor that maps a subset of vertices to the joints.
|
|
4
|
+
# And this is done for the post-LBS case, i.e. the vertices and joints are posed with linear
|
|
5
|
+
# blend skinning.
|
|
6
|
+
# This is useful in body model fitting in case the user does not provide the joint positions.
|
|
7
|
+
# In that case we must run a regressor on the (already posed) vertices.
|
|
8
|
+
#
|
|
9
|
+
# The joint regressors need to create convex combinations, i.e. the weights must be non-negative
|
|
10
|
+
# and sum to 1 for each joint. They should also be sparse and spatially compact.
|
|
11
|
+
# That is, each joint should only depend on a few vertices that are spatially close to each other.
|
|
12
|
+
# This is achieved by adding a regularization term that encourages the weights to be
|
|
13
|
+
# sparse (L1/2 norm), and a term that computes the weighted variance of the template vertices.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from .. import pt as smplfitter_pt
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import torch.optim as optim
|
|
22
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
26
|
+
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
|
|
27
|
+
torch.backends.cudnn.benchmark = True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main():
|
|
31
|
+
for body_model_name in ['smpl', 'smplx']:
|
|
32
|
+
canonical_verts = torch.from_numpy(
|
|
33
|
+
np.load(f'{DATA_ROOT}/nlf/canonical_vertices_{body_model_name}.npy')
|
|
34
|
+
).float()
|
|
35
|
+
body_model = smplfitter_pt.BodyModel(body_model_name).to(DEVICE)
|
|
36
|
+
dataset = RandomBodyParamDataset(num_betas=16, num_joints=body_model.num_joints)
|
|
37
|
+
dataloader = DataLoader(dataset, batch_size=None, num_workers=4, pin_memory=True)
|
|
38
|
+
|
|
39
|
+
for n_verts_subset in reversed(
|
|
40
|
+
[32, 64, 128, 256, 512, 1024, 2048, 4096, body_model.num_vertices]
|
|
41
|
+
):
|
|
42
|
+
print(f'Fitting joint regressor for {n_verts_subset} vertices')
|
|
43
|
+
out_path = (
|
|
44
|
+
f'{DATA_ROOT}/body_models/{body_model_name}/'
|
|
45
|
+
f'vertex_subset_joint_regr_post_lbs_{n_verts_subset}.npy'
|
|
46
|
+
)
|
|
47
|
+
# if os.path.exists(out_path):
|
|
48
|
+
# continue
|
|
49
|
+
|
|
50
|
+
if n_verts_subset == body_model.num_vertices:
|
|
51
|
+
i_verts = torch.arange(body_model.num_vertices)
|
|
52
|
+
else:
|
|
53
|
+
subset = np.load(
|
|
54
|
+
f'{DATA_ROOT}/body_models/{body_model_name}/vertex_subset_{n_verts_subset}.npz'
|
|
55
|
+
)
|
|
56
|
+
i_verts = torch.from_numpy(subset['i_verts'])
|
|
57
|
+
|
|
58
|
+
model = ConvexCombiningRegressor(len(i_verts), body_model.num_joints).to(DEVICE)
|
|
59
|
+
trainer = ConvexCombiningRegressorTrainer(
|
|
60
|
+
model=model,
|
|
61
|
+
body_model=body_model,
|
|
62
|
+
template_verts=canonical_verts[i_verts],
|
|
63
|
+
vertex_subset=i_verts,
|
|
64
|
+
regul_lambda=6e-5,
|
|
65
|
+
)
|
|
66
|
+
optimizer = optim.Adam(model.parameters(), lr=1e0)
|
|
67
|
+
scheduler = optim.lr_scheduler.LambdaLR(
|
|
68
|
+
optimizer, lr_lambda=lambda step: 1.0 if step < int(37500 * 0.9) else 1e-3
|
|
69
|
+
)
|
|
70
|
+
trainer.train_loop(
|
|
71
|
+
dataloader, total_steps=37500, optimizer=optimizer, scheduler=scheduler
|
|
72
|
+
)
|
|
73
|
+
model.threshold_for_sparsity(1e-3)
|
|
74
|
+
|
|
75
|
+
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 30)
|
|
76
|
+
trainer.train_loop(
|
|
77
|
+
dataloader, total_steps=37500 + 12500, optimizer=optimizer, scheduler=scheduler
|
|
78
|
+
)
|
|
79
|
+
model.threshold_for_sparsity(1e-3)
|
|
80
|
+
|
|
81
|
+
J_subset = model.get_w().cpu().detach().numpy().T
|
|
82
|
+
print(f'Sparsity ratio: {sparsity_ratio(J_subset)}')
|
|
83
|
+
np.save(out_path, J_subset)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ConvexCombiningRegressor(nn.Module):
|
|
87
|
+
def __init__(self, n_in_points, n_out_points):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.n_in = n_in_points
|
|
90
|
+
self.n_out = n_out_points
|
|
91
|
+
|
|
92
|
+
self.weight_mask = nn.Parameter(torch.ones(n_in_points, n_out_points), requires_grad=False)
|
|
93
|
+
self.weights = nn.Parameter(torch.empty(n_in_points, n_out_points))
|
|
94
|
+
nn.init.uniform_(self.weights, -1, 1)
|
|
95
|
+
|
|
96
|
+
def forward(self, x):
|
|
97
|
+
w = self.get_w()
|
|
98
|
+
return torch.einsum('bjc,jJ->bJc', x, w)
|
|
99
|
+
|
|
100
|
+
def get_w(self):
|
|
101
|
+
w = torch.nn.functional.softplus(self.weights) * self.weight_mask
|
|
102
|
+
return w / w.sum(dim=0, keepdim=True)
|
|
103
|
+
|
|
104
|
+
def threshold_for_sparsity(self, threshold=1e-3):
|
|
105
|
+
with torch.no_grad():
|
|
106
|
+
self.weight_mask.data = (torch.abs(self.get_w()) > threshold).float()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ConvexCombiningRegressorTrainer:
|
|
110
|
+
def __init__(self, model, body_model, template_verts, vertex_subset, regul_lambda=3e-5):
|
|
111
|
+
self.model = model
|
|
112
|
+
self.body_model = body_model
|
|
113
|
+
self.template_verts = template_verts.to(DEVICE)
|
|
114
|
+
self.vertex_subset = vertex_subset
|
|
115
|
+
self.regul_lambda = regul_lambda
|
|
116
|
+
self.current_step = 0
|
|
117
|
+
|
|
118
|
+
def train_loop(self, dataloader, total_steps, optimizer, scheduler=None):
|
|
119
|
+
progress_bar = tqdm(
|
|
120
|
+
total=total_steps,
|
|
121
|
+
desc='Training',
|
|
122
|
+
unit='step',
|
|
123
|
+
dynamic_ncols=True,
|
|
124
|
+
initial=self.current_step,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
for batch in dataloader:
|
|
128
|
+
batch = EasyDict({k: v.to(DEVICE) for k, v in batch.items()})
|
|
129
|
+
need_metrics = self.current_step % 100 == 0
|
|
130
|
+
losses, metrics = self.train_step(
|
|
131
|
+
batch, optimizer, scheduler, need_metrics=need_metrics
|
|
132
|
+
)
|
|
133
|
+
self.current_step += 1
|
|
134
|
+
progress_bar.update(1)
|
|
135
|
+
|
|
136
|
+
if need_metrics:
|
|
137
|
+
losses_and_metrics_str = ', '.join(
|
|
138
|
+
f'{k}: {v.item():.4f}' for k, v in (losses | metrics).items()
|
|
139
|
+
)
|
|
140
|
+
progress_bar.set_postfix_str(losses_and_metrics_str)
|
|
141
|
+
|
|
142
|
+
if self.current_step >= total_steps:
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
def forward_train(self, inps):
|
|
146
|
+
preds = EasyDict()
|
|
147
|
+
r = self.body_model(inps.pose, inps.shape)
|
|
148
|
+
preds.postjoints = self.body_model.J_regressor_post_lbs @ r['vertices']
|
|
149
|
+
preds.pose3d = self.model(r['vertices'][:, self.vertex_subset])
|
|
150
|
+
preds.pose3d_gt = r['joints']
|
|
151
|
+
return preds
|
|
152
|
+
|
|
153
|
+
def train_step(self, inps, optimizer, scheduler=None, need_metrics=True):
|
|
154
|
+
self.model.train()
|
|
155
|
+
optimizer.zero_grad()
|
|
156
|
+
preds = self.forward_train(inps)
|
|
157
|
+
losses = self.compute_losses(inps, preds)
|
|
158
|
+
losses.loss.backward()
|
|
159
|
+
optimizer.step()
|
|
160
|
+
if scheduler is not None:
|
|
161
|
+
scheduler.step()
|
|
162
|
+
|
|
163
|
+
if need_metrics:
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
metrics = self.compute_metrics(inps, preds)
|
|
166
|
+
else:
|
|
167
|
+
metrics = None
|
|
168
|
+
|
|
169
|
+
return losses, metrics
|
|
170
|
+
|
|
171
|
+
def compute_losses(self, inps, preds):
|
|
172
|
+
losses = EasyDict()
|
|
173
|
+
losses.main_loss = torch.mean(torch.abs(preds.pose3d_gt - preds.pose3d))
|
|
174
|
+
|
|
175
|
+
w = self.model.get_w()
|
|
176
|
+
losses.regul = torch.sum(soft_sqrt(torch.abs(w), 1e-5)) / w.shape[1]
|
|
177
|
+
# losses.supp = mean_spatial_support(self.template_verts, w)
|
|
178
|
+
losses.loss = (
|
|
179
|
+
losses.main_loss + self.regul_lambda * losses.regul # + self.supp_lambda * losses.supp
|
|
180
|
+
)
|
|
181
|
+
return losses
|
|
182
|
+
|
|
183
|
+
def compute_metrics(self, inps, preds):
|
|
184
|
+
m = EasyDict()
|
|
185
|
+
dist = torch.norm(preds.pose3d_gt - preds.pose3d, dim=-1)
|
|
186
|
+
m.pck1 = pck(dist, 0.01)
|
|
187
|
+
m.pck2 = pck(dist, 0.02)
|
|
188
|
+
m.pck3 = pck(dist, 0.03)
|
|
189
|
+
m.pck7 = pck(dist, 0.07)
|
|
190
|
+
m.euclidean = torch.mean(dist)
|
|
191
|
+
m.l1 = torch.mean(torch.abs(preds.pose3d_gt - preds.pose3d))
|
|
192
|
+
m.max_supp = torch.max(
|
|
193
|
+
torch.sqrt(spatial_support(self.template_verts, self.model.get_w()))
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
dist = torch.norm(preds.pose3d_gt - preds.postjoints, dim=-1)
|
|
197
|
+
m.pck1_post = pck(dist, 0.01)
|
|
198
|
+
return m
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class RandomBodyParamDataset(IterableDataset):
|
|
202
|
+
def __init__(self, batch_size=72, num_joints=24, num_betas=16):
|
|
203
|
+
self.batch_size = batch_size
|
|
204
|
+
self.num_joints = num_joints
|
|
205
|
+
self.num_betas = num_betas
|
|
206
|
+
self.rng = torch.Generator()
|
|
207
|
+
|
|
208
|
+
def __iter__(self):
|
|
209
|
+
with torch.no_grad():
|
|
210
|
+
while True:
|
|
211
|
+
pose = self.random_smpl_pose() / 3
|
|
212
|
+
shape = torch.empty((self.batch_size, self.num_betas))
|
|
213
|
+
nn.init.trunc_normal_(shape, mean=0.0, std=1.0, a=-2.0, b=2.0, generator=self.rng)
|
|
214
|
+
yield {'pose': pose, 'shape': shape * 10}
|
|
215
|
+
|
|
216
|
+
def random_rotvec(self, batch_size):
|
|
217
|
+
rand = torch.rand(3, batch_size, generator=self.rng)
|
|
218
|
+
r1 = torch.sqrt(1 - rand[0])
|
|
219
|
+
r2 = torch.sqrt(rand[0])
|
|
220
|
+
t1 = 2 * np.pi * rand[1]
|
|
221
|
+
t2 = 2 * np.pi * rand[2]
|
|
222
|
+
cost2 = torch.cos(t2)
|
|
223
|
+
xyz = torch.stack([torch.sin(t1) * r1, torch.cos(t1) * r1, torch.sin(t2) * r2], dim=0)
|
|
224
|
+
return (xyz / torch.sqrt(1 - cost2**2 * rand[0]) * 2 * torch.acos(cost2 * r2)).T
|
|
225
|
+
|
|
226
|
+
def random_smpl_pose(self):
|
|
227
|
+
return self.random_rotvec(self.batch_size * self.num_joints).reshape(
|
|
228
|
+
self.batch_size, self.num_joints * 3
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def sparsity_ratio(J):
|
|
233
|
+
return np.count_nonzero(np.abs(J) > 1e-4) / J.shape[0]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def soft_sqrt(x, eps):
|
|
237
|
+
return x / torch.sqrt(x + eps)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def spatial_support(template, weights):
|
|
241
|
+
weighted_mean = torch.matmul(weights.t(), template)
|
|
242
|
+
sq_diff = (template[np.newaxis, :] - weighted_mean[:, np.newaxis]) ** 2
|
|
243
|
+
sq_dists = torch.sum(sq_diff, dim=-1)
|
|
244
|
+
return torch.einsum('Jj,jJ->J', sq_dists, torch.abs(weights))
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def pck(x, t):
|
|
248
|
+
return (x <= t).float().mean()
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class EasyDict(dict):
|
|
252
|
+
def __init__(self, d=None, **kwargs):
|
|
253
|
+
super().__init__()
|
|
254
|
+
if d is None:
|
|
255
|
+
d = {}
|
|
256
|
+
else:
|
|
257
|
+
d = dict(d)
|
|
258
|
+
if kwargs:
|
|
259
|
+
d.update(**kwargs)
|
|
260
|
+
for k, v in d.items():
|
|
261
|
+
setattr(self, k, v)
|
|
262
|
+
# Class attributes
|
|
263
|
+
for k in self.__class__.__dict__.keys():
|
|
264
|
+
if not (k.startswith('__') and k.endswith('__')) and k not in ('update', 'pop'):
|
|
265
|
+
setattr(self, k, getattr(self, k))
|
|
266
|
+
|
|
267
|
+
def __setattr__(self, name, value):
|
|
268
|
+
super().__setitem__(name, value)
|
|
269
|
+
|
|
270
|
+
def __getattr__(self, name):
|
|
271
|
+
try:
|
|
272
|
+
return super().__getitem__(name)
|
|
273
|
+
except KeyError:
|
|
274
|
+
raise AttributeError(name)
|
|
275
|
+
|
|
276
|
+
__setitem__ = __setattr__
|
|
277
|
+
|
|
278
|
+
def update(self, e=None, **f):
|
|
279
|
+
d = e or dict()
|
|
280
|
+
d.update(f)
|
|
281
|
+
for k, v in d.items():
|
|
282
|
+
setattr(self, k, v)
|
|
283
|
+
|
|
284
|
+
def pop(self, k, d=None):
|
|
285
|
+
delattr(self, k)
|
|
286
|
+
return super().pop(k, d)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
if __name__ == '__main__':
|
|
290
|
+
main()
|