halide 19.0.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. halide/__init__.py +39 -0
  2. halide/_generator_helpers.py +835 -0
  3. halide/bin/Halide.dll +0 -0
  4. halide/bin/adams2019_retrain_cost_model.exe +0 -0
  5. halide/bin/adams2019_weightsdir_to_weightsfile.exe +0 -0
  6. halide/bin/anderson2021_retrain_cost_model.exe +0 -0
  7. halide/bin/anderson2021_weightsdir_to_weightsfile.exe +0 -0
  8. halide/bin/featurization_to_sample.exe +0 -0
  9. halide/bin/gengen.exe +0 -0
  10. halide/bin/get_host_target.exe +0 -0
  11. halide/halide_.cp310-win_amd64.pyd +0 -0
  12. halide/imageio.py +60 -0
  13. halide/include/Halide.h +35293 -0
  14. halide/include/HalideBuffer.h +2618 -0
  15. halide/include/HalidePyTorchCudaHelpers.h +64 -0
  16. halide/include/HalidePyTorchHelpers.h +120 -0
  17. halide/include/HalideRuntime.h +2221 -0
  18. halide/include/HalideRuntimeCuda.h +89 -0
  19. halide/include/HalideRuntimeD3D12Compute.h +91 -0
  20. halide/include/HalideRuntimeHexagonDma.h +104 -0
  21. halide/include/HalideRuntimeHexagonHost.h +157 -0
  22. halide/include/HalideRuntimeMetal.h +112 -0
  23. halide/include/HalideRuntimeOpenCL.h +119 -0
  24. halide/include/HalideRuntimeQurt.h +32 -0
  25. halide/include/HalideRuntimeVulkan.h +137 -0
  26. halide/include/HalideRuntimeWebGPU.h +44 -0
  27. halide/lib/Halide.lib +0 -0
  28. halide/lib/HalidePyStubs.lib +0 -0
  29. halide/lib/Halide_GenGen.lib +0 -0
  30. halide/lib/autoschedule_adams2019.dll +0 -0
  31. halide/lib/autoschedule_anderson2021.dll +0 -0
  32. halide/lib/autoschedule_li2018.dll +0 -0
  33. halide/lib/autoschedule_mullapudi2016.dll +0 -0
  34. halide/lib/cmake/Halide/FindHalide_LLVM.cmake +152 -0
  35. halide/lib/cmake/Halide/FindV8.cmake +33 -0
  36. halide/lib/cmake/Halide/Halide-shared-deps.cmake +0 -0
  37. halide/lib/cmake/Halide/Halide-shared-targets-release.cmake +29 -0
  38. halide/lib/cmake/Halide/Halide-shared-targets.cmake +154 -0
  39. halide/lib/cmake/Halide/HalideConfig.cmake +162 -0
  40. halide/lib/cmake/Halide/HalideConfigVersion.cmake +65 -0
  41. halide/lib/cmake/HalideHelpers/FindHalide_WebGPU.cmake +27 -0
  42. halide/lib/cmake/HalideHelpers/Halide-Interfaces-release.cmake +112 -0
  43. halide/lib/cmake/HalideHelpers/Halide-Interfaces.cmake +236 -0
  44. halide/lib/cmake/HalideHelpers/HalideGeneratorHelpers.cmake +1056 -0
  45. halide/lib/cmake/HalideHelpers/HalideHelpersConfig.cmake +28 -0
  46. halide/lib/cmake/HalideHelpers/HalideHelpersConfigVersion.cmake +54 -0
  47. halide/lib/cmake/HalideHelpers/HalideTargetHelpers.cmake +99 -0
  48. halide/lib/cmake/HalideHelpers/MutexCopy.ps1 +31 -0
  49. halide/lib/cmake/HalideHelpers/TargetExportScript.cmake +55 -0
  50. halide/lib/cmake/Halide_Python/Halide_Python-targets-release.cmake +29 -0
  51. halide/lib/cmake/Halide_Python/Halide_Python-targets.cmake +125 -0
  52. halide/lib/cmake/Halide_Python/Halide_PythonConfig.cmake +26 -0
  53. halide/lib/cmake/Halide_Python/Halide_PythonConfigVersion.cmake +65 -0
  54. halide/share/doc/Halide/LICENSE.txt +233 -0
  55. halide/share/doc/Halide/README.md +439 -0
  56. halide/share/doc/Halide/doc/BuildingHalideWithCMake.md +626 -0
  57. halide/share/doc/Halide/doc/CodeStyleCMake.md +393 -0
  58. halide/share/doc/Halide/doc/FuzzTesting.md +104 -0
  59. halide/share/doc/Halide/doc/HalideCMakePackage.md +812 -0
  60. halide/share/doc/Halide/doc/Hexagon.md +73 -0
  61. halide/share/doc/Halide/doc/Python.md +844 -0
  62. halide/share/doc/Halide/doc/RunGen.md +283 -0
  63. halide/share/doc/Halide/doc/Testing.md +125 -0
  64. halide/share/doc/Halide/doc/Vulkan.md +287 -0
  65. halide/share/doc/Halide/doc/WebAssembly.md +228 -0
  66. halide/share/doc/Halide/doc/WebGPU.md +128 -0
  67. halide/share/tools/RunGen.h +1470 -0
  68. halide/share/tools/RunGenMain.cpp +642 -0
  69. halide/share/tools/adams2019_autotune_loop.sh +227 -0
  70. halide/share/tools/anderson2021_autotune_loop.sh +591 -0
  71. halide/share/tools/halide_benchmark.h +240 -0
  72. halide/share/tools/halide_image.h +31 -0
  73. halide/share/tools/halide_image_info.h +318 -0
  74. halide/share/tools/halide_image_io.h +2794 -0
  75. halide/share/tools/halide_malloc_trace.h +102 -0
  76. halide/share/tools/halide_thread_pool.h +161 -0
  77. halide/share/tools/halide_trace_config.h +559 -0
  78. halide-19.0.0.data/data/share/cmake/Halide/HalideConfig.cmake +6 -0
  79. halide-19.0.0.data/data/share/cmake/Halide/HalideConfigVersion.cmake +65 -0
  80. halide-19.0.0.data/data/share/cmake/HalideHelpers/HalideHelpersConfig.cmake +6 -0
  81. halide-19.0.0.data/data/share/cmake/HalideHelpers/HalideHelpersConfigVersion.cmake +54 -0
  82. halide-19.0.0.dist-info/METADATA +301 -0
  83. halide-19.0.0.dist-info/RECORD +85 -0
  84. halide-19.0.0.dist-info/WHEEL +5 -0
  85. halide-19.0.0.dist-info/licenses/LICENSE.txt +233 -0
