nv-sgl 0.6.0__cp313-cp313-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. include/tevclient.h +393 -0
  2. nv_sgl-0.6.0.dist-info/LICENSE +29 -0
  3. nv_sgl-0.6.0.dist-info/METADATA +21 -0
  4. nv_sgl-0.6.0.dist-info/RECORD +142 -0
  5. nv_sgl-0.6.0.dist-info/WHEEL +5 -0
  6. nv_sgl-0.6.0.dist-info/top_level.txt +1 -0
  7. sgl/__init__.py +15 -0
  8. sgl/__init__.pyi +6738 -0
  9. sgl/d3d12/D3D12Core.dll +0 -0
  10. sgl/d3d12/d3d12SDKLayers.dll +0 -0
  11. sgl/dxcompiler.dll +0 -0
  12. sgl/dxil.dll +0 -0
  13. sgl/gfx.dll +0 -0
  14. sgl/include/sgl/app/app.h +113 -0
  15. sgl/include/sgl/core/bitmap.h +302 -0
  16. sgl/include/sgl/core/crypto.h +89 -0
  17. sgl/include/sgl/core/data_type.h +46 -0
  18. sgl/include/sgl/core/dds_file.h +103 -0
  19. sgl/include/sgl/core/enum.h +201 -0
  20. sgl/include/sgl/core/error.h +161 -0
  21. sgl/include/sgl/core/file_stream.h +77 -0
  22. sgl/include/sgl/core/file_system_watcher.h +141 -0
  23. sgl/include/sgl/core/format.h +36 -0
  24. sgl/include/sgl/core/fwd.h +90 -0
  25. sgl/include/sgl/core/hash.h +45 -0
  26. sgl/include/sgl/core/input.h +522 -0
  27. sgl/include/sgl/core/logger.h +214 -0
  28. sgl/include/sgl/core/macros.h +184 -0
  29. sgl/include/sgl/core/maths.h +45 -0
  30. sgl/include/sgl/core/memory_mapped_file.h +112 -0
  31. sgl/include/sgl/core/memory_mapped_file_stream.h +32 -0
  32. sgl/include/sgl/core/memory_stream.h +74 -0
  33. sgl/include/sgl/core/object.h +683 -0
  34. sgl/include/sgl/core/platform.h +239 -0
  35. sgl/include/sgl/core/plugin.h +331 -0
  36. sgl/include/sgl/core/resolver.h +39 -0
  37. sgl/include/sgl/core/short_vector.h +141 -0
  38. sgl/include/sgl/core/static_vector.h +111 -0
  39. sgl/include/sgl/core/stream.h +54 -0
  40. sgl/include/sgl/core/string.h +276 -0
  41. sgl/include/sgl/core/struct.h +360 -0
  42. sgl/include/sgl/core/thread.h +28 -0
  43. sgl/include/sgl/core/timer.h +52 -0
  44. sgl/include/sgl/core/traits.h +15 -0
  45. sgl/include/sgl/core/type_utils.h +19 -0
  46. sgl/include/sgl/core/window.h +177 -0
  47. sgl/include/sgl/device/agility_sdk.h +24 -0
  48. sgl/include/sgl/device/blit.h +88 -0
  49. sgl/include/sgl/device/buffer_cursor.h +162 -0
  50. sgl/include/sgl/device/command.h +539 -0
  51. sgl/include/sgl/device/cuda_api.h +766 -0
  52. sgl/include/sgl/device/cuda_interop.h +39 -0
  53. sgl/include/sgl/device/cuda_utils.h +107 -0
  54. sgl/include/sgl/device/cursor_utils.h +129 -0
  55. sgl/include/sgl/device/device.h +668 -0
  56. sgl/include/sgl/device/device_resource.h +37 -0
  57. sgl/include/sgl/device/fence.h +91 -0
  58. sgl/include/sgl/device/formats.h +330 -0
  59. sgl/include/sgl/device/framebuffer.h +85 -0
  60. sgl/include/sgl/device/fwd.h +164 -0
  61. sgl/include/sgl/device/helpers.h +20 -0
  62. sgl/include/sgl/device/hot_reload.h +75 -0
  63. sgl/include/sgl/device/input_layout.h +74 -0
  64. sgl/include/sgl/device/kernel.h +69 -0
  65. sgl/include/sgl/device/memory_heap.h +155 -0
  66. sgl/include/sgl/device/native_formats.h +342 -0
  67. sgl/include/sgl/device/native_handle.h +73 -0
  68. sgl/include/sgl/device/native_handle_traits.h +65 -0
  69. sgl/include/sgl/device/pipeline.h +138 -0
  70. sgl/include/sgl/device/print.h +45 -0
  71. sgl/include/sgl/device/python/cursor_utils.h +853 -0
  72. sgl/include/sgl/device/query.h +52 -0
  73. sgl/include/sgl/device/raytracing.h +84 -0
  74. sgl/include/sgl/device/reflection.h +1254 -0
  75. sgl/include/sgl/device/resource.h +705 -0
  76. sgl/include/sgl/device/sampler.h +57 -0
  77. sgl/include/sgl/device/shader.h +516 -0
  78. sgl/include/sgl/device/shader_cursor.h +85 -0
  79. sgl/include/sgl/device/shader_object.h +94 -0
  80. sgl/include/sgl/device/shader_offset.h +67 -0
  81. sgl/include/sgl/device/shared_handle.h +12 -0
  82. sgl/include/sgl/device/slang_utils.h +54 -0
  83. sgl/include/sgl/device/swapchain.h +74 -0
  84. sgl/include/sgl/device/types.h +782 -0
  85. sgl/include/sgl/math/colorspace.h +56 -0
  86. sgl/include/sgl/math/constants.h +7 -0
  87. sgl/include/sgl/math/float16.h +146 -0
  88. sgl/include/sgl/math/matrix.h +6 -0
  89. sgl/include/sgl/math/matrix_math.h +746 -0
  90. sgl/include/sgl/math/matrix_types.h +207 -0
  91. sgl/include/sgl/math/python/primitivetype.h +33 -0
  92. sgl/include/sgl/math/quaternion.h +6 -0
  93. sgl/include/sgl/math/quaternion_math.h +484 -0
  94. sgl/include/sgl/math/quaternion_types.h +83 -0
  95. sgl/include/sgl/math/ray.h +47 -0
  96. sgl/include/sgl/math/scalar_math.h +249 -0
  97. sgl/include/sgl/math/scalar_types.h +107 -0
  98. sgl/include/sgl/math/vector.h +6 -0
  99. sgl/include/sgl/math/vector_math.h +1796 -0
  100. sgl/include/sgl/math/vector_types.h +336 -0
  101. sgl/include/sgl/python/nanobind.h +489 -0
  102. sgl/include/sgl/python/py_doc.h +11600 -0
  103. sgl/include/sgl/python/sgl_ext_pch.h +8 -0
  104. sgl/include/sgl/sgl.h +21 -0
  105. sgl/include/sgl/sgl_pch.h +6 -0
  106. sgl/include/sgl/stl/bit.h +377 -0
  107. sgl/include/sgl/tests/testing.h +54 -0
  108. sgl/include/sgl/ui/fwd.h +34 -0
  109. sgl/include/sgl/ui/imgui_config.h +43 -0
  110. sgl/include/sgl/ui/ui.h +71 -0
  111. sgl/include/sgl/ui/widgets.h +918 -0
  112. sgl/include/sgl/utils/python/slangpy.h +366 -0
  113. sgl/include/sgl/utils/renderdoc.h +50 -0
  114. sgl/include/sgl/utils/slangpy.h +153 -0
  115. sgl/include/sgl/utils/tev.h +93 -0
  116. sgl/include/sgl/utils/texture_loader.h +106 -0
  117. sgl/math/__init__.pyi +5083 -0
  118. sgl/platform/__init__.pyi +102 -0
  119. sgl/renderdoc/__init__.pyi +51 -0
  120. sgl/sgl.dll +0 -0
  121. sgl/sgl_ext.cp313-win_amd64.pyd +0 -0
  122. sgl/shaders/nvapi/nvHLSLExtns.h +2315 -0
  123. sgl/shaders/nvapi/nvHLSLExtnsInternal.h +758 -0
  124. sgl/shaders/nvapi/nvShaderExtnEnums.h +142 -0
  125. sgl/shaders/sgl/device/blit.slang +93 -0
  126. sgl/shaders/sgl/device/nvapi.slang +5 -0
  127. sgl/shaders/sgl/device/nvapi.slangh +7 -0
  128. sgl/shaders/sgl/device/print.slang +445 -0
  129. sgl/shaders/sgl/math/constants.slang +4 -0
  130. sgl/shaders/sgl/math/ray.slang +29 -0
  131. sgl/shaders/sgl/ui/imgui.slang +49 -0
  132. sgl/slang-glslang.dll +0 -0
  133. sgl/slang-llvm.dll +0 -0
  134. sgl/slang-rt.dll +0 -0
  135. sgl/slang.dll +0 -0
  136. sgl/slangpy/__init__.pyi +268 -0
  137. sgl/tev/__init__.pyi +108 -0
  138. sgl/tevclient.lib +0 -0
  139. sgl/thread/__init__.pyi +4 -0
  140. sgl/ui/__init__.pyi +1118 -0
  141. share/cmake/tevclient/tevclient-config-release.cmake +19 -0
  142. share/cmake/tevclient/tevclient-config.cmake +108 -0
