warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/build.py
CHANGED
|
@@ -39,6 +39,7 @@ def build_cuda(
|
|
|
39
39
|
fast_math=False,
|
|
40
40
|
fuse_fp=True,
|
|
41
41
|
lineinfo=False,
|
|
42
|
+
compile_time_trace=False,
|
|
42
43
|
ltoirs=None,
|
|
43
44
|
fatbins=None,
|
|
44
45
|
) -> None:
|
|
@@ -79,6 +80,7 @@ def build_cuda(
|
|
|
79
80
|
fast_math,
|
|
80
81
|
fuse_fp,
|
|
81
82
|
lineinfo,
|
|
83
|
+
compile_time_trace,
|
|
82
84
|
output_path,
|
|
83
85
|
num_link,
|
|
84
86
|
arr_link,
|
|
@@ -223,7 +225,7 @@ def get_cached_lto(path):
|
|
|
223
225
|
|
|
224
226
|
def get_cached_lto_meta(path, symbol):
|
|
225
227
|
if os.path.exists(path):
|
|
226
|
-
with open(path
|
|
228
|
+
with open(path) as f:
|
|
227
229
|
keys = json.load(f)
|
|
228
230
|
value = keys[symbol]
|
|
229
231
|
return value
|
|
@@ -231,9 +233,106 @@ def get_cached_lto_meta(path, symbol):
|
|
|
231
233
|
return None
|
|
232
234
|
|
|
233
235
|
|
|
236
|
+
def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
|
|
237
|
+
"""Generic LTO build function that handles caching, file operations and process management.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
lto_symbol: Unique identifier for the LTO operation
|
|
241
|
+
compile_func: Function to compile the specific LTO
|
|
242
|
+
(receives a dictionary of build paths)
|
|
243
|
+
builder: Builder object to store results
|
|
244
|
+
extra_files: Dictionary of additional file types to handle (e.g.,
|
|
245
|
+
{".meta": None, ".fatbin": None}). Values are the functions to get
|
|
246
|
+
the cached file data.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple containing lto_code_data followed by any extra data from extra_files
|
|
250
|
+
"""
|
|
251
|
+
if extra_files is None:
|
|
252
|
+
extra_files = {}
|
|
253
|
+
|
|
254
|
+
# Hash symbol and set up paths
|
|
255
|
+
h = hash_symbol(lto_symbol)
|
|
256
|
+
lto_dir = get_lto_cache_dir()
|
|
257
|
+
lto_name = f"{h[:7]}.lto"
|
|
258
|
+
lto_path = os.path.join(lto_dir, lto_name)
|
|
259
|
+
|
|
260
|
+
# Set up paths for extra files
|
|
261
|
+
file_paths = {".lto": lto_path}
|
|
262
|
+
temp_file_paths = {}
|
|
263
|
+
|
|
264
|
+
for ext, _ in extra_files.items():
|
|
265
|
+
name = f"{h[:7]}{ext}"
|
|
266
|
+
file_paths[ext] = os.path.join(lto_dir, name)
|
|
267
|
+
|
|
268
|
+
# Check if already built but not cached
|
|
269
|
+
lto_code_data = get_cached_lto(lto_path)
|
|
270
|
+
if lto_code_data is not None:
|
|
271
|
+
# Get the cached data for the extra files and early return
|
|
272
|
+
all_files_cached = True
|
|
273
|
+
for ext, getter in extra_files.items():
|
|
274
|
+
if getter and os.path.exists(file_paths[ext]):
|
|
275
|
+
cached_data = getter(file_paths[ext])
|
|
276
|
+
if cached_data is None:
|
|
277
|
+
all_files_cached = False
|
|
278
|
+
break
|
|
279
|
+
extra_files[ext] = cached_data
|
|
280
|
+
elif getter: # If there's a getter but file doesn't exist
|
|
281
|
+
all_files_cached = False
|
|
282
|
+
break
|
|
283
|
+
|
|
284
|
+
if all_files_cached:
|
|
285
|
+
if not extra_files:
|
|
286
|
+
return (lto_code_data,)
|
|
287
|
+
else:
|
|
288
|
+
return (lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
|
|
289
|
+
|
|
290
|
+
# Create process-dependent temporary build directory
|
|
291
|
+
build_dir = f"{lto_dir}_p{os.getpid()}"
|
|
292
|
+
Path(build_dir).mkdir(parents=True, exist_ok=True)
|
|
293
|
+
|
|
294
|
+
# Set up temporary paths for the build outputs
|
|
295
|
+
for ext, path in file_paths.items():
|
|
296
|
+
temp_file_paths[ext] = os.path.join(build_dir, os.path.basename(path))
|
|
297
|
+
|
|
298
|
+
# Compile LTO with the specialized function
|
|
299
|
+
result, outputs = compile_func(temp_file_paths)
|
|
300
|
+
|
|
301
|
+
if not result:
|
|
302
|
+
# Clean up and fail
|
|
303
|
+
for path in temp_file_paths.values():
|
|
304
|
+
if Path(path).exists():
|
|
305
|
+
Path(path).unlink()
|
|
306
|
+
raise RuntimeError(f"Failed to compile {lto_symbol}")
|
|
307
|
+
|
|
308
|
+
# Move outputs to cache
|
|
309
|
+
safe_rename(build_dir, lto_dir)
|
|
310
|
+
|
|
311
|
+
# If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
|
|
312
|
+
if os.path.exists(lto_dir):
|
|
313
|
+
for ext, path in file_paths.items():
|
|
314
|
+
if not os.path.exists(path):
|
|
315
|
+
try:
|
|
316
|
+
# copy output file to the destination lto dir
|
|
317
|
+
os.rename(temp_file_paths[ext], path)
|
|
318
|
+
except (OSError, FileExistsError):
|
|
319
|
+
# another process likely updated the lto dir first
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
# Clean up the temporary build directory
|
|
323
|
+
if build_dir:
|
|
324
|
+
import shutil
|
|
325
|
+
|
|
326
|
+
shutil.rmtree(build_dir, ignore_errors=True)
|
|
327
|
+
|
|
328
|
+
if not extra_files:
|
|
329
|
+
return (outputs[".lto"],)
|
|
330
|
+
else:
|
|
331
|
+
return (outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
|
|
332
|
+
|
|
333
|
+
|
|
234
334
|
def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
|
|
235
|
-
|
|
236
|
-
arch = min(arch, 90)
|
|
335
|
+
arch = 120 if arch > 121 else arch
|
|
237
336
|
|
|
238
337
|
# Maps Python/Warp types to C++ types and enums
|
|
239
338
|
def cublasdx_type_map(dtype):
|
|
@@ -266,292 +365,182 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
|
|
|
266
365
|
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
267
366
|
|
|
268
367
|
if a_type != b_type or a_type != c_type:
|
|
269
|
-
raise TypeError("
|
|
368
|
+
raise TypeError("tile_matmul(A, B, C) requires all inputs to be real or complex")
|
|
270
369
|
|
|
271
370
|
element_type = a_type
|
|
272
371
|
|
|
273
372
|
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
274
373
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
374
|
+
def compile_lto_dot(temp_paths):
|
|
375
|
+
result = warp.context.runtime.core.cuda_compile_dot(
|
|
376
|
+
temp_paths[".lto"].encode("utf-8"),
|
|
377
|
+
lto_symbol.encode("utf-8"),
|
|
378
|
+
0,
|
|
379
|
+
None,
|
|
380
|
+
None,
|
|
381
|
+
arch,
|
|
382
|
+
M,
|
|
383
|
+
N,
|
|
384
|
+
K,
|
|
385
|
+
a_prec,
|
|
386
|
+
b_prec,
|
|
387
|
+
c_prec,
|
|
388
|
+
element_type,
|
|
389
|
+
a_arrangement,
|
|
390
|
+
b_arrangement,
|
|
391
|
+
c_arrangement,
|
|
392
|
+
num_threads,
|
|
393
|
+
)
|
|
281
394
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
395
|
+
if result:
|
|
396
|
+
with open(temp_paths[".lto"], "rb") as f:
|
|
397
|
+
lto_code_data = f.read()
|
|
398
|
+
return True, {".lto": lto_code_data}
|
|
399
|
+
return False, {}
|
|
285
400
|
|
|
286
|
-
#
|
|
287
|
-
|
|
401
|
+
# Early out if already cached in module
|
|
402
|
+
if lto_symbol in builder.ltoirs:
|
|
403
|
+
lto_code_data = builder.ltoirs[lto_symbol]
|
|
404
|
+
else:
|
|
405
|
+
(lto_code_data,) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
|
|
288
406
|
|
|
289
|
-
|
|
407
|
+
# Update builder
|
|
290
408
|
builder.ltoirs[lto_symbol] = lto_code_data
|
|
291
409
|
builder.ltoirs_decl[lto_symbol] = (
|
|
292
|
-
f"void {lto_symbol}({c_dtype}
|
|
410
|
+
f"void {lto_symbol}({c_dtype}*, {a_dtype}*, {b_dtype}*, {c_dtype}*, {c_dtype}*);"
|
|
293
411
|
)
|
|
294
412
|
|
|
295
|
-
return lto_symbol, lto_code_data
|
|
296
|
-
|
|
297
|
-
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
298
|
-
build_dir = f"{lto_dir}_p{os.getpid()}"
|
|
299
|
-
|
|
300
|
-
# dir may exist from previous attempts / runs / archs
|
|
301
|
-
Path(build_dir).mkdir(parents=True, exist_ok=True)
|
|
302
|
-
|
|
303
|
-
# temporary path to compile to in build_dir
|
|
304
|
-
temp_lto_path = os.path.join(build_dir, lto_name)
|
|
305
|
-
|
|
306
|
-
# compile LTO
|
|
307
|
-
result = warp.context.runtime.core.cuda_compile_dot(
|
|
308
|
-
temp_lto_path.encode("utf-8"),
|
|
309
|
-
lto_symbol.encode("utf-8"),
|
|
310
|
-
0,
|
|
311
|
-
None,
|
|
312
|
-
None,
|
|
313
|
-
arch,
|
|
314
|
-
M,
|
|
315
|
-
N,
|
|
316
|
-
K,
|
|
317
|
-
a_prec,
|
|
318
|
-
b_prec,
|
|
319
|
-
c_prec,
|
|
320
|
-
element_type,
|
|
321
|
-
a_arrangement,
|
|
322
|
-
b_arrangement,
|
|
323
|
-
c_arrangement,
|
|
324
|
-
num_threads,
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
if not result:
|
|
328
|
-
if Path(temp_lto_path).exists():
|
|
329
|
-
Path(temp_lto_path).unlink()
|
|
330
|
-
raise RuntimeError("Failed to compile tile_matmul")
|
|
331
|
-
else:
|
|
332
|
-
with open(temp_lto_path, "rb") as f:
|
|
333
|
-
lto_code_data = f.read()
|
|
334
|
-
|
|
335
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
336
|
-
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
|
|
337
|
-
|
|
338
|
-
# try to move process outputs to cache
|
|
339
|
-
safe_rename(build_dir, lto_dir)
|
|
340
|
-
|
|
341
|
-
if os.path.exists(lto_dir):
|
|
342
|
-
if not os.path.exists(lto_path):
|
|
343
|
-
# copy output file to the destination lto dir
|
|
344
|
-
try:
|
|
345
|
-
os.rename(temp_lto_path, lto_path)
|
|
346
|
-
except (OSError, FileExistsError):
|
|
347
|
-
# another process likely updated the lto dir first
|
|
348
|
-
pass
|
|
349
|
-
|
|
350
|
-
if build_dir:
|
|
351
|
-
import shutil
|
|
352
|
-
|
|
353
|
-
# clean up build_dir used for this process
|
|
354
|
-
shutil.rmtree(build_dir, ignore_errors=True)
|
|
355
|
-
|
|
356
413
|
return lto_symbol, lto_code_data
|
|
357
414
|
|
|
358
415
|
|
|
359
|
-
def build_lto_solver(
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
416
|
+
def build_lto_solver(
|
|
417
|
+
M,
|
|
418
|
+
N,
|
|
419
|
+
NRHS,
|
|
420
|
+
solver,
|
|
421
|
+
solver_enum,
|
|
422
|
+
side_enum,
|
|
423
|
+
diag_enum,
|
|
424
|
+
alayout,
|
|
425
|
+
blayout,
|
|
426
|
+
fill_mode,
|
|
427
|
+
arch,
|
|
428
|
+
precision_enum,
|
|
429
|
+
num_threads,
|
|
430
|
+
parameter_list,
|
|
431
|
+
builder,
|
|
432
|
+
):
|
|
433
|
+
arch = 120 if arch > 121 else arch
|
|
434
|
+
|
|
435
|
+
def cusolverdx_arrangement_map(layout):
|
|
436
|
+
if layout == "colmajor":
|
|
437
|
+
return 0 # CUSOLVERDX_ARRANGEMENT_COL_MAJOR
|
|
438
|
+
if layout == "rowmajor":
|
|
439
|
+
return 1 # CUSOLVERDX_ARRANGEMENT_ROW_MAJOR
|
|
440
|
+
raise ValueError("Unsupported layout in tile_matmul")
|
|
372
441
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
442
|
+
a_arrangement = cusolverdx_arrangement_map(alayout)
|
|
443
|
+
b_arrangement = cusolverdx_arrangement_map(blayout)
|
|
444
|
+
|
|
445
|
+
lto_symbol = f"{solver}_{M}_{N}_{NRHS}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{precision_enum}_{side_enum if side_enum >= 0 else 'x'}_{diag_enum if diag_enum >= 0 else 'x'}_{fill_mode}"
|
|
446
|
+
|
|
447
|
+
def compile_lto_solver(temp_paths):
|
|
448
|
+
# compile LTO
|
|
449
|
+
result = warp.context.runtime.core.cuda_compile_solver(
|
|
450
|
+
temp_paths["_fatbin.lto"].encode("utf-8"),
|
|
451
|
+
temp_paths[".lto"].encode("utf-8"),
|
|
452
|
+
lto_symbol.encode("utf-8"),
|
|
453
|
+
0,
|
|
454
|
+
None,
|
|
455
|
+
None,
|
|
456
|
+
arch,
|
|
457
|
+
M,
|
|
458
|
+
N,
|
|
459
|
+
NRHS,
|
|
460
|
+
solver_enum,
|
|
461
|
+
side_enum,
|
|
462
|
+
diag_enum,
|
|
463
|
+
precision_enum,
|
|
464
|
+
a_arrangement,
|
|
465
|
+
b_arrangement,
|
|
466
|
+
fill_mode,
|
|
467
|
+
num_threads,
|
|
468
|
+
)
|
|
376
469
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
470
|
+
if result:
|
|
471
|
+
with open(temp_paths[".lto"], "rb") as f:
|
|
472
|
+
lto_code_data = f.read()
|
|
473
|
+
with open(temp_paths["_fatbin.lto"], "rb") as f:
|
|
474
|
+
universal_fatbin_code_data = f.read()
|
|
475
|
+
return True, {".lto": lto_code_data, "_fatbin.lto": universal_fatbin_code_data}
|
|
476
|
+
return False, {}
|
|
380
477
|
|
|
381
|
-
|
|
382
|
-
|
|
478
|
+
# Early out if already cached in module
|
|
479
|
+
if lto_symbol in builder.ltoirs:
|
|
480
|
+
lto_code_data = builder.ltoirs[lto_symbol]
|
|
481
|
+
else:
|
|
482
|
+
lto_code_data, universal_fatbin_code_data = _build_lto_base(
|
|
483
|
+
lto_symbol, compile_lto_solver, builder, {"_fatbin.lto": get_cached_lto}
|
|
484
|
+
)
|
|
383
485
|
|
|
384
|
-
|
|
385
|
-
if lto_code_data is not None and universal_fatbin_code_data is not None:
|
|
486
|
+
# Update builder
|
|
386
487
|
builder.ltoirs[lto_symbol] = lto_code_data
|
|
387
|
-
builder.ltoirs_decl[lto_symbol] =
|
|
488
|
+
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}{parameter_list};"
|
|
388
489
|
builder.fatbins[lto_symbol] = universal_fatbin_code_data
|
|
389
490
|
|
|
390
|
-
return lto_symbol, lto_code_data
|
|
391
|
-
|
|
392
|
-
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
393
|
-
build_dir = f"{lto_dir}_p{os.getpid()}"
|
|
394
|
-
|
|
395
|
-
# dir may exist from previous attempts / runs / archs
|
|
396
|
-
Path(build_dir).mkdir(parents=True, exist_ok=True)
|
|
397
|
-
|
|
398
|
-
# temporary paths to compile to in build_dir
|
|
399
|
-
temp_lto_path = os.path.join(build_dir, lto_name)
|
|
400
|
-
temp_universal_fatbin_path = os.path.join(build_dir, universal_fatbin_name)
|
|
401
|
-
|
|
402
|
-
# compile LTO
|
|
403
|
-
result = warp.context.runtime.core.cuda_compile_solver(
|
|
404
|
-
temp_universal_fatbin_path.encode("utf-8"),
|
|
405
|
-
temp_lto_path.encode("utf-8"),
|
|
406
|
-
lto_symbol.encode("utf-8"),
|
|
407
|
-
0,
|
|
408
|
-
None,
|
|
409
|
-
None,
|
|
410
|
-
arch,
|
|
411
|
-
M,
|
|
412
|
-
N,
|
|
413
|
-
solver_enum,
|
|
414
|
-
precision_enum,
|
|
415
|
-
fill_mode,
|
|
416
|
-
num_threads,
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
if not result:
|
|
420
|
-
for path in [temp_universal_fatbin_path, temp_lto_path]:
|
|
421
|
-
if Path(path).exists():
|
|
422
|
-
Path(path).unlink()
|
|
423
|
-
raise RuntimeError("Failed to compile tile_cholesky")
|
|
424
|
-
|
|
425
|
-
else:
|
|
426
|
-
with open(temp_lto_path, "rb") as f:
|
|
427
|
-
lto_code_data = f.read()
|
|
428
|
-
with open(temp_universal_fatbin_path, "rb") as f:
|
|
429
|
-
universal_fatbin_code_data = f.read()
|
|
430
|
-
|
|
431
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
432
|
-
builder.ltoirs_decl[lto_symbol] = ltoir_decl
|
|
433
|
-
builder.fatbins[lto_symbol] = universal_fatbin_code_data
|
|
434
|
-
|
|
435
|
-
# try to move process outputs to lto cache
|
|
436
|
-
safe_rename(build_dir, lto_dir)
|
|
437
|
-
|
|
438
|
-
if os.path.exists(lto_dir):
|
|
439
|
-
for p in [(lto_path, temp_lto_path), (universal_fatbin_path, temp_universal_fatbin_path)]:
|
|
440
|
-
path, temp_path = p
|
|
441
|
-
if not os.path.exists(path):
|
|
442
|
-
# copy output file to the destination lto dir
|
|
443
|
-
try:
|
|
444
|
-
os.rename(temp_path, path)
|
|
445
|
-
except (OSError, FileExistsError):
|
|
446
|
-
# another process likely updated the lto dir first
|
|
447
|
-
pass
|
|
448
|
-
|
|
449
|
-
if build_dir:
|
|
450
|
-
import shutil
|
|
451
|
-
|
|
452
|
-
# clean up build_dir used for this process
|
|
453
|
-
shutil.rmtree(build_dir, ignore_errors=True)
|
|
454
|
-
|
|
455
491
|
return lto_symbol, lto_code_data
|
|
456
492
|
|
|
457
493
|
|
|
458
494
|
def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
|
|
459
|
-
|
|
460
|
-
arch = min(arch, 90)
|
|
495
|
+
arch = 120 if arch > 121 else arch
|
|
461
496
|
|
|
462
497
|
lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
|
|
463
498
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
499
|
+
def compile_lto_fft(temp_paths):
|
|
500
|
+
shared_memory_size = ctypes.c_int(0)
|
|
501
|
+
|
|
502
|
+
result = warp.context.runtime.core.cuda_compile_fft(
|
|
503
|
+
temp_paths[".lto"].encode("utf-8"),
|
|
504
|
+
lto_symbol.encode("utf-8"),
|
|
505
|
+
0,
|
|
506
|
+
None,
|
|
507
|
+
None,
|
|
508
|
+
arch,
|
|
509
|
+
size,
|
|
510
|
+
ept,
|
|
511
|
+
dir,
|
|
512
|
+
precision,
|
|
513
|
+
ctypes.byref(shared_memory_size),
|
|
514
|
+
)
|
|
478
515
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
516
|
+
if result:
|
|
517
|
+
with open(temp_paths[".lto"], "rb") as f:
|
|
518
|
+
lto_code_data = f.read()
|
|
482
519
|
|
|
483
|
-
|
|
484
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
485
|
-
builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
|
|
520
|
+
shared_memory_bytes = tile.round_up(shared_memory_size.value)
|
|
486
521
|
|
|
487
|
-
|
|
522
|
+
# output meta file with shared memory requirements for this lto_symbol
|
|
523
|
+
meta = {}
|
|
524
|
+
meta[lto_symbol] = shared_memory_bytes
|
|
488
525
|
|
|
489
|
-
|
|
490
|
-
|
|
526
|
+
with open(temp_paths[".meta"], "w") as meta_file:
|
|
527
|
+
json.dump(meta, meta_file)
|
|
491
528
|
|
|
492
|
-
|
|
493
|
-
Path(build_dir).mkdir(parents=True, exist_ok=True)
|
|
494
|
-
|
|
495
|
-
# temporary paths to compile to in build_dir
|
|
496
|
-
temp_lto_path = os.path.join(build_dir, lto_name)
|
|
497
|
-
temp_meta_path = os.path.join(build_dir, meta_name)
|
|
498
|
-
|
|
499
|
-
# compile LTO
|
|
500
|
-
shared_memory_size = ctypes.c_int(0)
|
|
501
|
-
|
|
502
|
-
result = warp.context.runtime.core.cuda_compile_fft(
|
|
503
|
-
temp_lto_path.encode("utf-8"),
|
|
504
|
-
lto_symbol.encode("utf-8"),
|
|
505
|
-
0,
|
|
506
|
-
None,
|
|
507
|
-
None,
|
|
508
|
-
arch,
|
|
509
|
-
size,
|
|
510
|
-
ept,
|
|
511
|
-
dir,
|
|
512
|
-
precision,
|
|
513
|
-
ctypes.byref(shared_memory_size),
|
|
514
|
-
)
|
|
515
|
-
|
|
516
|
-
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
529
|
+
return True, {".lto": lto_code_data, ".meta": shared_memory_bytes}
|
|
517
530
|
|
|
518
|
-
|
|
519
|
-
if Path(temp_lto_path).exists():
|
|
520
|
-
Path(temp_lto_path).unlink()
|
|
521
|
-
raise RuntimeError("Failed to compile tile_fft")
|
|
531
|
+
return False, {}
|
|
522
532
|
|
|
533
|
+
# Early out if already cached in module
|
|
534
|
+
if lto_symbol in builder.ltoirs and lto_symbol in builder.shared_memory_bytes:
|
|
535
|
+
lto_code_data = builder.ltoirs[lto_symbol]
|
|
536
|
+
shared_memory_bytes = builder.shared_memory_bytes[lto_symbol]
|
|
523
537
|
else:
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
# output meta file with shared memory requirements for this lto_symbol
|
|
528
|
-
meta = {}
|
|
529
|
-
meta[lto_symbol] = shared_memory_bytes
|
|
530
|
-
|
|
531
|
-
with open(temp_meta_path, "w") as meta_file:
|
|
532
|
-
json.dump(meta, meta_file)
|
|
533
|
-
|
|
534
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
535
|
-
builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
|
|
536
|
-
|
|
537
|
-
# try to move process outputs to cache
|
|
538
|
-
safe_rename(build_dir, lto_dir)
|
|
539
|
-
|
|
540
|
-
if os.path.exists(lto_dir):
|
|
541
|
-
for p in [(lto_path, temp_lto_path), (meta_path, temp_meta_path)]:
|
|
542
|
-
path, temp_path = p
|
|
543
|
-
if not os.path.exists(path):
|
|
544
|
-
# copy output file to the destination lto dir
|
|
545
|
-
try:
|
|
546
|
-
os.rename(temp_path, path)
|
|
547
|
-
except (OSError, FileExistsError):
|
|
548
|
-
# another process likely updated the lto dir first
|
|
549
|
-
pass
|
|
550
|
-
|
|
551
|
-
if build_dir:
|
|
552
|
-
import shutil
|
|
538
|
+
lto_code_data, shared_memory_bytes = _build_lto_base(
|
|
539
|
+
lto_symbol, compile_lto_fft, builder, {".meta": lambda path: get_cached_lto_meta(path, lto_symbol)}
|
|
540
|
+
)
|
|
553
541
|
|
|
554
|
-
#
|
|
555
|
-
|
|
542
|
+
# Update builder
|
|
543
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
544
|
+
builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
|
|
556
545
|
|
|
557
546
|
return lto_symbol, lto_code_data, shared_memory_bytes
|