halide 19.0.0__cp38-cp38-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. halide/__init__.py +39 -0
  2. halide/_generator_helpers.py +835 -0
  3. halide/bin/adams2019_retrain_cost_model +0 -0
  4. halide/bin/adams2019_weightsdir_to_weightsfile +0 -0
  5. halide/bin/anderson2021_retrain_cost_model +0 -0
  6. halide/bin/anderson2021_weightsdir_to_weightsfile +0 -0
  7. halide/bin/featurization_to_sample +0 -0
  8. halide/bin/gengen +0 -0
  9. halide/bin/get_host_target +0 -0
  10. halide/halide_.cpython-38-darwin.so +0 -0
  11. halide/imageio.py +60 -0
  12. halide/include/Halide.h +35293 -0
  13. halide/include/HalideBuffer.h +2618 -0
  14. halide/include/HalidePyTorchCudaHelpers.h +64 -0
  15. halide/include/HalidePyTorchHelpers.h +120 -0
  16. halide/include/HalideRuntime.h +2221 -0
  17. halide/include/HalideRuntimeCuda.h +89 -0
  18. halide/include/HalideRuntimeD3D12Compute.h +91 -0
  19. halide/include/HalideRuntimeHexagonDma.h +104 -0
  20. halide/include/HalideRuntimeHexagonHost.h +157 -0
  21. halide/include/HalideRuntimeMetal.h +112 -0
  22. halide/include/HalideRuntimeOpenCL.h +119 -0
  23. halide/include/HalideRuntimeQurt.h +32 -0
  24. halide/include/HalideRuntimeVulkan.h +137 -0
  25. halide/include/HalideRuntimeWebGPU.h +44 -0
  26. halide/lib/cmake/Halide/FindHalide_LLVM.cmake +152 -0
  27. halide/lib/cmake/Halide/FindV8.cmake +33 -0
  28. halide/lib/cmake/Halide/Halide-shared-deps.cmake +0 -0
  29. halide/lib/cmake/Halide/Halide-shared-targets-release.cmake +29 -0
  30. halide/lib/cmake/Halide/Halide-shared-targets.cmake +154 -0
  31. halide/lib/cmake/Halide/HalideConfig.cmake +162 -0
  32. halide/lib/cmake/Halide/HalideConfigVersion.cmake +65 -0
  33. halide/lib/cmake/HalideHelpers/FindHalide_WebGPU.cmake +27 -0
  34. halide/lib/cmake/HalideHelpers/Halide-Interfaces-release.cmake +116 -0
  35. halide/lib/cmake/HalideHelpers/Halide-Interfaces.cmake +236 -0
  36. halide/lib/cmake/HalideHelpers/HalideGeneratorHelpers.cmake +1056 -0
  37. halide/lib/cmake/HalideHelpers/HalideHelpersConfig.cmake +28 -0
  38. halide/lib/cmake/HalideHelpers/HalideHelpersConfigVersion.cmake +54 -0
  39. halide/lib/cmake/HalideHelpers/HalideTargetHelpers.cmake +99 -0
  40. halide/lib/cmake/HalideHelpers/MutexCopy.ps1 +31 -0
  41. halide/lib/cmake/HalideHelpers/TargetExportScript.cmake +55 -0
  42. halide/lib/cmake/Halide_Python/Halide_Python-targets-release.cmake +30 -0
  43. halide/lib/cmake/Halide_Python/Halide_Python-targets.cmake +125 -0
  44. halide/lib/cmake/Halide_Python/Halide_PythonConfig.cmake +26 -0
  45. halide/lib/cmake/Halide_Python/Halide_PythonConfigVersion.cmake +65 -0
  46. halide/lib/libHalide.dylib +0 -0
  47. halide/lib/libHalidePyStubs.a +0 -0
  48. halide/lib/libHalide_GenGen.a +0 -0
  49. halide/lib/libautoschedule_adams2019.so +0 -0
  50. halide/lib/libautoschedule_anderson2021.so +0 -0
  51. halide/lib/libautoschedule_li2018.so +0 -0
  52. halide/lib/libautoschedule_mullapudi2016.so +0 -0
  53. halide/share/doc/Halide/LICENSE.txt +233 -0
  54. halide/share/doc/Halide/README.md +439 -0
  55. halide/share/doc/Halide/doc/BuildingHalideWithCMake.md +626 -0
  56. halide/share/doc/Halide/doc/CodeStyleCMake.md +393 -0
  57. halide/share/doc/Halide/doc/FuzzTesting.md +104 -0
  58. halide/share/doc/Halide/doc/HalideCMakePackage.md +812 -0
  59. halide/share/doc/Halide/doc/Hexagon.md +73 -0
  60. halide/share/doc/Halide/doc/Python.md +844 -0
  61. halide/share/doc/Halide/doc/RunGen.md +283 -0
  62. halide/share/doc/Halide/doc/Testing.md +125 -0
  63. halide/share/doc/Halide/doc/Vulkan.md +287 -0
  64. halide/share/doc/Halide/doc/WebAssembly.md +228 -0
  65. halide/share/doc/Halide/doc/WebGPU.md +128 -0
  66. halide/share/tools/RunGen.h +1470 -0
  67. halide/share/tools/RunGenMain.cpp +642 -0
  68. halide/share/tools/adams2019_autotune_loop.sh +227 -0
  69. halide/share/tools/anderson2021_autotune_loop.sh +591 -0
  70. halide/share/tools/halide_benchmark.h +240 -0
  71. halide/share/tools/halide_image.h +31 -0
  72. halide/share/tools/halide_image_info.h +318 -0
  73. halide/share/tools/halide_image_io.h +2794 -0
  74. halide/share/tools/halide_malloc_trace.h +102 -0
  75. halide/share/tools/halide_thread_pool.h +161 -0
  76. halide/share/tools/halide_trace_config.h +559 -0
  77. halide-19.0.0.data/data/share/cmake/Halide/HalideConfig.cmake +6 -0
  78. halide-19.0.0.data/data/share/cmake/Halide/HalideConfigVersion.cmake +65 -0
  79. halide-19.0.0.data/data/share/cmake/HalideHelpers/HalideHelpersConfig.cmake +6 -0
  80. halide-19.0.0.data/data/share/cmake/HalideHelpers/HalideHelpersConfigVersion.cmake +54 -0
  81. halide-19.0.0.dist-info/METADATA +301 -0
  82. halide-19.0.0.dist-info/RECORD +84 -0
  83. halide-19.0.0.dist-info/WHEEL +5 -0
  84. halide-19.0.0.dist-info/licenses/LICENSE.txt +233 -0
