xnd 0.2.0dev6 → 0.2.0dev7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -0
  3. data/Rakefile +1 -1
  4. data/ext/ruby_xnd/GPATH +0 -0
  5. data/ext/ruby_xnd/GRTAGS +0 -0
  6. data/ext/ruby_xnd/GTAGS +0 -0
  7. data/ext/ruby_xnd/extconf.rb +8 -5
  8. data/ext/ruby_xnd/gc_guard.c +53 -2
  9. data/ext/ruby_xnd/gc_guard.h +8 -2
  10. data/ext/ruby_xnd/include/overflow.h +147 -0
  11. data/ext/ruby_xnd/include/ruby_xnd.h +62 -0
  12. data/ext/ruby_xnd/include/xnd.h +590 -0
  13. data/ext/ruby_xnd/lib/libxnd.a +0 -0
  14. data/ext/ruby_xnd/lib/libxnd.so +1 -0
  15. data/ext/ruby_xnd/lib/libxnd.so.0 +1 -0
  16. data/ext/ruby_xnd/lib/libxnd.so.0.2.0dev3 +0 -0
  17. data/ext/ruby_xnd/ruby_xnd.c +556 -47
  18. data/ext/ruby_xnd/ruby_xnd.h +2 -1
  19. data/ext/ruby_xnd/xnd/Makefile +80 -0
  20. data/ext/ruby_xnd/xnd/config.h +26 -0
  21. data/ext/ruby_xnd/xnd/config.h.in +3 -0
  22. data/ext/ruby_xnd/xnd/config.log +421 -0
  23. data/ext/ruby_xnd/xnd/config.status +1023 -0
  24. data/ext/ruby_xnd/xnd/configure +376 -8
  25. data/ext/ruby_xnd/xnd/configure.ac +48 -7
  26. data/ext/ruby_xnd/xnd/doc/xnd/index.rst +3 -1
  27. data/ext/ruby_xnd/xnd/doc/xnd/{types.rst → xnd.rst} +3 -18
  28. data/ext/ruby_xnd/xnd/libxnd/Makefile +142 -0
  29. data/ext/ruby_xnd/xnd/libxnd/Makefile.in +43 -3
  30. data/ext/ruby_xnd/xnd/libxnd/Makefile.vc +19 -3
  31. data/ext/ruby_xnd/xnd/libxnd/bitmaps.c +42 -3
  32. data/ext/ruby_xnd/xnd/libxnd/bitmaps.o +0 -0
  33. data/ext/ruby_xnd/xnd/libxnd/bounds.c +366 -0
  34. data/ext/ruby_xnd/xnd/libxnd/bounds.o +0 -0
  35. data/ext/ruby_xnd/xnd/libxnd/contrib.h +98 -0
  36. data/ext/ruby_xnd/xnd/libxnd/contrib/bfloat16.h +213 -0
  37. data/ext/ruby_xnd/xnd/libxnd/copy.c +155 -4
  38. data/ext/ruby_xnd/xnd/libxnd/copy.o +0 -0
  39. data/ext/ruby_xnd/xnd/libxnd/cuda/cuda_memory.cu +121 -0
  40. data/ext/ruby_xnd/xnd/libxnd/cuda/cuda_memory.h +58 -0
  41. data/ext/ruby_xnd/xnd/libxnd/equal.c +195 -7
  42. data/ext/ruby_xnd/xnd/libxnd/equal.o +0 -0
  43. data/ext/ruby_xnd/xnd/libxnd/inline.h +32 -0
  44. data/ext/ruby_xnd/xnd/libxnd/libxnd.a +0 -0
  45. data/ext/ruby_xnd/xnd/libxnd/libxnd.so +1 -0
  46. data/ext/ruby_xnd/xnd/libxnd/libxnd.so.0 +1 -0
  47. data/ext/ruby_xnd/xnd/libxnd/libxnd.so.0.2.0dev3 +0 -0
  48. data/ext/ruby_xnd/xnd/libxnd/shape.c +207 -0
  49. data/ext/ruby_xnd/xnd/libxnd/shape.o +0 -0
  50. data/ext/ruby_xnd/xnd/libxnd/split.c +2 -2
  51. data/ext/ruby_xnd/xnd/libxnd/split.o +0 -0
  52. data/ext/ruby_xnd/xnd/libxnd/tests/Makefile +39 -0
  53. data/ext/ruby_xnd/xnd/libxnd/xnd.c +613 -91
  54. data/ext/ruby_xnd/xnd/libxnd/xnd.h +145 -4
  55. data/ext/ruby_xnd/xnd/libxnd/xnd.o +0 -0
  56. data/ext/ruby_xnd/xnd/python/test_xnd.py +1125 -50
  57. data/ext/ruby_xnd/xnd/python/xnd/__init__.py +609 -124
  58. data/ext/ruby_xnd/xnd/python/xnd/_version.py +1 -0
  59. data/ext/ruby_xnd/xnd/python/xnd/_xnd.c +1652 -101
  60. data/ext/ruby_xnd/xnd/python/xnd/libxnd.a +0 -0
  61. data/ext/ruby_xnd/xnd/python/xnd/libxnd.so +1 -0
  62. data/ext/ruby_xnd/xnd/python/xnd/libxnd.so.0 +1 -0
  63. data/ext/ruby_xnd/xnd/python/xnd/libxnd.so.0.2.0dev3 +0 -0
  64. data/ext/ruby_xnd/xnd/python/xnd/pyxnd.h +1 -1
  65. data/ext/ruby_xnd/xnd/python/xnd/util.h +25 -0
  66. data/ext/ruby_xnd/xnd/python/xnd/xnd.h +590 -0
  67. data/ext/ruby_xnd/xnd/python/xnd_randvalue.py +106 -6
  68. data/ext/ruby_xnd/xnd/python/xnd_support.py +4 -0
  69. data/ext/ruby_xnd/xnd/setup.py +46 -4
  70. data/lib/ruby_xnd.so +0 -0
  71. data/lib/xnd.rb +39 -3
  72. data/lib/xnd/version.rb +2 -2
  73. data/xnd.gemspec +2 -1
  74. metadata +58 -5
@@ -80,6 +80,23 @@ _var_dim_next(const xnd_t *x, const int64_t start, const int64_t step,
80
80
  return next;
81
81
  }
82
82
 
83
+ static inline xnd_t
84
+ _array_next(const xnd_t *x, const int64_t i)
85
+ {
86
+ const ndt_t *t = x->type;
87
+ const ndt_t *u = t->Array.type;
88
+ xnd_t next;
89
+
90
+ assert(t->tag == Array);
91
+
92
+ next.bitmap = xnd_bitmap_empty;
93
+ next.index = 0;
94
+ next.type = u;
95
+ next.ptr = XND_ARRAY_DATA(x->ptr) + i * next.type->datasize;
96
+
97
+ return next;
98
+ }
99
+
83
100
  static inline xnd_t
84
101
  _tuple_next(const xnd_t *x, const int64_t i)
85
102
  {
@@ -108,6 +125,21 @@ _record_next(const xnd_t *x, const int64_t i)
108
125
  return next;
109
126
  }
110
127
 
128
+ static inline xnd_t
129
+ _union_next(const xnd_t *x)
130
+ {
131
+ uint8_t i = XND_UNION_TAG(x->ptr);
132
+ const ndt_t *t = x->type;
133
+ xnd_t next;
134
+
135
+ next.bitmap = xnd_bitmap_empty;
136
+ next.index = 0;
137
+ next.type = t->Union.types[i];
138
+ next.ptr = x->ptr+1;
139
+
140
+ return next;
141
+ }
142
+
111
143
  static inline xnd_t
112
144
  _ref_next(const xnd_t *x)