@@ -0,0 +1,1470 @@
1
+ #include "HalideBuffer.h"
2
+ #include "HalideRuntime.h"
3
+ #include "halide_benchmark.h"
4
+ #include "halide_image_io.h"
5
+
6
+ #include <cstdio>
7
+ #include <cstdlib>
8
+ #include <iomanip>
9
+ #include <iostream>
10
+ #include <map>
11
+ #include <mutex>
12
+ #include <random>
13
+ #include <set>
14
+ #include <sstream>
15
+ #include <string>
16
+ #include <utility>
17
+
18
+ #include <vector>
19
+
20
+ namespace Halide {
21
+ namespace RunGen {
22
+
23
+ using ::Halide::Runtime::Buffer;
24
+
25
+ // Buffer<> uses "shape" to mean "array of halide_dimension_t", but doesn't
26
+ // provide a typedef for it (and doesn't use a vector for it in any event).
27
+ using Shape = std::vector<halide_dimension_t>;
28
+
29
+ // A ShapePromise is a function that returns a Shape. If the Promise can't
30
+ // return a valid Shape, it may fail. This allows us to defer error reporting
31
+ // for situations until the Shape is actually needed; in particular, it allows
32
+ // us to attempt doing bounds-query for the shape of input buffers early,
33
+ // but to ignore the error unless we actually need it... which we won't if an
34
+ // estimate is provided for the input in question.
35
+ using ShapePromise = std::function<Shape()>;
36
+
37
+ // Standard stream output for halide_type_t
38
+ inline std::ostream &operator<<(std::ostream &stream, const halide_type_t &type) {
39
+ if (type.code == halide_type_uint && type.bits == 1) {
40
+ stream << "bool";
41
+ } else {
42
+ assert(type.code >= 0 && type.code <= 3);
43
+ static const char *const names[4] = {"int", "uint", "float", "handle"};
44
+ stream << names[type.code] << (int)type.bits;
45
+ }
46
+ if (type.lanes > 1) {
47
+ stream << "x" << (int)type.lanes;
48
+ }
49
+ return stream;
50
+ }
51
+
52
+ // Standard stream output for halide_dimension_t
53
+ inline std::ostream &operator<<(std::ostream &stream, const halide_dimension_t &d) {
54
+ stream << "[" << d.min << "," << d.extent << "," << d.stride << "]";
55
+ return stream;
56
+ }
57
+
58
+ // Standard stream output for vector<halide_dimension_t>
59
+ inline std::ostream &operator<<(std::ostream &stream, const Shape &shape) {
60
+ stream << "[";
61
+ bool need_comma = false;
62
+ for (auto &d : shape) {
63
+ if (need_comma) {
64
+ stream << ",";
65
+ }
66
+ stream << d;
67
+ need_comma = true;
68
+ }
69
+ stream << "]";
70
+ return stream;
71
+ }
72
+
73
+ // Bottleneck all our logging so that client code can override any/all of them.
74
+ struct Logger {
75
+ using LogFn = std::function<void(const std::string &)>;
76
+
77
+ const LogFn out, info, warn, fail;
78
+
79
+ Logger()
80
+ : out(log_out), info(log_cerr), warn(log_cerr), fail(log_fail) {
81
+ }
82
+ Logger(LogFn o, LogFn i, LogFn w, LogFn f)
83
+ : out(std::move(o)), info(std::move(i)), warn(std::move(w)), fail(std::move(f)) {
84
+ }
85
+
86
+ private:
87
+ static void log_out(const std::string &s) {
88
+ std::cout << s;
89
+ }
90
+
91
+ static void log_cerr(const std::string &s) {
92
+ std::cerr << s;
93
+ }
94
+
95
+ static void log_fail(const std::string &s) {
96
+ log_cerr(s);
97
+ abort();
98
+ }
99
+ };
100
+
101
+ // Client code must provide a definition of Halide::Runtime::log();
102
+ // it is sufficient to merely return a default Logger instance.
103
+ extern Logger log();
104
+
105
+ // Gather up all output in a stringstream, emit in the dtor
106
+ struct LogEmitter {
107
+ template<typename T>
108
+ LogEmitter &operator<<(const T &x) {
109
+ msg << x;
110
+ return *this;
111
+ }
112
+
113
+ ~LogEmitter() {
114
+ std::string s = msg.str();
115
+ if (s.back() != '\n') {
116
+ s += '\n';
117
+ }
118
+ f(s);
119
+ }
120
+
121
+ protected:
122
+ explicit LogEmitter(Logger::LogFn f)
123
+ : f(std::move(f)) {
124
+ }
125
+
126
+ private:
127
+ const Logger::LogFn f;
128
+ std::ostringstream msg;
129
+ };
130
+
131
+ // Emit ordinary non-error output that should never be suppressed (ie, stdout)
132
+ struct out : LogEmitter {
133
+ out()
134
+ : LogEmitter(log().out) {
135
+ }
136
+ };
137
+
138
+ // Log detailed informational output
139
+ struct info : LogEmitter {
140
+ info()
141
+ : LogEmitter(log().info) {
142
+ }
143
+ };
144
+
145
+ // Log warnings
146
+ struct warn : LogEmitter {
147
+ warn()
148
+ : LogEmitter(log().warn) {
149
+ }
150
+ };
151
+
152
+ // Log unrecoverable errors, then abort
153
+ struct fail : LogEmitter {
154
+ fail()
155
+ : LogEmitter(log().fail) {
156
+ }
157
+ };
158
+
159
+ // Replace the failure handlers from halide_image_io to fail()
160
+ inline bool IOCheckFail(bool condition, const char *msg) {
161
+ if (!condition) {
162
+ fail() << "Error in I/O: " << msg;
163
+ }
164
+ return condition;
165
+ }
166
+
167
+ inline std::vector<std::string> split_string(const std::string &source,
168
+ const std::string &delim) {
169
+ std::vector<std::string> elements;
170
+ size_t start = 0;
171
+ size_t found = 0;
172
+ while ((found = source.find(delim, start)) != std::string::npos) {
173
+ elements.push_back(source.substr(start, found - start));
174
+ start = found + delim.size();
175
+ }
176
+
177
+ // If start is exactly source.size(), the last thing in source is a
178
+ // delimiter, in which case we want to add an empty std::string to elements.
179
+ if (start <= source.size()) {
180
+ elements.push_back(source.substr(start, std::string::npos));
181
+ }
182
+ return elements;
183
+ }
184
+
185
+ // dynamic_type_dispatch is a utility for functors that want to be able
186
+ // to dynamically dispatch a halide_type_t to type-specialized code.
187
+ // To use it, a functor must be a *templated* class, e.g.
188
+ //
189
+ // template<typename T> class MyFunctor { int operator()(arg1, arg2...); };
190
+ //
191
+ // dynamic_type_dispatch() is called with a halide_type_t as the first argument,
192
+ // followed by the arguments to the Functor's operator():
193
+ //
194
+ // auto result = dynamic_type_dispatch<MyFunctor>(some_halide_type, arg1, arg2);
195
+ //
196
+ // Note that this means that the functor must be able to instantiate its
197
+ // operator() for all the Halide scalar types; it also means that all those
198
+ // variants *will* be instantiated (increasing code size), so this approach
199
+ // should only be used when strictly necessary.
200
+ template<template<typename> class Functor, typename... Args>
201
+ auto dynamic_type_dispatch(const halide_type_t &type, Args &&...args) -> decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...)) {
202
+
203
+ #define HANDLE_CASE(CODE, BITS, TYPE) \
204
+ case halide_type_t(CODE, BITS).as_u32(): \
205
+ return Functor<TYPE>()(std::forward<Args>(args)...);
206
+
207
+ switch (type.element_of().as_u32()) {
208
+ HANDLE_CASE(halide_type_float, 32, float)
209
+ HANDLE_CASE(halide_type_float, 64, double)
210
+ HANDLE_CASE(halide_type_int, 8, int8_t)
211
+ HANDLE_CASE(halide_type_int, 16, int16_t)
212
+ HANDLE_CASE(halide_type_int, 32, int32_t)
213
+ HANDLE_CASE(halide_type_int, 64, int64_t)
214
+ HANDLE_CASE(halide_type_uint, 1, bool)
215
+ HANDLE_CASE(halide_type_uint, 8, uint8_t)
216
+ HANDLE_CASE(halide_type_uint, 16, uint16_t)
217
+ HANDLE_CASE(halide_type_uint, 32, uint32_t)
218
+ HANDLE_CASE(halide_type_uint, 64, uint64_t)
219
+ HANDLE_CASE(halide_type_handle, 64, void *)
220
+ default:
221
+ fail() << "Unsupported type: " << type << "\n";
222
+ using ReturnType = decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...));
223
+ return ReturnType();
224
+ }
225
+
226
+ #undef HANDLE_CASE
227
+ }
228
+
229
+ // Functor to parse a string into one of the known Halide scalar types.
230
+ template<typename T>
231
+ struct ScalarParser {
232
+ bool operator()(const std::string &str, halide_scalar_value_t *v) {
233
+ std::istringstream iss(str);
234
+ // std::setbase(0) means "infer base from input", and allows hex and octal constants
235
+ iss >> std::setbase(0) >> *(T *)v;
236
+ return !iss.fail() && iss.get() == EOF;
237
+ }
238
+ };
239
+
240
+ // Override for int8 and uint8, to avoid parsing as char variants
241
+ template<>
242
+ inline bool ScalarParser<int8_t>::operator()(const std::string &str, halide_scalar_value_t *v) {
243
+ std::istringstream iss(str);
244
+ int i;
245
+ iss >> std::setbase(0) >> i;
246
+ if (!(!iss.fail() && iss.get() == EOF) || i < -128 || i > 127) {
247
+ return false;
248
+ }
249
+ v->u.i8 = (int8_t)i;
250
+ return true;
251
+ }
252
+
253
+ template<>
254
+ inline bool ScalarParser<uint8_t>::operator()(const std::string &str, halide_scalar_value_t *v) {
255
+ std::istringstream iss(str);
256
+ unsigned int u;
257
+ iss >> std::setbase(0) >> u;
258
+ if (!(!iss.fail() && iss.get() == EOF) || u > 255) {
259
+ return false;
260
+ }
261
+ v->u.u8 = (uint8_t)u;
262
+ return true;
263
+ }
264
+
265
+ // Override for bool, since istream just expects '1' or '0'.
266
+ template<>
267
+ inline bool ScalarParser<bool>::operator()(const std::string &str, halide_scalar_value_t *v) {
268
+ if (str == "true") {
269
+ v->u.b = true;
270
+ return true;
271
+ }
272
+ if (str == "false") {
273
+ v->u.b = false;
274
+ return true;
275
+ }
276
+ return false;
277
+ }
278
+
279
+ // Override for handle, since we only accept "nullptr".
280
+ template<>
281
+ inline bool ScalarParser<void *>::operator()(const std::string &str, halide_scalar_value_t *v) {
282
+ if (str == "nullptr") {
283
+ v->u.handle = nullptr;
284
+ return true;
285
+ }
286
+ return false;
287
+ }
288
+
289
+ // Parse a scalar when we know the corresponding C++ type at compile time.
290
+ template<typename T>
291
+ inline bool parse_scalar(const std::string &str, T *scalar) {
292
+ return ScalarParser<T>()(str, (halide_scalar_value_t *)scalar);
293
+ }
294
+
295
+ // Dynamic-dispatch wrapper around ScalarParser.
296
+ inline bool parse_scalar(const halide_type_t &type,
297
+ const std::string &str,
298
+ halide_scalar_value_t *scalar) {
299
+ return dynamic_type_dispatch<ScalarParser>(type, str, scalar);
300
+ }
301
+
302
+ // Parse an extent list, which should be of the form
303
+ //
304
+ // [extent0, extent1...]
305
+ //
306
+ // Return a vector<halide_dimension_t> (aka a "shape") with the extents filled in,
307
+ // but with the min of each dimension set to zero and the stride set to the
308
+ // planar-default value.
309
+ inline Shape parse_extents(const std::string &extent_list) {
310
+ if (extent_list.empty() || extent_list[0] != '[' || extent_list.back() != ']') {
311
+ fail() << "Invalid format for extents: " << extent_list;
312
+ }
313
+ Shape result;
314
+ if (extent_list == "[]") {
315
+ return result;
316
+ }
317
+ std::vector<std::string> extents = split_string(extent_list.substr(1, extent_list.size() - 2), ",");
318
+ for (size_t i = 0; i < extents.size(); i++) {
319
+ const std::string &s = extents[i];
320
+ const int stride = (i == 0) ? 1 : result[i - 1].stride * result[i - 1].extent;
321
+ halide_dimension_t d = {0, 0, stride};
322
+ if (!parse_scalar(s, &d.extent)) {
323
+ fail() << "Invalid value for extents: " << s << " (" << extent_list << ")";
324
+ }
325
+ result.push_back(d);
326
+ }
327
+ return result;
328
+ }
329
+
330
+ // Parse the buffer_estimate list from a given argument's metadata into a Shape.
331
+ // If no valid buffer_estimate exists, return false.
332
+ inline bool try_parse_metadata_buffer_estimates(const halide_filter_argument_t *md, Shape *shape) {
333
+ if (!md->buffer_estimates) {
334
+ // zero-dimensional buffers don't have (or need) estimates, so don't fail.
335
+ if (md->dimensions == 0) {
336
+ *shape = Shape();
337
+ return true;
338
+ }
339
+ return false;
340
+ }
341
+ Shape result(md->dimensions);
342
+ int32_t stride = 1;
343
+ for (int i = 0; i < md->dimensions; i++) {
344
+ const int64_t *min = md->buffer_estimates[i * 2];
345
+ const int64_t *extent = md->buffer_estimates[i * 2 + 1];
346
+ if (!min || !extent) {
347
+ return false;
348
+ }
349
+ result[i] = halide_dimension_t{(int32_t)*min, (int32_t)*extent, stride};
350
+ stride *= result[i].extent;
351
+ }
352
+ *shape = result;
353
+ return true;
354
+ };
355
+
356
+ // Parse the buffer_estimate list from a given argument's metadata into a Shape.
357
+ // If no valid buffer_estimate exists, fail.
358
+ inline Shape parse_metadata_buffer_estimates(const halide_filter_argument_t *md) {
359
+ Shape shape;
360
+ if (!try_parse_metadata_buffer_estimates(md, &shape)) {
361
+ fail() << "Argument " << md->name << " was specified as 'estimate', but no valid estimates were provided.";
362
+ }
363
+ return shape;
364
+ };
365
+
366
+ // Given a Buffer<>, return its shape in the form of a vector<halide_dimension_t>.
367
+ // (Oddly, Buffer<> has no API to do this directly.)
368
+ inline Shape get_shape(const Buffer<> &b) {
369
+ Shape s;
370
+ for (int i = 0; i < b.dimensions(); ++i) {
371
+ s.push_back(b.raw_buffer()->dim[i]);
372
+ }
373
+ return s;
374
+ }
375
+
376
+ // Given a type and shape, create a new Buffer<> but *don't* allocate allocate storage for it.
377
+ inline Buffer<> make_with_shape(const halide_type_t &type, const Shape &shape) {
378
+ return Buffer<>(type, nullptr, (int)shape.size(), &shape[0]);
379
+ }
380
+
381
+ // Given a type and shape, create a new Buffer<> and allocate storage for it.
382
+ // (Oddly, Buffer<> has an API to do this with vector-of-extent, but not vector-of-halide_dimension_t.)
383
+ inline Buffer<> allocate_buffer(const halide_type_t &type, const Shape &shape) {
384
+ Buffer<> b = make_with_shape(type, shape);
385
+ if (b.number_of_elements() > 0) {
386
+ b.check_overflow();
387
+ b.allocate();
388
+ b.set_host_dirty();
389
+ }
390
+ return b;
391
+ }
392
+
393
+ inline Shape choose_output_extents(int dimensions, const Shape &defaults) {
394
+ Shape s(dimensions);
395
+ for (int i = 0; i < dimensions; ++i) {
396
+ if ((size_t)i < defaults.size()) {
397
+ s[i] = defaults[i];
398
+ } else {
399
+ // If the defaults don't provide enough dimensions, make a guess.
400
+ s[i].extent = (i < 2 ? 1000 : 4);
401
+ }
402
+ }
403
+ return s;
404
+ }
405
+
406
+ inline void fix_chunky_strides(const Shape &constrained_shape, Shape *new_shape) {
407
+ // Special-case Chunky: most "chunky" generators tend to constrain stride[0]
408
+ // and stride[2] to exact values, leaving stride[1] unconstrained;
409
+ // in practice, we must ensure that stride[1] == stride[0] * extent[0]
410
+ // and stride[0] = extent[2] to get results that are not garbled.
411
+ // This is unpleasantly hacky and will likely need aditional enhancements.
412
+ // (Note that there are, theoretically, other stride combinations that might
413
+ // need fixing; in practice, ~all generators that aren't planar tend
414
+ // to be classically chunky.)
415
+ if (new_shape->size() >= 3 &&
416
+ (*new_shape)[0].extent > 1 &&
417
+ (*new_shape)[1].extent > 1) {
418
+ if (constrained_shape[2].stride == 1) {
419
+ if (constrained_shape[0].stride >= 1) {
420
+ // If we have stride[0] and stride[2] set to obviously-chunky,
421
+ // then force extent[2] to match stride[0].
422
+ (*new_shape)[2].extent = constrained_shape[0].stride;
423
+ } else {
424
+ // If we have stride[2] == 1 but stride[0] < 1,
425
+ // force stride[0] = extent[2]
426
+ (*new_shape)[0].stride = (*new_shape)[2].extent;
427
+ }
428
+ // Ensure stride[1] is reasonable.
429
+ (*new_shape)[1].stride = (*new_shape)[0].extent * (*new_shape)[0].stride;
430
+ }
431
+ }
432
+ }
433
+
434
+ // Return true iff all of the dimensions in the range [first, last] have an extent of <= 1.
435
+ inline bool dims_in_range_are_trivial(const Buffer<> &b, int first, int last) {
436
+ for (int d = first; d <= last; ++d) {
437
+ if (b.dim(d).extent() > 1) {
438
+ return false;
439
+ }
440
+ }
441
+ return true;
442
+ }
443
+
444
+ // Add or subtract dimensions to the given buffer to match dims_needed,
445
+ // emitting warnings if we do so.
446
+ inline Buffer<> adjust_buffer_dims(const std::string &title, const std::string &name,
447
+ const int dims_needed, Buffer<> b) {
448
+ const int dims_actual = b.dimensions();
449
+ if (dims_actual > dims_needed) {
450
+ // Warn that we are ignoring dimensions, but only if at least one of the
451
+ // ignored dimensions has extent > 1
452
+ if (!dims_in_range_are_trivial(b, dims_needed, dims_actual - 1)) {
453
+ warn() << "Image for " << title << " \"" << name << "\" has "
454
+ << dims_actual << " dimensions, but only the first "
455
+ << dims_needed << " were used; data loss may have occurred.";
456
+ }
457
+ auto old_shape = get_shape(b);
458
+ while (b.dimensions() > dims_needed) {
459
+ b = b.sliced(dims_needed);
460
+ }
461
+ info() << "Shape for " << name << " changed: " << old_shape << " -> " << get_shape(b);
462
+ } else if (dims_actual < dims_needed) {
463
+ warn() << "Image for " << title << " \"" << name << "\" has "
464
+ << dims_actual << " dimensions, but this argument requires at least "
465
+ << dims_needed << " dimensions: adding dummy dimensions of extent 1.";
466
+ auto old_shape = get_shape(b);
467
+ while (b.dimensions() < dims_needed) {
468
+ b = b.embedded(b.dimensions(), 0);
469
+ }
470
+ info() << "Shape for " << name << " changed: " << old_shape << " -> " << get_shape(b);
471
+ }
472
+ return b;
473
+ }
474
+
475
+ // Load a buffer from a pathname, adjusting the type and dimensions to
476
+ // fit the metadata's requirements as needed.
477
+ inline Buffer<> load_input_from_file(const std::string &pathname,
478
+ const halide_filter_argument_t &metadata) {
479
+ Buffer<> b = Buffer<>(metadata.type, 0);
480
+ info() << "Loading input " << metadata.name << " from " << pathname << " ...";
481
+ if (!Halide::Tools::load<Buffer<>, IOCheckFail>(pathname, &b)) {
482
+ fail() << "Unable to load input: " << pathname;
483
+ }
484
+ if (b.dimensions() != metadata.dimensions) {
485
+ b = adjust_buffer_dims("Input", metadata.name, metadata.dimensions, b);
486
+ }
487
+ if (b.type() != metadata.type) {
488
+ warn() << "Image loaded for argument \"" << metadata.name << "\" is type "
489
+ << b.type() << " but this argument expects type "
490
+ << metadata.type << "; data loss may have occurred.";
491
+ b = Halide::Tools::ImageTypeConversion::convert_image(b, metadata.type);
492
+ }
493
+ return b;
494
+ }
495
+
496
+ template<typename T>
497
+ struct FillWithRandom {
498
+ public:
499
+ void operator()(Buffer<> &b_dynamic, int seed) {
500
+ Buffer<T> b = b_dynamic;
501
+ std::mt19937 rng(seed);
502
+ fill(b, rng);
503
+ }
504
+
505
+ private:
506
+ template<typename T2 = T,
507
+ typename std::enable_if<std::is_integral<T2>::value && !std::is_same<T2, bool>::value && !std::is_same<T2, char>::value && !std::is_same<T2, signed char>::value && !std::is_same<T2, unsigned char>::value>::type * = nullptr>
508
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
509
+ std::uniform_int_distribution<T2> dis;
510
+ b.for_each_value([&rng, &dis](T2 &value) {
511
+ value = dis(rng);
512
+ });
513
+ }
514
+
515
+ template<typename T2 = T, typename std::enable_if<std::is_floating_point<T2>::value>::type * = nullptr>
516
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
517
+ std::uniform_real_distribution<T2> dis(0.0, 1.0);
518
+ b.for_each_value([&rng, &dis](T2 &value) {
519
+ value = dis(rng);
520
+ });
521
+ }
522
+
523
+ template<typename T2 = T, typename std::enable_if<std::is_same<T2, bool>::value>::type * = nullptr>
524
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
525
+ std::uniform_int_distribution<int> dis(0, 1);
526
+ b.for_each_value([&rng, &dis](T2 &value) {
527
+ value = static_cast<T2>(dis(rng));
528
+ });
529
+ }
530
+
531
+ // std::uniform_int_distribution<char> is UB in C++11,
532
+ // so special-case to avoid compiler variation
533
+ template<typename T2 = T, typename std::enable_if<std::is_same<T2, char>::value>::type * = nullptr>
534
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
535
+ std::uniform_int_distribution<int> dis(-128, 127);
536
+ b.for_each_value([&rng, &dis](T2 &value) {
537
+ value = static_cast<T2>(dis(rng));
538
+ });
539
+ }
540
+
541
+ // std::uniform_int_distribution<signed char> is UB in C++11,
542
+ // so special-case to avoid compiler variation
543
+ template<typename T2 = T, typename std::enable_if<std::is_same<T2, signed char>::value>::type * = nullptr>
544
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
545
+ std::uniform_int_distribution<int> dis(-128, 127);
546
+ b.for_each_value([&rng, &dis](T2 &value) {
547
+ value = static_cast<T2>(dis(rng));
548
+ });
549
+ }
550
+
551
+ // std::uniform_int_distribution<unsigned char> is UB in C++11,
552
+ // so special-case to avoid compiler variation
553
+ template<typename T2 = T, typename std::enable_if<std::is_same<T2, unsigned char>::value>::type * = nullptr>
554
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
555
+ std::uniform_int_distribution<int> dis(0, 255);
556
+ b.for_each_value([&rng, &dis](T2 &value) {
557
+ value = static_cast<T2>(dis(rng));
558
+ });
559
+ }
560
+
561
+ template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>::type * = nullptr>
562
+ void fill(Buffer<T2> &b, std::mt19937 &rng) {
563
+ std::uniform_int_distribution<intptr_t> dis;
564
+ b.for_each_value([&rng, &dis](T2 &value) {
565
+ value = reinterpret_cast<T2>(dis(rng));
566
+ });
567
+ }
568
+ };
569
+
570
+ template<typename T>
571
+ struct FillWithScalar {
572
+ public:
573
+ void operator()(Buffer<> &b_dynamic, const halide_scalar_value_t &value) {
574
+ Buffer<T> b = b_dynamic;
575
+ b.fill(as_T(value));
576
+ }
577
+
578
+ private:
579
+ // Segregate into pointer and non-pointer clauses to avoid compiler warnings
580
+ // about casting from (e.g.) int8 to void*
581
+ template<typename T2 = T, typename std::enable_if<!std::is_pointer<T2>::value>::type * = nullptr>
582
+ T as_T(const halide_scalar_value_t &value) {
583
+ constexpr halide_type_t type = halide_type_of<T>();
584
+ switch (type.element_of().as_u32()) {
585
+ case halide_type_t(halide_type_int, 8).as_u32():
586
+ return (T)value.u.i8;
587
+ case halide_type_t(halide_type_int, 16).as_u32():
588
+ return (T)value.u.i16;
589
+ case halide_type_t(halide_type_int, 32).as_u32():
590
+ return (T)value.u.i32;
591
+ case halide_type_t(halide_type_int, 64).as_u32():
592
+ return (T)value.u.i64;
593
+ case halide_type_t(halide_type_uint, 1).as_u32():
594
+ return (T)value.u.b;
595
+ case halide_type_t(halide_type_uint, 8).as_u32():
596
+ return (T)value.u.u8;
597
+ case halide_type_t(halide_type_uint, 16).as_u32():
598
+ return (T)value.u.u16;
599
+ case halide_type_t(halide_type_uint, 32).as_u32():
600
+ return (T)value.u.u32;
601
+ case halide_type_t(halide_type_uint, 64).as_u32():
602
+ return (T)value.u.u64;
603
+ case halide_type_t(halide_type_float, 32).as_u32():
604
+ return (T)value.u.f32;
605
+ case halide_type_t(halide_type_float, 64).as_u32():
606
+ return (T)value.u.f64;
607
+ default:
608
+ fail() << "Can't convert value with type: " << (int)type.code << "bits: " << type.bits;
609
+ return (T)0;
610
+ }
611
+ }
612
+
613
+ template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>::type * = nullptr>
614
+ T as_T(const halide_scalar_value_t &value) {
615
+ constexpr halide_type_t type = halide_type_of<T>();
616
+ switch (type.element_of().as_u32()) {
617
+ case halide_type_t(halide_type_handle, 64).as_u32():
618
+ return (T)value.u.handle;
619
+ default:
620
+ fail() << "Can't convert value with type: " << (int)type.code << "bits: " << type.bits;
621
+ return (T)0;
622
+ }
623
+ }
624
+ };
625
+
626
+ // This logic exists in Halide::Tools, but is Internal; we're going to replicate
627
+ // it here for now since we may want slightly different logic in some cases
628
+ // for this tool.
629
+ inline Halide::Tools::FormatInfo best_save_format(const Buffer<> &b, const std::set<Halide::Tools::FormatInfo> &info) {
630
+ // Perfect score is zero (exact match).
631
+ // The larger the score, the worse the match.
632
+ int best_score = 0x7fffffff;
633
+ Halide::Tools::FormatInfo best{};
634
+ const halide_type_t type = b.type();
635
+ const int dimensions = b.dimensions();
636
+ for (auto &f : info) {
637
+ int score = 0;
638
+ // If format has too-few dimensions, that's very bad.
639
+ score += std::abs(f.dimensions - dimensions) * 128;
640
+ // If format has too-few bits, that's pretty bad.
641
+ score += std::abs(f.type.bits - type.bits);
642
+ // If format has different code, that's a little bad.
643
+ score += (f.type.code != type.code) ? 1 : 0;
644
+ if (score < best_score) {
645
+ best_score = score;
646
+ best = f;
647
+ }
648
+ }
649
+
650
+ return best;
651
+ }
652
+
653
+ inline std::string scalar_to_string(const halide_type_t &type,
654
+ const halide_scalar_value_t &value) {
655
+ std::ostringstream o;
656
+ switch (type.element_of().as_u32()) {
657
+ case halide_type_t(halide_type_float, 32).as_u32():
658
+ o << value.u.f32;
659
+ break;
660
+ case halide_type_t(halide_type_float, 64).as_u32():
661
+ o << value.u.f64;
662
+ break;
663
+ case halide_type_t(halide_type_int, 8).as_u32():
664
+ o << (int)value.u.i8;
665
+ break;
666
+ case halide_type_t(halide_type_int, 16).as_u32():
667
+ o << value.u.i16;
668
+ break;
669
+ case halide_type_t(halide_type_int, 32).as_u32():
670
+ o << value.u.i32;
671
+ break;
672
+ case halide_type_t(halide_type_int, 64).as_u32():
673
+ o << value.u.i64;
674
+ break;
675
+ case halide_type_t(halide_type_uint, 1).as_u32():
676
+ o << (value.u.b ? "true" : "false");
677
+ break;
678
+ case halide_type_t(halide_type_uint, 8).as_u32():
679
+ o << (int)value.u.u8;
680
+ break;
681
+ case halide_type_t(halide_type_uint, 16).as_u32():
682
+ o << value.u.u16;
683
+ break;
684
+ case halide_type_t(halide_type_uint, 32).as_u32():
685
+ o << value.u.u32;
686
+ break;
687
+ case halide_type_t(halide_type_uint, 64).as_u32():
688
+ o << value.u.u64;
689
+ break;
690
+ case halide_type_t(halide_type_handle, 64).as_u32():
691
+ o << (uint64_t)value.u.handle;
692
+ break;
693
+ default:
694
+ fail() << "Unsupported type: " << type << "\n";
695
+ break;
696
+ }
697
+ return o.str();
698
+ }
699
+
700
+ struct ArgData {
701
+ size_t index{0};
702
+ std::string name;
703
+ const halide_filter_argument_t *metadata{nullptr};
704
+ std::string raw_string;
705
+ halide_scalar_value_t scalar_value;
706
+ Buffer<> buffer_value;
707
+
708
+ ArgData() = default;
709
+
710
+ ArgData(size_t index, const std::string &name, const halide_filter_argument_t *metadata)
711
+ : index(index), name(name), metadata(metadata) {
712
+ }
713
+
714
+ Buffer<> load_buffer(ShapePromise shape_promise, const halide_filter_argument_t *argument_metadata) {
715
+ const auto parse_optional_extents = [&](const std::string &s) -> Shape {
716
+ if (s == "auto") {
717
+ return shape_promise();
718
+ }
719
+ if (s == "estimate") {
720
+ return parse_metadata_buffer_estimates(argument_metadata);
721
+ }
722
+ if (s == "estimate_then_auto") {
723
+ Shape shape;
724
+ if (!try_parse_metadata_buffer_estimates(argument_metadata, &shape)) {
725
+ info() << "Input " << argument_metadata->name << " has no estimates; using bounds-query result instead.";
726
+ shape = shape_promise();
727
+ }
728
+ return shape;
729
+ }
730
+ return parse_extents(s);
731
+ };
732
+
733
+ std::vector<std::string> v = split_string(raw_string, ":");
734
+ if (v[0] == "zero") {
735
+ if (v.size() != 2) {
736
+ fail() << "Invalid syntax: " << raw_string;
737
+ }
738
+ auto shape = parse_optional_extents(v[1]);
739
+ Buffer<> b = allocate_buffer(metadata->type, shape);
740
+ memset(b.data(), 0, b.size_in_bytes());
741
+ return b;
742
+ } else if (v[0] == "constant") {
743
+ if (v.size() != 3) {
744
+ fail() << "Invalid syntax: " << raw_string;
745
+ }
746
+ halide_scalar_value_t value;
747
+ if (!parse_scalar(metadata->type, v[1], &value)) {
748
+ fail() << "Invalid value for constant value";
749
+ }
750
+ auto shape = parse_optional_extents(v[2]);
751
+ Buffer<> b = allocate_buffer(metadata->type, shape);
752
+ dynamic_type_dispatch<FillWithScalar>(metadata->type, b, value);
753
+ return b;
754
+ } else if (v[0] == "identity") {
755
+ if (v.size() != 2) {
756
+ fail() << "Invalid syntax: " << raw_string;
757
+ }
758
+ auto shape = parse_optional_extents(v[1]);
759
+ // Make a binary buffer with diagonal elements set to true. Diagonal
760
+ // elements are those whose first two dimensions are equal.
761
+ Buffer<bool> b = allocate_buffer(halide_type_of<bool>(), shape);
762
+ b.for_each_element([&b](const int *pos) {
763
+ b(pos) = (b.dimensions() >= 2) ? (pos[0] == pos[1]) : (pos[0] == 0);
764
+ });
765
+ // Convert the binary buffer to the required type, so true becomes 1.
766
+ return Halide::Tools::ImageTypeConversion::convert_image(b, metadata->type);
767
+ } else if (v[0] == "random") {
768
+ if (v.size() != 3) {
769
+ fail() << "Invalid syntax: " << raw_string;
770
+ }
771
+ int seed;
772
+ if (!parse_scalar(v[1], &seed)) {
773
+ fail() << "Invalid value for seed";
774
+ }
775
+ auto shape = parse_optional_extents(v[2]);
776
+ Buffer<> b = allocate_buffer(metadata->type, shape);
777
+ dynamic_type_dispatch<FillWithRandom>(metadata->type, b, seed);
778
+ return b;
779
+ } else {
780
+ return load_input_from_file(v[0], *metadata);
781
+ }
782
+ }
783
+
784
+ Buffer<> load_buffer(const Shape &shape, const halide_filter_argument_t *argument_metadata) {
785
+ ShapePromise promise = [shape]() -> Shape { return shape; };
786
+ return load_buffer(promise, argument_metadata);
787
+ }
788
+
789
+ void adapt_input_buffer(const Shape &constrained_shape) {
790
+ if (metadata->kind != halide_argument_kind_input_buffer) {
791
+ return;
792
+ }
793
+
794
+ // Ensure that the input Buffer meets our constraints; if it doesn't, allcoate
795
+ // and copy into a new Buffer.
796
+ bool updated = false;
797
+ Shape new_shape = get_shape(buffer_value);
798
+ info() << "Input " << name << ": Shape is " << new_shape;
799
+ if (new_shape.size() != constrained_shape.size()) {
800
+ fail() << "Dimension mismatch; expected " << constrained_shape.size() << "dimensions";
801
+ }
802
+ for (size_t i = 0; i < constrained_shape.size(); ++i) {
803
+ // If the constrained shape is not in bounds of the
804
+ // buffer's current shape we need to use the constrained
805
+ // shape.
806
+ int current_min = new_shape[i].min;
807
+ int current_max = new_shape[i].min + new_shape[i].extent - 1;
808
+ int constrained_min = constrained_shape[i].min;
809
+ int constrained_max = constrained_shape[i].min + constrained_shape[i].extent - 1;
810
+ if (constrained_min < current_min || constrained_max > current_max) {
811
+ new_shape[i].min = constrained_shape[i].min;
812
+ new_shape[i].extent = constrained_shape[i].extent;
813
+ updated = true;
814
+ }
815
+ // stride of nonzero means "required stride", stride of zero means "no constraints"
816
+ if (constrained_shape[i].stride != 0 && new_shape[i].stride != constrained_shape[i].stride) {
817
+ new_shape[i].stride = constrained_shape[i].stride;
818
+ updated = true;
819
+ }
820
+ }
821
+ if (updated) {
822
+ fix_chunky_strides(constrained_shape, &new_shape);
823
+ Buffer<> new_buf = allocate_buffer(buffer_value.type(), new_shape);
824
+ new_buf.copy_from(buffer_value);
825
+ buffer_value = new_buf;
826
+ }
827
+
828
+ info() << "Input " << name << ": BoundsQuery result is " << constrained_shape;
829
+ if (updated) {
830
+ info() << "Input " << name << ": Updated Shape is " << get_shape(buffer_value);
831
+ }
832
+ }
833
+
834
+ void allocate_output_buffer(const Shape &constrained_shape) {
835
+ if (metadata->kind != halide_argument_kind_output_buffer) {
836
+ return;
837
+ }
838
+
839
+ // Given a constraint Shape (generally produced by a bounds query), create a new
840
+ // Shape that can legally be used to create and allocate a new Buffer:
841
+ // ensure that extents/strides aren't zero, do some reality checking
842
+ // on planar vs interleaved, and generally try to guess at a reasonable result.
843
+ Shape new_shape = constrained_shape;
844
+
845
+ // Make sure that the extents and strides for these are nonzero.
846
+ for (size_t i = 0; i < new_shape.size(); ++i) {
847
+ if (!new_shape[i].extent) {
848
+ // A bit of a hack: fill in unconstrained dimensions to 1... except
849
+ // for probably-the-channels dimension, which we'll special-case to
850
+ // fill in to 4 when possible (unless it appears to be chunky).
851
+ // Stride will be fixed below.
852
+ if (i == 2) {
853
+ if (constrained_shape[0].stride >= 1 && constrained_shape[2].stride == 1) {
854
+ // Definitely chunky, so make extent[2] match the chunk size
855
+ new_shape[i].extent = constrained_shape[0].stride;
856
+ } else {
857
+ // Not obviously chunky; let's go with 4 channels.
858
+ new_shape[i].extent = 4;
859
+ }
860
+ } else {
861
+ new_shape[i].extent = 1;
862
+ }
863
+ }
864
+ }
865
+
866
+ fix_chunky_strides(constrained_shape, &new_shape);
867
+
868
+ // If anything else is zero, just set strides to planar and hope for the best.
869
+ bool any_strides_zero = false;
870
+ for (size_t i = 0; i < new_shape.size(); ++i) {
871
+ if (!new_shape[i].stride) {
872
+ any_strides_zero = true;
873
+ }
874
+ }
875
+ if (any_strides_zero) {
876
+ // Planar
877
+ new_shape[0].stride = 1;
878
+ for (size_t i = 1; i < new_shape.size(); ++i) {
879
+ new_shape[i].stride = new_shape[i - 1].stride * new_shape[i - 1].extent;
880
+ }
881
+ }
882
+
883
+ buffer_value = allocate_buffer(metadata->type, new_shape);
884
+
885
+ // allocate_buffer conservatively sets host dirty. Don't waste
886
+ // time copying output buffers to device.
887
+ buffer_value.set_host_dirty(false);
888
+
889
+ info() << "Output " << name << ": BoundsQuery result is " << constrained_shape;
890
+ info() << "Output " << name << ": Shape is " << get_shape(buffer_value);
891
+ }
892
+ };
893
+
894
+ class RunGen {
895
+ public:
896
+ using ArgvCall = int (*)(void **);
897
+
898
+ RunGen(ArgvCall halide_argv_call,
899
+ const struct halide_filter_metadata_t *halide_metadata)
900
+ : halide_argv_call(halide_argv_call), md(halide_metadata) {
901
+ if (md->version != halide_filter_metadata_t::VERSION) {
902
+ fail() << "Unexpected metadata version " << md->version;
903
+ }
904
+ for (size_t i = 0; i < (size_t)md->num_arguments; ++i) {
905
+ std::string name = md->arguments[i].name;
906
+ if (name.size() > 2 && name[name.size() - 2] == '$' && isdigit(name[name.size() - 1])) {
907
+ // If it ends in "$3" or similar, just lop it off
908
+ name = name.substr(0, name.size() - 2);
909
+ }
910
+ ArgData arg(i, name, &md->arguments[i]);
911
+ args[name] = arg;
912
+ }
913
+ halide_set_error_handler(rungen_halide_error);
914
+ halide_set_custom_print(rungen_halide_print);
915
+ }
916
+
917
+ ArgvCall get_halide_argv_call() const {
918
+ return halide_argv_call;
919
+ }
920
+ const struct halide_filter_metadata_t *get_halide_metadata() const {
921
+ return md;
922
+ }
923
+
924
+ int argument_kind(const std::string &name) const {
925
+ auto it = args.find(name);
926
+ if (it == args.end()) {
927
+ return -1;
928
+ }
929
+ return it->second.metadata->kind;
930
+ }
931
+
932
+ void parse_one(const std::string &name,
933
+ const std::string &value,
934
+ std::set<std::string> *seen_args) {
935
+ if (value.empty()) {
936
+ fail() << "Argument value is empty for: " << name;
937
+ }
938
+ seen_args->insert(name);
939
+ auto it = args.find(name);
940
+ if (it == args.end()) {
941
+ // Don't fail, just return.
942
+ return;
943
+ }
944
+ if (!it->second.raw_string.empty()) {
945
+ fail() << "Argument value specified multiple times for: " << name;
946
+ }
947
+ it->second.raw_string = value;
948
+ }
949
+
950
+ void validate(const std::set<std::string> &seen_args,
951
+ const std::string &default_input_buffers,
952
+ const std::string &default_input_scalars,
953
+ bool ok_to_omit_outputs) {
954
+ std::ostringstream o;
955
+ for (auto &s : seen_args) {
956
+ if (args.find(s) == args.end()) {
957
+ o << "Unknown argument name: " << s << "\n";
958
+ }
959
+ }
960
+ for (auto &arg_pair : args) {
961
+ auto &arg = arg_pair.second;
962
+ if (arg.raw_string.empty()) {
963
+ if (ok_to_omit_outputs && arg.metadata->kind == halide_argument_kind_output_buffer) {
964
+ continue;
965
+ }
966
+ if (!default_input_buffers.empty() &&
967
+ arg.metadata->kind == halide_argument_kind_input_buffer) {
968
+ arg.raw_string = default_input_buffers;
969
+ info() << "Using value of '" << arg.raw_string << "' for: " << arg.metadata->name;
970
+ continue;
971
+ }
972
+ if (!default_input_scalars.empty() &&
973
+ arg.metadata->kind == halide_argument_kind_input_scalar) {
974
+ arg.raw_string = default_input_scalars;
975
+ info() << "Using value of '" << arg.raw_string << "' for: " << arg.metadata->name;
976
+ continue;
977
+ }
978
+ o << "Argument value missing for: " << arg.metadata->name << "\n";
979
+ }
980
+ }
981
+ if (!o.str().empty()) {
982
+ fail() << o.str();
983
+ }
984
+ }
985
+
986
+ // Parse all the input arguments, loading images as necessary.
987
+ // (Don't handle outputs yet.)
988
+ void load_inputs(const std::string &user_specified_output_shape_string) {
989
+ assert(output_shapes.empty());
990
+
991
+ Shape first_input_shape;
992
+ std::map<std::string, ShapePromise> auto_input_shape_promises;
993
+
994
+ // First, set all the scalar inputs: we need those to be correct
995
+ // in order to get useful values from the bound-query for input buffers.
996
+ for (auto &arg_pair : args) {
997
+ auto &arg_name = arg_pair.first;
998
+ auto &arg = arg_pair.second;
999
+ switch (arg.metadata->kind) {
1000
+ case halide_argument_kind_input_scalar: {
1001
+ if (!strcmp(arg.metadata->name, "__user_context")) {
1002
+ arg.scalar_value.u.handle = nullptr;
1003
+ info() << "Argument value for: __user_context is special-cased as: nullptr";
1004
+ break;
1005
+ }
1006
+ std::vector<std::pair<const halide_scalar_value_t *, const char *>> values;
1007
+ // If this gets any more complex, smarten it up, but for now,
1008
+ // simpleminded code is fine.
1009
+ if (arg.raw_string == "default") {
1010
+ values.emplace_back(arg.metadata->scalar_def, "default");
1011
+ } else if (arg.raw_string == "estimate") {
1012
+ values.emplace_back(arg.metadata->scalar_estimate, "estimate");
1013
+ } else if (arg.raw_string == "default,estimate") {
1014
+ values.emplace_back(arg.metadata->scalar_def, "default");
1015
+ values.emplace_back(arg.metadata->scalar_estimate, "estimate");
1016
+ } else if (arg.raw_string == "estimate,default") {
1017
+ values.emplace_back(arg.metadata->scalar_estimate, "estimate");
1018
+ values.emplace_back(arg.metadata->scalar_def, "default");
1019
+ }
1020
+ if (!values.empty()) {
1021
+ bool set = false;
1022
+ for (auto &v : values) {
1023
+ if (!v.first) {
1024
+ continue;
1025
+ }
1026
+ info() << "Argument value for: " << arg.metadata->name << " is parsed from metadata (" << v.second << ") as: "
1027
+ << scalar_to_string(arg.metadata->type, *v.first);
1028
+ arg.scalar_value = *v.first;
1029
+ set = true;
1030
+ break;
1031
+ }
1032
+ if (!set) {
1033
+ fail() << "Argument value for: " << arg.metadata->name << " was specified as '" << arg.raw_string << "', "
1034
+ << "but no default and/or estimate was found in the metadata.";
1035
+ }
1036
+ } else {
1037
+ if (!parse_scalar(arg.metadata->type, arg.raw_string, &arg.scalar_value)) {
1038
+ fail() << "Argument value for: " << arg_name << " could not be parsed as type "
1039
+ << arg.metadata->type << ": "
1040
+ << arg.raw_string;
1041
+ }
1042
+ }
1043
+ break;
1044
+ }
1045
+ case halide_argument_kind_input_buffer:
1046
+ case halide_argument_kind_output_buffer:
1047
+ // Nothing yet
1048
+ break;
1049
+ }
1050
+ }
1051
+
1052
+ if (!user_specified_output_shape_string.empty()) {
1053
+ // For now, we set all output shapes to be identical -- there's no
1054
+ // way on the command line to specify different shapes for each
1055
+ // output. Would be nice to try?
1056
+ for (auto &arg_pair : args) {
1057
+ auto &arg = arg_pair.second;
1058
+ if (arg.metadata->kind == halide_argument_kind_output_buffer) {
1059
+ auto &arg_name = arg_pair.first;
1060
+ if (user_specified_output_shape_string == "estimate") {
1061
+ output_shapes[arg_name] = parse_metadata_buffer_estimates(arg.metadata);
1062
+ info() << "Output " << arg_name << " is parsed from metadata as: " << output_shapes[arg_name];
1063
+ } else {
1064
+ output_shapes[arg_name] = parse_extents(user_specified_output_shape_string);
1065
+ info() << "Output " << arg_name << " has user-specified Shape: " << output_shapes[arg_name];
1066
+ }
1067
+ }
1068
+ }
1069
+ auto_input_shape_promises = bounds_query_input_shapes();
1070
+ }
1071
+
1072
+ for (auto &arg_pair : args) {
1073
+ auto &arg_name = arg_pair.first;
1074
+ auto &arg = arg_pair.second;
1075
+ switch (arg.metadata->kind) {
1076
+ case halide_argument_kind_input_buffer:
1077
+ arg.buffer_value = arg.load_buffer(auto_input_shape_promises[arg_name], arg.metadata);
1078
+ info() << "Input " << arg_name << ": Shape is " << get_shape(arg.buffer_value);
1079
+ if (first_input_shape.empty()) {
1080
+ first_input_shape = get_shape(arg.buffer_value);
1081
+ }
1082
+ break;
1083
+ case halide_argument_kind_input_scalar:
1084
+ // Already handled.
1085
+ break;
1086
+ case halide_argument_kind_output_buffer:
1087
+ // Nothing yet
1088
+ break;
1089
+ }
1090
+ }
1091
+
1092
+ if (user_specified_output_shape_string.empty() && !first_input_shape.empty()) {
1093
+ // If there was no output shape specified by the user, use the shape of
1094
+ // the first input buffer (if any). (This is a better-than-nothing guess
1095
+ // that is definitely not always correct, but is convenient and useful enough
1096
+ // to be worth doing.)
1097
+ for (auto &arg_pair : args) {
1098
+ auto &arg = arg_pair.second;
1099
+ if (arg.metadata->kind == halide_argument_kind_output_buffer) {
1100
+ auto &arg_name = arg_pair.first;
1101
+ output_shapes[arg_name] = first_input_shape;
1102
+ info() << "Output " << arg_name << " assumes the shape of first input: " << first_input_shape;
1103
+ }
1104
+ }
1105
+ }
1106
+ }
1107
+
1108
+ void save_outputs() {
1109
+ // Save the output(s), if necessary.
1110
+ for (auto &arg_pair : args) {
1111
+ auto &arg_name = arg_pair.first;
1112
+ auto &arg = arg_pair.second;
1113
+ if (arg.metadata->kind != halide_argument_kind_output_buffer) {
1114
+ continue;
1115
+ }
1116
+ if (arg.raw_string.empty()) {
1117
+ info() << "(Output " << arg_name << " was not saved.)";
1118
+ continue;
1119
+ }
1120
+
1121
+ info() << "Saving output " << arg_name << " to " << arg.raw_string << " ...";
1122
+ Buffer<> &b = arg.buffer_value;
1123
+
1124
+ std::set<Halide::Tools::FormatInfo> savable_types;
1125
+ if (!Halide::Tools::save_query<Buffer<>, IOCheckFail>(arg.raw_string, &savable_types)) {
1126
+ fail() << "Unable to save output: " << arg.raw_string;
1127
+ }
1128
+ const Halide::Tools::FormatInfo best = best_save_format(b, savable_types);
1129
+ if (best.dimensions != b.dimensions()) {
1130
+ b = adjust_buffer_dims("Output", arg_name, best.dimensions, b);
1131
+ }
1132
+ if (best.type != b.type()) {
1133
+ warn() << "Image for argument \"" << arg_name << "\" is of type "
1134
+ << b.type() << " but is being saved as type "
1135
+ << best.type << "; data loss may have occurred.";
1136
+ b = Halide::Tools::ImageTypeConversion::convert_image(b, best.type);
1137
+ }
1138
+ if (!Halide::Tools::save<Buffer<const void>, IOCheckFail>(b.as<const void>(), arg.raw_string)) {
1139
+ fail() << "Unable to save output: " << arg.raw_string;
1140
+ }
1141
+ }
1142
+ }
1143
+
1144
+ void device_sync_outputs() {
1145
+ for (auto &arg_pair : args) {
1146
+ auto &arg = arg_pair.second;
1147
+ if (arg.metadata->kind == halide_argument_kind_output_buffer) {
1148
+ Buffer<> &b = arg.buffer_value;
1149
+ b.device_sync();
1150
+ }
1151
+ }
1152
+ }
1153
+
1154
+ int copy_outputs_to_host() {
1155
+ for (auto &arg_pair : args) {
1156
+ auto &arg = arg_pair.second;
1157
+ if (arg.metadata->kind == halide_argument_kind_output_buffer) {
1158
+ Buffer<> &b = arg.buffer_value;
1159
+ if (auto err = b.copy_to_host(); err != halide_error_code_success) {
1160
+ return err;
1161
+ }
1162
+ }
1163
+ }
1164
+ return halide_error_code_success;
1165
+ }
1166
+
1167
+ uint64_t pixels_out() const {
1168
+ uint64_t pixels_out = 0;
1169
+ for (const auto &arg_pair : args) {
1170
+ const auto &arg = arg_pair.second;
1171
+ switch (arg.metadata->kind) {
1172
+ case halide_argument_kind_output_buffer: {
1173
+ // TODO: this assumes that most output is "pixel-ish", and counting the size of the first
1174
+ // two dimensions approximates the "pixel size". This is not, in general, a valid assumption,
1175
+ // but is a useful metric for benchmarking.
1176
+ Shape shape = get_shape(arg.buffer_value);
1177
+ if (shape.size() >= 2) {
1178
+ pixels_out += shape[0].extent * shape[1].extent;
1179
+ } else if (!shape.empty()) {
1180
+ pixels_out += shape[0].extent;
1181
+ } else {
1182
+ pixels_out += 1;
1183
+ }
1184
+ break;
1185
+ }
1186
+ }
1187
+ }
1188
+ return pixels_out;
1189
+ }
1190
+
1191
+ double megapixels_out() const {
1192
+ return (double)pixels_out() / (1024.0 * 1024.0);
1193
+ }
1194
+
1195
+ uint64_t elements_out() const {
1196
+ uint64_t elements_out = 0;
1197
+ for (const auto &arg_pair : args) {
1198
+ const auto &arg = arg_pair.second;
1199
+ switch (arg.metadata->kind) {
1200
+ case halide_argument_kind_output_buffer: {
1201
+ elements_out += arg.buffer_value.number_of_elements();
1202
+ break;
1203
+ }
1204
+ }
1205
+ }
1206
+ return elements_out;
1207
+ }
1208
+
1209
+ uint64_t bytes_out() const {
1210
+ uint64_t bytes_out = 0;
1211
+ for (const auto &arg_pair : args) {
1212
+ const auto &arg = arg_pair.second;
1213
+ switch (arg.metadata->kind) {
1214
+ case halide_argument_kind_output_buffer: {
1215
+ // size_in_bytes() is not necessarily the same, since
1216
+ // it may include unused space for padding.
1217
+ bytes_out += arg.buffer_value.number_of_elements() * arg.buffer_value.type().bytes();
1218
+ break;
1219
+ }
1220
+ }
1221
+ }
1222
+ return bytes_out;
1223
+ }
1224
+
1225
+ // Run a bounds-query call with the given args, and return the shapes
1226
+ // to which we are constrained.
1227
+ std::vector<Shape> run_bounds_query() const {
1228
+ std::vector<void *> filter_argv(args.size(), nullptr);
1229
+ // These vectors are larger than needed, but simplifies logic downstream.
1230
+ std::vector<Buffer<>> bounds_query_buffers(args.size());
1231
+ std::vector<Shape> constrained_shapes(args.size());
1232
+ for (const auto &arg_pair : args) {
1233
+ const auto &arg_name = arg_pair.first;
1234
+ auto &arg = arg_pair.second;
1235
+ switch (arg.metadata->kind) {
1236
+ case halide_argument_kind_input_scalar:
1237
+ filter_argv[arg.index] = const_cast<halide_scalar_value_t *>(&arg.scalar_value);
1238
+ break;
1239
+ case halide_argument_kind_input_buffer:
1240
+ case halide_argument_kind_output_buffer:
1241
+ Shape shape = (arg.metadata->kind == halide_argument_kind_input_buffer) ?
1242
+ get_shape(arg.buffer_value) :
1243
+ choose_output_extents(arg.metadata->dimensions, output_shapes.at(arg_name));
1244
+ bounds_query_buffers[arg.index] = make_with_shape(arg.metadata->type, shape);
1245
+ filter_argv[arg.index] = bounds_query_buffers[arg.index].raw_buffer();
1246
+ break;
1247
+ }
1248
+ }
1249
+
1250
+ info() << "Running bounds query...";
1251
+ // Ignore result since our halide_error() should catch everything.
1252
+ (void)halide_argv_call(&filter_argv[0]);
1253
+
1254
+ for (const auto &arg_pair : args) {
1255
+ auto &arg = arg_pair.second;
1256
+ switch (arg.metadata->kind) {
1257
+ case halide_argument_kind_input_scalar:
1258
+ break;
1259
+ case halide_argument_kind_input_buffer:
1260
+ case halide_argument_kind_output_buffer:
1261
+ constrained_shapes[arg.index] = get_shape(bounds_query_buffers[arg.index]);
1262
+ break;
1263
+ }
1264
+ }
1265
+ return constrained_shapes;
1266
+ }
1267
+
1268
+ void adapt_input_buffers(const std::vector<Shape> &constrained_shapes) {
1269
+ for (auto &arg_pair : args) {
1270
+ auto &arg = arg_pair.second;
1271
+ arg.adapt_input_buffer(constrained_shapes[arg.index]);
1272
+ }
1273
+ }
1274
+
1275
+ void allocate_output_buffers(const std::vector<Shape> &constrained_shapes) {
1276
+ for (auto &arg_pair : args) {
1277
+ auto &arg = arg_pair.second;
1278
+ arg.allocate_output_buffer(constrained_shapes[arg.index]);
1279
+ }
1280
+ }
1281
+
1282
+ void run_for_benchmark(double benchmark_min_time) {
1283
+ std::vector<void *> filter_argv = build_filter_argv();
1284
+
1285
+ const auto benchmark_inner = [this, &filter_argv]() {
1286
+ // Ignore result since our halide_error() should catch everything.
1287
+ (void)halide_argv_call(&filter_argv[0]);
1288
+ // Ensure that all outputs are finished, otherwise we may just be
1289
+ // measuring how long it takes to do a kernel launch for GPU code.
1290
+ this->device_sync_outputs();
1291
+ };
1292
+
1293
+ info() << "Benchmarking filter...";
1294
+
1295
+ Halide::Tools::BenchmarkConfig config;
1296
+ config.min_time = benchmark_min_time;
1297
+ config.max_time = benchmark_min_time * 4;
1298
+ auto result = Halide::Tools::benchmark(benchmark_inner, config);
1299
+
1300
+ if (!parsable_output) {
1301
+ out() << "Benchmark for " << md->name << " produces best case of " << result.wall_time << " sec/iter (over "
1302
+ << result.samples << " samples, "
1303
+ << result.iterations << " iterations, "
1304
+ << "accuracy " << std::setprecision(2) << (result.accuracy * 100.0) << "%).\n"
1305
+ << "Best output throughput is " << (megapixels_out() / result.wall_time) << " mpix/sec.\n";
1306
+ } else {
1307
+ out() << md->name << " BEST_TIME_MSEC_PER_ITER " << result.wall_time * 1000.f << "\n"
1308
+ << md->name << " SAMPLES " << result.samples << "\n"
1309
+ << md->name << " ITERATIONS " << result.iterations << "\n"
1310
+ << md->name << " TIMING_ACCURACY " << result.accuracy << "\n"
1311
+ << md->name << " THROUGHPUT_MPIX_PER_SEC " << (megapixels_out() / result.wall_time) << "\n"
1312
+ << md->name << " HALIDE_TARGET " << md->target << "\n";
1313
+ }
1314
+ }
1315
+
1316
+ struct Output {
1317
+ std::string name;
1318
+ Buffer<> actual;
1319
+ };
1320
+ std::vector<Output> run_for_output() {
1321
+ std::vector<void *> filter_argv = build_filter_argv();
1322
+
1323
+ info() << "Running filter...";
1324
+ // Ignore result since our halide_error() should catch everything.
1325
+ (void)halide_argv_call(&filter_argv[0]);
1326
+
1327
+ std::vector<Output> v;
1328
+ for (auto &arg_pair : args) {
1329
+ const auto &arg_name = arg_pair.first;
1330
+ const auto &arg = arg_pair.second;
1331
+ if (arg.metadata->kind != halide_argument_kind_output_buffer) {
1332
+ continue;
1333
+ }
1334
+ v.push_back({arg_name, arg.buffer_value});
1335
+ }
1336
+ return v;
1337
+ }
1338
+
1339
+ Buffer<> get_expected_output(const std::string &output) {
1340
+ auto it = args.find(output);
1341
+ if (it == args.end()) {
1342
+ fail() << "Unable to find output: " << output;
1343
+ }
1344
+ const auto &arg = it->second;
1345
+ return args.at(output).load_buffer(output_shapes.at(output), arg.metadata);
1346
+ }
1347
+
1348
+ void describe() const {
1349
+ out() << "Filter name: \"" << md->name << "\"\n";
1350
+ for (size_t i = 0; i < (size_t)md->num_arguments; ++i) {
1351
+ std::ostringstream o;
1352
+ auto &a = md->arguments[i];
1353
+ bool is_input = a.kind != halide_argument_kind_output_buffer;
1354
+ bool is_scalar = a.kind == halide_argument_kind_input_scalar;
1355
+ o << " " << (is_input ? "Input" : "Output") << " \"" << a.name << "\" is of type ";
1356
+ if (is_scalar) {
1357
+ o << a.type;
1358
+ } else {
1359
+ o << "Buffer<" << a.type << "> with " << a.dimensions << " dimensions";
1360
+ }
1361
+ out() << o.str();
1362
+ }
1363
+ }
1364
+
1365
+ std::vector<void *> build_filter_argv() {
1366
+ std::vector<void *> filter_argv(args.size(), nullptr);
1367
+ for (auto &arg_pair : args) {
1368
+ auto &arg = arg_pair.second;
1369
+ switch (arg.metadata->kind) {
1370
+ case halide_argument_kind_input_scalar:
1371
+ filter_argv[arg.index] = &arg.scalar_value;
1372
+ break;
1373
+ case halide_argument_kind_input_buffer:
1374
+ case halide_argument_kind_output_buffer:
1375
+ filter_argv[arg.index] = arg.buffer_value.raw_buffer();
1376
+ break;
1377
+ }
1378
+ }
1379
+ return filter_argv;
1380
+ }
1381
+
1382
+ std::string name() const {
1383
+ return md->name;
1384
+ }
1385
+
1386
+ void set_quiet(bool quiet = true) {
1387
+ halide_set_custom_print(quiet ? rungen_halide_print_quiet : rungen_halide_print);
1388
+ }
1389
+
1390
+ void set_parsable_output(bool parsable_output = true) {
1391
+ this->parsable_output = parsable_output;
1392
+ }
1393
+
1394
+ private:
1395
+ static void rungen_ignore_error(void *user_context, const char *message) {
1396
+ // nothing
1397
+ }
1398
+
1399
+ std::map<std::string, ShapePromise> bounds_query_input_shapes() const {
1400
+ assert(!output_shapes.empty());
1401
+ std::vector<void *> filter_argv(args.size(), nullptr);
1402
+ std::vector<Buffer<>> bounds_query_buffers(args.size());
1403
+ for (const auto &arg_pair : args) {
1404
+ auto &arg_name = arg_pair.first;
1405
+ auto &arg = arg_pair.second;
1406
+ switch (arg.metadata->kind) {
1407
+ case halide_argument_kind_input_scalar:
1408
+ filter_argv[arg.index] = const_cast<halide_scalar_value_t *>(&arg.scalar_value);
1409
+ break;
1410
+ case halide_argument_kind_input_buffer:
1411
+ // Make a Buffer<> that has the right dimension count and extent=0 for all of them
1412
+ bounds_query_buffers[arg.index] = Buffer<>(arg.metadata->type, std::vector<int>(arg.metadata->dimensions, 0));
1413
+ filter_argv[arg.index] = bounds_query_buffers[arg.index].raw_buffer();
1414
+ break;
1415
+ case halide_argument_kind_output_buffer:
1416
+ bounds_query_buffers[arg.index] = make_with_shape(arg.metadata->type, output_shapes.at(arg_name));
1417
+ filter_argv[arg.index] = bounds_query_buffers[arg.index].raw_buffer();
1418
+ break;
1419
+ }
1420
+ }
1421
+
1422
+ auto previous_error_handler = halide_set_error_handler(rungen_ignore_error);
1423
+ int result = halide_argv_call(&filter_argv[0]);
1424
+ halide_set_error_handler(previous_error_handler);
1425
+
1426
+ std::map<std::string, ShapePromise> input_shape_promises;
1427
+ for (const auto &arg_pair : args) {
1428
+ auto &arg_name = arg_pair.first;
1429
+ auto &arg = arg_pair.second;
1430
+ if (arg.metadata->kind == halide_argument_kind_input_buffer) {
1431
+ if (result == 0) {
1432
+ Shape shape = get_shape(bounds_query_buffers[arg.index]);
1433
+ input_shape_promises[arg_name] = [shape]() -> Shape { return shape; };
1434
+ info() << "Input " << arg_name << " has a bounds-query shape of " << shape;
1435
+ } else {
1436
+ input_shape_promises[arg_name] = [arg_name]() -> Shape {
1437
+ fail() << "Input " << arg_name << " could not calculate a shape satisfying bounds-query constraints.\n"
1438
+ << "Try relaxing the constraints, or providing an explicit estimate for the input.\n";
1439
+ return Shape();
1440
+ };
1441
+ info() << "Input " << arg_name << " failed bounds-query\n";
1442
+ }
1443
+ }
1444
+ }
1445
+ return input_shape_promises;
1446
+ }
1447
+
1448
+ // Replace the standard Halide runtime function to capture print output to stdout
1449
+ static void rungen_halide_print(void *user_context, const char *message) {
1450
+ out() << "halide_print: " << message;
1451
+ }
1452
+
1453
+ static void rungen_halide_print_quiet(void *user_context, const char *message) {
1454
+ // nothing
1455
+ }
1456
+
1457
+ // Replace the standard Halide runtime function to capture Halide errors to fail()
1458
+ static void rungen_halide_error(void *user_context, const char *message) {
1459
+ fail() << "halide_error: " << message;
1460
+ }
1461
+
1462
+ ArgvCall halide_argv_call;
1463
+ const struct halide_filter_metadata_t *const md;
1464
+ std::map<std::string, ArgData> args;
1465
+ std::map<std::string, Shape> output_shapes;
1466
+ bool parsable_output = false;
1467
+ };
1468
+
1469
+ } // namespace RunGen
1470
+ } // namespace Halide