@@ -0,0 +1,2794 @@
1
+ // This simple IO library works the Halide::Buffer<T> type or any
2
+ // other image type with the same API.
3
+
4
+ #ifndef HALIDE_IMAGE_IO_H
5
+ #define HALIDE_IMAGE_IO_H
6
+
7
+ #include <algorithm>
8
+ #include <cctype>
9
+ #include <cmath>
10
+ #include <cstdarg>
11
+ #include <cstddef>
12
+ #include <cstdio>
13
+ #include <cstdlib>
14
+ #include <functional>
15
+ #include <map>
16
+ #include <set>
17
+ #include <string>
18
+ #include <vector>
19
+
20
+ #ifndef HALIDE_NO_PNG
21
+ #include "png.h"
22
+ #endif
23
+
24
+ #ifndef HALIDE_NO_JPEG
25
+ #ifdef _WIN32
26
+ #ifndef NOMINMAX
27
+ #define NOMINMAX
28
+ #endif
29
+ #include <windows.h>
30
+ #endif
31
+ #include "jpeglib.h"
32
+ #endif
33
+
34
+ #include "HalideRuntime.h" // for halide_type_t
35
+
36
+ namespace Halide {
37
+ namespace Tools {
38
+
39
+ struct FormatInfo {
40
+ halide_type_t type;
41
+ int dimensions;
42
+
43
+ bool operator<(const FormatInfo &other) const {
44
+ if (type.code < other.type.code) {
45
+ return true;
46
+ } else if (type.code > other.type.code) {
47
+ return false;
48
+ }
49
+ if (type.bits < other.type.bits) {
50
+ return true;
51
+ } else if (type.bits > other.type.bits) {
52
+ return false;
53
+ }
54
+ if (type.lanes < other.type.lanes) {
55
+ return true;
56
+ } else if (type.lanes > other.type.lanes) {
57
+ return false;
58
+ }
59
+ return (dimensions < other.dimensions);
60
+ }
61
+ };
62
+
63
+ namespace Internal {
64
+
65
+ typedef bool (*CheckFunc)(bool condition, const char *msg);
66
+
67
+ inline bool CheckFail(bool condition, const char *msg) {
68
+ if (!condition) {
69
+ fprintf(stderr, "%s\n", msg);
70
+ abort();
71
+ }
72
+ return condition;
73
+ }
74
+
75
+ inline bool CheckReturn(bool condition, const char *msg) {
76
+ return condition;
77
+ }
78
+
79
+ template<typename To, typename From>
80
+ To convert(const From &from);
81
+
82
+ // Convert to bool
83
+ template<>
84
+ inline bool convert(const bool &in) {
85
+ return in;
86
+ }
87
+ template<>
88
+ inline bool convert(const uint8_t &in) {
89
+ return in != 0;
90
+ }
91
+ template<>
92
+ inline bool convert(const uint16_t &in) {
93
+ return in != 0;
94
+ }
95
+ template<>
96
+ inline bool convert(const uint32_t &in) {
97
+ return in != 0;
98
+ }
99
+ template<>
100
+ inline bool convert(const uint64_t &in) {
101
+ return in != 0;
102
+ }
103
+ template<>
104
+ inline bool convert(const int8_t &in) {
105
+ return in != 0;
106
+ }
107
+ template<>
108
+ inline bool convert(const int16_t &in) {
109
+ return in != 0;
110
+ }
111
+ template<>
112
+ inline bool convert(const int32_t &in) {
113
+ return in != 0;
114
+ }
115
+ template<>
116
+ inline bool convert(const int64_t &in) {
117
+ return in != 0;
118
+ }
119
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
120
+ template<>
121
+ inline bool convert(const _Float16 &in) {
122
+ return (float)in != 0;
123
+ }
124
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
125
+ template<>
126
+ inline bool convert(const float &in) {
127
+ return in != 0;
128
+ }
129
+ template<>
130
+ inline bool convert(const double &in) {
131
+ return in != 0;
132
+ }
133
+
134
+ // Convert to u8
135
+ template<>
136
+ inline uint8_t convert(const bool &in) {
137
+ return in;
138
+ }
139
+ template<>
140
+ inline uint8_t convert(const uint8_t &in) {
141
+ return in;
142
+ }
143
+ template<>
144
+ inline uint8_t convert(const uint16_t &in) {
145
+ uint32_t tmp = (uint32_t)(in) + 0x80;
146
+ // Fast approximation of div-by-257: see http://research.swtch.com/divmult
147
+ return ((tmp * 255 + 255) >> 16);
148
+ }
149
+ template<>
150
+ inline uint8_t convert(const uint32_t &in) {
151
+ return (uint8_t)((((uint64_t)in) + 0x00808080) / 0x01010101);
152
+ }
153
+ // uint64 -> 8 just discards the lower 32 bits: if you were expecting more precision, well, sorry
154
+ template<>
155
+ inline uint8_t convert(const uint64_t &in) {
156
+ return convert<uint8_t, uint32_t>(uint32_t(in >> 32));
157
+ }
158
+ template<>
159
+ inline uint8_t convert(const int8_t &in) {
160
+ return convert<uint8_t, uint8_t>(in);
161
+ }
162
+ template<>
163
+ inline uint8_t convert(const int16_t &in) {
164
+ return convert<uint8_t, uint16_t>(in);
165
+ }
166
+ template<>
167
+ inline uint8_t convert(const int32_t &in) {
168
+ return convert<uint8_t, uint32_t>(in);
169
+ }
170
+ template<>
171
+ inline uint8_t convert(const int64_t &in) {
172
+ return convert<uint8_t, uint64_t>(in);
173
+ }
174
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
175
+ template<>
176
+ inline uint8_t convert(const _Float16 &in) {
177
+ return (uint8_t)std::lround((float)in * 255.0f);
178
+ }
179
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
180
+ template<>
181
+ inline uint8_t convert(const float &in) {
182
+ return (uint8_t)std::lround(in * 255.0f);
183
+ }
184
+ template<>
185
+ inline uint8_t convert(const double &in) {
186
+ return (uint8_t)std::lround(in * 255.0);
187
+ }
188
+
189
+ // Convert to u16
190
+ template<>
191
+ inline uint16_t convert(const bool &in) {
192
+ return in;
193
+ }
194
+ template<>
195
+ inline uint16_t convert(const uint8_t &in) {
196
+ return uint16_t(in) * 0x0101;
197
+ }
198
+ template<>
199
+ inline uint16_t convert(const uint16_t &in) {
200
+ return in;
201
+ }
202
+ template<>
203
+ inline uint16_t convert(const uint32_t &in) {
204
+ return in >> 16;
205
+ }
206
+ template<>
207
+ inline uint16_t convert(const uint64_t &in) {
208
+ return in >> 48;
209
+ }
210
+ template<>
211
+ inline uint16_t convert(const int8_t &in) {
212
+ return convert<uint16_t, uint8_t>(in);
213
+ }
214
+ template<>
215
+ inline uint16_t convert(const int16_t &in) {
216
+ return convert<uint16_t, uint16_t>(in);
217
+ }
218
+ template<>
219
+ inline uint16_t convert(const int32_t &in) {
220
+ return convert<uint16_t, uint32_t>(in);
221
+ }
222
+ template<>
223
+ inline uint16_t convert(const int64_t &in) {
224
+ return convert<uint16_t, uint64_t>(in);
225
+ }
226
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
227
+ template<>
228
+ inline uint16_t convert(const _Float16 &in) {
229
+ return (uint16_t)std::lround((float)in * 65535.0f);
230
+ }
231
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
232
+ template<>
233
+ inline uint16_t convert(const float &in) {
234
+ return (uint16_t)std::lround(in * 65535.0f);
235
+ }
236
+ template<>
237
+ inline uint16_t convert(const double &in) {
238
+ return (uint16_t)std::lround(in * 65535.0);
239
+ }
240
+
241
+ // Convert to u32
242
+ template<>
243
+ inline uint32_t convert(const bool &in) {
244
+ return in;
245
+ }
246
+ template<>
247
+ inline uint32_t convert(const uint8_t &in) {
248
+ return uint32_t(in) * 0x01010101;
249
+ }
250
+ template<>
251
+ inline uint32_t convert(const uint16_t &in) {
252
+ return uint32_t(in) * 0x00010001;
253
+ }
254
+ template<>
255
+ inline uint32_t convert(const uint32_t &in) {
256
+ return in;
257
+ }
258
+ template<>
259
+ inline uint32_t convert(const uint64_t &in) {
260
+ return (uint32_t)(in >> 32);
261
+ }
262
+ template<>
263
+ inline uint32_t convert(const int8_t &in) {
264
+ return convert<uint32_t, uint8_t>(in);
265
+ }
266
+ template<>
267
+ inline uint32_t convert(const int16_t &in) {
268
+ return convert<uint32_t, uint16_t>(in);
269
+ }
270
+ template<>
271
+ inline uint32_t convert(const int32_t &in) {
272
+ return convert<uint32_t, uint32_t>(in);
273
+ }
274
+ template<>
275
+ inline uint32_t convert(const int64_t &in) {
276
+ return convert<uint32_t, uint64_t>(in);
277
+ }
278
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
279
+ template<>
280
+ inline uint32_t convert(const _Float16 &in) {
281
+ return (uint32_t)std::llround((float)in * 4294967295.0);
282
+ }
283
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
284
+ template<>
285
+ inline uint32_t convert(const float &in) {
286
+ return (uint32_t)std::llround(in * 4294967295.0);
287
+ }
288
+ template<>
289
+ inline uint32_t convert(const double &in) {
290
+ return (uint32_t)std::llround(in * 4294967295.0);
291
+ }
292
+
293
+ // Convert to u64
294
+ template<>
295
+ inline uint64_t convert(const bool &in) {
296
+ return in;
297
+ }
298
+ template<>
299
+ inline uint64_t convert(const uint8_t &in) {
300
+ return uint64_t(in) * 0x0101010101010101LL;
301
+ }
302
+ template<>
303
+ inline uint64_t convert(const uint16_t &in) {
304
+ return uint64_t(in) * 0x0001000100010001LL;
305
+ }
306
+ template<>
307
+ inline uint64_t convert(const uint32_t &in) {
308
+ return uint64_t(in) * 0x0000000100000001LL;
309
+ }
310
+ template<>
311
+ inline uint64_t convert(const uint64_t &in) {
312
+ return in;
313
+ }
314
+ template<>
315
+ inline uint64_t convert(const int8_t &in) {
316
+ return convert<uint64_t, uint8_t>(in);
317
+ }
318
+ template<>
319
+ inline uint64_t convert(const int16_t &in) {
320
+ return convert<uint64_t, uint16_t>(in);
321
+ }
322
+ template<>
323
+ inline uint64_t convert(const int32_t &in) {
324
+ return convert<uint64_t, uint64_t>(in);
325
+ }
326
+ template<>
327
+ inline uint64_t convert(const int64_t &in) {
328
+ return convert<uint64_t, uint64_t>(in);
329
+ }
330
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
331
+ template<>
332
+ inline uint64_t convert(const _Float16 &in) {
333
+ return convert<uint64_t, uint32_t>((uint32_t)std::llround((float)in * 4294967295.0));
334
+ }
335
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
336
+ template<>
337
+ inline uint64_t convert(const float &in) {
338
+ return convert<uint64_t, uint32_t>((uint32_t)std::llround(in * 4294967295.0));
339
+ }
340
+ template<>
341
+ inline uint64_t convert(const double &in) {
342
+ return convert<uint64_t, uint32_t>((uint32_t)std::llround(in * 4294967295.0));
343
+ }
344
+
345
+ // Convert to i8
346
+ template<>
347
+ inline int8_t convert(const bool &in) {
348
+ return in;
349
+ }
350
+ template<>
351
+ inline int8_t convert(const uint8_t &in) {
352
+ return convert<uint8_t, uint8_t>(in);
353
+ }
354
+ template<>
355
+ inline int8_t convert(const uint16_t &in) {
356
+ return convert<uint8_t, uint16_t>(in);
357
+ }
358
+ template<>
359
+ inline int8_t convert(const uint32_t &in) {
360
+ return convert<uint8_t, uint32_t>(in);
361
+ }
362
+ template<>
363
+ inline int8_t convert(const uint64_t &in) {
364
+ return convert<uint8_t, uint64_t>(in);
365
+ }
366
+ template<>
367
+ inline int8_t convert(const int8_t &in) {
368
+ return convert<uint8_t, int8_t>(in);
369
+ }
370
+ template<>
371
+ inline int8_t convert(const int16_t &in) {
372
+ return convert<uint8_t, int16_t>(in);
373
+ }
374
+ template<>
375
+ inline int8_t convert(const int32_t &in) {
376
+ return convert<uint8_t, int32_t>(in);
377
+ }
378
+ template<>
379
+ inline int8_t convert(const int64_t &in) {
380
+ return convert<uint8_t, int64_t>(in);
381
+ }
382
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
383
+ template<>
384
+ inline int8_t convert(const _Float16 &in) {
385
+ return convert<uint8_t, float>((float)in);
386
+ }
387
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
388
+ template<>
389
+ inline int8_t convert(const float &in) {
390
+ return convert<uint8_t, float>(in);
391
+ }
392
+ template<>
393
+ inline int8_t convert(const double &in) {
394
+ return convert<uint8_t, double>(in);
395
+ }
396
+
397
+ // Convert to i16
398
+ template<>
399
+ inline int16_t convert(const bool &in) {
400
+ return in;
401
+ }
402
+ template<>
403
+ inline int16_t convert(const uint8_t &in) {
404
+ return convert<uint16_t, uint8_t>(in);
405
+ }
406
+ template<>
407
+ inline int16_t convert(const uint16_t &in) {
408
+ return convert<uint16_t, uint16_t>(in);
409
+ }
410
+ template<>
411
+ inline int16_t convert(const uint32_t &in) {
412
+ return convert<uint16_t, uint32_t>(in);
413
+ }
414
+ template<>
415
+ inline int16_t convert(const uint64_t &in) {
416
+ return convert<uint16_t, uint64_t>(in);
417
+ }
418
+ template<>
419
+ inline int16_t convert(const int8_t &in) {
420
+ return convert<uint16_t, int8_t>(in);
421
+ }
422
+ template<>
423
+ inline int16_t convert(const int16_t &in) {
424
+ return convert<uint16_t, int16_t>(in);
425
+ }
426
+ template<>
427
+ inline int16_t convert(const int32_t &in) {
428
+ return convert<uint16_t, int32_t>(in);
429
+ }
430
+ template<>
431
+ inline int16_t convert(const int64_t &in) {
432
+ return convert<uint16_t, int64_t>(in);
433
+ }
434
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
435
+ template<>
436
+ inline int16_t convert(const _Float16 &in) {
437
+ return convert<uint16_t, float>((float)in);
438
+ }
439
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
440
+ template<>
441
+ inline int16_t convert(const float &in) {
442
+ return convert<uint16_t, float>(in);
443
+ }
444
+ template<>
445
+ inline int16_t convert(const double &in) {
446
+ return convert<uint16_t, double>(in);
447
+ }
448
+
449
+ // Convert to i32
450
+ template<>
451
+ inline int32_t convert(const bool &in) {
452
+ return in;
453
+ }
454
+ template<>
455
+ inline int32_t convert(const uint8_t &in) {
456
+ return convert<uint32_t, uint8_t>(in);
457
+ }
458
+ template<>
459
+ inline int32_t convert(const uint16_t &in) {
460
+ return convert<uint32_t, uint16_t>(in);
461
+ }
462
+ template<>
463
+ inline int32_t convert(const uint32_t &in) {
464
+ return convert<uint32_t, uint32_t>(in);
465
+ }
466
+ template<>
467
+ inline int32_t convert(const uint64_t &in) {
468
+ return convert<uint32_t, uint64_t>(in);
469
+ }
470
+ template<>
471
+ inline int32_t convert(const int8_t &in) {
472
+ return convert<uint32_t, int8_t>(in);
473
+ }
474
+ template<>
475
+ inline int32_t convert(const int16_t &in) {
476
+ return convert<uint32_t, int16_t>(in);
477
+ }
478
+ template<>
479
+ inline int32_t convert(const int32_t &in) {
480
+ return convert<uint32_t, int32_t>(in);
481
+ }
482
+ template<>
483
+ inline int32_t convert(const int64_t &in) {
484
+ return convert<uint32_t, int64_t>(in);
485
+ }
486
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
487
+ template<>
488
+ inline int32_t convert(const _Float16 &in) {
489
+ return convert<uint32_t, float>((float)in);
490
+ }
491
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
492
+ template<>
493
+ inline int32_t convert(const float &in) {
494
+ return convert<uint32_t, float>(in);
495
+ }
496
+ template<>
497
+ inline int32_t convert(const double &in) {
498
+ return convert<uint32_t, double>(in);
499
+ }
500
+
501
+ // Convert to i64
502
+ template<>
503
+ inline int64_t convert(const bool &in) {
504
+ return in;
505
+ }
506
+ template<>
507
+ inline int64_t convert(const uint8_t &in) {
508
+ return convert<uint64_t, uint8_t>(in);
509
+ }
510
+ template<>
511
+ inline int64_t convert(const uint16_t &in) {
512
+ return convert<uint64_t, uint16_t>(in);
513
+ }
514
+ template<>
515
+ inline int64_t convert(const uint32_t &in) {
516
+ return convert<uint64_t, uint32_t>(in);
517
+ }
518
+ template<>
519
+ inline int64_t convert(const uint64_t &in) {
520
+ return convert<uint64_t, uint64_t>(in);
521
+ }
522
+ template<>
523
+ inline int64_t convert(const int8_t &in) {
524
+ return convert<uint64_t, int8_t>(in);
525
+ }
526
+ template<>
527
+ inline int64_t convert(const int16_t &in) {
528
+ return convert<uint64_t, int16_t>(in);
529
+ }
530
+ template<>
531
+ inline int64_t convert(const int32_t &in) {
532
+ return convert<uint64_t, int32_t>(in);
533
+ }
534
+ template<>
535
+ inline int64_t convert(const int64_t &in) {
536
+ return convert<uint64_t, int64_t>(in);
537
+ }
538
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
539
+ template<>
540
+ inline int64_t convert(const _Float16 &in) {
541
+ return convert<uint64_t, float>((float)in);
542
+ }
543
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
544
+ template<>
545
+ inline int64_t convert(const float &in) {
546
+ return convert<uint64_t, float>(in);
547
+ }
548
+ template<>
549
+ inline int64_t convert(const double &in) {
550
+ return convert<uint64_t, double>(in);
551
+ }
552
+
553
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
554
+ // Convert to f16
555
+ template<>
556
+ inline _Float16 convert(const bool &in) {
557
+ return in;
558
+ }
559
+ template<>
560
+ inline _Float16 convert(const uint8_t &in) {
561
+ return (_Float16)(in / 255.0f);
562
+ }
563
+ template<>
564
+ inline _Float16 convert(const uint16_t &in) {
565
+ return (_Float16)(in / 65535.0f);
566
+ }
567
+ template<>
568
+ inline _Float16 convert(const uint32_t &in) {
569
+ return (_Float16)(in / 4294967295.0);
570
+ }
571
+ template<>
572
+ inline _Float16 convert(const uint64_t &in) {
573
+ return convert<_Float16, uint32_t>(uint32_t(in >> 32));
574
+ }
575
+ template<>
576
+ inline _Float16 convert(const int8_t &in) {
577
+ return convert<_Float16, uint8_t>(in);
578
+ }
579
+ template<>
580
+ inline _Float16 convert(const int16_t &in) {
581
+ return convert<_Float16, uint16_t>(in);
582
+ }
583
+ template<>
584
+ inline _Float16 convert(const int32_t &in) {
585
+ return convert<_Float16, uint64_t>(in);
586
+ }
587
+ template<>
588
+ inline _Float16 convert(const int64_t &in) {
589
+ return convert<_Float16, uint64_t>(in);
590
+ }
591
+ template<>
592
+ inline _Float16 convert(const _Float16 &in) {
593
+ return in;
594
+ }
595
+ template<>
596
+ inline _Float16 convert(const float &in) {
597
+ return (_Float16)in;
598
+ }
599
+ template<>
600
+ inline _Float16 convert(const double &in) {
601
+ return (_Float16)in;
602
+ }
603
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
604
+
605
+ // Convert to f32
606
+ template<>
607
+ inline float convert(const bool &in) {
608
+ return in;
609
+ }
610
+ template<>
611
+ inline float convert(const uint8_t &in) {
612
+ return in / 255.0f;
613
+ }
614
+ template<>
615
+ inline float convert(const uint16_t &in) {
616
+ return in / 65535.0f;
617
+ }
618
+ template<>
619
+ inline float convert(const uint32_t &in) {
620
+ return (float)(in / 4294967295.0);
621
+ }
622
+ template<>
623
+ inline float convert(const uint64_t &in) {
624
+ return convert<float, uint32_t>(uint32_t(in >> 32));
625
+ }
626
+ template<>
627
+ inline float convert(const int8_t &in) {
628
+ return convert<float, uint8_t>(in);
629
+ }
630
+ template<>
631
+ inline float convert(const int16_t &in) {
632
+ return convert<float, uint16_t>(in);
633
+ }
634
+ template<>
635
+ inline float convert(const int32_t &in) {
636
+ return convert<float, uint64_t>(in);
637
+ }
638
+ template<>
639
+ inline float convert(const int64_t &in) {
640
+ return convert<float, uint64_t>(in);
641
+ }
642
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
643
+ template<>
644
+ inline float convert(const _Float16 &in) {
645
+ return (float)in;
646
+ }
647
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
648
+ template<>
649
+ inline float convert(const float &in) {
650
+ return in;
651
+ }
652
+ template<>
653
+ inline float convert(const double &in) {
654
+ return (float)in;
655
+ }
656
+
657
+ // Convert to f64
658
+ template<>
659
+ inline double convert(const bool &in) {
660
+ return in;
661
+ }
662
+ template<>
663
+ inline double convert(const uint8_t &in) {
664
+ return in / 255.0f;
665
+ }
666
+ template<>
667
+ inline double convert(const uint16_t &in) {
668
+ return in / 65535.0f;
669
+ }
670
+ template<>
671
+ inline double convert(const uint32_t &in) {
672
+ return (double)(in / 4294967295.0);
673
+ }
674
+ template<>
675
+ inline double convert(const uint64_t &in) {
676
+ return convert<double, uint32_t>(uint32_t(in >> 32));
677
+ }
678
+ template<>
679
+ inline double convert(const int8_t &in) {
680
+ return convert<double, uint8_t>(in);
681
+ }
682
+ template<>
683
+ inline double convert(const int16_t &in) {
684
+ return convert<double, uint16_t>(in);
685
+ }
686
+ template<>
687
+ inline double convert(const int32_t &in) {
688
+ return convert<double, uint64_t>(in);
689
+ }
690
+ template<>
691
+ inline double convert(const int64_t &in) {
692
+ return convert<double, uint64_t>(in);
693
+ }
694
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
695
+ template<>
696
+ inline double convert(const _Float16 &in) {
697
+ return (double)in;
698
+ }
699
+ #endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
700
+ template<>
701
+ inline double convert(const float &in) {
702
+ return (double)in;
703
+ }
704
+ template<>
705
+ inline double convert(const double &in) {
706
+ return in;
707
+ }
708
+
709
+ inline std::string to_lowercase(const std::string &s) {
710
+ std::string r = s;
711
+ std::transform(r.begin(), r.end(), r.begin(), ::tolower);
712
+ return r;
713
+ }
714
+
715
+ inline std::string get_lowercase_extension(const std::string &path) {
716
+ size_t last_dot = path.rfind('.');
717
+ if (last_dot == std::string::npos) {
718
+ return "";
719
+ }
720
+ return to_lowercase(path.substr(last_dot + 1));
721
+ }
722
+
723
+ template<typename ElemType>
724
+ ElemType read_big_endian(const uint8_t *src);
725
+
726
+ template<>
727
+ inline uint8_t read_big_endian(const uint8_t *src) {
728
+ return *src;
729
+ }
730
+
731
+ template<>
732
+ inline uint16_t read_big_endian(const uint8_t *src) {
733
+ return (((uint16_t)src[0]) << 8) | ((uint16_t)src[1]);
734
+ }
735
+
736
+ template<typename ElemType>
737
+ void write_big_endian(const ElemType &src, uint8_t *dst);
738
+
739
+ template<>
740
+ inline void write_big_endian(const uint8_t &src, uint8_t *dst) {
741
+ *dst = src;
742
+ }
743
+
744
+ template<>
745
+ inline void write_big_endian(const uint16_t &src, uint8_t *dst) {
746
+ dst[0] = src >> 8;
747
+ dst[1] = src & 0xff;
748
+ }
749
+
750
+ struct FileOpener {
751
+ FileOpener(const std::string &filename, const char *mode)
752
+ : f(fopen(filename.c_str(), mode)) {
753
+ // nothing
754
+ }
755
+
756
+ ~FileOpener() {
757
+ if (f != nullptr) {
758
+ fclose(f);
759
+ }
760
+ }
761
+
762
+ // read a line of data, skipping lines that begin with '#"
763
+ char *read_line(char *buf, int maxlen) {
764
+ char *status;
765
+ do {
766
+ status = fgets(buf, maxlen, f);
767
+ } while (status && buf[0] == '#');
768
+ return (status);
769
+ }
770
+
771
+ // call read_line and to a sscanf() on it
772
+ int scan_line(const char *fmt, ...) {
773
+ char buf[1024];
774
+ if (!read_line(buf, 1024)) {
775
+ return 0;
776
+ }
777
+ va_list args;
778
+ va_start(args, fmt);
779
+ int result = vsscanf(buf, fmt, args);
780
+ va_end(args);
781
+ return result;
782
+ }
783
+
784
+ bool read_bytes(void *data, size_t count) {
785
+ return fread(data, 1, count, f) == count;
786
+ }
787
+
788
+ template<typename T, size_t N>
789
+ bool read_array(T (&data)[N]) {
790
+ return read_bytes(&data[0], sizeof(T) * N);
791
+ }
792
+
793
+ template<typename T>
794
+ bool read_vector(std::vector<T> *v) {
795
+ return read_bytes(v->data(), v->size() * sizeof(T));
796
+ }
797
+
798
+ bool write_bytes(const void *data, size_t count) {
799
+ return fwrite(data, 1, count, f) == count;
800
+ }
801
+
802
+ template<typename T>
803
+ bool write_vector(const std::vector<T> &v) {
804
+ return write_bytes(v.data(), v.size() * sizeof(T));
805
+ }
806
+
807
+ template<typename T, size_t N>
808
+ bool write_array(const T (&data)[N]) {
809
+ return write_bytes(&data[0], sizeof(T) * N);
810
+ }
811
+
812
+ FILE *const f;
813
+ };
814
+
815
+ constexpr int AnyDims = -1;
816
+
817
+ // Read a row of ElemTypes from a byte buffer and copy them into a specific image row.
818
+ // Multibyte elements are assumed to be big-endian.
819
+ template<typename ElemType, typename ImageType>
820
+ void read_big_endian_row(const uint8_t *src, int y, ImageType *im) {
821
+ auto im_typed = im->template as<ElemType, AnyDims>();
822
+ const int xmin = im_typed.dim(0).min();
823
+ const int xmax = im_typed.dim(0).max();
824
+ if (im_typed.dimensions() > 2) {
825
+ const int cmin = im_typed.dim(2).min();
826
+ const int cmax = im_typed.dim(2).max();
827
+ for (int x = xmin; x <= xmax; x++) {
828
+ for (int c = cmin; c <= cmax; c++) {
829
+ im_typed(x, y, c + cmin) = read_big_endian<ElemType>(src);
830
+ src += sizeof(ElemType);
831
+ }
832
+ }
833
+ } else {
834
+ for (int x = xmin; x <= xmax; x++) {
835
+ im_typed(x, y) = read_big_endian<ElemType>(src);
836
+ src += sizeof(ElemType);
837
+ }
838
+ }
839
+ }
840
+
841
+ // Copy a row from an image into a byte buffer.
842
+ // Multibyte elements are written in big-endian layout.
843
+ template<typename ElemType, typename ImageType>
844
+ void write_big_endian_row(const ImageType &im, int y, uint8_t *dst) {
845
+ auto im_typed = im.template as<typename std::add_const<ElemType>::type, AnyDims>();
846
+ const int xmin = im_typed.dim(0).min();
847
+ const int xmax = im_typed.dim(0).max();
848
+ if (im_typed.dimensions() > 2) {
849
+ const int cmin = im_typed.dim(2).min();
850
+ const int cmax = im_typed.dim(2).max();
851
+ for (int x = xmin; x <= xmax; x++) {
852
+ for (int c = cmin; c <= cmax; c++) {
853
+ write_big_endian<ElemType>(im_typed(x, y, c), dst);
854
+ dst += sizeof(ElemType);
855
+ }
856
+ }
857
+ } else {
858
+ for (int x = xmin; x <= xmax; x++) {
859
+ write_big_endian<ElemType>(im_typed(x, y), dst);
860
+ dst += sizeof(ElemType);
861
+ }
862
+ }
863
+ }
864
+
865
+ #ifndef HALIDE_NO_PNG
866
+
867
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
868
+ bool load_png(const std::string &filename, ImageType *im) {
869
+ static_assert(!ImageType::has_static_halide_type, "");
870
+
871
+ /* open file and test for it being a png */
872
+ Internal::FileOpener f(filename, "rb");
873
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
874
+ return false;
875
+ }
876
+ png_byte header[8];
877
+ if (!check(f.read_array(header), "File ended before end of header")) {
878
+ return false;
879
+ }
880
+ if (!check(!png_sig_cmp(header, 0, 8), "File is not recognized as a PNG file")) {
881
+ return false;
882
+ }
883
+
884
+ /* initialize stuff */
885
+ png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
886
+ if (!check(png_ptr != nullptr, "png_create_read_struct failed")) {
887
+ return false;
888
+ }
889
+
890
+ png_infop info_ptr = png_create_info_struct(png_ptr);
891
+ if (!check(info_ptr != nullptr, "png_create_info_struct failed")) {
892
+ return false;
893
+ }
894
+
895
+ if (!check(!setjmp(png_jmpbuf(png_ptr)), "Error loading PNG")) {
896
+ return false;
897
+ }
898
+
899
+ png_init_io(png_ptr, f.f);
900
+ png_set_sig_bytes(png_ptr, 8);
901
+
902
+ png_read_info(png_ptr, info_ptr);
903
+
904
+ const int width = png_get_image_width(png_ptr, info_ptr);
905
+ const int height = png_get_image_height(png_ptr, info_ptr);
906
+ const int channels = png_get_channels(png_ptr, info_ptr);
907
+ const int bit_depth = png_get_bit_depth(png_ptr, info_ptr);
908
+
909
+ const halide_type_t im_type(halide_type_uint, bit_depth);
910
+ std::vector<int> im_dimensions = {width, height};
911
+ if (channels != 1) {
912
+ im_dimensions.push_back(channels);
913
+ }
914
+
915
+ *im = ImageType(im_type, im_dimensions);
916
+
917
+ png_read_update_info(png_ptr, info_ptr);
918
+
919
+ auto copy_to_image = bit_depth == 8 ?
920
+ Internal::read_big_endian_row<uint8_t, ImageType> :
921
+ Internal::read_big_endian_row<uint16_t, ImageType>;
922
+
923
+ std::vector<uint8_t> row(png_get_rowbytes(png_ptr, info_ptr));
924
+ const int ymin = im->dim(1).min();
925
+ const int ymax = im->dim(1).max();
926
+ for (int y = ymin; y <= ymax; ++y) {
927
+ png_read_row(png_ptr, row.data(), nullptr);
928
+ copy_to_image(row.data(), y, im);
929
+ }
930
+
931
+ png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
932
+
933
+ return true;
934
+ }
935
+
936
+ inline const std::set<FormatInfo> &query_png() {
937
+ static std::set<FormatInfo> info = {
938
+ {halide_type_t(halide_type_uint, 8), 2},
939
+ {halide_type_t(halide_type_uint, 16), 2},
940
+ {halide_type_t(halide_type_uint, 8), 3},
941
+ {halide_type_t(halide_type_uint, 16), 3}};
942
+ return info;
943
+ }
944
+
945
+ // "im" is not const-ref because copy_to_host() is not const.
946
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
947
+ bool save_png(ImageType &im, const std::string &filename) {
948
+ static_assert(!ImageType::has_static_halide_type, "");
949
+
950
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
951
+ return false;
952
+ }
953
+
954
+ const int width = im.width();
955
+ const int height = im.height();
956
+ const int channels = im.channels();
957
+
958
+ if (!check(channels >= 1 && channels <= 4,
959
+ "Can't write PNG files that have other than 1, 2, 3, or 4 channels")) {
960
+ return false;
961
+ }
962
+
963
+ const png_byte color_types[4] = {
964
+ PNG_COLOR_TYPE_GRAY,
965
+ PNG_COLOR_TYPE_GRAY_ALPHA,
966
+ PNG_COLOR_TYPE_RGB,
967
+ PNG_COLOR_TYPE_RGB_ALPHA};
968
+ png_byte color_type = color_types[channels - 1];
969
+
970
+ // open file
971
+ Internal::FileOpener f(filename, "wb");
972
+ if (!check(f.f != nullptr, "[write_png_file] File could not be opened for writing")) {
973
+ return false;
974
+ }
975
+
976
+ // initialize stuff
977
+ png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
978
+ if (!check(png_ptr != nullptr, "[write_png_file] png_create_write_struct failed")) {
979
+ return false;
980
+ }
981
+
982
+ png_infop info_ptr = png_create_info_struct(png_ptr);
983
+ if (!check(info_ptr != nullptr, "[write_png_file] png_create_info_struct failed")) {
984
+ return false;
985
+ }
986
+
987
+ if (!check(!setjmp(png_jmpbuf(png_ptr)), "Error saving PNG")) {
988
+ return false;
989
+ }
990
+
991
+ png_init_io(png_ptr, f.f);
992
+
993
+ const halide_type_t im_type = im.type();
994
+ const int bit_depth = im_type.bits;
995
+
996
+ png_set_IHDR(png_ptr, info_ptr, width, height,
997
+ bit_depth, color_type, PNG_INTERLACE_NONE,
998
+ PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);
999
+
1000
+ png_write_info(png_ptr, info_ptr);
1001
+
1002
+ auto copy_from_image = bit_depth == 8 ?
1003
+ Internal::write_big_endian_row<uint8_t, ImageType> :
1004
+ Internal::write_big_endian_row<uint16_t, ImageType>;
1005
+
1006
+ std::vector<uint8_t> row(png_get_rowbytes(png_ptr, info_ptr));
1007
+ const int ymin = im.dim(1).min();
1008
+ const int ymax = im.dim(1).max();
1009
+ for (int y = ymin; y <= ymax; ++y) {
1010
+ copy_from_image(im, y, row.data());
1011
+ png_write_row(png_ptr, row.data());
1012
+ }
1013
+ png_write_end(png_ptr, nullptr);
1014
+ png_destroy_write_struct(&png_ptr, &info_ptr);
1015
+
1016
+ return true;
1017
+ }
1018
+
1019
+ #endif // not HALIDE_NO_PNG
1020
+
1021
+ template<Internal::CheckFunc check>
1022
+ bool read_pnm_header(Internal::FileOpener &f, const std::string &hdr_fmt, int *width, int *height, int *bit_depth) {
1023
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
1024
+ return false;
1025
+ }
1026
+
1027
+ char header[256];
1028
+ if (!check(f.scan_line("%255s", header) == 1, "Could not read header")) {
1029
+ return false;
1030
+ }
1031
+
1032
+ if (!check(to_lowercase(hdr_fmt) == to_lowercase(header), "Unexpected file header")) {
1033
+ return false;
1034
+ }
1035
+
1036
+ if (!check(f.scan_line("%d %d\n", width, height) == 2, "Could not read width and height")) {
1037
+ return false;
1038
+ }
1039
+
1040
+ int maxval;
1041
+ if (!check(f.scan_line("%d", &maxval) == 1, "Could not read max value")) {
1042
+ return false;
1043
+ }
1044
+ if (maxval == 255) {
1045
+ *bit_depth = 8;
1046
+ } else if (maxval == 65535) {
1047
+ *bit_depth = 16;
1048
+ } else {
1049
+ *bit_depth = 0;
1050
+ return check(false, "Invalid bit depth");
1051
+ }
1052
+
1053
+ return true;
1054
+ }
1055
+
1056
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1057
+ bool load_pnm(const std::string &filename, int channels, ImageType *im) {
1058
+ static_assert(!ImageType::has_static_halide_type, "");
1059
+
1060
+ const char *hdr_fmt = channels == 3 ? "P6" : "P5";
1061
+
1062
+ Internal::FileOpener f(filename, "rb");
1063
+ int width, height, bit_depth;
1064
+ if (!Internal::read_pnm_header<check>(f, hdr_fmt, &width, &height, &bit_depth)) {
1065
+ return false;
1066
+ }
1067
+
1068
+ const halide_type_t im_type(halide_type_uint, bit_depth);
1069
+ std::vector<int> im_dimensions = {width, height};
1070
+ if (channels > 1) {
1071
+ im_dimensions.push_back(channels);
1072
+ }
1073
+ *im = ImageType(im_type, im_dimensions);
1074
+
1075
+ auto copy_to_image = bit_depth == 8 ?
1076
+ Internal::read_big_endian_row<uint8_t, ImageType> :
1077
+ Internal::read_big_endian_row<uint16_t, ImageType>;
1078
+
1079
+ std::vector<uint8_t> row(width * channels * (bit_depth / 8));
1080
+ const int ymin = im->dim(1).min();
1081
+ const int ymax = im->dim(1).max();
1082
+ for (int y = ymin; y <= ymax; ++y) {
1083
+ if (!check(f.read_vector(&row), "Could not read data")) {
1084
+ return false;
1085
+ }
1086
+ copy_to_image(row.data(), y, im);
1087
+ }
1088
+
1089
+ return true;
1090
+ }
1091
+
1092
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1093
+ bool save_pnm(ImageType &im, const int channels, const std::string &filename) {
1094
+ static_assert(!ImageType::has_static_halide_type, "");
1095
+
1096
+ if (!check(im.channels() == channels, "Wrong number of channels")) {
1097
+ return false;
1098
+ }
1099
+
1100
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
1101
+ return false;
1102
+ }
1103
+
1104
+ const halide_type_t im_type = im.type();
1105
+ const int width = im.width();
1106
+ const int height = im.height();
1107
+ const int bit_depth = im_type.bits;
1108
+
1109
+ Internal::FileOpener f(filename, "wb");
1110
+ if (!check(f.f != nullptr, "File could not be opened for writing")) {
1111
+ return false;
1112
+ }
1113
+ const char *hdr_fmt = channels == 3 ? "P6" : "P5";
1114
+ fprintf(f.f, "%s\n%d %d\n%d\n", hdr_fmt, width, height, (1 << bit_depth) - 1);
1115
+
1116
+ auto copy_from_image = bit_depth == 8 ?
1117
+ Internal::write_big_endian_row<uint8_t, ImageType> :
1118
+ Internal::write_big_endian_row<uint16_t, ImageType>;
1119
+
1120
+ std::vector<uint8_t> row(width * channels * (bit_depth / 8));
1121
+ const int ymin = im.dim(1).min();
1122
+ const int ymax = im.dim(1).max();
1123
+ for (int y = ymin; y <= ymax; ++y) {
1124
+ copy_from_image(im, y, row.data());
1125
+ if (!check(f.write_vector(row), "Could not write data")) {
1126
+ return false;
1127
+ }
1128
+ }
1129
+
1130
+ return true;
1131
+ }
1132
+
1133
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1134
+ bool load_pgm(const std::string &filename, ImageType *im) {
1135
+ return Internal::load_pnm<ImageType, check>(filename, 1, im);
1136
+ }
1137
+
1138
+ inline const std::set<FormatInfo> &query_pgm() {
1139
+ static std::set<FormatInfo> info = {
1140
+ {halide_type_t(halide_type_uint, 8), 2},
1141
+ {halide_type_t(halide_type_uint, 16), 2}};
1142
+ return info;
1143
+ }
1144
+
1145
+ // "im" is not const-ref because copy_to_host() is not const.
1146
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1147
+ bool save_pgm(ImageType &im, const std::string &filename) {
1148
+ return Internal::save_pnm<ImageType, check>(im, 1, filename);
1149
+ }
1150
+
1151
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1152
+ bool load_ppm(const std::string &filename, ImageType *im) {
1153
+ return Internal::load_pnm<ImageType, check>(filename, 3, im);
1154
+ }
1155
+
1156
+ inline const std::set<FormatInfo> &query_ppm() {
1157
+ static std::set<FormatInfo> info = {
1158
+ {halide_type_t(halide_type_uint, 8), 3},
1159
+ {halide_type_t(halide_type_uint, 16), 3}};
1160
+ return info;
1161
+ }
1162
+
1163
+ // "im" is not const-ref because copy_to_host() is not const.
1164
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1165
+ bool save_ppm(ImageType &im, const std::string &filename) {
1166
+ return Internal::save_pnm<ImageType, check>(im, 3, filename);
1167
+ }
1168
+
1169
+ // -------------- .npy file format
1170
+ // Based on documentation at https://numpy.org/devdocs/reference/generated/numpy.lib.format.html
1171
+ // and elsewhere
1172
+
1173
+ #if (defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN) || defined(HALIDE_FORCE_BIG_ENDIAN)
1174
+ constexpr bool host_is_big_endian = true;
1175
+ #else
1176
+ constexpr bool host_is_big_endian = false;
1177
+ #endif
1178
+
1179
+ constexpr char little_endian_char = '<';
1180
+ constexpr char big_endian_char = '>';
1181
+ constexpr char no_endian_char = '|';
1182
+ constexpr char host_endian_char = (host_is_big_endian ? big_endian_char : little_endian_char);
1183
+
1184
+ struct npy_dtype_info_t {
1185
+ char byte_order;
1186
+ char type_code;
1187
+ char type_bytes;
1188
+
1189
+ std::string descr() const {
1190
+ return std::string(1, byte_order) + std::string(1, type_code) + std::to_string((int)type_bytes);
1191
+ }
1192
+ };
1193
+
1194
+ inline static const std::array<std::pair<halide_type_t, npy_dtype_info_t>, 11> npy_dtypes = {{
1195
+ {halide_type_t(halide_type_float, 16), {host_endian_char, 'f', 2}},
1196
+ {halide_type_of<float>(), {host_endian_char, 'f', sizeof(float)}},
1197
+ {halide_type_of<double>(), {host_endian_char, 'f', sizeof(double)}},
1198
+ {halide_type_of<int8_t>(), {no_endian_char, 'i', sizeof(int8_t)}},
1199
+ {halide_type_of<int16_t>(), {host_endian_char, 'i', sizeof(int16_t)}},
1200
+ {halide_type_of<int32_t>(), {host_endian_char, 'i', sizeof(int32_t)}},
1201
+ {halide_type_of<int64_t>(), {host_endian_char, 'i', sizeof(int64_t)}},
1202
+ {halide_type_of<uint8_t>(), {no_endian_char, 'u', sizeof(uint8_t)}},
1203
+ {halide_type_of<uint16_t>(), {host_endian_char, 'u', sizeof(uint16_t)}},
1204
+ {halide_type_of<uint32_t>(), {host_endian_char, 'u', sizeof(uint32_t)}},
1205
+ {halide_type_of<uint64_t>(), {host_endian_char, 'u', sizeof(uint64_t)}},
1206
+ }};
1207
+
1208
+ inline static const std::array<char, 6> npy_magic_string = {'\x93', 'N', 'U', 'M', 'P', 'Y'};
1209
+ inline static const std::array<char, 2> npy_v1_bytes = {'\x01', '\x00'};
1210
+
1211
+ inline std::string trim_whitespace(const std::string &s) {
1212
+ const size_t first = s.find_first_not_of(" \t\n");
1213
+ if (first == std::string::npos) {
1214
+ return "";
1215
+ }
1216
+ const size_t last = s.find_last_not_of(" \t\n");
1217
+ return s.substr(first, (last - first + 1));
1218
+ }
1219
+
1220
+ struct NpyHeader {
1221
+ char type_code;
1222
+ int type_bytes;
1223
+ std::vector<int> extents;
1224
+
1225
+ bool parse(const std::string &header) {
1226
+ const char *ptr = &header[0];
1227
+ if (*ptr++ != '{') {
1228
+ return false;
1229
+ }
1230
+ while (true) {
1231
+ char endian;
1232
+ int consumed;
1233
+ if (std::sscanf(ptr, "'descr': '%c%c%d'%n", &endian, &type_code, &type_bytes, &consumed) == 3) {
1234
+ if (endian != '<' && endian != '|') {
1235
+ return false;
1236
+ }
1237
+ ptr += consumed;
1238
+ } else if (std::strncmp(ptr, "'fortran_order': False", 22) == 0) {
1239
+ ptr += 22;
1240
+ } else if (std::strncmp(ptr, "'shape': (", 10) == 0) {
1241
+ ptr += 10;
1242
+ int n;
1243
+ while (std::sscanf(ptr, "%d%n", &n, &consumed) == 1) {
1244
+ extents.push_back(n);
1245
+ ptr += consumed;
1246
+ if (*ptr == ',') {
1247
+ ptr++;
1248
+ }
1249
+ if (*ptr == ' ') {
1250
+ ptr++;
1251
+ }
1252
+ }
1253
+ if (*ptr++ != ')') {
1254
+ return false;
1255
+ }
1256
+ } else if (*ptr == '}') {
1257
+ return true;
1258
+ } else {
1259
+ return false;
1260
+ }
1261
+ if (*ptr == ',') {
1262
+ ptr++;
1263
+ }
1264
+ if (*ptr == ' ') {
1265
+ ptr++;
1266
+ }
1267
+ assert(ptr <= &header.back());
1268
+ }
1269
+ }
1270
+ };
1271
+
1272
+ // return true iff the buffer storage has no padding between
1273
+ // any elements, and is in strictly planar order.
1274
+ template<typename ImageType>
1275
+ bool buffer_is_compact_planar(ImageType &im) {
1276
+ const halide_type_t im_type = im.type();
1277
+ const size_t elem_size = (im_type.bits / 8);
1278
+ if (((const uint8_t *)im.begin() + (im.number_of_elements() * elem_size)) != (const uint8_t *)im.end()) {
1279
+ return false;
1280
+ }
1281
+ for (int d = 1; d < im.dimensions(); ++d) {
1282
+ if (im.dim(d - 1).stride() > im.dim(d).stride()) {
1283
+ return false;
1284
+ }
1285
+ // Strides can only match if the previous dimension has extent 1
1286
+ // (this can happen when artificially adding dimension(s), e.g.
1287
+ // to write a .tmp file)
1288
+ if (im.dim(d - 1).stride() == im.dim(d).stride() && im.dim(d - 1).extent() != 1) {
1289
+ return false;
1290
+ }
1291
+ }
1292
+ return true;
1293
+ }
1294
+
1295
+ template<typename ImageType, CheckFunc check = CheckReturn>
1296
+ bool load_npy(const std::string &filename, ImageType *im) {
1297
+ static_assert(!ImageType::has_static_halide_type, "");
1298
+
1299
+ FileOpener f(filename, "rb");
1300
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
1301
+ return false;
1302
+ }
1303
+
1304
+ char magic_and_version[8];
1305
+ if (!check(f.read_bytes(magic_and_version, 8), "Could not read .npy header")) {
1306
+ return false;
1307
+ }
1308
+ if (memcmp(magic_and_version, npy_magic_string.data(), npy_magic_string.size()) != 0) {
1309
+ return check(false, "Bad .npy magic string");
1310
+ }
1311
+ if ((magic_and_version[6] != 1 && magic_and_version[6] != 2 && magic_and_version[6] != 3) || magic_and_version[7] != 0) {
1312
+ return check(false, "Bad .npy version");
1313
+ }
1314
+ size_t header_len;
1315
+ uint8_t header_len_le[4];
1316
+ if (magic_and_version[6] == 1) {
1317
+ if (!check(f.read_bytes(header_len_le, 2), "Could not read .npy header")) {
1318
+ return false;
1319
+ }
1320
+ header_len = (header_len_le[0] << 0) | (header_len_le[1] << 8);
1321
+ if (!check((6 + 2 + 2 + header_len) % 64 == 0, ".npy header is not aligned properly")) {
1322
+ return false;
1323
+ }
1324
+ } else {
1325
+ if (!check(f.read_bytes(header_len_le, 4), "Could not read .npy header")) {
1326
+ return false;
1327
+ }
1328
+ header_len = (header_len_le[0] << 0) | (header_len_le[1] << 8) | (header_len_le[2] << 16) | (header_len_le[3] << 24);
1329
+ if (!check((6 + 2 + 4 + header_len) % 64 == 0, ".npy header is not aligned properly")) {
1330
+ return false;
1331
+ }
1332
+ }
1333
+
1334
+ std::string header(header_len + 1, ' ');
1335
+ if (!check(f.read_bytes(header.data(), header_len), "Could not read .npy header string")) {
1336
+ return false;
1337
+ }
1338
+
1339
+ NpyHeader h;
1340
+ if (!check(h.parse(header), "Could not parse .npy header dict")) {
1341
+ return false;
1342
+ }
1343
+
1344
+ halide_type_t im_type((halide_type_code_t)0, 0, 0);
1345
+ for (const auto &d : npy_dtypes) {
1346
+ if (h.type_code == d.second.type_code && h.type_bytes == d.second.type_bytes) {
1347
+ im_type = d.first;
1348
+ break;
1349
+ }
1350
+ }
1351
+ if (!check(im_type.bits != 0, "Unsupported type in load_npy")) {
1352
+ return false;
1353
+ }
1354
+
1355
+ *im = ImageType(im_type, h.extents);
1356
+
1357
+ // This should never fail unless the default Buffer<> constructor behavior changes.
1358
+ if (!check(buffer_is_compact_planar(*im), "load_npy() requires compact planar images")) {
1359
+ return false;
1360
+ }
1361
+
1362
+ if (!check(f.read_bytes(im->begin(), im->size_in_bytes()), "Count not read .npy payload")) {
1363
+ return false;
1364
+ }
1365
+
1366
+ im->set_host_dirty();
1367
+ return true;
1368
+ }
1369
+
1370
+ template<typename ImageType, CheckFunc check = CheckReturn>
1371
+ bool write_planar_payload(ImageType &im, FileOpener &f) {
1372
+ if (im.dimensions() == 0 || buffer_is_compact_planar(im)) {
1373
+ // Contiguous buffer! Write it all in one swell foop.
1374
+ if (!check(f.write_bytes(im.begin(), im.size_in_bytes()), "Count not write planar payload")) {
1375
+ return false;
1376
+ }
1377
+ } else {
1378
+ // We have to do this the hard way.
1379
+ int d = im.dimensions() - 1;
1380
+ for (int i = im.dim(d).min(); i <= im.dim(d).max(); i++) {
1381
+ auto slice = im.sliced(d, i);
1382
+ if (!write_planar_payload(slice, f)) {
1383
+ return false;
1384
+ }
1385
+ }
1386
+ }
1387
+ return true;
1388
+ }
1389
+
1390
+ template<typename ImageType, CheckFunc check = CheckReturn>
1391
+ bool save_npy(ImageType &im, const std::string &filename) {
1392
+ static_assert(!ImageType::has_static_halide_type, "");
1393
+
1394
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
1395
+ return false;
1396
+ }
1397
+
1398
+ const halide_type_t im_type = im.type();
1399
+ npy_dtype_info_t di = {0, 0, 0};
1400
+ for (const auto &d : npy_dtypes) {
1401
+ if (d.first == im_type) {
1402
+ di = d.second;
1403
+ break;
1404
+ }
1405
+ }
1406
+ if (!check(di.byte_order != 0, "Unsupported type in save_npy")) {
1407
+ return false;
1408
+ }
1409
+
1410
+ std::string shape = "(";
1411
+ for (int d = 0; d < im.dimensions(); ++d) {
1412
+ if (d > 0) {
1413
+ shape += ",";
1414
+ }
1415
+ shape += std::to_string(im.dim(d).extent());
1416
+ if (im.dimensions() == 1) {
1417
+ shape += ","; // special-case for single-element tuples
1418
+ }
1419
+ }
1420
+ shape += ")";
1421
+
1422
+ std::string header_dict_str = "{'descr': '" + di.descr() + "', 'fortran_order': False, 'shape': " + shape + "}\n";
1423
+
1424
+ const size_t unpadded_length = npy_magic_string.size() + npy_v1_bytes.size() + 2 + header_dict_str.size();
1425
+ const size_t padded_length = (unpadded_length + 64 - 1) & ~(64 - 1);
1426
+ const size_t padding = padded_length - unpadded_length;
1427
+ header_dict_str += std::string(padding, ' ');
1428
+
1429
+ if (!check(header_dict_str.size() <= 65535, "Header is too large for v1 .npy file")) {
1430
+ return false;
1431
+ }
1432
+ const uint16_t header_len = (uint16_t)(header_dict_str.size());
1433
+ const uint8_t header_len_le[2] = {
1434
+ (uint8_t)((header_len >> 0) & 0xff),
1435
+ (uint8_t)((header_len >> 8) & 0xff)};
1436
+
1437
+ FileOpener f(filename, "wb");
1438
+ if (!check(f.write_bytes(npy_magic_string.data(), npy_magic_string.size()), ".npy write failed")) {
1439
+ return false;
1440
+ }
1441
+ if (!check(f.write_bytes(npy_v1_bytes.data(), npy_v1_bytes.size()), ".npy write failed")) {
1442
+ return false;
1443
+ }
1444
+ if (!check(f.write_bytes(header_len_le, 2), ".npy write failed")) {
1445
+ return false;
1446
+ }
1447
+ if (!check(f.write_bytes(header_dict_str.data(), header_dict_str.size()), ".npy write failed")) {
1448
+ return false;
1449
+ }
1450
+
1451
+ if (!write_planar_payload<ImageType, check>(im, f)) {
1452
+ return false;
1453
+ }
1454
+
1455
+ return true;
1456
+ }
1457
+
1458
+ inline const std::set<FormatInfo> &query_npy() {
1459
+ auto build_set = []() -> std::set<FormatInfo> {
1460
+ // NumPy doesn't support bfloat16, not sure if they plan to,
1461
+ // so we don't attempt to support it here
1462
+ std::set<FormatInfo> s;
1463
+ for (halide_type_code_t code : {halide_type_int, halide_type_uint, halide_type_float}) {
1464
+ for (int bits : {8, 16, 32, 64}) {
1465
+ if (code == halide_type_float && bits < 16) {
1466
+ continue;
1467
+ }
1468
+ for (int dims : {1, 2, 3, 4}) {
1469
+ s.insert({halide_type_t(code, bits), dims});
1470
+ }
1471
+ }
1472
+ }
1473
+ return s;
1474
+ };
1475
+
1476
+ static std::set<FormatInfo> info = build_set();
1477
+ return info;
1478
+ }
1479
+
1480
+ #ifndef HALIDE_NO_JPEG
1481
+
1482
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1483
+ bool load_jpg(const std::string &filename, ImageType *im) {
1484
+ static_assert(!ImageType::has_static_halide_type, "");
1485
+
1486
+ Internal::FileOpener f(filename, "rb");
1487
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
1488
+ return false;
1489
+ }
1490
+
1491
+ struct jpeg_decompress_struct cinfo;
1492
+ struct jpeg_error_mgr jerr;
1493
+ cinfo.err = jpeg_std_error(&jerr);
1494
+ jpeg_create_decompress(&cinfo);
1495
+ jpeg_stdio_src(&cinfo, f.f);
1496
+ jpeg_read_header(&cinfo, TRUE);
1497
+ jpeg_start_decompress(&cinfo);
1498
+
1499
+ const int width = cinfo.output_width;
1500
+ const int height = cinfo.output_height;
1501
+ const int channels = cinfo.output_components;
1502
+
1503
+ const halide_type_t im_type(halide_type_uint, 8);
1504
+ std::vector<int> im_dimensions = {width, height};
1505
+ if (channels > 1) {
1506
+ im_dimensions.push_back(channels);
1507
+ }
1508
+ *im = ImageType(im_type, im_dimensions);
1509
+
1510
+ auto copy_to_image = Internal::read_big_endian_row<uint8_t, ImageType>;
1511
+
1512
+ std::vector<uint8_t> row(width * channels);
1513
+ const int ymin = im->dim(1).min();
1514
+ const int ymax = im->dim(1).max();
1515
+ for (int y = ymin; y <= ymax; ++y) {
1516
+ uint8_t *src = row.data();
1517
+ jpeg_read_scanlines(&cinfo, &src, 1);
1518
+ copy_to_image(row.data(), y, im);
1519
+ }
1520
+
1521
+ jpeg_finish_decompress(&cinfo);
1522
+ jpeg_destroy_decompress(&cinfo);
1523
+
1524
+ return true;
1525
+ }
1526
+
1527
+ inline const std::set<FormatInfo> &query_jpg() {
1528
+ static std::set<FormatInfo> info = {
1529
+ {halide_type_t(halide_type_uint, 8), 2},
1530
+ {halide_type_t(halide_type_uint, 8), 3},
1531
+ };
1532
+ return info;
1533
+ }
1534
+
1535
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
1536
+ bool save_jpg(ImageType &im, const std::string &filename) {
1537
+ static_assert(!ImageType::has_static_halide_type, "");
1538
+
1539
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
1540
+ return false;
1541
+ }
1542
+
1543
+ const int width = im.width();
1544
+ const int height = im.height();
1545
+ const int channels = im.channels();
1546
+ if (!check(channels == 1 || channels == 3, "Wrong number of channels")) {
1547
+ return false;
1548
+ }
1549
+
1550
+ Internal::FileOpener f(filename, "wb");
1551
+ if (!check(f.f != nullptr, "File could not be opened for writing")) {
1552
+ return false;
1553
+ }
1554
+
1555
+ // TODO: Make this an argument?
1556
+ constexpr int quality = 99;
1557
+
1558
+ struct jpeg_compress_struct cinfo;
1559
+ struct jpeg_error_mgr jerr;
1560
+ cinfo.err = jpeg_std_error(&jerr);
1561
+ jpeg_create_compress(&cinfo);
1562
+ jpeg_stdio_dest(&cinfo, f.f);
1563
+ cinfo.image_width = width;
1564
+ cinfo.image_height = height;
1565
+ cinfo.input_components = channels;
1566
+ cinfo.in_color_space = (channels == 3) ? JCS_RGB : JCS_GRAYSCALE;
1567
+ jpeg_set_defaults(&cinfo);
1568
+ jpeg_set_quality(&cinfo, quality, TRUE);
1569
+ jpeg_start_compress(&cinfo, TRUE);
1570
+
1571
+ auto copy_from_image = Internal::write_big_endian_row<uint8_t, ImageType>;
1572
+
1573
+ std::vector<uint8_t> row(width * channels);
1574
+ const int ymin = im.dim(1).min();
1575
+ const int ymax = im.dim(1).max();
1576
+ for (int y = ymin; y <= ymax; ++y) {
1577
+ uint8_t *dst = row.data();
1578
+ copy_from_image(im, y, dst);
1579
+ jpeg_write_scanlines(&cinfo, &dst, 1);
1580
+ }
1581
+
1582
+ jpeg_finish_compress(&cinfo);
1583
+ jpeg_destroy_compress(&cinfo);
1584
+
1585
+ return true;
1586
+ }
1587
+
1588
+ #endif // not HALIDE_NO_JPEG
1589
+
1590
+ constexpr int kNumTmpCodes = 10;
1591
+
1592
+ inline const halide_type_t *tmp_code_to_halide_type() {
1593
+ static const halide_type_t tmp_code_to_halide_type_[kNumTmpCodes] = {
1594
+ {halide_type_float, 32},
1595
+ {halide_type_float, 64},
1596
+ {halide_type_uint, 8},
1597
+ {halide_type_int, 8},
1598
+ {halide_type_uint, 16},
1599
+ {halide_type_int, 16},
1600
+ {halide_type_uint, 32},
1601
+ {halide_type_int, 32},
1602
+ {halide_type_uint, 64},
1603
+ {halide_type_int, 64}};
1604
+ return tmp_code_to_halide_type_;
1605
+ }
1606
+
1607
+ // ".tmp" is a file format used by the ImageStack tool (see https://github.com/abadams/ImageStack)
1608
+ template<typename ImageType, CheckFunc check = CheckReturn>
1609
+ bool load_tmp(const std::string &filename, ImageType *im) {
1610
+ static_assert(!ImageType::has_static_halide_type, "");
1611
+
1612
+ FileOpener f(filename, "rb");
1613
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
1614
+ return false;
1615
+ }
1616
+
1617
+ int32_t header[5];
1618
+ if (!check(f.read_array(header), "Count not read .tmp header")) {
1619
+ return false;
1620
+ }
1621
+
1622
+ if (!check(header[0] > 0 && header[1] > 0 && header[2] > 0 && header[3] > 0 &&
1623
+ header[4] >= 0 && header[4] < kNumTmpCodes,
1624
+ "Bad header on .tmp file")) {
1625
+ return false;
1626
+ }
1627
+
1628
+ const halide_type_t im_type = tmp_code_to_halide_type()[header[4]];
1629
+ std::vector<int> im_dimensions = {header[0], header[1], header[2], header[3]};
1630
+ *im = ImageType(im_type, im_dimensions);
1631
+
1632
+ // This should never fail unless the default Buffer<> constructor behavior changes.
1633
+ if (!check(buffer_is_compact_planar(*im), "load_tmp() requires compact planar images")) {
1634
+ return false;
1635
+ }
1636
+
1637
+ if (!check(f.read_bytes(im->begin(), im->size_in_bytes()), "Count not read .tmp payload")) {
1638
+ return false;
1639
+ }
1640
+
1641
+ im->set_host_dirty();
1642
+ return true;
1643
+ }
1644
+
1645
+ inline const std::set<FormatInfo> &query_tmp() {
1646
+ // TMP files require exactly 4 dimensions.
1647
+ static std::set<FormatInfo> info = {
1648
+ {halide_type_t(halide_type_float, 32), 4},
1649
+ {halide_type_t(halide_type_float, 64), 4},
1650
+ {halide_type_t(halide_type_uint, 8), 4},
1651
+ {halide_type_t(halide_type_int, 8), 4},
1652
+ {halide_type_t(halide_type_uint, 16), 4},
1653
+ {halide_type_t(halide_type_int, 16), 4},
1654
+ {halide_type_t(halide_type_uint, 32), 4},
1655
+ {halide_type_t(halide_type_int, 32), 4},
1656
+ {halide_type_t(halide_type_uint, 64), 4},
1657
+ {halide_type_t(halide_type_int, 64), 4},
1658
+ };
1659
+ return info;
1660
+ }
1661
+
1662
+ // ".tmp" is a file format used by the ImageStack tool (see https://github.com/abadams/ImageStack)
1663
+ template<typename ImageType, CheckFunc check = CheckReturn>
1664
+ bool save_tmp(ImageType &im, const std::string &filename) {
1665
+ static_assert(!ImageType::has_static_halide_type, "");
1666
+
1667
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
1668
+ return false;
1669
+ }
1670
+
1671
+ int32_t header[5] = {1, 1, 1, 1, -1};
1672
+ for (int i = 0; i < im.dimensions(); ++i) {
1673
+ header[i] = im.dim(i).extent();
1674
+ }
1675
+ const auto *table = tmp_code_to_halide_type();
1676
+ for (int i = 0; i < kNumTmpCodes; i++) {
1677
+ if (im.type() == table[i]) {
1678
+ header[4] = i;
1679
+ break;
1680
+ }
1681
+ }
1682
+ if (!check(header[4] >= 0, "Unsupported type for .tmp file")) {
1683
+ return false;
1684
+ }
1685
+
1686
+ FileOpener f(filename, "wb");
1687
+ if (!check(f.f != nullptr, "File could not be opened for writing")) {
1688
+ return false;
1689
+ }
1690
+ if (!check(f.write_array(header), "Could not write .tmp header")) {
1691
+ return false;
1692
+ }
1693
+
1694
+ if (!write_planar_payload<ImageType, check>(im, f)) {
1695
+ return false;
1696
+ }
1697
+
1698
+ return true;
1699
+ }
1700
+
1701
+ // ".mat" is the matlab level 5 format documented here:
1702
+ // http://www.mathworks.com/help/pdf_doc/matlab/matfile_format.pdf
1703
+
1704
+ enum MatlabTypeCode {
1705
+ miINT8 = 1,
1706
+ miUINT8 = 2,
1707
+ miINT16 = 3,
1708
+ miUINT16 = 4,
1709
+ miINT32 = 5,
1710
+ miUINT32 = 6,
1711
+ miSINGLE = 7,
1712
+ miDOUBLE = 9,
1713
+ miINT64 = 12,
1714
+ miUINT64 = 13,
1715
+ miMATRIX = 14,
1716
+ miCOMPRESSED = 15,
1717
+ miUTF8 = 16,
1718
+ miUTF16 = 17,
1719
+ miUTF32 = 18
1720
+ };
1721
+
1722
+ enum MatlabClassCode {
1723
+ mxCHAR_CLASS = 3,
1724
+ mxDOUBLE_CLASS = 6,
1725
+ mxSINGLE_CLASS = 7,
1726
+ mxINT8_CLASS = 8,
1727
+ mxUINT8_CLASS = 9,
1728
+ mxINT16_CLASS = 10,
1729
+ mxUINT16_CLASS = 11,
1730
+ mxINT32_CLASS = 12,
1731
+ mxUINT32_CLASS = 13,
1732
+ mxINT64_CLASS = 14,
1733
+ mxUINT64_CLASS = 15
1734
+ };
1735
+
1736
+ template<typename ImageType, CheckFunc check = CheckReturn>
1737
+ bool load_mat(const std::string &filename, ImageType *im) {
1738
+ static_assert(!ImageType::has_static_halide_type, "");
1739
+
1740
+ FileOpener f(filename, "rb");
1741
+ if (!check(f.f != nullptr, "File could not be opened for reading")) {
1742
+ return false;
1743
+ }
1744
+
1745
+ uint8_t header[128];
1746
+ if (!check(f.read_array(header), "Could not read .mat header\n")) {
1747
+ return false;
1748
+ }
1749
+
1750
+ // Matrix header
1751
+ uint32_t matrix_header[2];
1752
+ if (!check(f.read_array(matrix_header), "Could not read .mat header\n")) {
1753
+ return false;
1754
+ }
1755
+ if (!check(matrix_header[0] == miMATRIX, "Could not parse this .mat file: bad matrix header\n")) {
1756
+ return false;
1757
+ }
1758
+
1759
+ // Array flags
1760
+ uint32_t flags[4];
1761
+ if (!check(f.read_array(flags), "Could not read .mat header\n")) {
1762
+ return false;
1763
+ }
1764
+ if (!check(flags[0] == miUINT32 && flags[1] == 8, "Could not parse this .mat file: bad flags\n")) {
1765
+ return false;
1766
+ }
1767
+
1768
+ // Shape
1769
+ uint32_t shape_header[2];
1770
+ if (!check(f.read_array(shape_header), "Could not read .mat header\n")) {
1771
+ return false;
1772
+ }
1773
+ if (!check(shape_header[0] == miINT32, "Could not parse this .mat file: bad shape header\n")) {
1774
+ return false;
1775
+ }
1776
+ int dims = shape_header[1] / 4;
1777
+ std::vector<int> extents(dims);
1778
+ if (!check(f.read_vector(&extents), "Could not read .mat header\n")) {
1779
+ return false;
1780
+ }
1781
+ if (dims & 1) {
1782
+ uint32_t padding;
1783
+ if (!check(f.read_bytes(&padding, 4), "Could not read .mat header\n")) {
1784
+ return false;
1785
+ }
1786
+ }
1787
+
1788
+ // Skip over the name
1789
+ uint32_t name_header[2];
1790
+ if (!check(f.read_array(name_header), "Could not read .mat header\n")) {
1791
+ return false;
1792
+ }
1793
+
1794
+ if (name_header[0] >> 16) {
1795
+ // Name must be fewer than 4 chars, and so the whole name
1796
+ // field was stored packed into 8 bytes
1797
+ } else {
1798
+ if (!check(name_header[0] == miINT8, "Could not parse this .mat file: bad name header\n")) {
1799
+ return false;
1800
+ }
1801
+ std::vector<uint64_t> scratch((name_header[1] + 7) / 8);
1802
+ if (!check(f.read_vector(&scratch), "Could not read .mat header\n")) {
1803
+ return false;
1804
+ }
1805
+ }
1806
+
1807
+ // Payload header
1808
+ uint32_t payload_header[2];
1809
+ if (!check(f.read_array(payload_header), "Could not read .mat header\n")) {
1810
+ return false;
1811
+ }
1812
+ halide_type_t type;
1813
+ switch (payload_header[0]) {
1814
+ case miINT8:
1815
+ type = halide_type_of<int8_t>();
1816
+ break;
1817
+ case miINT16:
1818
+ type = halide_type_of<int16_t>();
1819
+ break;
1820
+ case miINT32:
1821
+ type = halide_type_of<int32_t>();
1822
+ break;
1823
+ case miINT64:
1824
+ type = halide_type_of<int64_t>();
1825
+ break;
1826
+ case miUINT8:
1827
+ type = halide_type_of<uint8_t>();
1828
+ break;
1829
+ case miUINT16:
1830
+ type = halide_type_of<uint16_t>();
1831
+ break;
1832
+ case miUINT32:
1833
+ type = halide_type_of<uint32_t>();
1834
+ break;
1835
+ case miUINT64:
1836
+ type = halide_type_of<uint64_t>();
1837
+ break;
1838
+ case miSINGLE:
1839
+ type = halide_type_of<float>();
1840
+ break;
1841
+ case miDOUBLE:
1842
+ type = halide_type_of<double>();
1843
+ break;
1844
+ default:
1845
+ check(false, "Unknown header");
1846
+ return false;
1847
+ }
1848
+
1849
+ *im = ImageType(type, extents);
1850
+
1851
+ // This should never fail unless the default Buffer<> constructor behavior changes.
1852
+ if (!check(buffer_is_compact_planar(*im), "load_mat() requires compact planar images")) {
1853
+ return false;
1854
+ }
1855
+
1856
+ if (!check(f.read_bytes(im->begin(), im->size_in_bytes()), "Could not read .tmp payload")) {
1857
+ return false;
1858
+ }
1859
+
1860
+ im->set_host_dirty();
1861
+ return true;
1862
+ }
1863
+
1864
+ inline const std::set<FormatInfo> &query_mat() {
1865
+ // MAT files must have at least 2 dimensions, but there's no upper
1866
+ // bound. Our support arbitrarily stops at 16 dimensions.
1867
+ static std::set<FormatInfo> info = []() {
1868
+ std::set<FormatInfo> s;
1869
+ for (int i = 2; i < 16; i++) {
1870
+ s.insert({halide_type_t(halide_type_float, 32), i});
1871
+ s.insert({halide_type_t(halide_type_float, 64), i});
1872
+ s.insert({halide_type_t(halide_type_uint, 8), i});
1873
+ s.insert({halide_type_t(halide_type_int, 8), i});
1874
+ s.insert({halide_type_t(halide_type_uint, 16), i});
1875
+ s.insert({halide_type_t(halide_type_int, 16), i});
1876
+ s.insert({halide_type_t(halide_type_uint, 32), i});
1877
+ s.insert({halide_type_t(halide_type_int, 32), i});
1878
+ s.insert({halide_type_t(halide_type_uint, 64), i});
1879
+ s.insert({halide_type_t(halide_type_int, 64), i});
1880
+ }
1881
+ return s;
1882
+ }();
1883
+ return info;
1884
+ }
1885
+
1886
+ template<typename ImageType, CheckFunc check = CheckReturn>
1887
+ bool save_mat(ImageType &im, const std::string &filename) {
1888
+ static_assert(!ImageType::has_static_halide_type, "");
1889
+
1890
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
1891
+ return false;
1892
+ }
1893
+
1894
+ uint32_t class_code = 0, type_code = 0;
1895
+ switch (im.raw_buffer()->type.code) {
1896
+ case halide_type_int:
1897
+ switch (im.raw_buffer()->type.bits) {
1898
+ case 8:
1899
+ class_code = mxINT8_CLASS;
1900
+ type_code = miINT8;
1901
+ break;
1902
+ case 16:
1903
+ class_code = mxINT16_CLASS;
1904
+ type_code = miINT16;
1905
+ break;
1906
+ case 32:
1907
+ class_code = mxINT32_CLASS;
1908
+ type_code = miINT32;
1909
+ break;
1910
+ case 64:
1911
+ class_code = mxINT64_CLASS;
1912
+ type_code = miINT64;
1913
+ break;
1914
+ default:
1915
+ check(false, "unreachable");
1916
+ };
1917
+ break;
1918
+ case halide_type_uint:
1919
+ switch (im.raw_buffer()->type.bits) {
1920
+ case 8:
1921
+ class_code = mxUINT8_CLASS;
1922
+ type_code = miUINT8;
1923
+ break;
1924
+ case 16:
1925
+ class_code = mxUINT16_CLASS;
1926
+ type_code = miUINT16;
1927
+ break;
1928
+ case 32:
1929
+ class_code = mxUINT32_CLASS;
1930
+ type_code = miUINT32;
1931
+ break;
1932
+ case 64:
1933
+ class_code = mxUINT64_CLASS;
1934
+ type_code = miUINT64;
1935
+ break;
1936
+ default:
1937
+ check(false, "unreachable");
1938
+ };
1939
+ break;
1940
+ case halide_type_float:
1941
+ switch (im.raw_buffer()->type.bits) {
1942
+ case 16:
1943
+ check(false, "float16 not supported by .mat");
1944
+ break;
1945
+ case 32:
1946
+ class_code = mxSINGLE_CLASS;
1947
+ type_code = miSINGLE;
1948
+ break;
1949
+ case 64:
1950
+ class_code = mxDOUBLE_CLASS;
1951
+ type_code = miDOUBLE;
1952
+ break;
1953
+ default:
1954
+ check(false, "unreachable");
1955
+ };
1956
+ break;
1957
+ case halide_type_bfloat:
1958
+ check(false, "bfloat not supported by .mat");
1959
+ break;
1960
+ default:
1961
+ check(false, "unreachable");
1962
+ }
1963
+
1964
+ FileOpener f(filename, "wb");
1965
+ if (!check(f.f != nullptr, "File could not be opened for writing")) {
1966
+ return false;
1967
+ }
1968
+
1969
+ // Pick a name for the array
1970
+ size_t idx = filename.rfind('.');
1971
+ std::string name = filename.substr(0, idx);
1972
+ idx = filename.rfind('/');
1973
+ if (idx != std::string::npos) {
1974
+ name = name.substr(idx + 1);
1975
+ }
1976
+
1977
+ // Matlab variable names conform to similar rules as C
1978
+ if (name.empty() || !std::isalpha(name[0])) {
1979
+ name = "v" + name;
1980
+ }
1981
+ for (char &c : name) {
1982
+ if (!std::isalnum(c)) {
1983
+ c = '_';
1984
+ }
1985
+ }
1986
+
1987
+ uint32_t name_size = (int)name.size();
1988
+ while (name.size() & 0x7) {
1989
+ name += '\0';
1990
+ }
1991
+
1992
+ char header[128] = "MATLAB 5.0 MAT-file, produced by Halide";
1993
+ int len = strlen(header);
1994
+ memset(header + len, ' ', sizeof(header) - len);
1995
+
1996
+ // Version
1997
+ *((uint16_t *)(header + 124)) = 0x0100;
1998
+
1999
+ // Endianness check
2000
+ header[126] = 'I';
2001
+ header[127] = 'M';
2002
+
2003
+ uint64_t payload_bytes = im.size_in_bytes();
2004
+
2005
+ if (!check((payload_bytes >> 32) == 0, "Buffer too large to save as .mat")) {
2006
+ return false;
2007
+ }
2008
+
2009
+ int dims = im.dimensions();
2010
+ if (dims < 2) {
2011
+ dims = 2;
2012
+ }
2013
+ int padded_dims = dims + (dims & 1);
2014
+
2015
+ uint32_t padding_bytes = 7 - ((payload_bytes - 1) & 7);
2016
+
2017
+ // Matrix header
2018
+ uint32_t matrix_header[2] = {
2019
+ miMATRIX, 40 + padded_dims * 4 + (uint32_t)name.size() + (uint32_t)payload_bytes + padding_bytes};
2020
+
2021
+ // Array flags
2022
+ uint32_t flags[4] = {
2023
+ miUINT32, 8, class_code, 1};
2024
+
2025
+ // Shape
2026
+ int32_t shape[2] = {
2027
+ miINT32,
2028
+ im.dimensions() * 4,
2029
+ };
2030
+ std::vector<int> extents(im.dimensions());
2031
+ for (int d = 0; d < im.dimensions(); d++) {
2032
+ extents[d] = im.dim(d).extent();
2033
+ }
2034
+ while ((int)extents.size() < dims) {
2035
+ extents.push_back(1);
2036
+ }
2037
+ while ((int)extents.size() < padded_dims) {
2038
+ extents.push_back(0);
2039
+ }
2040
+
2041
+ // Name
2042
+ uint32_t name_header[2] = {
2043
+ miINT8, name_size};
2044
+
2045
+ // Payload header
2046
+ uint32_t payload_header[2] = {
2047
+ type_code, (uint32_t)payload_bytes};
2048
+
2049
+ bool success =
2050
+ f.write_array(header) &&
2051
+ f.write_array(matrix_header) &&
2052
+ f.write_array(flags) &&
2053
+ f.write_array(shape) &&
2054
+ f.write_vector(extents) &&
2055
+ f.write_array(name_header) &&
2056
+ f.write_bytes(&name[0], name.size()) &&
2057
+ f.write_array(payload_header);
2058
+
2059
+ if (!check(success, "Could not write .mat header")) {
2060
+ return false;
2061
+ }
2062
+
2063
+ if (!write_planar_payload<ImageType, check>(im, f)) {
2064
+ return false;
2065
+ }
2066
+
2067
+ // Padding
2068
+ if (!check(padding_bytes < 8, "Too much padding!\n")) {
2069
+ return false;
2070
+ }
2071
+ uint64_t padding = 0;
2072
+ if (!f.write_bytes(&padding, padding_bytes)) {
2073
+ return false;
2074
+ }
2075
+
2076
+ return true;
2077
+ }
2078
+
2079
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
2080
+ bool load_tiff(const std::string &filename, ImageType *im) {
2081
+ static_assert(!ImageType::has_static_halide_type, "");
2082
+ check(false, "Reading TIFF is not yet supported");
2083
+ return false;
2084
+ }
2085
+
2086
+ inline const std::set<FormatInfo> &query_tiff() {
2087
+ auto build_set = []() -> std::set<FormatInfo> {
2088
+ std::set<FormatInfo> s;
2089
+ for (halide_type_code_t code : {halide_type_int, halide_type_uint, halide_type_float}) {
2090
+ for (int bits : {8, 16, 32, 64}) {
2091
+ for (int dims : {1, 2, 3, 4}) {
2092
+ if (code == halide_type_float && bits < 32) {
2093
+ continue;
2094
+ }
2095
+ s.insert({halide_type_t(code, bits), dims});
2096
+ }
2097
+ }
2098
+ }
2099
+ return s;
2100
+ };
2101
+
2102
+ static std::set<FormatInfo> info = build_set();
2103
+ return info;
2104
+ }
2105
+
2106
+ #pragma pack(push)
2107
+ #pragma pack(2)
2108
+
2109
+ struct halide_tiff_tag {
2110
+ uint16_t tag_code;
2111
+ int16_t type_code;
2112
+ int32_t count;
2113
+ union {
2114
+ int8_t i8;
2115
+ int16_t i16;
2116
+ int32_t i32;
2117
+ } value;
2118
+
2119
+ void assign16(uint16_t tag_code, int32_t count, int16_t value) {
2120
+ this->tag_code = tag_code;
2121
+ this->type_code = 3; // SHORT
2122
+ this->count = count;
2123
+ this->value.i16 = value;
2124
+ }
2125
+
2126
+ void assign32(uint16_t tag_code, int32_t count, int32_t value) {
2127
+ this->tag_code = tag_code;
2128
+ this->type_code = 4; // LONG
2129
+ this->count = count;
2130
+ this->value.i32 = value;
2131
+ }
2132
+
2133
+ void assign32(uint16_t tag_code, int16_t type_code, int32_t count, int32_t value) {
2134
+ this->tag_code = tag_code;
2135
+ this->type_code = type_code;
2136
+ this->count = count;
2137
+ this->value.i32 = value;
2138
+ }
2139
+ };
2140
+
2141
+ struct halide_tiff_header {
2142
+ int16_t byte_order_marker;
2143
+ int16_t version;
2144
+ int32_t ifd0_offset;
2145
+ int16_t entry_count;
2146
+ halide_tiff_tag entries[15];
2147
+ int32_t ifd0_end;
2148
+ int32_t width_resolution[2];
2149
+ int32_t height_resolution[2];
2150
+ };
2151
+
2152
+ #pragma pack(pop)
2153
+
2154
+ template<typename ElemType, int BUFFER_SIZE = 1024>
2155
+ struct ElemWriter {
2156
+ ElemWriter(FileOpener *f)
2157
+ : f(f), next(&buf[0]) {
2158
+ }
2159
+ ~ElemWriter() {
2160
+ flush();
2161
+ }
2162
+
2163
+ void operator()(const ElemType &elem) {
2164
+ if (!ok) {
2165
+ return;
2166
+ }
2167
+
2168
+ *next++ = elem;
2169
+ if (next == &buf[BUFFER_SIZE]) {
2170
+ flush();
2171
+ }
2172
+ }
2173
+
2174
+ void flush() {
2175
+ if (!ok) {
2176
+ return;
2177
+ }
2178
+
2179
+ if (next > buf) {
2180
+ if (!f->write_bytes(buf, (next - buf) * sizeof(ElemType))) {
2181
+ ok = false;
2182
+ }
2183
+ next = buf;
2184
+ }
2185
+ }
2186
+
2187
+ FileOpener *const f;
2188
+ ElemType buf[BUFFER_SIZE];
2189
+ ElemType *next;
2190
+ bool ok = true;
2191
+ };
2192
+
2193
+ // Note that this is a fairly simpleminded TIFF writer that doesn't
2194
+ // do any compression. It would be desirable to (optionally) support using libtiff
2195
+ // here instead, which would also allow us to provide a useful implementation
2196
+ // for TIFF reading.
2197
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
2198
+ bool save_tiff(ImageType &im, const std::string &filename) {
2199
+ static_assert(!ImageType::has_static_halide_type, "");
2200
+
2201
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
2202
+ return false;
2203
+ }
2204
+
2205
+ if (!check(im.dimensions() <= 4, "Can only save TIFF files with <= 4 dimensions")) {
2206
+ return false;
2207
+ }
2208
+
2209
+ FileOpener f(filename, "wb");
2210
+ if (!check(f.f != nullptr, "File could not be opened for writing")) {
2211
+ return false;
2212
+ }
2213
+
2214
+ const size_t elements = im.number_of_elements();
2215
+ halide_dimension_t shape[4];
2216
+ for (int i = 0; i < im.dimensions() && i < 4; i++) {
2217
+ const auto &d = im.dim(i);
2218
+ shape[i].min = d.min();
2219
+ shape[i].extent = d.extent();
2220
+ shape[i].stride = d.stride();
2221
+ }
2222
+ for (int i = im.dimensions(); i < 4; i++) {
2223
+ shape[i].min = 0;
2224
+ shape[i].extent = 1;
2225
+ shape[i].stride = 0;
2226
+ }
2227
+ const halide_type_t im_type = im.type();
2228
+ if (!check(im_type.code >= 0 && im_type.code < 3, "Unsupported image type")) {
2229
+ return false;
2230
+ }
2231
+ const int32_t bytes_per_element = im_type.bytes();
2232
+ const int32_t width = shape[0].extent;
2233
+ const int32_t height = shape[1].extent;
2234
+ int32_t depth = shape[2].extent;
2235
+ int32_t channels = shape[3].extent;
2236
+
2237
+ if ((channels == 0 || channels == 1) && (depth < 5)) {
2238
+ channels = depth;
2239
+ depth = 1;
2240
+ }
2241
+
2242
+ // TIFF sample type values are:
2243
+ // 0 => Signed int
2244
+ // 1 => Unsigned int
2245
+ // 2 => Floating-point
2246
+ static const int16_t type_code_to_tiff_sample_type[] = {
2247
+ 2, 1, 3};
2248
+
2249
+ struct halide_tiff_header header;
2250
+ memset(&header, 0, sizeof(header));
2251
+
2252
+ const int32_t MMII = 0x4d4d4949;
2253
+ // Select the appropriate two bytes signaling byte order automatically
2254
+ const char *c = (const char *)&MMII;
2255
+ header.byte_order_marker = (c[0] << 8) | c[1];
2256
+ header.version = 42;
2257
+ header.ifd0_offset = offsetof(halide_tiff_header, entry_count);
2258
+ header.entry_count = sizeof(header.entries) / sizeof(header.entries[0]);
2259
+
2260
+ static_assert(sizeof(halide_tiff_tag) == 12, "Unexpected halide_tiff_tag packing");
2261
+ halide_tiff_tag *tag = &header.entries[0];
2262
+ tag++->assign32(256, 1, width); // ImageWidth
2263
+ tag++->assign32(257, 1, height); // ImageLength
2264
+ tag++->assign16(258, 1, int16_t(bytes_per_element * 8)); // BitsPerSample
2265
+ tag++->assign16(259, 1, 1); // Compression -- none
2266
+ tag++->assign16(262, 1, channels >= 3 ? 2 : 1); // PhotometricInterpretation -- black is zero or RGB
2267
+ tag++->assign32(273, channels, sizeof(header)); // StripOffsets
2268
+ tag++->assign16(277, 1, int16_t(channels)); // SamplesPerPixel
2269
+ tag++->assign32(278, 1, height); // RowsPerStrip
2270
+ tag++->assign32(279, channels, // StripByteCounts
2271
+ (channels == 1) ?
2272
+ elements * bytes_per_element :
2273
+ sizeof(header) + channels * sizeof(int32_t)); // for channels > 1, this is an offset
2274
+ tag++->assign32(282, 5, 1,
2275
+ offsetof(halide_tiff_header, width_resolution)); // XResolution
2276
+ tag++->assign32(283, 5, 1,
2277
+ offsetof(halide_tiff_header, height_resolution)); // YResolution
2278
+ tag++->assign16(284, 1, channels == 1 ? 1 : 2); // PlanarConfiguration -- contig or planar
2279
+ tag++->assign16(296, 1, 1); // ResolutionUnit -- none
2280
+ tag++->assign16(339, 1, type_code_to_tiff_sample_type[im_type.code]); // SampleFormat
2281
+ tag++->assign32(32997, 1, depth); // Image depth
2282
+
2283
+ // Verify we used exactly the number we declared
2284
+ assert(tag == &header.entries[header.entry_count]);
2285
+
2286
+ header.ifd0_end = 0;
2287
+ header.width_resolution[0] = 1;
2288
+ header.width_resolution[1] = 1;
2289
+ header.height_resolution[0] = 1;
2290
+ header.height_resolution[1] = 1;
2291
+
2292
+ if (!check(f.write_bytes(&header, sizeof(header)), "TIFF write failed")) {
2293
+ return false;
2294
+ }
2295
+
2296
+ if (channels > 1) {
2297
+ // Fill in the values for StripOffsets
2298
+ int32_t offset = sizeof(header) + channels * sizeof(int32_t) * 2;
2299
+ for (int32_t i = 0; i < channels; i++) {
2300
+ if (!check(f.write_bytes(&offset, sizeof(offset)), "TIFF write failed")) {
2301
+ return false;
2302
+ }
2303
+ offset += width * height * depth * bytes_per_element;
2304
+ }
2305
+ // Fill in the values for StripByteCounts
2306
+ int32_t count = width * height * depth * bytes_per_element;
2307
+ for (int32_t i = 0; i < channels; i++) {
2308
+ if (!check(f.write_bytes(&count, sizeof(count)), "TIFF write failed")) {
2309
+ return false;
2310
+ }
2311
+ }
2312
+ }
2313
+
2314
+ // If image is dense, we can write it in one fell swoop
2315
+ if (elements * bytes_per_element == im.size_in_bytes()) {
2316
+ if (!check(f.write_bytes(im.data(), im.size_in_bytes()), "TIFF write failed")) {
2317
+ return false;
2318
+ }
2319
+ return true;
2320
+ }
2321
+
2322
+ // Otherwise, write it out via manual traversal.
2323
+ #define HANDLE_CASE(CODE, BITS, TYPE) \
2324
+ case halide_type_t(CODE, BITS).as_u32(): { \
2325
+ ElemWriter<TYPE> ew(&f); \
2326
+ im.template as<const TYPE, AnyDims>().for_each_value(ew); \
2327
+ if (!check(ew.ok, "TIFF write failed")) { \
2328
+ return false; \
2329
+ } \
2330
+ break; \
2331
+ }
2332
+
2333
+ switch (im_type.element_of().as_u32()) {
2334
+ HANDLE_CASE(halide_type_float, 32, float)
2335
+ HANDLE_CASE(halide_type_float, 64, double)
2336
+ HANDLE_CASE(halide_type_int, 8, int8_t)
2337
+ HANDLE_CASE(halide_type_int, 16, int16_t)
2338
+ HANDLE_CASE(halide_type_int, 32, int32_t)
2339
+ HANDLE_CASE(halide_type_int, 64, int64_t)
2340
+ HANDLE_CASE(halide_type_uint, 1, bool)
2341
+ HANDLE_CASE(halide_type_uint, 8, uint8_t)
2342
+ HANDLE_CASE(halide_type_uint, 16, uint16_t)
2343
+ HANDLE_CASE(halide_type_uint, 32, uint32_t)
2344
+ HANDLE_CASE(halide_type_uint, 64, uint64_t)
2345
+ // Note that we don't attempt to handle halide_type_handle here.
2346
+ default:
2347
+ assert(false && "Unsupported type");
2348
+ return false;
2349
+ }
2350
+ #undef HANDLE_CASE
2351
+
2352
+ return true;
2353
+ }
2354
+
2355
+ // Given something like ImageType<Foo, 2>, produce typedef ImageType<Foo, AnyDims>
2356
+ template<typename ImageType>
2357
+ struct ImageTypeWithDynamicDims {
2358
+ using type = decltype(std::declval<ImageType>().template as<typename ImageType::ElemType, AnyDims>());
2359
+ };
2360
+
2361
+ // Given something like ImageType<Foo>, produce typedef ImageType<Bar, AnyDims>
2362
+ template<typename ImageType, typename ElemType>
2363
+ struct ImageTypeWithElemType {
2364
+ using type = decltype(std::declval<ImageType>().template as<ElemType, AnyDims>());
2365
+ };
2366
+
2367
+ // Given something like ImageType<Foo>, produce typedef ImageType<const Bar, AnyDims>
2368
+ template<typename ImageType, typename ElemType>
2369
+ struct ImageTypeWithConstElemType {
2370
+ using type = decltype(std::declval<ImageType>().template as<typename std::add_const<ElemType>::type, AnyDims>());
2371
+ };
2372
+
2373
+ template<typename ImageType, Internal::CheckFunc check>
2374
+ struct ImageIO {
2375
+ using ConstImageType = typename ImageTypeWithConstElemType<ImageType, typename ImageType::ElemType>::type;
2376
+
2377
+ std::function<bool(const std::string &, ImageType *)> load;
2378
+ std::function<bool(ConstImageType &im, const std::string &)> save;
2379
+ std::function<const std::set<FormatInfo> &()> query;
2380
+ };
2381
+
2382
+ template<typename ImageType, Internal::CheckFunc check>
2383
+ bool find_imageio(const std::string &filename, ImageIO<ImageType, check> *result) {
2384
+ static_assert(!ImageType::has_static_halide_type, "");
2385
+ using ConstImageType = typename ImageTypeWithConstElemType<ImageType, typename ImageType::ElemType>::type;
2386
+
2387
+ const std::map<std::string, ImageIO<ImageType, check>> m = {
2388
+ #ifndef HALIDE_NO_JPEG
2389
+ {"jpeg", {load_jpg<ImageType, check>, save_jpg<ConstImageType, check>, query_jpg}},
2390
+ {"jpg", {load_jpg<ImageType, check>, save_jpg<ConstImageType, check>, query_jpg}},
2391
+ #endif
2392
+ {"npy", {load_npy<ImageType, check>, save_npy<ConstImageType, check>, query_npy}},
2393
+ {"pgm", {load_pgm<ImageType, check>, save_pgm<ConstImageType, check>, query_pgm}},
2394
+ #ifndef HALIDE_NO_PNG
2395
+ {"png", {load_png<ImageType, check>, save_png<ConstImageType, check>, query_png}},
2396
+ #endif
2397
+ {"ppm", {load_ppm<ImageType, check>, save_ppm<ConstImageType, check>, query_ppm}},
2398
+ {"tmp", {load_tmp<ImageType, check>, save_tmp<ConstImageType, check>, query_tmp}},
2399
+ {"mat", {load_mat<ImageType, check>, save_mat<ConstImageType, check>, query_mat}},
2400
+ {"tiff", {load_tiff<ImageType, check>, save_tiff<ConstImageType, check>, query_tiff}},
2401
+ };
2402
+ std::string ext = Internal::get_lowercase_extension(filename);
2403
+ auto it = m.find(ext);
2404
+ if (it != m.end()) {
2405
+ *result = it->second;
2406
+ return true;
2407
+ }
2408
+
2409
+ std::string err = "unsupported file extension \"" + ext + "\", supported are:";
2410
+ for (auto &it : m) {
2411
+ err += " " + it.first;
2412
+ }
2413
+ err += "\n";
2414
+ return check(false, err.c_str());
2415
+ }
2416
+
2417
+ template<typename ImageType>
2418
+ FormatInfo best_save_format(const ImageType &im, const std::set<FormatInfo> &info) {
2419
+ // A bit ad hoc, but will do for now:
2420
+ // Perfect score is zero (exact match).
2421
+ // The larger the score, the worse the match.
2422
+ int best_score = 0x7fffffff;
2423
+ FormatInfo best{};
2424
+ const halide_type_t im_type = im.type();
2425
+ const int im_dimensions = im.dimensions();
2426
+ for (const auto &f : info) {
2427
+ int score = 0;
2428
+ // If format has too-few dimensions, that's very bad.
2429
+ score += std::max(0, im_dimensions - f.dimensions) * 1024;
2430
+ // If format has too-few bits, that's pretty bad.
2431
+ score += std::max(0, im_type.bits - f.type.bits) * 8;
2432
+ // If format has too-many bits, that's a little bad.
2433
+ score += std::max(0, f.type.bits - im_type.bits);
2434
+ // If format has different code, that's a little bad.
2435
+ score += (f.type.code != im_type.code) ? 1 : 0;
2436
+ if (score < best_score) {
2437
+ best_score = score;
2438
+ best = f;
2439
+ }
2440
+ }
2441
+
2442
+ return best;
2443
+ }
2444
+
2445
+ } // namespace Internal
2446
+
2447
+ struct ImageTypeConversion {
2448
+ // Convert an Image from one ElemType to another, where the src and
2449
+ // dst types are statically known (e.g. Buffer<uint8_t> -> Buffer<float>).
2450
+ // Note that this does conversion with scaling -- intepreting integers
2451
+ // as fixed-point numbers between 0 and 1 -- not merely C-style casting.
2452
+ //
2453
+ // You'd normally call this with an explicit type for DstElemType and
2454
+ // allow ImageType to be inferred, e.g.
2455
+ // Buffer<uint8_t> src = ...;
2456
+ // Buffer<float> dst = convert_image<float>(src);
2457
+ template<typename DstElemType, typename ImageType,
2458
+ typename std::enable_if<ImageType::has_static_halide_type && !std::is_void<DstElemType>::value>::type * = nullptr>
2459
+ static auto convert_image(const ImageType &src) ->
2460
+ typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type {
2461
+ // The enable_if ensures this will never fire; this is here primarily
2462
+ // as documentation and a backstop against breakage.
2463
+ static_assert(ImageType::has_static_halide_type,
2464
+ "This variant of convert_image() requires a statically-typed image");
2465
+
2466
+ using SrcImageType = ImageType;
2467
+ using SrcElemType = typename SrcImageType::ElemType;
2468
+
2469
+ using DstImageType = typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type;
2470
+
2471
+ DstImageType dst = DstImageType::make_with_shape_of(src);
2472
+ const auto converter = [](DstElemType &dst_elem, SrcElemType src_elem) {
2473
+ dst_elem = Internal::convert<DstElemType>(src_elem);
2474
+ };
2475
+ dst.for_each_value(converter, src);
2476
+ dst.set_host_dirty();
2477
+
2478
+ return dst;
2479
+ }
2480
+
2481
+ // Convert an Image from one ElemType to another, where the dst type is statically
2482
+ // known but the src type is not (e.g. Buffer<> -> Buffer<float>).
2483
+ // You'd normally call this with an explicit type for DstElemType and
2484
+ // allow ImageType to be inferred, e.g.
2485
+ // Buffer<uint8_t> src = ...;
2486
+ // Buffer<float> dst = convert_image<float>(src);
2487
+ template<typename DstElemType, typename ImageType,
2488
+ typename std::enable_if<!ImageType::has_static_halide_type && !std::is_void<DstElemType>::value>::type * = nullptr>
2489
+ static auto convert_image(const ImageType &src) ->
2490
+ typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type {
2491
+ // The enable_if ensures this will never fire; this is here primarily
2492
+ // as documentation and a backstop against breakage.
2493
+ static_assert(!ImageType::has_static_halide_type,
2494
+ "This variant of convert_image() requires a dynamically-typed image");
2495
+ constexpr int AnyDims = Internal::AnyDims;
2496
+
2497
+ const halide_type_t src_type = src.type();
2498
+ switch (src_type.element_of().as_u32()) {
2499
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
2500
+ case halide_type_t(halide_type_float, 16).as_u32():
2501
+ return convert_image<DstElemType>(src.template as<_Float16, AnyDims>());
2502
+ #endif
2503
+ case halide_type_t(halide_type_float, 32).as_u32():
2504
+ return convert_image<DstElemType>(src.template as<float, AnyDims>());
2505
+ case halide_type_t(halide_type_float, 64).as_u32():
2506
+ return convert_image<DstElemType>(src.template as<double, AnyDims>());
2507
+ case halide_type_t(halide_type_int, 8).as_u32():
2508
+ return convert_image<DstElemType>(src.template as<int8_t, AnyDims>());
2509
+ case halide_type_t(halide_type_int, 16).as_u32():
2510
+ return convert_image<DstElemType>(src.template as<int16_t, AnyDims>());
2511
+ case halide_type_t(halide_type_int, 32).as_u32():
2512
+ return convert_image<DstElemType>(src.template as<int32_t, AnyDims>());
2513
+ case halide_type_t(halide_type_int, 64).as_u32():
2514
+ return convert_image<DstElemType>(src.template as<int64_t, AnyDims>());
2515
+ case halide_type_t(halide_type_uint, 1).as_u32():
2516
+ return convert_image<DstElemType>(src.template as<bool, AnyDims>());
2517
+ case halide_type_t(halide_type_uint, 8).as_u32():
2518
+ return convert_image<DstElemType>(src.template as<uint8_t, AnyDims>());
2519
+ case halide_type_t(halide_type_uint, 16).as_u32():
2520
+ return convert_image<DstElemType>(src.template as<uint16_t, AnyDims>());
2521
+ case halide_type_t(halide_type_uint, 32).as_u32():
2522
+ return convert_image<DstElemType>(src.template as<uint32_t, AnyDims>());
2523
+ case halide_type_t(halide_type_uint, 64).as_u32():
2524
+ return convert_image<DstElemType>(src.template as<uint64_t, AnyDims>());
2525
+ default:
2526
+ assert(false && "Unsupported type");
2527
+ using DstImageType = typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type;
2528
+ return DstImageType();
2529
+ }
2530
+ }
2531
+
2532
+ // Convert an Image from one ElemType to another, where the src type
2533
+ // is statically known but the dst type is not
2534
+ // (e.g. Buffer<uint8_t> -> Buffer<>(halide_type_t)).
2535
+ template<typename DstElemType = void,
2536
+ typename ImageType,
2537
+ typename std::enable_if<ImageType::has_static_halide_type && std::is_void<DstElemType>::value>::type * = nullptr>
2538
+ static auto convert_image(const ImageType &src, const halide_type_t &dst_type) ->
2539
+ typename Internal::ImageTypeWithElemType<ImageType, void>::type {
2540
+ // The enable_if ensures this will never fire; this is here primarily
2541
+ // as documentation and a backstop against breakage.
2542
+ static_assert(ImageType::has_static_halide_type,
2543
+ "This variant of convert_image() requires a statically-typed image");
2544
+
2545
+ // Call the appropriate static-to-static conversion routine
2546
+ // based on the desired dst type.
2547
+ switch (dst_type.element_of().as_u32()) {
2548
+ #ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
2549
+ case halide_type_t(halide_type_float, 16).as_u32():
2550
+ return convert_image<_Float16>(src);
2551
+ #endif
2552
+ case halide_type_t(halide_type_float, 32).as_u32():
2553
+ return convert_image<float>(src);
2554
+ case halide_type_t(halide_type_float, 64).as_u32():
2555
+ return convert_image<double>(src);
2556
+ case halide_type_t(halide_type_int, 8).as_u32():
2557
+ return convert_image<int8_t>(src);
2558
+ case halide_type_t(halide_type_int, 16).as_u32():
2559
+ return convert_image<int16_t>(src);
2560
+ case halide_type_t(halide_type_int, 32).as_u32():
2561
+ return convert_image<int32_t>(src);
2562
+ case halide_type_t(halide_type_int, 64).as_u32():
2563
+ return convert_image<int64_t>(src);
2564
+ case halide_type_t(halide_type_uint, 1).as_u32():
2565
+ return convert_image<bool>(src);
2566
+ case halide_type_t(halide_type_uint, 8).as_u32():
2567
+ return convert_image<uint8_t>(src);
2568
+ case halide_type_t(halide_type_uint, 16).as_u32():
2569
+ return convert_image<uint16_t>(src);
2570
+ case halide_type_t(halide_type_uint, 32).as_u32():
2571
+ return convert_image<uint32_t>(src);
2572
+ case halide_type_t(halide_type_uint, 64).as_u32():
2573
+ return convert_image<uint64_t>(src);
2574
+ default:
2575
+ assert(false && "Unsupported type");
2576
+ using RetImageType = typename Internal::ImageTypeWithDynamicDims<ImageType>::type;
2577
+ return RetImageType();
2578
+ }
2579
+ }
2580
+
2581
+ // Convert an Image from one ElemType to another, where neither src type
2582
+ // nor dst type are statically known
2583
+ // (e.g. Buffer<>(halide_type_t) -> Buffer<>(halide_type_t)).
2584
+ template<typename DstElemType = void,
2585
+ typename ImageType,
2586
+ typename std::enable_if<!ImageType::has_static_halide_type && std::is_void<DstElemType>::value>::type * = nullptr>
2587
+ static auto convert_image(const ImageType &src, const halide_type_t &dst_type) ->
2588
+ typename Internal::ImageTypeWithElemType<ImageType, void>::type {
2589
+ // The enable_if ensures this will never fire; this is here primarily
2590
+ // as documentation and a backstop against breakage.
2591
+ static_assert(!ImageType::has_static_halide_type,
2592
+ "This variant of convert_image() requires a dynamically-typed image");
2593
+ constexpr int AnyDims = Internal::AnyDims;
2594
+
2595
+ // Sniff the runtime type of src, coerce it to that type using as<>(),
2596
+ // and call the static-to-dynamic variant of this method. (Note that
2597
+ // this forces instantiation of the complete any-to-any conversion
2598
+ // matrix of code.)
2599
+ const halide_type_t src_type = src.type();
2600
+ switch (src_type.element_of().as_u32()) {
2601
+ case halide_type_t(halide_type_float, 32).as_u32():
2602
+ return convert_image(src.template as<float, AnyDims>(), dst_type);
2603
+ case halide_type_t(halide_type_float, 64).as_u32():
2604
+ return convert_image(src.template as<double, AnyDims>(), dst_type);
2605
+ case halide_type_t(halide_type_int, 8).as_u32():
2606
+ return convert_image(src.template as<int8_t, AnyDims>(), dst_type);
2607
+ case halide_type_t(halide_type_int, 16).as_u32():
2608
+ return convert_image(src.template as<int16_t, AnyDims>(), dst_type);
2609
+ case halide_type_t(halide_type_int, 32).as_u32():
2610
+ return convert_image(src.template as<int32_t, AnyDims>(), dst_type);
2611
+ case halide_type_t(halide_type_int, 64).as_u32():
2612
+ return convert_image(src.template as<int64_t, AnyDims>(), dst_type);
2613
+ case halide_type_t(halide_type_uint, 1).as_u32():
2614
+ return convert_image(src.template as<bool, AnyDims>(), dst_type);
2615
+ case halide_type_t(halide_type_uint, 8).as_u32():
2616
+ return convert_image(src.template as<uint8_t, AnyDims>(), dst_type);
2617
+ case halide_type_t(halide_type_uint, 16).as_u32():
2618
+ return convert_image(src.template as<uint16_t, AnyDims>(), dst_type);
2619
+ case halide_type_t(halide_type_uint, 32).as_u32():
2620
+ return convert_image(src.template as<uint32_t, AnyDims>(), dst_type);
2621
+ case halide_type_t(halide_type_uint, 64).as_u32():
2622
+ return convert_image(src.template as<uint64_t, AnyDims>(), dst_type);
2623
+ default:
2624
+ assert(false && "Unsupported type");
2625
+ using RetImageType = typename Internal::ImageTypeWithDynamicDims<ImageType>::type;
2626
+ return RetImageType();
2627
+ }
2628
+ }
2629
+ };
2630
+
2631
+ // Load the Image from the given file.
2632
+ // If output Image has a static type, and the loaded image cannot be stored
2633
+ // in such an image without losing data, fail.
2634
+ // Returns false upon failure.
2635
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
2636
+ bool load(const std::string &filename, ImageType *im) {
2637
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2638
+ Internal::ImageIO<DynamicImageType, check> imageio;
2639
+ if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2640
+ return false;
2641
+ }
2642
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2643
+ DynamicImageType im_d;
2644
+ if (!imageio.load(filename, &im_d)) {
2645
+ return false;
2646
+ }
2647
+ // Allow statically-typed images to be passed as the out-param, but do
2648
+ // a runtime check to ensure
2649
+ if (ImageType::has_static_halide_type) {
2650
+ const halide_type_t expected_type = ImageType::static_halide_type();
2651
+ if (!check(im_d.type() == expected_type, "Image loaded did not match the expected type")) {
2652
+ return false;
2653
+ }
2654
+ }
2655
+ *im = im_d.template as<typename ImageType::ElemType, Internal::AnyDims>();
2656
+ im->set_host_dirty();
2657
+ return true;
2658
+ }
2659
+
2660
+ // Save the Image in the format associated with the filename's extension.
2661
+ // If the format can't represent the Image without losing data, fail.
2662
+ // Returns false upon failure.
2663
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
2664
+ bool save(ImageType &im, const std::string &filename) {
2665
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2666
+ Internal::ImageIO<DynamicImageType, check> imageio;
2667
+ if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2668
+ return false;
2669
+ }
2670
+ if (!check(imageio.query().count({im.type(), im.dimensions()}) > 0, "Image cannot be saved in this format")) {
2671
+ return false;
2672
+ }
2673
+
2674
+ // Allow statically-typed images to be passed in, but quietly pass them on
2675
+ // as dynamically-typed images.
2676
+ auto im_d = im.template as<const void, Internal::AnyDims>();
2677
+ return imageio.save(im_d, filename);
2678
+ }
2679
+
2680
+ // Return a set of FormatInfo structs that contain the legal type-and-dimensions
2681
+ // that can be saved in this format. Most applications won't ever need to use
2682
+ // this call. Returns false upon failure.
2683
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
2684
+ bool save_query(const std::string &filename, std::set<FormatInfo> *info) {
2685
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2686
+ Internal::ImageIO<DynamicImageType, check> imageio;
2687
+ if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2688
+ return false;
2689
+ }
2690
+ *info = imageio.query();
2691
+ return true;
2692
+ }
2693
+
2694
+ // Fancy wrapper to call load() with CheckFail, inferring the return type;
2695
+ // this allows you to simply use
2696
+ //
2697
+ // Image im = load_image("filename");
2698
+ //
2699
+ // without bothering to check error results (all errors simply abort).
2700
+ //
2701
+ // Note that if the image being loaded doesn't match the static type and
2702
+ // dimensions of of the image on the LHS, a runtime error will occur.
2703
+ class load_image {
2704
+ public:
2705
+ load_image(const std::string &f)
2706
+ : filename(f) {
2707
+ }
2708
+
2709
+ template<typename ImageType>
2710
+ operator ImageType() {
2711
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2712
+ DynamicImageType im_d;
2713
+ Internal::CheckFail(load<DynamicImageType, Internal::CheckFail>(filename, &im_d), "load() failed");
2714
+ Internal::CheckFail(ImageType::can_convert_from(im_d),
2715
+ "Type mismatch assigning the result of load_image. "
2716
+ "Did you mean to use load_and_convert_image?");
2717
+ return im_d.template as<typename ImageType::ElemType, Internal::AnyDims>();
2718
+ }
2719
+
2720
+ private:
2721
+ const std::string filename;
2722
+ };
2723
+
2724
+ // Like load_image, but quietly convert the loaded image to the type of the LHS
2725
+ // if necessary, discarding information if necessary.
2726
+ class load_and_convert_image {
2727
+ public:
2728
+ load_and_convert_image(const std::string &f)
2729
+ : filename(f) {
2730
+ }
2731
+
2732
+ template<typename ImageType>
2733
+ inline operator ImageType() {
2734
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2735
+ DynamicImageType im_d;
2736
+ Internal::CheckFail(load<DynamicImageType, Internal::CheckFail>(filename, &im_d), "load() failed");
2737
+ const halide_type_t expected_type = ImageType::static_halide_type();
2738
+ if (im_d.type() == expected_type) {
2739
+ return im_d.template as<typename ImageType::ElemType, Internal::AnyDims>();
2740
+ } else {
2741
+ return ImageTypeConversion::convert_image<typename ImageType::ElemType>(im_d);
2742
+ }
2743
+ }
2744
+
2745
+ private:
2746
+ const std::string filename;
2747
+ };
2748
+
2749
+ // Fancy wrapper to call save() with CheckFail; this allows you to simply use
2750
+ //
2751
+ // save_image(im, "filename");
2752
+ //
2753
+ // without bothering to check error results (all errors simply abort).
2754
+ //
2755
+ // If the specified image file format cannot represent the image without
2756
+ // losing data (e.g, a float32 or 4-dimensional image saved as a JPEG),
2757
+ // a runtime error will occur.
2758
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckFail>
2759
+ void save_image(ImageType &im, const std::string &filename) {
2760
+ auto im_d = im.template as<const void, Internal::AnyDims>();
2761
+ (void)save<decltype(im_d), check>(im_d, filename);
2762
+ }
2763
+
2764
+ // Like save_image, but quietly convert the saved image to a type that the
2765
+ // specified image file format can hold, discarding information if necessary.
2766
+ // (Note that the input image is unaffected!)
2767
+ template<typename ImageType, Internal::CheckFunc check = Internal::CheckFail>
2768
+ void convert_and_save_image(ImageType &im, const std::string &filename) {
2769
+ // We'll be doing any conversion on the CPU
2770
+ if (!check(im.copy_to_host() == halide_error_code_success, "copy_to_host() failed.")) {
2771
+ return;
2772
+ }
2773
+
2774
+ std::set<FormatInfo> info;
2775
+ (void)save_query<typename Internal::ImageTypeWithDynamicDims<ImageType>::type, check>(filename, &info);
2776
+ const FormatInfo best = Internal::best_save_format(im, info);
2777
+ if (best.type == im.type() && best.dimensions == im.dimensions()) {
2778
+ // It's an exact match, we can save as-is.
2779
+ using DynamicImageDims = typename Internal::ImageTypeWithDynamicDims<ImageType>::type;
2780
+ (void)save<DynamicImageDims, check>(im.template as<typename ImageType::ElemType, Internal::AnyDims>(), filename);
2781
+ } else {
2782
+ using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2783
+ DynamicImageType im_converted = ImageTypeConversion::convert_image(im, best.type);
2784
+ while (im_converted.dimensions() < best.dimensions) {
2785
+ im_converted.add_dimension();
2786
+ }
2787
+ (void)save<DynamicImageType, check>(im_converted, filename);
2788
+ }
2789
+ }
2790
+
2791
+ } // namespace Tools
2792
+ } // namespace Halide
2793
+
2794
+ #endif // HALIDE_IMAGE_IO_H