113
145
  {
@@ -0,0 +1 @@
1
+ ext/ruby_xnd/xnd/libxnd/libxnd.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_xnd/xnd/libxnd/libxnd.so.0.2.0dev3
@@ -0,0 +1,207 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include <inttypes.h>
38
+ #include "ndtypes.h"
39
+ #include "xnd.h"
40
+ #include "contrib.h"
41
+ #include "overflow.h"
42
+
43
+
44
+ static bool
45
+ shape_equal(const ndt_ndarray_t *dest, const ndt_ndarray_t *src)
46
+ {
47
+ if (dest->ndim != src->ndim) {
48
+ return false;
49
+ }
50
+
51
+ for (int i = 0; i < src->ndim; i++) {
52
+ if (dest->shape[i] != src->shape[i]) {
53
+ return false;
54
+ }
55
+ }
56
+
57
+ return true;
58
+ }
59
+
60
+ static int64_t
61
+ prod(const int64_t shape[], int N)
62
+ {
63
+ bool overflow = false;
64
+ int64_t p = 1;
65
+
66
+ for (int64_t i = 0; i < N; i++) {
67
+ p = MULi64(p, shape[i], &overflow);
68
+ if (overflow) {
69
+ return -1;
70
+ }
71
+ }
72
+
73
+ return p;
74
+ }
75
+
76
+ static inline bool
77
+ zero_in_shape(const ndt_ndarray_t *x)
78
+ {
79
+ for (int i = 0; i < x->ndim; i++) {
80
+ if (x->shape[i] == 0) {
81
+ return true;
82
+ }
83
+ }
84
+
85
+ return false;
86
+ }
87
+
88
+ static void
89
+ init_contiguous_c_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
90
+ {
91
+ int64_t q;
92
+ int64_t i;
93
+
94
+ if (src->ndim == 0 && dest->ndim == 0) {
95
+ return;
96
+ }
97
+
98
+ q = 1;
99
+ for (i = dest->ndim-1; i >= 0; i--) {
100
+ dest->steps[i] = q;
101
+ q *= dest->shape[i];
102
+ }
103
+ }
104
+
105
+ static void
106
+ init_contiguous_f_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
107
+ {
108
+ int64_t q;
109
+ int64_t i;
110
+
111
+ if (src->ndim == 0 && dest->ndim == 0) {
112
+ return;
113
+ }
114
+
115
+ q = 1;
116
+ for (i = 0; i < dest->ndim; i++) {
117
+ dest->steps[i] = q;
118
+ q *= dest->shape[i];
119
+ }
120
+ }
121
+
122
+ xnd_t
123
+ xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, char order,
124
+ ndt_context_t *ctx)
125
+ {
126
+ const ndt_t *t = x->type;
127
+ ndt_ndarray_t src, dest;
128
+ int64_t p, q;
129
+ int ret;
130
+ int use_fortran = 0;
131
+
132
+ if (order == 'F') {
133
+ use_fortran = 1;
134
+ }
135
+ else if (order == 'A') {
136
+ use_fortran = ndt_is_f_contiguous(t);
137
+ }
138
+ else if (order != 'C') {
139
+ ndt_err_format(ctx, NDT_ValueError, "'order' must be 'C', 'F' or 'A'");
140
+ return xnd_error;
141
+ }
142
+
143
+ if (ndt_as_ndarray(&src, t, ctx) < 0) {
144
+ return xnd_error;
145
+ }
146
+
147
+ dest.ndim = ndim;
148
+ dest.itemsize = src.itemsize;
149
+ for (int i = 0; i < ndim; i++) {
150
+ dest.shape[i] = shape[i];
151
+ dest.steps[i] = 0;
152
+ dest.strides[i] = 0;
153
+ }
154
+
155
+ p = prod(src.shape, src.ndim);
156
+ q = prod(dest.shape, dest.ndim);
157
+ if (p < 0 || q < 0) {
158
+ ndt_err_format(ctx, NDT_ValueError,
159
+ "reshaped array has too many elements");
160
+ return xnd_error;
161
+ }
162
+ if (p != q) {
163
+ ndt_err_format(ctx, NDT_ValueError,
164
+ "shapes do not have the same number of elements");
165
+ return xnd_error;
166
+ }
167
+
168
+ if (shape_equal(&dest, &src)) {
169
+ dest = src;
170
+ }
171
+ else if (zero_in_shape(&dest)) {
172
+ ;
173
+ }
174
+ else if (!use_fortran && ndt_is_c_contiguous(t)) {
175
+ init_contiguous_c_strides(&dest, &src);
176
+ }
177
+ else if (use_fortran && ndt_is_f_contiguous(t)) {
178
+ init_contiguous_f_strides(&dest, &src);
179
+ }
180
+ else {
181
+ ret = xnd_nocopy_reshape(dest.shape, dest.steps, dest.ndim,
182
+ src.shape, src.steps, src.ndim, use_fortran);
183
+ if (!ret) {
184
+ ndt_err_format(ctx, NDT_ValueError, "inplace reshape not possible");
185
+ return xnd_error;
186
+ }
187
+ }
188
+
189
+ xnd_t res = *x;
190
+
191
+ const ndt_t *u = ndt_copy(ndt_dtype(t), ctx);
192
+ if (u == NULL) {
193
+ return xnd_error;
194
+ }
195
+
196
+ for (int i = dest.ndim-1; i >= 0; i--) {
197
+ const ndt_t *v = ndt_fixed_dim(u, dest.shape[i], dest.steps[i], ctx);
198
+ ndt_decref(u);
199
+ if (v == NULL) {
200
+ return xnd_error;
201
+ }
202
+ u = v;
203
+ }
204
+
205
+ res.type = u;
206
+ return res;
207
+ }
@@ -56,7 +56,7 @@ static void
56
56
  free_slices(xnd_t *lst, int64_t len)
57
57
  {
58
58
  for (int64_t i = 0; i < len; i++) {
59
- ndt_del((ndt_t *)lst[i].type);
59
+ ndt_decref(lst[i].type);
60
60
  }
61
61
 
62
62
  ndt_free(lst);
@@ -269,7 +269,7 @@ xnd_split(const xnd_t *x, int64_t *nparts, int max_outer, ndt_context_t *ctx)
269
269
  }
270
270
 