@@ -0,0 +1,853 @@
1
+ #pragma once
2
+
3
+ #include <optional>
4
+
5
+ #include "nanobind.h"
6
+
7
+ #include "sgl/device/reflection.h"
8
+ #include "sgl/device/cursor_utils.h"
9
+
10
+ #include "sgl/math/vector_types.h"
11
+ #include "sgl/math/matrix_types.h"
12
+
13
+ namespace sgl {
14
+
15
+ /// Helper to convert from numpy type mask to slang scalar type.
16
+ inline std::optional<TypeReflection::ScalarType> dtype_to_scalar_type(nb::dlpack::dtype dtype)
17
+ {
18
+ switch (dtype.code) {
19
+ case uint8_t(nb::dlpack::dtype_code::Int):
20
+ switch (dtype.bits) {
21
+ case 8:
22
+ return TypeReflection::ScalarType::int8;
23
+ case 16:
24
+ return TypeReflection::ScalarType::int16;
25
+ case 32:
26
+ return TypeReflection::ScalarType::int32;
27
+ case 64:
28
+ return TypeReflection::ScalarType::int64;
29
+ }
30
+ break;
31
+ case uint8_t(nb::dlpack::dtype_code::UInt):
32
+ switch (dtype.bits) {
33
+ case 8:
34
+ return TypeReflection::ScalarType::uint8;
35
+ case 16:
36
+ return TypeReflection::ScalarType::uint16;
37
+ case 32:
38
+ return TypeReflection::ScalarType::uint32;
39
+ case 64:
40
+ return TypeReflection::ScalarType::uint64;
41
+ }
42
+ break;
43
+ case uint8_t(nb::dlpack::dtype_code::Float):
44
+ switch (dtype.bits) {
45
+ case 16:
46
+ return TypeReflection::ScalarType::float16;
47
+ case 32:
48
+ return TypeReflection::ScalarType::float32;
49
+ case 64:
50
+ return TypeReflection::ScalarType::float64;
51
+ }
52
+ break;
53
+ }
54
+ return {};
55
+ }
56
+
57
+ /// Enforces properties needed for converting to/from vector.
58
+ template<typename T>
59
+ concept IsSpecializationOfVector = requires {
60
+ T::dimension;
61
+ typename T::value_type;
62
+ };
63
+
64
+ /// Enforces properties needed for converting to/from matrix.
65
+ template<typename T>
66
+ concept IsSpecializationOfMatrix = requires {
67
+ T::rows;
68
+ T::cols;
69
+ typename T::value_type;
70
+ };
71
+
72
+ #define scalar_case(c_type, scalar_type) \
73
+ m_read_scalar[(int)TypeReflection::ScalarType::scalar_type] \
74
+ = [](const CursorType& self) { return _read_scalar<c_type>(self); };
75
+
76
+ #define vector_case(c_type, scalar_type) \
77
+ m_read_vector[(int)TypeReflection::ScalarType::scalar_type][c_type::dimension] \
78
+ = [](const CursorType& self) { return _read_vector<c_type>(self); };
79
+
80
+ #define matrix_case(c_type, scalar_type) \
81
+ m_read_matrix[(int)TypeReflection::ScalarType::scalar_type][c_type::rows][c_type::cols] \
82
+ = [](const CursorType& self) { return _read_matrix<c_type>(self); };
83
+
84
+ /// Table of converters based on slang scalar type and shape.
85
+ template<typename CursorType>
86
+ class ReadConverterTable {
87
+ public:
88
+ ReadConverterTable()
89
+ {
90
+ // Initialize all entries to an error function that throws an exception.
91
+ auto read_err_func = [](const CursorType&) -> nb::object { SGL_THROW("Unsupported element type"); };
92
+ for (int i = 0; i < (int)TypeReflection::ScalarType::COUNT; i++) {
93
+ m_read_scalar[i] = read_err_func;
94
+ for (int j = 0; j < 5; ++j) {
95
+ m_read_vector[i][j] = read_err_func;
96
+ for (int k = 0; k < 5; ++k) {
97
+ m_read_matrix[i][j][k] = read_err_func;
98
+ }
99
+ }
100
+ }
101
+
102
+ // Register converters for all supported scalar types.
103
+ scalar_case(bool, bool_);
104
+ scalar_case(int8_t, int8);
105
+ scalar_case(uint8_t, uint8);
106
+ scalar_case(int16_t, int16);
107
+ scalar_case(uint16_t, uint16);
108
+ scalar_case(int32_t, int32);
109
+ scalar_case(uint32_t, uint32);
110
+ scalar_case(int64_t, int64);
111
+ scalar_case(uint64_t, uint64);
112
+ scalar_case(float16_t, float16);
113
+ scalar_case(float, float32);
114
+ scalar_case(double, float64);
115
+ scalar_case(intptr_t, intptr);
116
+ scalar_case(uintptr_t, uintptr);
117
+
118
+ // Register converters for all supported vector types.
119
+ vector_case(bool1, bool_);
120
+ vector_case(float1, float32);
121
+ vector_case(float16_t1, float16);
122
+ vector_case(int1, int32);
123
+ vector_case(uint1, uint32);
124
+ vector_case(bool2, bool_);
125
+ vector_case(float2, float32);
126
+ vector_case(float16_t2, float16);
127
+ vector_case(int2, int32);
128
+ vector_case(uint2, uint32);
129
+ vector_case(bool3, bool_);
130
+ vector_case(float3, float32);
131
+ vector_case(float16_t3, float16);
132
+ vector_case(int3, int32);
133
+ vector_case(uint3, uint32);
134
+ vector_case(bool4, bool_);
135
+ vector_case(float4, float32);
136
+ vector_case(float16_t4, float16);
137
+ vector_case(int4, int32);
138
+ vector_case(uint4, uint32);
139
+
140
+ // Register converters for all supported matrix types.
141
+ matrix_case(float2x2, float32);
142
+ matrix_case(float3x3, float32);
143
+ matrix_case(float2x4, float32);
144
+ matrix_case(float3x4, float32);
145
+ matrix_case(float4x4, float32);
146
+ }
147
+
148
+ /// Read function inspects the slang type and attempts to convert it
149
+ /// to a matching python type. For structs and arrays, generates
150
+ /// a nested dictionary or list and recurses.
151
+ nb::object read(const CursorType& self)
152
+ {
153
+ m_stack.clear();
154
+ try {
155
+ return read_internal(self);
156
+ } catch (const std::exception& err) {
157
+ SGL_THROW("{}: {}", build_error(), err.what());
158
+ }
159
+ }
160
+
161
+ private:
162
+ std::function<nb::object(const CursorType&)> m_read_scalar[(int)TypeReflection::ScalarType::COUNT];
163
+ std::function<nb::object(const CursorType&)> m_read_vector[(int)TypeReflection::ScalarType::COUNT][5];
164
+ std::function<nb::object(const CursorType&)> m_read_matrix[(int)TypeReflection::ScalarType::COUNT][5][5];
165
+ std::vector<const char*> m_stack;
166
+
167
+ std::string build_error() { return fmt::format("{}", fmt::join(m_stack, ".")); }
168
+
169
+ nb::object read_internal(const CursorType& self)
170
+ {
171
+ if (!self.is_valid())
172
+ return nb::none();
173
+ auto type = self.type();
174
+ if (type) {
175
+ switch (type->kind()) {
176
+ case TypeReflection::Kind::scalar: {
177
+ return m_read_scalar[(int)type->scalar_type()](self);
178
+ }
179
+ case TypeReflection::Kind::vector: {
180
+ return m_read_vector[(int)type->scalar_type()][type->col_count()](self);
181
+ }
182
+ case TypeReflection::Kind::matrix: {
183
+ return m_read_matrix[(int)type->scalar_type()][type->row_count()][type->col_count()](self);
184
+ }
185
+ case TypeReflection::Kind::struct_: {
186
+ nb::dict res;
187
+ for (uint32_t i = 0; i < type->field_count(); i++) {
188
+ auto field = type->get_field_by_index(i);
189
+ const char* name = field->name();
190
+ m_stack.push_back(name);
191
+ res[name] = read_internal(self[name]);
192
+ m_stack.pop_back();
193
+ }
194
+ return res;
195
+ }
196
+ case TypeReflection::Kind::array: {
197
+ nb::list res;
198
+ m_stack.push_back("[]");
199
+ for (uint32_t i = 0; i < type->element_count(); i++) {
200
+ res.append(read_internal(self[i]));
201
+ }
202
+ m_stack.pop_back();
203
+ return res;
204
+ }
205
+ default:
206
+ break;
207
+ }
208
+ }
209
+ SGL_THROW("Unsupported element type");
210
+ }
211
+
212
+ /// Read scalar value from buffer element cursor and convert to Python object.
213
+ template<typename ValType>
214
+ inline static nb::object _read_scalar(const CursorType& self)
215
+ {
216
+ ValType res;
217
+ self.get(res);
218
+ return nb::cast(res);
219
+ }
220
+
221
+ /// Read vector value from buffer element cursor and convert to Python object.
222
+ template<typename ValType>
223
+ requires IsSpecializationOfVector<ValType>
224
+ inline static nb::object _read_vector(const CursorType& self)
225
+ {
226
+ ValType res;
227
+ self.get(res);
228
+ return nb::cast(res);
229
+ }
230
+
231
+ /// Read matrix value from buffer element cursor and convert to Python object.
232
+ template<typename ValType>
233
+ requires IsSpecializationOfMatrix<ValType>
234
+ inline static nb::object _read_matrix(const CursorType& self)
235
+ {
236
+ ValType res;
237
+ self.get(res);
238
+ return nb::cast(res);
239
+ }
240
+ };
241
+
242
+ #undef scalar_case
243
+ #undef vector_case
244
+ #undef matrix_case
245
+
246
+ #define scalar_case(c_type, scalar_type) \
247
+ m_write_scalar[(int)TypeReflection::ScalarType::scalar_type] \
248
+ = [](CursorType& self, nb::object nbval) { _write_scalar<c_type>(self, nbval); };
249
+
250
+ #define vector_case(c_type, scalar_type) \
251
+ m_write_vector[(int)TypeReflection::ScalarType::scalar_type][c_type::dimension] \
252
+ = [](CursorType& self, nb::object nbval) { _write_vector<c_type>(self, nbval); };
253
+
254
+ #define bool_vector_case(c_type, scalar_type) \
255
+ m_write_vector[(int)TypeReflection::ScalarType::scalar_type][c_type::dimension] \
256
+ = [](CursorType& self, nb::object nbval) { _write_bool_vector<c_type>(self, nbval); };
257
+
258
+ #define matrix_case(c_type, scalar_type) \
259
+ m_write_matrix[(int)TypeReflection::ScalarType::scalar_type][c_type::rows][c_type::cols] \
260
+ = [](CursorType& self, nb::object nbval) { _write_matrix<c_type>(self, nbval); };
261
+
262
+ /// Table of converters based on slang scalar type and shape.
263
+ template<typename CursorType>
264
+ class WriteConverterTable {
265
+ public:
266
+ WriteConverterTable()
267
+ {
268
+ // Initialize all entries to an error function that throws an exception.
269
+ auto write_err_func = [](const CursorType&, nb::object) { SGL_THROW("Unsupported element type"); };
270
+ for (int i = 0; i < (int)TypeReflection::ScalarType::COUNT; i++) {
271
+ m_write_scalar[i] = write_err_func;
272
+ for (int j = 0; j < 5; ++j) {
273
+ m_write_vector[i][j] = write_err_func;
274
+ for (int k = 0; k < 5; ++k) {
275
+ m_write_matrix[i][j][k] = write_err_func;
276
+ }
277
+ }
278
+ }
279
+
280
+ // Register converters for all supported scalar types.
281
+ scalar_case(bool, bool_);
282
+ scalar_case(int8_t, int8);
283
+ scalar_case(uint8_t, uint8);
284
+ scalar_case(int16_t, int16);
285
+ scalar_case(uint16_t, uint16);
286
+ scalar_case(int32_t, int32);
287
+ scalar_case(uint32_t, uint32);
288
+ scalar_case(int64_t, int64);
289
+ scalar_case(uint64_t, uint64);
290
+ scalar_case(float16_t, float16);
291
+ scalar_case(float, float32);
292
+ scalar_case(double, float64);
293
+ scalar_case(intptr_t, intptr);
294
+ scalar_case(uintptr_t, uintptr);
295
+
296
+ // Register converters for all supported vector types.
297
+ bool_vector_case(bool1, bool_);
298
+ vector_case(float1, float32);
299
+ vector_case(float16_t1, float16);
300
+ vector_case(int1, int32);
301
+ vector_case(uint1, uint32);
302
+ bool_vector_case(bool2, bool_);
303
+ vector_case(float2, float32);
304
+ vector_case(float16_t2, float16);
305
+ vector_case(int2, int32);
306
+ vector_case(uint2, uint32);
307
+ bool_vector_case(bool3, bool_);
308
+ vector_case(float3, float32);
309
+ vector_case(float16_t3, float16);
310
+ vector_case(int3, int32);
311
+ vector_case(uint3, uint32);
312
+ bool_vector_case(bool4, bool_);
313
+ vector_case(float4, float32);
314
+ vector_case(float16_t4, float16);
315
+ vector_case(int4, int32);
316
+ vector_case(uint4, uint32);
317
+
318
+ // Register converters for all supported matrix types.
319
+ matrix_case(float2x2, float32);
320
+ matrix_case(float3x3, float32);
321
+ matrix_case(float2x4, float32);
322
+ matrix_case(float3x4, float32);
323
+ matrix_case(float4x4, float32);
324
+ }
325
+
326
+ /// Virtual for writing none-basic value types.
327
+ virtual bool write_value(CursorType& self, nb::object nbval)
328
+ {
329
+ SGL_UNUSED(self);
330
+ SGL_UNUSED(nbval);
331
+ return false;
332
+ }
333
+
334
+ /// Write function inspects the slang type and uses it to try
335
+ /// and convert a Python input to the correct c++ type. For structs
336
+ /// and arrays, expects a dict, sequence type or numpy array.
337
+ void write(CursorType& self, nb::object nbval)
338
+ {
339
+ m_stack.clear();
340
+ try {
341
+ write_internal(self, nbval);
342
+ } catch (const std::exception& err) {
343
+ SGL_THROW("{}: {}", build_error(), err.what());
344
+ }
345
+ }
346
+
347
+ private:
348
+ std::function<void(CursorType&, nb::object)> m_write_scalar[(int)TypeReflection::ScalarType::COUNT];
349
+ std::function<void(CursorType&, nb::object)> m_write_vector[(int)TypeReflection::ScalarType::COUNT][5];
350
+ std::function<void(CursorType&, nb::object)> m_write_matrix[(int)TypeReflection::ScalarType::COUNT][5][5];
351
+ std::vector<const char*> m_stack;
352
+
353
+ std::string build_error() { return fmt::format("{}", fmt::join(m_stack, ".")); }
354
+
355
+ void write_internal(CursorType& self, nb::object nbval)
356
+ {
357
+ if (!self.is_valid())
358
+ return;
359
+
360
+ ref<const TypeLayoutReflection> type_layout = self.type_layout();
361
+ auto kind = type_layout->kind();
362
+
363
+ switch (kind) {
364
+ case TypeReflection::Kind::scalar: {
365
+ auto type = type_layout->type();
366
+ SGL_ASSERT(type);
367
+ return m_write_scalar[(int)type->scalar_type()](self, nbval);
368
+ }
369
+ case TypeReflection::Kind::vector: {
370
+ auto type = type_layout->type();
371
+ SGL_ASSERT(type);
372
+ return m_write_vector[(int)type->scalar_type()][type->col_count()](self, nbval);
373
+ }
374
+ case TypeReflection::Kind::matrix: {
375
+ auto type = type_layout->type();
376
+ SGL_ASSERT(type);
377
+ return m_write_matrix[(int)type->scalar_type()][type->row_count()][type->col_count()](self, nbval);
378
+ }
379
+ case TypeReflection::Kind::constant_buffer:
380
+ case TypeReflection::Kind::parameter_block:
381
+ case TypeReflection::Kind::struct_: {
382
+ // Unwrap constant buffers or parameter blocks
383
+ if (kind != TypeReflection::Kind::struct_)
384
+ type_layout = type_layout->element_type_layout();
385
+
386
+ // Handle shader object if possible.
387
+ if constexpr (requires { self.set_object(nullptr); }) {
388
+ if (nb::isinstance<MutableShaderObject>(nbval)) {
389
+ self.set_object(nb::cast<ref<MutableShaderObject>>(nbval));
390
+ return;
391
+ }
392
+ }
393
+
394
+ // Expect a dict for a slang struct.
395
+ if (nb::isinstance<nb::dict>(nbval)) {
396
+ auto dict = nb::cast<nb::dict>(nbval);
397
+ for (uint32_t i = 0; i < type_layout->field_count(); i++) {
398
+ auto field = type_layout->get_field_by_index(i);
399
+ const char* name = field->name();
400
+ auto child = self[name];
401
+ if (dict.contains(name)) {
402
+ m_stack.push_back(name);
403
+ write_internal(child, dict[name]);
404
+ m_stack.pop_back();
405
+ }
406
+ }
407
+ return;
408
+ } else {
409
+ SGL_THROW("Expected dict");
410
+ }
411
+ }
412
+ case TypeReflection::Kind::array: {
413
+ // Expect numpy array or sequence for a slang array.
414
+ if (nb::isinstance<nb::ndarray<nb::numpy>>(nbval)) {
415
+ // TODO: Should be able to do better job of interpreting nb array values by reading
416
+ // data type and extracting individual elements.
417
+ auto nbarray = nb::cast<nb::ndarray<nb::numpy>>(nbval);
418
+ SGL_CHECK(nbarray.ndim() == 1, "numpy array must have 1 dimension.");
419
+ SGL_CHECK(nbarray.shape(0) == type_layout->element_count(), "numpy array is the wrong length.");
420
+ SGL_CHECK(is_ndarray_contiguous(nbarray), "data is not contiguous");
421
+ self._set_array(
422
+ nbarray.data(),
423
+ nbarray.nbytes(),
424
+ type_layout->element_type_layout()->type()->scalar_type(),
425
+ narrow_cast<int>(nbarray.shape(0))
426
+ );
427
+ return;
428
+ } else if (nb::isinstance<nb::sequence>(nbval)) {
429
+ auto seq = nb::cast<nb::sequence>(nbval);
430
+ SGL_CHECK(
431
+ nb::len(seq) == type_layout->element_count(),
432
+ "sequence is the wrong length accessing type {}: {} != {}.",
433
+ type_layout->type()->full_name(),
434
+ nb::len(seq),
435
+ type_layout->element_count()
436
+ );
437
+ m_stack.push_back("[]");
438
+ for (uint32_t i = 0; i < type_layout->element_count(); i++) {
439
+ auto child = self[i];
440
+ write_internal(child, seq[i]);
441
+ }
442
+ m_stack.pop_back();
443
+ return;
444
+ } else {
445
+ SGL_THROW("Expected list");
446
+ }
447
+ }
448
+ default:
449
+ break;
450
+ }
451
+
452
+ // In default case call the virtual write_value, and fail if it returns false.
453
+ if (write_value(self, nbval))
454
+ return;
455
+
456
+ SGL_THROW("Unsupported element type");
457
+ }
458
+
459
+
460
+ /// Write scalar value to buffer element cursor from Python object.
461
+ template<typename ValType>
462
+ inline static void _write_scalar(CursorType& self, nb::object nbval)
463
+ {
464
+ auto val = nb::cast<ValType>(nbval);
465
+ self.set(val);
466
+ }
467
+
468
+ /// Default implementation of write vector from numpy array.
469
+ template<typename ValType>
470
+ requires IsSpecializationOfVector<ValType>
471
+ inline static void _write_vector_from_numpy(CursorType& self, nb::ndarray<nb::numpy> nbarray)
472
+ {
473
+ SGL_CHECK(nbarray.nbytes() == sizeof(ValType), "numpy array has wrong size.");
474
+ auto val = *reinterpret_cast<const ValType*>(nbarray.data());
475
+ self.set(val);
476
+ }
477
+
478
+ /// Version of vector write specifically for bool vectors (which are stored as uint32_t)
479
+ template<typename ValType>
480
+ requires IsSpecializationOfVector<ValType>
481
+ inline static void _write_bool_vector_from_numpy(CursorType& self, nb::ndarray<nb::numpy> nbarray)
482
+ {
483
+ SGL_CHECK(nbarray.nbytes() == ValType::dimension * 4, "numpy array has wrong size.");
484
+ self._set_vector(nbarray.data(), nbarray.nbytes(), TypeReflection::ScalarType::bool_, ValType::dimension);
485
+ }
486
+
487
+ /// Write vector value to buffer element cursor from Python object
488
+ template<typename ValType>
489
+ requires IsSpecializationOfVector<ValType>
490
+ inline static void _write_vector(CursorType& self, nb::object nbval)
491
+ {
492
+ if (nb::isinstance<ValType>(nbval)) {
493
+ // A vector of the correct type - just convert it.
494
+ auto val = nb::cast<ValType>(nbval);
495
+ self.set(val);
496
+ } else if (nb::isinstance<nb::ndarray<nb::numpy>>(nbval)) {
497
+ // A numpy array. Reinterpret numpy memory as vector type.
498
+ nb::ndarray<nb::numpy> nbarray = nb::cast<nb::ndarray<nb::numpy>>(nbval);
499
+ SGL_CHECK(is_ndarray_contiguous(nbarray), "data is not contiguous");
500
+ SGL_CHECK(nbarray.ndim() == 1 || nbarray.ndim() == 2, "numpy array must have 1 or 2 dimensions.");
501
+ size_t dimension = 1;
502
+ for (size_t i = 0; i < nbarray.ndim(); ++i)
503
+ dimension *= nbarray.shape(i);
504
+ SGL_CHECK(dimension == ValType::dimension, "numpy array has wrong dimension.");
505
+ _write_vector_from_numpy<ValType>(self, nbarray);
506
+ } else if (nb::isinstance<nb::sequence>(nbval)) {
507
+ // A list or tuple. Attempt to cast each element of list to element of vector.
508
+ auto seq = nb::cast<nb::sequence>(nbval);
509
+ SGL_CHECK(nb::len(seq) == ValType::dimension, "sequence has wrong dimension.");
510
+ ValType val;
511
+ for (int i = 0; i < ValType::dimension; i++) {
512
+ val[i] = nb::cast<typename ValType::value_type>(seq[i]);
513
+ }
514
+ self.set(val);
515
+ } else {
516
+ SGL_THROW("Expected numpy array or vector");
517
+ }
518
+ }
519
+
520
+ /// Bespoke vector implementation for bools.
521
+ template<typename ValType>
522
+ requires IsSpecializationOfVector<ValType>
523
+ inline static void _write_bool_vector(CursorType& self, nb::object nbval)
524
+ {
525
+ if (nb::isinstance<ValType>(nbval)) {
526
+ // A vector of the correct type - just convert it.
527
+ auto val = nb::cast<ValType>(nbval);
528
+ self.set(val);
529
+ } else if (nb::isinstance<nb::ndarray<nb::numpy>>(nbval)) {
530
+ // A numpy array. Reinterpret numpy memory as vector type.
531
+ nb::ndarray<nb::numpy> nbarray = nb::cast<nb::ndarray<nb::numpy>>(nbval);
532
+ SGL_CHECK(is_ndarray_contiguous(nbarray), "data is not contiguous");
533
+ SGL_CHECK(nbarray.ndim() == 1 || nbarray.ndim() == 2, "numpy array must have 1 or 2 dimensions.");
534
+ size_t dimension = 1;
535
+ for (size_t i = 0; i < nbarray.ndim(); ++i)
536
+ dimension *= nbarray.shape(i);
537
+ SGL_CHECK(dimension == ValType::dimension, "numpy array has wrong dimension.");
538
+ _write_bool_vector_from_numpy<ValType>(self, nbarray);
539
+ } else if (nb::isinstance<nb::sequence>(nbval)) {
540
+ // A list or tuple. Attempt to cast each element of list to element of vector.
541
+ auto seq = nb::cast<nb::sequence>(nbval);
542
+ SGL_CHECK(nb::len(seq) == ValType::dimension, "sequence has wrong dimension.");
543
+ ValType val;
544
+ for (int i = 0; i < ValType::dimension; i++) {
545
+ val[i] = nb::cast<typename ValType::value_type>(seq[i]);
546
+ }
547
+ self.set(val);
548
+ } else {
549
+ SGL_THROW("Expected numpy array or vector");
550
+ }
551
+ }
552
+
553
+ /// Write matrix value to buffer element cursor from Python object.
554
+ template<typename ValType>
555
+ requires IsSpecializationOfMatrix<ValType>
556
+ inline static void _write_matrix(CursorType& self, nb::object nbval)
557
+ {
558
+ if (nb::isinstance<ValType>(nbval)) {
559
+ // Matrix of correct type
560
+ auto val = nb::cast<ValType>(nbval);
561
+ self.set(val);
562
+ } else if (nb::isinstance<nb::ndarray<nb::numpy>>(nbval)) {
563
+ // A numpy array. We have a python cast from numpy->matrix,
564
+ // so can just call it here to convert properly.
565
+ auto val = nb::cast<ValType>(nbval);
566
+ self.set(val);
567
+ } else {
568
+ SGL_THROW("Expected numpy array or matrix");
569
+ }
570
+ }
571
+ };
572
+
573
+ #undef scalar_case
574
+ #undef vector_case
575
+ #undef bool_vector_case
576
+ #undef matrix_case
577
+
578
+
579
+ template<typename CursorType>
580
+ requires TraversableCursor<CursorType>
581
+ inline void bind_traversable_cursor(nanobind::class_<CursorType>& cursor)
582
+ {
583
+ cursor //
584
+ .def_prop_ro("_type_layout", &CursorType::type_layout, D_NA(CursorType, type_layout))
585
+ .def_prop_ro("_type", &CursorType::type, D_NA(CursorType, type))
586
+ .def("is_valid", &CursorType::is_valid, D_NA(CursorType, is_valid))
587
+ .def("find_field", &CursorType::find_field, "name"_a, D_NA(CursorType, find_field))
588
+ .def("find_element", &CursorType::find_element, "index"_a, D_NA(CursorType, find_element))
589
+ .def("has_field", &CursorType::has_field, "name"_a, D_NA(CursorType, has_field))
590
+ .def("has_element", &CursorType::has_element, "index"_a, D_NA(CursorType, has_element))
591
+ .def("__getitem__", [](CursorType& self, std::string_view name) { return self[name]; })
592
+ .def("__getitem__", [](CursorType& self, int index) { return self[index]; })
593
+ // note: __getattr__ should not except if field is not found
594
+ .def("__getattr__", [](CursorType& self, std::string_view name) { return self.find_field(name); });
595
+ }
596
+
597
+ template<typename CursorType>
598
+ inline void bind_writable_cursor(WriteConverterTable<CursorType>& table, nanobind::class_<CursorType>& cursor)
599
+ {
600
+ // __setitem__ and __setattr__ functions are overloaded to allow direct setting
601
+ // of fields and elements.
602
+ cursor //
603
+ .def(
604
+ "__setattr__",
605
+ [&table](CursorType& self, std::string_view name, nb::object nbval)
606
+ {
607
+ auto child = self[name];
608
+ table.write(child, nbval);
609
+ },
610
+ "name"_a,
611
+ "val"_a,
612
+ D_NA(CursorType, write)
613
+ )
614
+ .def(
615
+ "__setitem__",
616
+ [&table](CursorType& self, std::string_view name, nb::object nbval)
617
+ {
618
+ auto child = self[name];
619
+ table.write(child, nbval);
620
+ },
621
+ "index"_a,
622
+ "val"_a,
623
+ D_NA(CursorType, write)
624
+ )
625
+ .def(
626
+ "__setitem__",
627
+ [&table](CursorType& self, int index, nb::object nbval)
628
+ {
629
+ auto child = self[index];
630
+ table.write(child, nbval);
631
+ },
632
+ "index"_a,
633
+ "val"_a,
634
+ D_NA(CursorType, write)
635
+ )
636
+ .def(
637
+ "set_data",
638
+ [](CursorType& self, nb::ndarray<nb::device::cpu> data)
639
+ {
640
+ SGL_CHECK(is_ndarray_contiguous(data), "data is not contiguous");
641
+ self.set_data(data.data(), data.nbytes());
642
+ },
643
+ "data"_a,
644
+ D(ShaderCursor, set_data)
645
+ )
646
+ .def(
647
+ "write",
648
+ [&table](CursorType& self, nb::object nbval) { table.write(self, nbval); },
649
+ "val"_a,
650
+ D_NA(CursorType, write)
651
+ );
652
+ }
653
+
654
+ template<typename CursorType>
655
+ inline void bind_readable_cursor(ReadConverterTable<CursorType>& table, nanobind::class_<CursorType>& cursor)
656
+ {
657
+ // __setitem__ and __setattr__ functions are overloaded to allow direct setting
658
+ // of fields and elements.
659
+ cursor //
660
+ .def(
661
+ "read",
662
+ [&table](CursorType& self) { return table.read(self); },
663
+ D_NA(CursorType, read)
664
+ );
665
+ }
666
+
667
+ template<typename CursorType>
668
+ requires WritableCursor<CursorType>
669
+ inline void bind_writable_cursor_basic_types(nanobind::class_<CursorType>& cursor)
670
+ {
671
+ #define def_setter(type) \
672
+ cursor.def("__setitem__", [](CursorType& self, std::string_view name, type value) { self[name] = value; }); \
673
+ cursor.def("__setattr__", [](CursorType& self, std::string_view name, type value) { self[name] = value; });
674
+
675
+ def_setter(bool);
676
+ def_setter(bool1);
677
+ def_setter(bool2);
678
+ def_setter(bool3);
679
+ def_setter(bool4);
680
+
681
+ def_setter(uint1);
682
+ def_setter(uint2);
683
+ def_setter(uint3);
684
+ def_setter(uint4);
685
+
686
+ def_setter(int1);
687
+ def_setter(int2);
688
+ def_setter(int3);
689
+ def_setter(int4);
690
+
691
+ def_setter(float1);
692
+ def_setter(float2);
693
+ def_setter(float3);
694
+ def_setter(float4);
695
+
696
+ def_setter(float2x2);
697
+ def_setter(float3x3);
698
+ def_setter(float2x4);
699
+ def_setter(float3x4);
700
+ def_setter(float4x4);
701
+
702
+ def_setter(float16_t2);
703
+ def_setter(float16_t3);
704
+ def_setter(float16_t4);
705
+
706
+ #undef def_setter
707
+
708
+ auto set_int_field = [](CursorType& self, std::string_view name, nb::int_ value)
709
+ {
710
+ ref<const TypeReflection> type = self[name].type();
711
+ SGL_CHECK(type->kind() == TypeReflection::Kind::scalar, "Field \"{}\" is not a scalar type.", name);
712
+ switch (type->scalar_type()) {
713
+ case TypeReflection::ScalarType::int16:
714
+ self[name] = nb::cast<int16_t>(value);
715
+ break;
716
+ case TypeReflection::ScalarType::int32:
717
+ self[name] = nb::cast<int32_t>(value);
718
+ break;
719
+ case TypeReflection::ScalarType::int64:
720
+ self[name] = nb::cast<int64_t>(value);
721
+ break;
722
+ case TypeReflection::ScalarType::uint16:
723
+ self[name] = nb::cast<uint16_t>(value);
724
+ break;
725
+ case TypeReflection::ScalarType::uint32:
726
+ self[name] = nb::cast<uint32_t>(value);
727
+ break;
728
+ case TypeReflection::ScalarType::uint64:
729
+ self[name] = nb::cast<uint64_t>(value);
730
+ break;
731
+ default:
732
+ SGL_THROW("Field \"{}\" is not an integer type.");
733
+ break;
734
+ }
735
+ };
736
+
737
+ auto set_int_element = [](CursorType& self, int index, nb::int_ value)
738
+ {
739
+ ref<const TypeReflection> type = self[index].type();
740
+ SGL_CHECK(type->kind() == TypeReflection::Kind::scalar, "Element {} is not a scalar type.", index);
741
+ switch (type->scalar_type()) {
742
+ case TypeReflection::ScalarType::int16:
743
+ self[index] = nb::cast<int16_t>(value);
744
+ break;
745
+ case TypeReflection::ScalarType::int32:
746
+ self[index] = nb::cast<int32_t>(value);
747
+ break;
748
+ case TypeReflection::ScalarType::int64:
749
+ self[index] = nb::cast<int64_t>(value);
750
+ break;
751
+ case TypeReflection::ScalarType::uint16:
752
+ self[index] = nb::cast<uint16_t>(value);
753
+ break;
754
+ case TypeReflection::ScalarType::uint32:
755
+ self[index] = nb::cast<uint32_t>(value);
756
+ break;
757
+ case TypeReflection::ScalarType::uint64:
758
+ self[index] = nb::cast<uint64_t>(value);
759
+ break;
760
+ default:
761
+ SGL_THROW("Element {} is not an integer type.");
762
+ break;
763
+ }
764
+ };
765
+
766
+ cursor.def("__setitem__", set_int_field);
767
+ cursor.def("__setitem__", set_int_element);
768
+ cursor.def("__setattr__", set_int_field);
769
+
770
+ auto set_float_field = [](CursorType& self, std::string_view name, nb::float_ value)
771
+ {
772
+ ref<const TypeReflection> type = self[name].type();
773
+ SGL_CHECK(type->kind() == TypeReflection::Kind::scalar, "Field \"{}\" is not a scalar type.", name);
774
+ switch (type->scalar_type()) {
775
+ case TypeReflection::ScalarType::float16:
776
+ self[name] = float16_t(nb::cast<float>(value));
777
+ break;
778
+ case TypeReflection::ScalarType::float32:
779
+ self[name] = nb::cast<float>(value);
780
+ break;
781
+ case TypeReflection::ScalarType::float64:
782
+ self[name] = nb::cast<double>(value);
783
+ break;
784
+ default:
785
+ SGL_THROW("Field \"{}\" is not a floating point type.");
786
+ break;
787
+ }
788
+ };
789
+
790
+ auto set_float_element = [](CursorType& self, int index, nb::float_ value)
791
+ {
792
+ ref<const TypeReflection> type = self[index].type();
793
+ SGL_CHECK(type->kind() == TypeReflection::Kind::scalar, "Element {} is not a scalar type.", index);
794
+ switch (type->scalar_type()) {
795
+ case TypeReflection::ScalarType::float16:
796
+ self[index] = float16_t(nb::cast<float>(value));
797
+ break;
798
+ case TypeReflection::ScalarType::float32:
799
+ self[index] = nb::cast<float>(value);
800
+ break;
801
+ case TypeReflection::ScalarType::float64:
802
+ self[index] = nb::cast<double>(value);
803
+ break;
804
+ default:
805
+ SGL_THROW("Element {} is not a floating point type.");
806
+ break;
807
+ }
808
+ };
809
+
810
+ cursor.def("__setitem__", set_float_field);
811
+ cursor.def("__setitem__", set_float_element);
812
+ cursor.def("__setattr__", set_float_field);
813
+
814
+ auto set_numpy_field = [](CursorType& self, std::string_view name, nb::ndarray<nb::numpy> value)
815
+ {
816
+ ref<const TypeReflection> type = self[name].type();
817
+ auto src_scalar_type = dtype_to_scalar_type(value.dtype());
818
+ SGL_CHECK(src_scalar_type, "numpy array has unsupported dtype.");
819
+ SGL_CHECK(is_ndarray_contiguous(value), "numpy array is not contiguous.");
820
+
821
+ switch (type->kind()) {
822
+ case TypeReflection::Kind::array:
823
+ SGL_CHECK(value.ndim() == 1, "numpy array must have 1 dimension.");
824
+ self[name]._set_array(value.data(), value.nbytes(), *src_scalar_type, narrow_cast<int>(value.shape(0)));
825
+ break;
826
+ case TypeReflection::Kind::matrix:
827
+ SGL_CHECK(value.ndim() == 2, "numpy array must have 2 dimensions.");
828
+ self[name]._set_matrix(
829
+ value.data(),
830
+ value.nbytes(),
831
+ *src_scalar_type,
832
+ narrow_cast<int>(value.shape(0)),
833
+ narrow_cast<int>(value.shape(1))
834
+ );
835
+ break;
836
+ case TypeReflection::Kind::vector: {
837
+ SGL_CHECK(value.ndim() == 1 || value.ndim() == 2, "numpy array must have 1 or 2 dimensions.");
838
+ size_t dimension = 1;
839
+ for (size_t i = 0; i < value.ndim(); ++i)
840
+ dimension *= value.shape(i);
841
+ self[name]._set_vector(value.data(), value.nbytes(), *src_scalar_type, narrow_cast<int>(dimension));
842
+ break;
843
+ }
844
+ default:
845
+ SGL_THROW("Field \"{}\" is not a vector, matrix, or array type.", name);
846
+ }
847
+ };
848
+
849
+ cursor.def("__setitem__", set_numpy_field);
850
+ cursor.def("__setattr__", set_numpy_field);
851
+ }
852
+
853
+ } // namespace sgl