271
271
  for (int64_t i = 0; i < nrows; i++) {
272
- result[i] = xnd_multikey(x, indices+(i*ncols), nindices[i], ctx);
272
+ result[i] = xnd_subscript(x, indices+(i*ncols), nindices[i], ctx);
273
273
  if (ndt_err_occurred(ctx)) {
274
274
  ndt_free(nindices);
275
275
  ndt_free(indices);
@@ -0,0 +1,39 @@
1
+
2
+ SRCDIR = ..
3
+
4
+ CC = gcc
5
+ LIBSTATIC = libxnd.a
6
+ LIBSHARED = libxnd.so.0.2.0dev3
7
+
8
+ INCLUDES = /home/sameer/.rvm/gems/ruby-2.4.1/gems/ndtypes-0.2.0dev6/ext/ruby_ndtypes/include
9
+ LIBS = ../../ndtypes/libndtypes
10
+
11
+ CONFIGURE_CFLAGS = -Wall -Wextra -std=c11 -pedantic -O2 -g
12
+ XND_CFLAGS = $(strip $(CONFIGURE_CFLAGS) $(CFLAGS))
13
+
14
+
15
+ default: runtest runtest_shared
16
+
17
+
18
+ runtest:\
19
+ Makefile runtest.c test_fixed.c test.h $(SRCDIR)/xnd.h $(SRCDIR)/$(LIBSTATIC)
20
+ $(CC) -I$(SRCDIR) -I$(INCLUDES) $(XND_CFLAGS) \
21
+ -o runtest runtest.c test_fixed.c $(SRCDIR)/libxnd.a \
22
+ $(LIBS)/libndtypes.a
23
+
24
+ runtest_shared:\
25
+ Makefile runtest.c test_fixed.c test.h $(SRCDIR)/xnd.h $(SRCDIR)/$(LIBSHARED)
26
+ $(CC) -I$(SRCDIR) -I$(INCLUDES) -L$(SRCDIR) -L$(LIBS) \
27
+ $(XND_CFLAGS) -o runtest_shared runtest.c test_fixed.c -lxnd -lndtypes
28
+
29
+
30
+ FORCE:
31
+
32
+ clean: FORCE
33
+ rm -f *.o *.gch *.gcda *.gcno *.gcov *.dyn *.dpi *.lock
34
+ rm -f runtest runtest_shared
35
+
36
+ distclean: clean
37
+ rm -rf Makefile
38
+
39
+
@@ -35,15 +35,20 @@
35
35
  #include <stdint.h>
36
36
  #include <inttypes.h>
37
37
  #include <string.h>
38
+ #include <math.h>
38
39
  #include <assert.h>
39
40
  #include "ndtypes.h"
40
41
  #include "xnd.h"
41
42
  #include "inline.h"
42
43
  #include "contrib.h"
44
+ #include "contrib/bfloat16.h"
45
+ #include "cuda/cuda_memory.h"
46
+ #ifndef _MSC_VER
47
+ #include "config.h"
48
+ #endif
43
49
 
44
50
 
45
51
  static int xnd_init(xnd_t * const x, const uint32_t flags, ndt_context_t *ctx);
46
- static void xnd_clear(xnd_t * const x, const uint32_t flags);
47
52
 
48
53
 
49
54
  /*****************************************************************************/
@@ -72,32 +77,104 @@ xnd_err_occurred(const xnd_t *x)
72
77
  static bool
73
78
  requires_init(const ndt_t * const t)
74
79
  {
75
- const ndt_t *dtype = ndt_dtype(t);
80
+ return !ndt_is_ref_free(t);
81
+ }
76
82
 
77
- switch (dtype->tag) {
78
- case Categorical:
79
- case Bool:
80
- case Int8: case Int16: case Int32: case Int64:
81
- case Uint8: case Uint16: case Uint32: case Uint64:
82
- case Float16: case Float32: case Float64:
83
- case Complex32: case Complex64: case Complex128:
84
- case FixedString: case FixedBytes:
85
- case String: case Bytes:
83
+ static bool
84
+ is_primary_type(const ndt_t * const t, ndt_context_t *ctx)
85
+ {
86
+ if (ndt_is_abstract(t)) {
87
+ ndt_err_format(ctx, NDT_ValueError,
88
+ "cannot create xnd container from abstract type");
89
+ return false;
90
+ }
91
+
92
+ if (t->flags & NDT_CHAR) {
93
+ ndt_err_format(ctx, NDT_NotImplementedError, "char is not implemented");
86
94
  return false;
95
+ }
96
+
97
+ switch (t->tag) {
98
+ case FixedDim: {
99
+ if (!ndt_is_c_contiguous(t) && !ndt_is_f_contiguous(t)) {
100
+ ndt_err_format(ctx, NDT_ValueError,
101
+ "cannot create xnd container from non-contiguous type");
102
+ return false;
103
+ }
104
+ return true;
105
+ }
106
+ case VarDim: case VarDimElem: {
107
+ if (!ndt_is_var_contiguous(t)) {
108
+ ndt_err_format(ctx, NDT_ValueError,
109
+ "cannot create xnd container from non-contiguous type");
110
+ return false;
111
+ }
112
+ return true;
113
+ }
114
+ case Array: {
115
+ if (requires_init(t)) {
116
+ ndt_err_format(ctx, NDT_ValueError,
117
+ "flexible arrays cannot have dtypes that require "
118
+ "initialization");
119
+ return false;
120
+ }
121
+ return true;
122
+ }
87
123
  default:
88
124
  return true;
89
125
  }
126
+
127
+ ndt_err_format(ctx, NDT_ValueError,
128
+ "cannot create xnd container from non-contiguous type");
129
+ return false;
90
130
  }
91
131
 
132
+
92
133
  /* Create and initialize memory with type 't'. */
134
+ #ifdef HAVE_CUDA
135
+ static char *
136
+ xnd_cuda_new(const ndt_t * const t, ndt_context_t *ctx)
137
+ {
138
+ void *ptr;
139
+
140
+ if (!is_primary_type(t, ctx)) {
141
+ return NULL;
142
+ }
143
+
144
+ if (!ndt_is_pointer_free(t)) {
145
+ ndt_err_format(ctx, NDT_ValueError,
146
+ "only pointer-free types are supported on cuda");
147
+ return NULL;
148
+ }
149
+
150
+ ptr = xnd_cuda_calloc_managed(t->align, t->datasize, ctx);
151
+ if (ptr == NULL) {
152
+ return NULL;
153
+ }
154
+
155
+ return ptr;
156
+ }
157
+ #else
158
+ static char *
159
+ xnd_cuda_new(const ndt_t * const t, ndt_context_t *ctx)
160
+ {
161
+ (void)t;
162
+
163
+ ndt_err_format(ctx, NDT_ValueError, "xnd compiled without cuda support");
164
+ return NULL;
165
+ }
166
+ #endif
167
+
93
168
  static char *
94
169
  xnd_new(const ndt_t * const t, const uint32_t flags, ndt_context_t *ctx)
95
170
  {
96
171
  xnd_t x;
97
172
 
98
- if (ndt_is_abstract(t)) {
99
- ndt_err_format(ctx, NDT_ValueError,
100
- "cannot create xnd container from abstract type");
173
+ if (flags & XND_CUDA_MANAGED) {
174
+ return xnd_cuda_new(t, ctx);
175
+ }
176
+
177
+ if (!is_primary_type(t, ctx)) {
101
178
  return NULL;
102
179
  }
103
180
 
@@ -136,6 +213,13 @@ xnd_init(xnd_t * const x, const uint32_t flags, ndt_context_t *ctx)
136
213
  {
137
214
  const ndt_t * const t = x->type;
138
215
 
216
+ if (flags & XND_CUDA_MANAGED) {
217
+ ndt_err_format(ctx, NDT_RuntimeError,
218
+ "internal error: cannot initialize cuda memory with a type "
219
+ "that contains pointers");
220
+ return -1;
221
+ }
222
+
139
223
  if (ndt_is_abstract(t)) {
140
224
  ndt_err_format(ctx, NDT_ValueError,
141
225
  "cannot initialize concrete memory from abstract type");
@@ -199,6 +283,16 @@ xnd_init(xnd_t * const x, const uint32_t flags, ndt_context_t *ctx)
199
283
  return 0;
200
284
  }
201
285
 
286
+ case Union: {
287
+ xnd_t next = _union_next(x);
288
+ if (xnd_init(&next, flags, ctx) < 0) {
289
+ xnd_clear(&next, flags);
290
+ return -1;
291
+ }
292
+
293
+ return 0;
294
+ }
295
+
202
296
  /*
203
297
  * Ref represents a pointer to an explicit type. If XND_OWN_POINTERS
204
298
  * is set, allocate memory for that type and set the pointer.
@@ -247,10 +341,19 @@ xnd_init(xnd_t * const x, const uint32_t flags, ndt_context_t *ctx)
247
341
  return 0;
248
342
  }
249
343
 
344
+ /* Array is already initialized by calloc(). */
345
+ case Array:
346
+ return 0;
347
+
250
348
  /* Categorical is already initialized by calloc(). */
251
349
  case Categorical:
252
350
  return 0;
253
351
 
352
+ case VarDimElem:
353
+ ndt_err_format(ctx, NDT_ValueError,
354
+ "cannot initialize var elem dimension");
355
+ return -1;
356
+
254
357
  case Char:
255
358
  ndt_err_format(ctx, NDT_NotImplementedError, "char not implemented");
256
359
  return -1;
@@ -259,8 +362,8 @@ xnd_init(xnd_t * const x, const uint32_t flags, ndt_context_t *ctx)
259
362
  case Bool:
260
363
  case Int8: case Int16: case Int32: case Int64:
261
364
  case Uint8: case Uint16: case Uint32: case Uint64:
262
- case Float16: case Float32: case Float64:
263
- case Complex32: case Complex64: case Complex128:
365
+ case BFloat16: case Float16: case Float32: case Float64:
366
+ case BComplex32: case Complex32: case Complex64: case Complex128:
264
367
  case FixedString: case FixedBytes:
265
368
  case String: case Bytes:
266
369
  return 0;
@@ -288,7 +391,7 @@ xnd_empty_from_string(const char *s, uint32_t flags, ndt_context_t *ctx)
288
391
  {
289
392
  xnd_bitmap_t b = {.data=NULL, .size=0, .next=NULL};
290
393
  xnd_master_t *x;
291
- ndt_t *t;
394
+ const ndt_t *t;
292
395
  char *ptr;
293
396
 
294
397
  if (!(flags & XND_OWN_TYPE)) {
@@ -310,13 +413,13 @@ xnd_empty_from_string(const char *s, uint32_t flags, ndt_context_t *ctx)
310
413
 
311
414
  if (!ndt_is_concrete(t)) {
312
415
  ndt_err_format(ctx, NDT_ValueError, "type must be concrete");
313
- ndt_del(t);
416
+ ndt_decref(t);
314
417
  ndt_free(x);
315
418
  return NULL;
316
419
  }
317
420
 
318
421
  if (xnd_bitmap_init(&b, t,ctx) < 0) {
319
- ndt_del(t);
422
+ ndt_decref(t);
320
423
  ndt_free(x);
321
424
  return NULL;
322
425
  }
@@ -324,7 +427,7 @@ xnd_empty_from_string(const char *s, uint32_t flags, ndt_context_t *ctx)
324
427
  ptr = xnd_new(t, flags, ctx);
325
428
  if (ptr == NULL) {
326
429
  xnd_bitmap_clear(&b);
327
- ndt_del(t);
430
+ ndt_decref(t);
328
431
  ndt_free(x);
329
432
  return NULL;
330
433
  }
@@ -401,10 +504,13 @@ xnd_from_xnd(xnd_t *src, uint32_t flags, ndt_context_t *ctx)
401
504
  {
402
505
  xnd_master_t *x;
403
506
 
507
+ /* XXX xnd_from_xnd() will probably be replaced. */
508
+ assert(!(flags & XND_CUDA_MANAGED));
509
+
404
510
  x = ndt_alloc(1, sizeof *x);
405
511
  if (x == NULL) {
406
512
  xnd_clear(src, XND_OWN_ALL);
407
- ndt_del((ndt_t *)src->type);
513
+ ndt_decref(src->type);
408
514
  ndt_aligned_free(src->ptr);
409
515
  xnd_bitmap_clear(&src->bitmap);
410
516
  return ndt_memory_error(ctx);
@@ -424,6 +530,10 @@ xnd_from_xnd(xnd_t *src, uint32_t flags, ndt_context_t *ctx)
424
530
  static bool
425
531
  requires_clear(const ndt_t * const t)
426
532
  {
533
+ if (t->tag == Array) {
534
+ return true;
535
+ }
536
+
427
537
  const ndt_t *dtype = ndt_dtype(t);
428
538
 
429
539
  switch (dtype->tag) {
@@ -431,8 +541,8 @@ requires_clear(const ndt_t * const t)
431
541
  case Bool:
432
542
  case Int8: case Int16: case Int32: case Int64:
433
543
  case Uint8: case Uint16: case Uint32: case Uint64:
434
- case Float16: case Float32: case Float64:
435
- case Complex32: case Complex64: case Complex128:
544
+ case BFloat16: case Float16: case Float32: case Float64:
545
+ case BComplex32: case Complex32: case Complex64: case Complex128:
436
546
  case FixedString: case FixedBytes:
437
547
  return false;
438
548
  default:
@@ -445,6 +555,7 @@ static void
445
555
  xnd_clear_ref(xnd_t *x, const uint32_t flags)
446
556
  {
447
557
  assert(x->type->tag == Ref);
558
+ assert(!(flags & XND_CUDA_MANAGED));
448
559
 
449
560
  if (flags & XND_OWN_POINTERS) {
450
561
  ndt_aligned_free(XND_POINTER_DATA(x->ptr));
@@ -457,6 +568,7 @@ static void
457
568
  xnd_clear_string(xnd_t *x, const uint32_t flags)
458
569
  {
459
570
  assert(x->type->tag == String);
571
+ assert(!(flags & XND_CUDA_MANAGED));
460
572
 
461
573
  if (flags & XND_OWN_STRINGS) {
462
574
  ndt_free(XND_POINTER_DATA(x->ptr));
@@ -469,21 +581,38 @@ static void
469
581
  xnd_clear_bytes(xnd_t *x, const uint32_t flags)
470
582
  {
471
583
  assert(x->type->tag == Bytes);
584
+ assert(!(flags & XND_CUDA_MANAGED));
472
585
 
473
586
  if (flags & XND_OWN_BYTES) {
474
587
  ndt_aligned_free(XND_BYTES_DATA(x->ptr));
588
+ XND_BYTES_SIZE(x->ptr) = 0;
475
589
  XND_BYTES_DATA(x->ptr) = NULL;
476
590
  }
477
591
  }
478
592
 
479
- /* Clear embedded pointers in the data according to flags. */
593
+ /* Flexible array data must always be allocated by aligned allocators. */
480
594
  static void
595
+ xnd_clear_array(xnd_t *x, const uint32_t flags)
596
+ {
597
+ assert(x->type->tag == Array);
598
+ assert(!(flags & XND_CUDA_MANAGED));
599
+
600
+ if (flags & XND_OWN_ARRAYS) {
601
+ ndt_aligned_free(XND_ARRAY_DATA(x->ptr));
602
+ XND_ARRAY_SHAPE(x->ptr) = 0;
603
+ XND_ARRAY_DATA(x->ptr) = NULL;
604
+ }
605
+ }
606
+
607
+ /* Clear embedded pointers in the data according to flags. */
608
+ void
481
609
  xnd_clear(xnd_t * const x, const uint32_t flags)
482
610
  {
483
611
  NDT_STATIC_CONTEXT(ctx);
484
612
  const ndt_t * const t = x->type;
485
613
 
486
614
  assert(ndt_is_concrete(t));
615
+ assert(!(flags & XND_CUDA_MANAGED));
487
616
 
488
617
  switch (t->tag) {
489
618
  case FixedDim: {
@@ -516,6 +645,23 @@ xnd_clear(xnd_t * const x, const uint32_t flags)
516
645
  return;
517
646
  }
518
647
 
648
+ case VarDimElem: {
649
+ fprintf(stderr, "xnd_clear: internal error: unexpected var elem dimension\n");
650
+ return;
651
+ }
652
+
653
+ case Array: {
654
+ const int64_t shape = XND_ARRAY_SHAPE(x->ptr);
655
+
656
+ for (int64_t i = 0; i < shape; i++) {
657
+ xnd_t next = _array_next(x, i);
658
+ xnd_clear(&next, flags);
659
+ }
660
+
661
+ xnd_clear_array(x, flags);
662
+ return;
663
+ }
664
+
519
665
  case Tuple: {
520
666
  for (int64_t i = 0; i < t->Tuple.shape; i++) {
521
667
  xnd_t next = _tuple_next(x, i);
@@ -534,6 +680,12 @@ xnd_clear(xnd_t * const x, const uint32_t flags)
534
680
  return;
535
681
  }
536
682
 
683
+ case Union: {
684
+ xnd_t next = _union_next(x);
685
+ xnd_clear(&next, flags);
686
+ return;
687
+ }
688
+
537
689
  case Ref: {
538
690
  if (flags & XND_OWN_POINTERS) {
539
691
  xnd_t next = _ref_next(x);
@@ -559,8 +711,8 @@ xnd_clear(xnd_t * const x, const uint32_t flags)
559
711
  case Bool:
560
712
  case Int8: case Int16: case Int32: case Int64:
561
713
  case Uint8: case Uint16: case Uint32: case Uint64:
562
- case Float16: case Float32: case Float64:
563
- case Complex32: case Complex64: case Complex128:
714
+ case BFloat16: case Float16: case Float32: case Float64:
715
+ case BComplex32: case Complex32: case Complex64: case Complex128:
564
716
  case FixedString: case FixedBytes:
565
717
  return;
566
718
 
@@ -603,11 +755,22 @@ xnd_del_buffer(xnd_t *x, uint32_t flags)
603
755
  }
604
756
 
605
757
  if (flags & XND_OWN_TYPE) {
606
- ndt_del((ndt_t *)x->type);
758
+ ndt_decref(x->type);
607
759
  }
608
760
 
609
761
  if (flags & XND_OWN_DATA) {
610
- ndt_aligned_free(x->ptr);
762
+ if (flags & XND_CUDA_MANAGED) {
763
+ #ifdef HAVE_CUDA
764
+ xnd_cuda_free(x->ptr);
765
+ #else
766
+ fprintf(stderr,
767
+ "xnd_del_buffer: internal error: XND_CUDA_MANAGED set "
768
+ "without cuda support\n");
769
+ #endif
770
+ }
771
+ else {
772
+ ndt_aligned_free(x->ptr);
773
+ }
611
774
  }
612
775
  }
613
776
 
@@ -632,23 +795,48 @@ xnd_del(xnd_master_t *x)
632
795
 
633
796
 
634
797
  /*****************************************************************************/
635
- /* Subtrees (single elements are a special case) */
798
+ /* Index checks */
636
799
  /*****************************************************************************/
637
800
 
638
801
  static int64_t
639
802
  get_index(const xnd_index_t *key, int64_t shape, ndt_context_t *ctx)
803
+ {
804
+ switch (key->tag) {
805
+ case Index:
806
+ return adjust_index(key->Index, shape, ctx);
807
+
808
+ case FieldName:
809
+ ndt_err_format(ctx, NDT_ValueError,
810
+ "expected integer index, got field name: '%s'", key->FieldName);
811
+ return -1;
812
+
813
+ case Slice:
814
+ ndt_err_format(ctx, NDT_ValueError,
815
+ "expected integer index, got slice");
816
+ return -1;
817
+ }
818
+
819
+ /* NOT REACHED: tags should be exhaustive */
820
+ ndt_err_format(ctx, NDT_RuntimeError, "invalid index tag");
821
+ return -1;
822
+ }
823
+
824
+ /*
825
+ * Ragged arrays have multiple shapes in a single dimension that are not known
826
+ * when a VarDimElem is created. Adjusting the index must be done when the
827
+ * VarDimElem is accessed and the slices have been applied.
828
+ */
829
+ static int64_t
830
+ get_index_var_elem(const xnd_index_t *key, ndt_context_t *ctx)
640
831
  {
641
832
  switch (key->tag) {
642
833
  case Index: {
643
834
  int64_t i = key->Index;
644
- if (i < 0) {
645
- i += shape;
646
- }
647
835
 
648
- if (i < 0 || i >= shape || i > XND_SSIZE_MAX) {
836
+ if (i < INT32_MIN || i > INT32_MAX) {
649
837
  ndt_err_format(ctx, NDT_IndexError,
650
838
  "index with value %" PRIi64 " out of bounds", key->Index);
651
- return -1;
839
+ return INT64_MIN;
652
840
  }
653
841
 
654
842
  return i;
@@ -657,17 +845,17 @@ get_index(const xnd_index_t *key, int64_t shape, ndt_context_t *ctx)
657
845
  case FieldName:
658
846
  ndt_err_format(ctx, NDT_ValueError,
659
847
  "expected integer index, got field name: '%s'", key->FieldName);
660
- return -1;
848
+ return INT64_MIN;
661
849
 
662
850
  case Slice:
663
851
  ndt_err_format(ctx, NDT_ValueError,
664
852
  "expected integer index, got slice");
665
- return -1;
853
+ return INT64_MIN;
666
854
  }
667
855
 
668
856
  /* NOT REACHED: tags should be exhaustive */
669
857
  ndt_err_format(ctx, NDT_RuntimeError, "invalid index tag");
670
- return -1;
858
+ return INT64_MIN;
671
859
  }
672
860
 
673
861
  static int64_t
@@ -698,6 +886,34 @@ get_index_record(const ndt_t *t, const xnd_index_t *key, ndt_context_t *ctx)
698
886
  return -1;
699
887
  }
700
888
 
889
+ static int64_t
890
+ get_index_union(const ndt_t *t, const xnd_index_t *key, ndt_context_t *ctx)
891
+ {
892
+ assert(t->tag == Union);
893
+
894
+ switch (key->tag) {
895
+ case FieldName: {
896
+ int64_t i;
897
+
898
+ for (i = 0; i < t->Union.ntags; i++) {
899
+ if (strcmp(key->FieldName, t->Union.tags[i]) == 0) {
900
+ return i;
901
+ }
902
+ }
903
+
904
+ ndt_err_format(ctx, NDT_ValueError,
905
+ "invalid field name '%s'", key->FieldName);
906
+ return -1;
907
+ }
908
+ case Index: case Slice:
909
+ return get_index(key, t->Union.ntags, ctx);
910
+ }
911
+
912
+ /* NOT REACHED: tags should be exhaustive */
913
+ ndt_err_format(ctx, NDT_RuntimeError, "invalid index tag");
914
+ return -1;
915
+ }
916
+
701
917
  static void
702
918
  set_index_exception(bool indexable, ndt_context_t *ctx)
703
919
  {
@@ -709,11 +925,72 @@ set_index_exception(bool indexable, ndt_context_t *ctx)
709
925
  }
710
926
  }
711
927
 
712
- /* Return a typed subtree of a memory block */
928
+
929
+ /*****************************************************************************/
930
+ /* Stored indices */
931
+ /*****************************************************************************/
932
+
933
+ bool
934
+ have_stored_index(const ndt_t *t)
935
+ {
936
+ return t->tag == VarDimElem;
937
+ }
938
+
939
+ int64_t
940
+ get_stored_index(const ndt_t *t)
941
+ {
942
+ return t->VarDimElem.index;
943
+ }
944
+
945
+ /* skip stored indices */
713
946
  xnd_t
714
- xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t *ctx)
947
+ apply_stored_index(const xnd_t *x, ndt_context_t *ctx)
715
948
  {
716
949
  const ndt_t * const t = x->type;
950
+ int64_t start, step, shape;
951
+
952
+ if (t->tag != VarDimElem) {
953
+ ndt_err_format(ctx, NDT_RuntimeError,
954
+ "apply_stored_index: need VarDimElem");
955
+ return xnd_error;
956
+ }
957
+
958
+ shape = ndt_var_indices(&start, &step, t, x->index, ctx);
959
+ if (shape < 0) {
960
+ return xnd_error;
961
+ }
962
+
963
+ const int64_t i = adjust_index(t->VarDimElem.index, shape, ctx);
964
+ if (i < 0) {
965
+ return xnd_error;
966
+ }
967
+
968
+ return xnd_var_dim_next(x, start, step, i);
969
+ }
970
+
971
+ xnd_t
972
+ apply_stored_indices(const xnd_t *x, ndt_context_t *ctx)
973
+ {
974
+ xnd_t tl = *x;
975
+
976
+ while (tl.type->tag == VarDimElem) {
977
+ tl = apply_stored_index(&tl, ctx);
978
+ }
979
+
980
+ return tl;
981
+ }
982
+
983
+
984
+ /*****************************************************************************/
985
+ /* Subtrees (single elements are a special case) */
986
+ /*****************************************************************************/
987
+
988
+ /* Return a typed subtree of a memory block */
989
+ static xnd_t
990
+ _xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t *ctx)
991
+ {
992
+ APPLY_STORED_INDICES_XND(x)
993
+ const ndt_t * const t = x->type;
717
994
 
718
995
  assert(ndt_is_concrete(t));
719
996
 
@@ -731,14 +1008,13 @@ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t
731
1008
 
732
1009
  switch (t->tag) {
733
1010
  case FixedDim: {
734
- if (i < 0 || i >= t->FixedDim.shape) {
735
- ndt_err_format(ctx, NDT_ValueError,
736
- "fixed dim index out of bounds");
1011
+ const int64_t k = adjust_index(i, t->FixedDim.shape, ctx);
1012
+ if (k < 0) {
737
1013
  return xnd_error;
738
1014
  }
739
1015
 
740
- const xnd_t next = xnd_fixed_dim_next(x, i);
741
- return xnd_subtree_index(&next, indices+1, len-1, ctx);
1016
+ const xnd_t next = xnd_fixed_dim_next(x, k);
1017
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
742
1018
  }
743
1019
 
744
1020
  case VarDim: {
@@ -749,41 +1025,74 @@ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t
749
1025
  return xnd_error;
750
1026
  }
751
1027
 
752
- if (i < 0 || i >= shape) {
753
- ndt_err_format(ctx, NDT_ValueError, "var dim index out of bounds");
1028
+ const int64_t k = adjust_index(i, shape, ctx);
1029
+ if (k < 0) {
754
1030
  return xnd_error;
755
1031
  }
756
1032
 
757
- const xnd_t next = xnd_var_dim_next(x, start, step, i);
758
- return xnd_subtree_index(&next, indices+1, len-1, ctx);
1033
+ const xnd_t next = xnd_var_dim_next(x, start, step, k);
1034
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
759
1035
  }
760
1036
 
761
1037
  case Tuple: {
762
- if (i < 0 || i >= t->Tuple.shape) {
763
- ndt_err_format(ctx, NDT_ValueError, "tuple index out of bounds");
1038
+ const int64_t k = adjust_index(i, t->Tuple.shape, ctx);
1039
+ if (k < 0) {
764
1040
  return xnd_error;
765
1041
  }
766
1042
 
767
- const xnd_t next = xnd_tuple_next(x, i, ctx);
1043
+ const xnd_t next = xnd_tuple_next(x, k, ctx);
768
1044
  if (next.ptr == NULL) {
769
1045
  return xnd_error;
770
1046
  }
771
1047
 
772
- return xnd_subtree_index(&next, indices+1, len-1, ctx);
1048
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
773
1049
  }
774
1050
 
775
1051
  case Record: {
776
- if (i < 0 || i >= t->Record.shape) {
777
- ndt_err_format(ctx, NDT_ValueError, "record index out of bounds");
1052
+ const int64_t k = adjust_index(i, t->Record.shape, ctx);
1053
+ if (k < 0) {
778
1054
  return xnd_error;
779
1055
  }
780
1056
 
781
- const xnd_t next = xnd_record_next(x, i, ctx);
1057
+ const xnd_t next = xnd_record_next(x, k, ctx);
1058
+ if (next.ptr == NULL) {
1059
+ return xnd_error;
1060
+ }
1061
+
1062
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
1063
+ }
1064
+
1065
+ case Union: {
1066
+ const int64_t k = adjust_index(i, t->Union.ntags, ctx);
1067
+ if (k < 0) {
1068
+ return xnd_error;
1069
+ }
1070
+
1071
+ const uint8_t l = XND_UNION_TAG(x->ptr);
1072
+ if (k != l) {
1073
+ ndt_err_format(ctx, NDT_ValueError,
1074
+ "tag mismatch in union addressing: expected '%s', got '%s'",
1075
+ t->Union.tags[l], t->Union.tags[k]);
1076
+ return xnd_error;
1077
+ }
1078
+
1079
+ const xnd_t next = xnd_union_next(x, ctx);
782
1080
  if (next.ptr == NULL) {
783
1081
  return xnd_error;
784
1082
  }
785
1083
 
786
- return xnd_subtree_index(&next, indices+1, len-1, ctx);
1084
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
1085
+ }
1086
+
1087
+ case Array: {
1088
+ const int64_t shape = XND_ARRAY_SHAPE(x->ptr);
1089
+ const int64_t k = adjust_index(i, shape, ctx);
1090
+ if (k < 0) {
1091
+ return xnd_error;
1092
+ }
1093
+
1094
+ const xnd_t next = xnd_array_next(x, k);
1095
+ return _xnd_subtree_index(&next, indices+1, len-1, ctx);
787
1096
  }
788
1097
 
789
1098
  case Ref: {
@@ -792,7 +1101,7 @@ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t
792
1101
  return xnd_error;
793
1102
  }
794
1103
 
795
- return xnd_subtree_index(&next, indices, len, ctx);
1104
+ return _xnd_subtree_index(&next, indices, len, ctx);
796
1105
  }
797
1106
 
798
1107
  case Constr: {
@@ -801,16 +1110,16 @@ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t
801
1110
  return xnd_error;
802
1111
  }
803
1112
 
804
- return xnd_subtree_index(&next, indices, len, ctx);
1113
+ return _xnd_subtree_index(&next, indices, len, ctx);
805
1114
  }
806
1115
 
807
- case Nominal: {
1116
+ case Nominal: {
808
1117
  const xnd_t next = xnd_nominal_next(x, ctx);
809
1118
  if (next.ptr == NULL) {
810
1119
  return xnd_error;
811
1120
  }
812
1121
 
813
- return xnd_subtree_index(&next, indices, len, ctx);
1122
+ return _xnd_subtree_index(&next, indices, len, ctx);
814
1123
  }
815
1124
 
816
1125
  default:
@@ -819,6 +1128,17 @@ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t
819
1128
  }
820
1129
  }
821
1130
 
1131
+ xnd_t
1132
+ xnd_subtree_index(const xnd_t *x, const int64_t *indices, int len, ndt_context_t *ctx)
1133
+ {
1134
+ if (len < 0 || len > NDT_MAX_DIM) {
1135
+ ndt_err_format(ctx, NDT_IndexError, "too many indices");
1136
+ return xnd_error;
1137
+ }
1138
+
1139
+ return _xnd_subtree_index(x, indices, len, ctx);
1140
+ }
1141
+
822
1142
  /*
823
1143
  * Return a zero copy view of an xnd object. If a dtype is indexable,
824
1144
  * descend into the dtype.
@@ -827,6 +1147,7 @@ static xnd_t
827
1147
  _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexable,
828
1148
  ndt_context_t *ctx)
829
1149
  {
1150
+ APPLY_STORED_INDICES_XND(x)
830
1151
  const ndt_t *t = x->type;
831
1152
  const xnd_index_t *key;
832
1153
 
@@ -846,7 +1167,7 @@ _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexabl
846
1167
 
847
1168
  switch (t->tag) {
848
1169
  case FixedDim: {
849
- int64_t i = get_index(key, t->FixedDim.shape, ctx);
1170
+ const int64_t i = get_index(key, t->FixedDim.shape, ctx);
850
1171
  if (i < 0) {
851
1172
  return xnd_error;
852
1173
  }
@@ -857,14 +1178,13 @@ _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexabl
857
1178
 
858
1179
  case VarDim: {
859
1180
  int64_t start, step, shape;
860
- int64_t i;
861
1181
 
862
1182
  shape = ndt_var_indices(&start, &step, t, x->index, ctx);
863
1183
  if (shape < 0) {
864
1184
  return xnd_error;
865
1185
  }
866
1186
 
867
- i = get_index(key, shape, ctx);
1187
+ const int64_t i = get_index(key, shape, ctx);
868
1188
  if (i < 0) {
869
1189
  return xnd_error;
870
1190
  }
@@ -888,7 +1208,7 @@ _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexabl
888
1208
  }
889
1209
 
890
1210
  case Record: {
891
- int64_t i = get_index_record(t, key, ctx);
1211
+ const int64_t i = get_index_record(t, key, ctx);
892
1212
  if (i < 0) {
893
1213
  return xnd_error;
894
1214
  }
@@ -901,6 +1221,39 @@ _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexabl
901
1221
  return _xnd_subtree(&next, indices+1, len-1, true, ctx);
902
1222
  }
903
1223
 
1224
+ case Union: {
1225
+ const int64_t i = get_index_union(t, key, ctx);
1226
+ if (i < 0) {
1227
+ return xnd_error;
1228
+ }
1229
+
1230
+ const uint8_t k = XND_UNION_TAG(x->ptr);
1231
+ if (i != k) {
1232
+ ndt_err_format(ctx, NDT_ValueError,
1233
+ "tag mismatch in union addressing: expected '%s', got '%s'",
1234
+ t->Union.tags[k], t->Union.tags[i]);
1235
+ return xnd_error;
1236
+ }
1237
+
1238
+ const xnd_t next = xnd_union_next(x, ctx);
1239
+ if (next.ptr == NULL) {
1240
+ return xnd_error;
1241
+ }
1242
+
1243
+ return _xnd_subtree(&next, indices+1, len-1, true, ctx);
1244
+ }
1245
+
1246
+ case Array: {
1247
+ const int64_t shape = XND_ARRAY_SHAPE(x->ptr);
1248
+ const int64_t i = get_index(key, shape, ctx);
1249
+ if (i < 0) {
1250
+ return xnd_error;
1251
+ }
1252
+
1253
+ const xnd_t next = xnd_array_next(x, i);
1254
+ return _xnd_subtree(&next, indices+1, len-1, true, ctx);
1255
+ }
1256
+
904
1257
  case Ref: {
905
1258
  const xnd_t next = xnd_ref_next(x, ctx);
906
1259
  if (next.ptr == NULL) {
@@ -941,13 +1294,18 @@ _xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, bool indexabl
941
1294
  xnd_t
942
1295
  xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx)
943
1296
  {
1297
+ if (len < 0 || len > NDT_MAX_DIM) {
1298
+ ndt_err_format(ctx, NDT_IndexError, "too many indices");
1299
+ return xnd_error;
1300
+ }
1301
+
944
1302
  return _xnd_subtree(x, indices, len, false, ctx);
945
1303
  }
946
1304
 
947
1305
  static xnd_t xnd_index(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx);
948
1306
  static xnd_t xnd_slice(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx);
949
1307
 
950
- xnd_t
1308
+ static xnd_t
951
1309
  xnd_multikey(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx)
952
1310
  {
953
1311
  const ndt_t *t = x->type;
@@ -957,18 +1315,14 @@ xnd_multikey(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t
957
1315
  assert(ndt_is_concrete(t));
958
1316
  assert(x->ptr != NULL);
959
1317
 
960
- if (len > t->ndim) {
1318
+ if (len > ndt_logical_ndim(t)) {
961
1319
  ndt_err_format(ctx, NDT_IndexError, "too many indices");
962
1320
  return xnd_error;
963
1321
  }
964
1322
 
965
1323
  if (len == 0) {
966
1324
  xnd_t next = *x;
967
- next.type = ndt_copy(t, ctx);
968
- if (next.type == NULL) {
969
- return xnd_error;
970
- }
971
-
1325
+ ndt_incref(next.type);
972
1326
  return next;
973
1327
  }
974
1328
 
@@ -997,6 +1351,7 @@ xnd_multikey(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t
997
1351
  static xnd_t
998
1352
  xnd_index(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx)
999
1353
  {
1354
+ xnd_index_t xindices[NDT_MAX_DIM+1];
1000
1355
  const ndt_t *t = x->type;
1001
1356
  const xnd_index_t *key;
1002
1357
 
@@ -1004,6 +1359,17 @@ xnd_index(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1004
1359
  assert(ndt_is_concrete(t));
1005
1360
  assert(x->ptr != NULL);
1006
1361
 
1362
+ /* Hidden element type, insert the stored index. */
1363
+ if (have_stored_index(t)) {
1364
+ xindices[0].tag = Index;
1365
+ xindices[0].Index = get_stored_index(t);
1366
+ for (int k = 0; k < len; k++) {
1367
+ xindices[k+1] = indices[k];
1368
+ }
1369
+ indices = xindices;
1370
+ len = len+1;
1371
+ }
1372
+
1007
1373
  key = &indices[0];
1008
1374
  assert(key->tag == Index);
1009
1375
 
@@ -1018,10 +1384,35 @@ xnd_index(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1018
1384
  return xnd_multikey(&next, indices+1, len-1, ctx);
1019
1385
  }
1020
1386
 
1021
- case VarDim: {
1022
- ndt_err_format(ctx, NDT_IndexError,
1023
- "mixed indexing and slicing is not supported for var dimensions");
1024
- return xnd_error;
1387
+ case VarDim: case VarDimElem: {
1388
+ const ndt_t *u;
1389
+
1390
+ if (ndt_is_optional(t)) {
1391
+ ndt_err_format(ctx, NDT_NotImplementedError,
1392
+ "optional dimensions are temporarily disabled");
1393
+ return xnd_error;
1394
+ }
1395
+
1396
+ const int64_t i = get_index_var_elem(key, ctx);
1397
+ if (i == INT64_MIN) {
1398
+ return xnd_error;
1399
+ }
1400
+
1401
+ const xnd_t next = xnd_var_dim_next(x, 0, 1, 0);
1402
+ const xnd_t tail = xnd_multikey(&next, indices+1, len-1, ctx);
1403
+ if (xnd_err_occurred(&tail)) {
1404
+ return xnd_error;
1405
+ }
1406
+
1407
+ u = ndt_convert_to_var_elem(t, tail.type, i, ctx);
1408
+ ndt_decref(tail.type);
1409
+ if (u == NULL) {
1410
+ return xnd_error;
1411
+ }
1412
+
1413
+ xnd_t ret = *x;
1414
+ ret.type = u;
1415
+ return ret;
1025
1416
  }
1026
1417
 
1027
1418
  default:
@@ -1059,9 +1450,10 @@ xnd_slice(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1059
1450
  }
1060
1451
 
1061
1452
  xnd_t ret = *x;
1062
- ret.type = ndt_fixed_dim((ndt_t *)sliced.type, shape,
1453
+ ret.type = ndt_fixed_dim(sliced.type, shape,
1063
1454
  t->Concrete.FixedDim.step * step,
1064
1455
  ctx);
1456
+ ndt_decref(sliced.type);
1065
1457
  if (ret.type == NULL) {
1066
1458
  return xnd_error;
1067
1459
  }
@@ -1093,15 +1485,14 @@ xnd_slice(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1093
1485
 
1094
1486
  slices = ndt_var_add_slice(&nslices, t, start, stop, step, ctx);
1095
1487
  if (slices == NULL) {
1488
+ ndt_decref(next.type);
1096
1489
  return xnd_error;
1097
1490
  }
1098
1491
 
1099
1492
  xnd_t ret = *x;
1100
- ret.type = ndt_var_dim((ndt_t *)next.type,
1101
- ExternalOffsets,
1102
- t->Concrete.VarDim.noffsets, t->Concrete.VarDim.offsets,
1103
- nslices, slices,
1104
- ctx);
1493
+ ret.type = ndt_var_dim(next.type, t->Concrete.VarDim.offsets,
1494
+ nslices, slices, false, ctx);
1495
+ ndt_decref(next.type);
1105
1496
  if (ret.type == NULL) {
1106
1497
  return xnd_error;
1107
1498
  }
@@ -1111,6 +1502,32 @@ xnd_slice(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1111
1502
  return ret;
1112
1503
  }
1113
1504
 
1505
+ case VarDimElem: {
1506
+ int64_t i = t->VarDimElem.index;
1507
+
1508
+ if (ndt_is_optional(t)) {
1509
+ ndt_err_format(ctx, NDT_NotImplementedError,
1510
+ "optional dimensions are temporarily disabled");
1511
+ return xnd_error;
1512
+ }
1513
+
1514
+ const xnd_t next = xnd_var_dim_next(x, 0, 1, 0);
1515
+ const xnd_t tail = xnd_multikey(&next, indices, len, ctx);
1516
+ if (xnd_err_occurred(&tail)) {
1517
+ return xnd_error;
1518
+ }
1519
+
1520
+ const ndt_t *u = ndt_convert_to_var_elem(t, tail.type, i, ctx);
1521
+ ndt_decref(tail.type);
1522
+ if (u == NULL) {
1523
+ return xnd_error;
1524
+ }
1525
+
1526
+ xnd_t ret = *x;
1527
+ ret.type = u;
1528
+ return ret;
1529
+ }
1530
+
1114
1531
  case Tuple: {
1115
1532
  ndt_err_format(ctx, NDT_NotImplementedError,
1116
1533
  "slicing tuples is not supported");
@@ -1123,42 +1540,108 @@ xnd_slice(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *c
1123
1540
  return xnd_error;
1124
1541
  }
1125
1542
 
1543
+ case Union: {
1544
+ ndt_err_format(ctx, NDT_NotImplementedError,
1545
+ "slicing unions is not supported");
1546
+ return xnd_error;
1547
+ }
1548
+
1126
1549
  default:
1127
1550
  ndt_err_format(ctx, NDT_IndexError, "type not sliceable");
1128
1551
  return xnd_error;
1129
1552
  }
1130
1553
  }
1131
1554
 
1555
+ /* Validate indices for mixed indexed/sliced var dimensions. */
1556
+ static bool
1557
+ validate_indices(const xnd_t *x, ndt_context_t *ctx)
1558
+ {
1559
+ const ndt_t * const t = x->type;
1560
+
1561
+ assert(ndt_is_concrete(t));
1562
+
1563
+ switch (t->tag) {
1564
+ case VarDim: {
1565
+ int64_t start, step, shape;
1566
+
1567
+ shape = ndt_var_indices_non_empty(&start, &step, t, x->index, ctx);
1568
+ if (shape < 0) {
1569
+ return false;
1570
+ }
1571
+
1572
+ for (int64_t i = 0; i < shape; i++) {
1573
+ const xnd_t next = xnd_var_dim_next(x, start, step, i);
1574
+ if (!validate_indices(&next, ctx)) {
1575
+ return false;
1576
+ }
1577
+ }
1578
+
1579
+ return true;
1580
+ }
1581
+
1582
+ case VarDimElem: {
1583
+ int64_t start, step, shape;
1584
+
1585
+ shape = ndt_var_indices(&start, &step, t, x->index, ctx);
1586
+ if (shape < 0) {
1587
+ return false;
1588
+ }
1589
+
1590
+ const int64_t k = adjust_index(t->VarDimElem.index, shape, ctx);
1591
+ if (k < 0) {
1592
+ return false;
1593
+ }
1594
+
1595
+ const xnd_t next = xnd_var_dim_next(x, start, step, k);
1596
+ return validate_indices(&next, ctx);
1597
+ }
1598
+
1599
+ default:
1600
+ return true;
1601
+ }
1602
+ }
1603
+
1132
1604
  xnd_t
1133
1605
  xnd_subscript(const xnd_t *x, const xnd_index_t indices[], int len,
1134
1606
  ndt_context_t *ctx)
1135
1607
  {
1608
+ bool have_index = false;
1136
1609
  bool have_slice = false;
1137
1610
 
1611
+ if (len < 0 || len > NDT_MAX_DIM) {
1612
+ ndt_err_format(ctx, NDT_IndexError, "too many indices");
1613
+ return xnd_error;
1614
+ }
1615
+
1138
1616
  for (int i = 0; i < len; i++) {
1617
+ if (indices[i].tag == Index) {
1618
+ have_index = true;
1619
+ }
1139
1620
  if (indices[i].tag == Slice) {
1140
1621
  have_slice = true;
1141
- break;
1142
1622
  }
1143
1623
  }
1144
1624
 
1145
1625
  if (have_slice) {
1146
- return xnd_multikey(x, indices, len, ctx);
1147
- }
1148
- else {
1149
- xnd_t res = xnd_subtree(x, indices, len, ctx);
1150
- const ndt_t *t;
1626
+ xnd_t res = xnd_multikey(x, indices, len, ctx);
1627
+ if (xnd_err_occurred(&res)) {
1628
+ return xnd_error;
1629
+ }
1151
1630
 
1152
- if (res.ptr == NULL) {
1631
+ if (have_index && !validate_indices(&res, ctx)) {
1632
+ ndt_decref(res.type);
1153
1633
  return xnd_error;
1154
1634
  }
1155
1635
 
1156
- t = ndt_copy(res.type, ctx);
1157
- if (t == NULL) {
1636
+ return res;
1637
+ }
1638
+ else {
1639
+ xnd_t res = xnd_subtree(x, indices, len, ctx);
1640
+ if (res.ptr == NULL) {
1158
1641
  return xnd_error;
1159
1642
  }
1160
1643
 
1161
- res.type = t;
1644
+ ndt_incref(res.type);
1162
1645
  return res;
1163
1646
  }
1164
1647
  }
@@ -1302,3 +1785,42 @@ xnd_double_is_big_endian(void)
1302
1785
  {
1303
1786
  return xnd_double_format==IEEE_BIG_ENDIAN;
1304
1787
  }
1788
+
1789
+ static float
1790
+ bfloat16_to_float(uint16_t b)
1791
+ {
1792
+ float f = 0;
1793
+ uint16_t *p = (uint16_t *)((char *)&f);
1794
+
1795
+ if (xnd_float_is_big_endian()) {
1796
+ p[0] = b;
1797
+ }
1798
+ else {
1799
+ p[1] = b;
1800
+ }
1801
+
1802
+ return f;
1803
+ }
1804
+
1805
+ /*
1806
+ * Unlike the corresponding Python conversion functions, Tensorflow does
1807
+ * not raise OverflowError.
1808
+ */
1809
+ void
1810
+ xnd_bfloat_pack(char *p, double x)
1811
+ {
1812
+ float f = (float)x;
1813
+ uint16_t u16;
1814
+
1815
+ u16 = xnd_round_to_bfloat16(f);
1816
+ PACK_SINGLE(p, u16, uint16_t, 0);
1817
+ }
1818
+
1819
+ double
1820
+ xnd_bfloat_unpack(char *p)
1821
+ {
1822
+ uint16_t u16;
1823
+
1824
+ UNPACK_SINGLE(u16, p, uint16_t, 0);
1825
+ return bfloat16_to_float(u16);
1826
+ }