@isidorus/cpu 0.0.0-alpha.0 → 0.0.0-alpha.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +47 -47
- package/binding.gyp +115 -102
- package/dist/_native.d.ts.map +1 -0
- package/dist/_native.js.map +1 -0
- package/dist/{ts/graph.d.ts → graph.d.ts} +25 -1
- package/dist/graph.d.ts.map +1 -0
- package/dist/{ts/graph.js → graph.js} +30 -2
- package/dist/graph.js.map +1 -0
- package/dist/{ts/index.d.ts → index.d.ts} +3 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/{ts/index.js → index.js} +3 -0
- package/dist/index.js.map +1 -0
- package/dist/inference-pool.d.ts.map +1 -0
- package/dist/inference-pool.js.map +1 -0
- package/dist/install-libtensorflow.d.ts.map +1 -0
- package/dist/{ts/install-libtensorflow.js → install-libtensorflow.js} +6 -6
- package/dist/install-libtensorflow.js.map +1 -0
- package/dist/model/index.d.ts +5 -0
- package/dist/model/index.d.ts.map +1 -0
- package/dist/model/index.js +3 -0
- package/dist/model/index.js.map +1 -0
- package/dist/model/layer.d.ts +25 -0
- package/dist/model/layer.d.ts.map +1 -0
- package/dist/model/layer.js +2 -0
- package/dist/model/layer.js.map +1 -0
- package/dist/model/layers.d.ts +47 -0
- package/dist/model/layers.d.ts.map +1 -0
- package/dist/model/layers.js +191 -0
- package/dist/model/layers.js.map +1 -0
- package/dist/model/sequential.d.ts +91 -0
- package/dist/model/sequential.d.ts.map +1 -0
- package/dist/model/sequential.js +248 -0
- package/dist/model/sequential.js.map +1 -0
- package/dist/ops/array_ops.d.ts.map +1 -0
- package/dist/ops/array_ops.js.map +1 -0
- package/dist/ops/index.d.ts.map +1 -0
- package/dist/ops/index.js.map +1 -0
- package/dist/ops/math_ops.d.ts.map +1 -0
- package/dist/{ts/ops → ops}/math_ops.js +1 -1
- package/dist/ops/math_ops.js.map +1 -0
- package/dist/ops/nn_ops.d.ts.map +1 -0
- package/dist/{ts/ops → ops}/nn_ops.js +9 -9
- package/dist/ops/nn_ops.js.map +1 -0
- package/dist/ops/variable_ops.d.ts.map +1 -0
- package/dist/{ts/ops → ops}/variable_ops.js +7 -9
- package/dist/ops/variable_ops.js.map +1 -0
- package/dist/optimizers/adam.d.ts +26 -0
- package/dist/optimizers/adam.d.ts.map +1 -0
- package/dist/optimizers/adam.js +97 -0
- package/dist/optimizers/adam.js.map +1 -0
- package/dist/optimizers/index.d.ts +5 -0
- package/dist/optimizers/index.d.ts.map +1 -0
- package/dist/optimizers/index.js +4 -0
- package/dist/optimizers/index.js.map +1 -0
- package/dist/optimizers/rmsprop.d.ts +22 -0
- package/dist/optimizers/rmsprop.d.ts.map +1 -0
- package/dist/optimizers/rmsprop.js +65 -0
- package/dist/optimizers/rmsprop.js.map +1 -0
- package/dist/optimizers/sgd.d.ts +53 -0
- package/dist/optimizers/sgd.d.ts.map +1 -0
- package/dist/optimizers/sgd.js +76 -0
- package/dist/optimizers/sgd.js.map +1 -0
- package/dist/session.d.ts.map +1 -0
- package/dist/session.js.map +1 -0
- package/dist/tsconfig.tsbuildinfo +1 -0
- package/package.json +63 -63
- package/scripts/install.js +100 -100
- package/scripts/test-install.js +82 -82
- package/scripts/test.js +45 -45
- package/src/native/addon.cc +11 -11
- package/src/native/graph.cc +577 -442
- package/src/native/graph.h +41 -51
- package/src/native/platform_tf.h +7 -7
- package/src/native/session.cc +796 -715
- package/src/native/session.h +91 -91
- package/dist/ts/_native.d.ts.map +0 -1
- package/dist/ts/_native.js.map +0 -1
- package/dist/ts/graph.d.ts.map +0 -1
- package/dist/ts/graph.js.map +0 -1
- package/dist/ts/index.d.ts.map +0 -1
- package/dist/ts/index.js.map +0 -1
- package/dist/ts/inference-pool.d.ts.map +0 -1
- package/dist/ts/inference-pool.js.map +0 -1
- package/dist/ts/inference_pool.d.ts +0 -99
- package/dist/ts/inference_pool.d.ts.map +0 -1
- package/dist/ts/inference_pool.js +0 -370
- package/dist/ts/inference_pool.js.map +0 -1
- package/dist/ts/install-libtensorflow.d.ts.map +0 -1
- package/dist/ts/install-libtensorflow.js.map +0 -1
- package/dist/ts/ops/array_ops.d.ts.map +0 -1
- package/dist/ts/ops/array_ops.js.map +0 -1
- package/dist/ts/ops/index.d.ts.map +0 -1
- package/dist/ts/ops/index.js.map +0 -1
- package/dist/ts/ops/math_ops.d.ts.map +0 -1
- package/dist/ts/ops/math_ops.js.map +0 -1
- package/dist/ts/ops/nn_ops.d.ts.map +0 -1
- package/dist/ts/ops/nn_ops.js.map +0 -1
- package/dist/ts/ops/variable_ops.d.ts.map +0 -1
- package/dist/ts/ops/variable_ops.js.map +0 -1
- package/dist/ts/session.d.ts.map +0 -1
- package/dist/ts/session.js.map +0 -1
- /package/dist/{ts/_native.d.ts → _native.d.ts} +0 -0
- /package/dist/{ts/_native.js → _native.js} +0 -0
- /package/dist/{ts/inference-pool.d.ts → inference-pool.d.ts} +0 -0
- /package/dist/{ts/inference-pool.js → inference-pool.js} +0 -0
- /package/dist/{ts/install-libtensorflow.d.ts → install-libtensorflow.d.ts} +0 -0
- /package/dist/{ts/ops → ops}/array_ops.d.ts +0 -0
- /package/dist/{ts/ops → ops}/array_ops.js +0 -0
- /package/dist/{ts/ops → ops}/index.d.ts +0 -0
- /package/dist/{ts/ops → ops}/index.js +0 -0
- /package/dist/{ts/ops → ops}/math_ops.d.ts +0 -0
- /package/dist/{ts/ops → ops}/nn_ops.d.ts +0 -0
- /package/dist/{ts/ops → ops}/variable_ops.d.ts +0 -0
- /package/dist/{ts/session.d.ts → session.d.ts} +0 -0
- /package/dist/{ts/session.js → session.js} +0 -0
package/src/native/graph.cc
CHANGED
|
@@ -1,442 +1,577 @@
|
|
|
1
|
-
#include "graph.h"
|
|
2
|
-
#include <cstring>
|
|
3
|
-
#include <stdexcept>
|
|
4
|
-
#include <unordered_map>
|
|
5
|
-
|
|
6
|
-
// -------------------------------------------------------------------
|
|
7
|
-
// GraphWrap - N-API ObjectWrap around TF_Graph
|
|
8
|
-
//
|
|
9
|
-
// Exposes graph construction primitives:
|
|
10
|
-
// addOp(type, inputs, attrs) -> {opName, numOutputs}
|
|
11
|
-
// hasOp(name) -> boolean
|
|
12
|
-
// opOutputType(name, index) -> DType integer
|
|
13
|
-
// opOutputShape(name, index) -> number[] (-1 = unknown dim)
|
|
14
|
-
// toGraphDef() -> Buffer (serialized GraphDef proto)
|
|
15
|
-
// numOps() -> number
|
|
16
|
-
//-------------------------------------------------------------------
|
|
17
|
-
|
|
18
|
-
Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
19
|
-
{
|
|
20
|
-
Napi::Function func = DefineClass(env, "Graph", {
|
|
21
|
-
InstanceMethod<&GraphWrap::AddOp>("addOp"),
|
|
22
|
-
InstanceMethod<&GraphWrap::HasOp>("hasOp"),
|
|
23
|
-
InstanceMethod<&GraphWrap::OpOutputType>("opOutputType"),
|
|
24
|
-
InstanceMethod<&GraphWrap::OpOutputShape>("opOutputShape"),
|
|
25
|
-
InstanceMethod<&GraphWrap::ToGraphDef>("toGraphDef"),
|
|
26
|
-
InstanceMethod<&GraphWrap::NumOps>("numOps"),
|
|
27
|
-
InstanceMethod<&GraphWrap::ImportGraphDef>("importGraphDef"),
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
graph_
|
|
48
|
-
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
//
|
|
53
|
-
//
|
|
54
|
-
//
|
|
55
|
-
//
|
|
56
|
-
//
|
|
57
|
-
// { kind: "
|
|
58
|
-
// { kind: "
|
|
59
|
-
// { kind: "
|
|
60
|
-
// { kind: "
|
|
61
|
-
// { kind: "shape", value:
|
|
62
|
-
// { kind: "
|
|
63
|
-
// { kind: "
|
|
64
|
-
//
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
//
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
{"
|
|
132
|
-
{"
|
|
133
|
-
{"
|
|
134
|
-
{"
|
|
135
|
-
{"
|
|
136
|
-
{"
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
auto
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
if (!
|
|
324
|
-
return
|
|
325
|
-
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
if (!graph_)
|
|
377
|
-
return
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
Napi::Value GraphWrap::
|
|
400
|
-
{
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
if (!
|
|
436
|
-
{
|
|
437
|
-
Napi::
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
}
|
|
441
|
-
|
|
442
|
-
|
|
1
|
+
#include "graph.h"
|
|
2
|
+
#include <cstring>
|
|
3
|
+
#include <stdexcept>
|
|
4
|
+
#include <unordered_map>
|
|
5
|
+
|
|
6
|
+
// -------------------------------------------------------------------
|
|
7
|
+
// GraphWrap - N-API ObjectWrap around TF_Graph
|
|
8
|
+
//
|
|
9
|
+
// Exposes graph construction primitives:
|
|
10
|
+
// addOp(type, inputs, attrs) -> {opName, numOutputs}
|
|
11
|
+
// hasOp(name) -> boolean
|
|
12
|
+
// opOutputType(name, index) -> DType integer
|
|
13
|
+
// opOutputShape(name, index) -> number[] (-1 = unknown dim)
|
|
14
|
+
// toGraphDef() -> Buffer (serialized GraphDef proto)
|
|
15
|
+
// numOps() -> number
|
|
16
|
+
//-------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
19
|
+
{
|
|
20
|
+
Napi::Function func = DefineClass(env, "Graph", {
|
|
21
|
+
InstanceMethod<&GraphWrap::AddOp>("addOp"),
|
|
22
|
+
InstanceMethod<&GraphWrap::HasOp>("hasOp"),
|
|
23
|
+
InstanceMethod<&GraphWrap::OpOutputType>("opOutputType"),
|
|
24
|
+
InstanceMethod<&GraphWrap::OpOutputShape>("opOutputShape"),
|
|
25
|
+
InstanceMethod<&GraphWrap::ToGraphDef>("toGraphDef"),
|
|
26
|
+
InstanceMethod<&GraphWrap::NumOps>("numOps"),
|
|
27
|
+
InstanceMethod<&GraphWrap::ImportGraphDef>("importGraphDef"),
|
|
28
|
+
InstanceMethod<&GraphWrap::AddGradients>("addGradients"),
|
|
29
|
+
});
|
|
30
|
+
auto *ctor = new Napi::FunctionReference(Napi::Persistent(func));
|
|
31
|
+
env.SetInstanceData<Napi::FunctionReference>(ctor);
|
|
32
|
+
exports.Set("Graph", func);
|
|
33
|
+
return exports;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
GraphWrap::GraphWrap(const Napi::CallbackInfo &info) : Napi::ObjectWrap<GraphWrap>(info)
|
|
37
|
+
{
|
|
38
|
+
graph_ = TF_NewGraph();
|
|
39
|
+
if (!graph_)
|
|
40
|
+
Napi::Error::New(info.Env(), "Failed to create TF_Graph").ThrowAsJavaScriptException();
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
GraphWrap::~GraphWrap()
|
|
44
|
+
{
|
|
45
|
+
if (graph_)
|
|
46
|
+
{
|
|
47
|
+
TF_DeleteGraph(graph_);
|
|
48
|
+
graph_ = nullptr;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
// ---------------------------------------------------------------
|
|
53
|
+
// addOp(type: string, inputs: {opName: string, index: number}[],
|
|
54
|
+
// attrs: Record<string, AttrValue>) -> { opName: string, numOutputs : number }
|
|
55
|
+
//
|
|
56
|
+
// AttrValue is one of:
|
|
57
|
+
// { kind: "int", value: number }
|
|
58
|
+
// { kind: "float", value: number }
|
|
59
|
+
// { kind: "bool", value: boolean }
|
|
60
|
+
// { kind: "type", value: number } <- DType integer
|
|
61
|
+
// { kind: "shape", value: number[] } <- -1 for unknown dim
|
|
62
|
+
// { kind: "shape", value: { dtype, shape, data: Buffer } }
|
|
63
|
+
// { kind: "list_type", value: number[] }
|
|
64
|
+
// { kind: "list_shape", value: number[] }
|
|
65
|
+
// ---------------------------------------------------------------
|
|
66
|
+
|
|
67
|
+
Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
68
|
+
{
|
|
69
|
+
Napi::Env env = info.Env();
|
|
70
|
+
|
|
71
|
+
if (!graph_)
|
|
72
|
+
{
|
|
73
|
+
Napi::Error::New(env, "Graph has been destroyed").ThrowAsJavaScriptException();
|
|
74
|
+
return env.Undefined();
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray())
|
|
78
|
+
{
|
|
79
|
+
Napi::TypeError::New(env, "addOp(type: string, inputs: TFOutput[], attrs?)").ThrowAsJavaScriptException();
|
|
80
|
+
return env.Undefined();
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
std::string op_type = info[0].As<Napi::String>().Utf8Value();
|
|
84
|
+
|
|
85
|
+
// Auto-generate a unique op name: type + "_" + counter.
|
|
86
|
+
std::string op_name = op_type + "_" + std::to_string(op_counter_++);
|
|
87
|
+
// Allow caller to override the name via optional 3rd arg.
|
|
88
|
+
if (info.Length() >= 4 && info[3].IsString())
|
|
89
|
+
op_name = info[3].As<Napi::String>().Utf8Value();
|
|
90
|
+
|
|
91
|
+
// Resolve inputs first so we can retry op construction with TF_AddInputList
|
|
92
|
+
// for list-input ops (e.g. IdentityN) when needed.
|
|
93
|
+
std::vector<TF_Output> resolved_inputs;
|
|
94
|
+
auto inputs_arr = info[1].As<Napi::Array>();
|
|
95
|
+
resolved_inputs.reserve(inputs_arr.Length());
|
|
96
|
+
for (uint32_t i = 0; i < inputs_arr.Length(); i++)
|
|
97
|
+
{
|
|
98
|
+
auto input_obj = inputs_arr.Get(i).As<Napi::Object>();
|
|
99
|
+
std::string input_op_name = input_obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
100
|
+
int input_idx = input_obj.Get("index").As<Napi::Number>().Int32Value();
|
|
101
|
+
|
|
102
|
+
TF_Operation *input_op = TF_GraphOperationByName(graph_, input_op_name.c_str());
|
|
103
|
+
if (!input_op)
|
|
104
|
+
{
|
|
105
|
+
Napi::Error::New(env, "Input op not found: " + input_op_name).ThrowAsJavaScriptException();
|
|
106
|
+
return env.Undefined();
|
|
107
|
+
}
|
|
108
|
+
resolved_inputs.push_back({input_op, input_idx});
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
auto apply_attrs = [&](TF_OperationDescription *desc) -> bool
|
|
112
|
+
{
|
|
113
|
+
if (!(info.Length() >= 3 && info[2].IsObject()))
|
|
114
|
+
return true;
|
|
115
|
+
|
|
116
|
+
enum class AttrKind
|
|
117
|
+
{
|
|
118
|
+
Int,
|
|
119
|
+
Float,
|
|
120
|
+
Bool,
|
|
121
|
+
Type,
|
|
122
|
+
Shape,
|
|
123
|
+
ListType,
|
|
124
|
+
ListInt,
|
|
125
|
+
Tensor,
|
|
126
|
+
String,
|
|
127
|
+
Unknown
|
|
128
|
+
};
|
|
129
|
+
|
|
130
|
+
static const std::unordered_map<std::string, AttrKind> kind_map = {
|
|
131
|
+
{"int", AttrKind::Int},
|
|
132
|
+
{"float", AttrKind::Float},
|
|
133
|
+
{"bool", AttrKind::Bool},
|
|
134
|
+
{"type", AttrKind::Type},
|
|
135
|
+
{"shape", AttrKind::Shape},
|
|
136
|
+
{"list_type", AttrKind::ListType},
|
|
137
|
+
{"list_int", AttrKind::ListInt},
|
|
138
|
+
{"tensor", AttrKind::Tensor},
|
|
139
|
+
{"string", AttrKind::String}};
|
|
140
|
+
|
|
141
|
+
auto attrs = info[2].As<Napi::Object>();
|
|
142
|
+
auto attrs_keys = attrs.GetPropertyNames();
|
|
143
|
+
for (uint32_t i = 0; i < attrs_keys.Length(); i++)
|
|
144
|
+
{
|
|
145
|
+
std::string attr_name = attrs_keys.Get(i).As<Napi::String>().Utf8Value();
|
|
146
|
+
auto attr_val = attrs.Get(attr_name).As<Napi::Object>();
|
|
147
|
+
std::string kind_str = attr_val.Get("kind").As<Napi::String>().Utf8Value();
|
|
148
|
+
auto it = kind_map.find(kind_str);
|
|
149
|
+
AttrKind kind = (it != kind_map.end()) ? it->second : AttrKind::Unknown;
|
|
150
|
+
switch (kind)
|
|
151
|
+
{
|
|
152
|
+
case AttrKind::Int:
|
|
153
|
+
{
|
|
154
|
+
int64_t v = static_cast<int64_t>(attr_val.Get("value").As<Napi::Number>().Int64Value());
|
|
155
|
+
TF_SetAttrInt(desc, attr_name.c_str(), v);
|
|
156
|
+
break;
|
|
157
|
+
}
|
|
158
|
+
case AttrKind::Float:
|
|
159
|
+
{
|
|
160
|
+
float f = attr_val.Get("value").As<Napi::Number>().FloatValue();
|
|
161
|
+
TF_SetAttrFloat(desc, attr_name.c_str(), f);
|
|
162
|
+
break;
|
|
163
|
+
}
|
|
164
|
+
case AttrKind::Bool:
|
|
165
|
+
{
|
|
166
|
+
unsigned char b = attr_val.Get("value").As<Napi::Boolean>().Value() ? 1
|
|
167
|
+
: 0;
|
|
168
|
+
TF_SetAttrBool(desc, attr_name.c_str(), b);
|
|
169
|
+
break;
|
|
170
|
+
}
|
|
171
|
+
case AttrKind::Type:
|
|
172
|
+
{
|
|
173
|
+
TF_DataType v = static_cast<TF_DataType>(
|
|
174
|
+
attr_val.Get("value").As<Napi::Number>().Int32Value());
|
|
175
|
+
TF_SetAttrType(desc, attr_name.c_str(), v);
|
|
176
|
+
break;
|
|
177
|
+
}
|
|
178
|
+
case AttrKind::Shape:
|
|
179
|
+
{
|
|
180
|
+
auto dims_arr = attr_val.Get("value").As<Napi::Array>();
|
|
181
|
+
std::vector<int64_t> dims(dims_arr.Length());
|
|
182
|
+
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
183
|
+
dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
184
|
+
TF_SetAttrShape(desc, attr_name.c_str(),
|
|
185
|
+
dims.data(), static_cast<int>(dims.size()));
|
|
186
|
+
break;
|
|
187
|
+
}
|
|
188
|
+
case AttrKind::ListType:
|
|
189
|
+
{
|
|
190
|
+
auto vals = attr_val.Get("value").As<Napi::Array>();
|
|
191
|
+
std::vector<TF_DataType> types(vals.Length());
|
|
192
|
+
for (uint32_t j = 0; j < vals.Length(); ++j)
|
|
193
|
+
types[j] = static_cast<TF_DataType>(vals.Get(j).As<Napi::Number>().Int32Value());
|
|
194
|
+
TF_SetAttrTypeList(desc, attr_name.c_str(), types.data(), static_cast<int>(types.size()));
|
|
195
|
+
break;
|
|
196
|
+
}
|
|
197
|
+
case AttrKind::ListInt:
|
|
198
|
+
{
|
|
199
|
+
auto vals_int = attr_val.Get("value").As<Napi::Array>();
|
|
200
|
+
std::vector<int64_t> ints(vals_int.Length());
|
|
201
|
+
for (uint32_t j = 0; j < vals_int.Length(); ++j)
|
|
202
|
+
ints[j] = static_cast<int64_t>(vals_int.Get(j).As<Napi::Number>().Int64Value());
|
|
203
|
+
TF_SetAttrIntList(desc, attr_name.c_str(), ints.data(), static_cast<int>(ints.size()));
|
|
204
|
+
break;
|
|
205
|
+
}
|
|
206
|
+
case AttrKind::Tensor:
|
|
207
|
+
{
|
|
208
|
+
// Inline constant tensor
|
|
209
|
+
auto tv = attr_val.Get("value").As<Napi::Object>();
|
|
210
|
+
TF_DataType dtype = static_cast<TF_DataType>(tv.Get("dtype").As<Napi::Number>().Int32Value());
|
|
211
|
+
auto data_buf = tv.Get("data").As<Napi::Buffer<uint8_t>>();
|
|
212
|
+
auto dims_arr = tv.Get("shape").As<Napi::Array>();
|
|
213
|
+
|
|
214
|
+
std::vector<int64_t> dims(dims_arr.Length());
|
|
215
|
+
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
216
|
+
dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
217
|
+
|
|
218
|
+
StatusGuard ts;
|
|
219
|
+
TF_Tensor *tensor = TF_AllocateTensor(dtype, dims.data(), static_cast<int>(dims.size()), data_buf.Length());
|
|
220
|
+
if (tensor)
|
|
221
|
+
{
|
|
222
|
+
std::memcpy(TF_TensorData(tensor), data_buf.Data(), data_buf.ByteLength());
|
|
223
|
+
TF_SetAttrTensor(desc, attr_name.c_str(), tensor, ts.s);
|
|
224
|
+
TF_DeleteTensor(tensor); // Graph takes ownership, safe to delete here.
|
|
225
|
+
}
|
|
226
|
+
break;
|
|
227
|
+
}
|
|
228
|
+
case AttrKind::String:
|
|
229
|
+
{
|
|
230
|
+
std::string s = attr_val.Get("value").As<Napi::String>().Utf8Value();
|
|
231
|
+
TF_SetAttrString(desc, attr_name.c_str(), s.data(), s.length());
|
|
232
|
+
break;
|
|
233
|
+
}
|
|
234
|
+
case AttrKind::Unknown:
|
|
235
|
+
default:
|
|
236
|
+
Napi::Error::New(env, "Unsupported attr kind: " + kind_str).ThrowAsJavaScriptException();
|
|
237
|
+
return false;
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
return true;
|
|
241
|
+
};
|
|
242
|
+
|
|
243
|
+
auto finish_op = [&](bool add_as_input_list, std::string &error_out) -> TF_Operation *
|
|
244
|
+
{
|
|
245
|
+
StatusGuard status;
|
|
246
|
+
TF_OperationDescription *desc = TF_NewOperation(graph_, op_type.c_str(), op_name.c_str());
|
|
247
|
+
if (!desc)
|
|
248
|
+
{
|
|
249
|
+
error_out = "TF_NewOperation failed for type " + op_type;
|
|
250
|
+
return nullptr;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
if (add_as_input_list)
|
|
254
|
+
{
|
|
255
|
+
if (!resolved_inputs.empty())
|
|
256
|
+
TF_AddInputList(desc, resolved_inputs.data(), static_cast<int>(resolved_inputs.size()));
|
|
257
|
+
}
|
|
258
|
+
else
|
|
259
|
+
{
|
|
260
|
+
for (const auto &input : resolved_inputs)
|
|
261
|
+
TF_AddInput(desc, input);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
if (info.Length() >= 5 && info[4].IsArray())
|
|
265
|
+
{
|
|
266
|
+
auto ctrl_arr = info[4].As<Napi::Array>();
|
|
267
|
+
for (uint32_t i = 0; i < ctrl_arr.Length(); i++)
|
|
268
|
+
{
|
|
269
|
+
std::string ctrl_name = ctrl_arr.Get(i).As<Napi::String>().Utf8Value();
|
|
270
|
+
TF_Operation *ctrl_op = TF_GraphOperationByName(graph_, ctrl_name.c_str());
|
|
271
|
+
if (!ctrl_op)
|
|
272
|
+
{
|
|
273
|
+
Napi::Error::New(env, "Control input op not found: " + ctrl_name).ThrowAsJavaScriptException();
|
|
274
|
+
return nullptr;
|
|
275
|
+
}
|
|
276
|
+
TF_AddControlInput(desc, ctrl_op);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
if (!apply_attrs(desc))
|
|
281
|
+
return nullptr;
|
|
282
|
+
|
|
283
|
+
TF_Operation *op = TF_FinishOperation(desc, status.s);
|
|
284
|
+
if (!status.ok() || !op)
|
|
285
|
+
{
|
|
286
|
+
error_out = status.message();
|
|
287
|
+
return nullptr;
|
|
288
|
+
}
|
|
289
|
+
return op;
|
|
290
|
+
};
|
|
291
|
+
|
|
292
|
+
std::string first_error;
|
|
293
|
+
TF_Operation *op = finish_op(false, first_error);
|
|
294
|
+
|
|
295
|
+
// Some ops have a single list-valued input arg and reject repeated
|
|
296
|
+
// TF_AddInput calls even for one element; retry with TF_AddInputList.
|
|
297
|
+
if (!op && first_error.find("expected list") != std::string::npos &&
|
|
298
|
+
!resolved_inputs.empty())
|
|
299
|
+
{
|
|
300
|
+
std::string retry_error;
|
|
301
|
+
op = finish_op(true, retry_error);
|
|
302
|
+
if (!op)
|
|
303
|
+
first_error = retry_error;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
if (!op)
|
|
307
|
+
{
|
|
308
|
+
if (!env.IsExceptionPending())
|
|
309
|
+
{
|
|
310
|
+
Napi::Error::New(env, "TF_FinishOperation failed for " + op_type + ": " + first_error).ThrowAsJavaScriptException();
|
|
311
|
+
}
|
|
312
|
+
return env.Undefined();
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
Napi::Object result = Napi::Object::New(env);
|
|
316
|
+
result.Set("opName", Napi::String::New(env, op_name));
|
|
317
|
+
result.Set("numOutputs", Napi::Number::New(env, static_cast<double>(TF_OperationNumOutputs(op))));
|
|
318
|
+
return result;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
Napi::Value GraphWrap::HasOp(const Napi::CallbackInfo &info)
|
|
322
|
+
{
|
|
323
|
+
if (!info[0].IsString())
|
|
324
|
+
return Napi::Boolean::New(info.Env(), false);
|
|
325
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
326
|
+
return Napi::Boolean::New(info.Env(), graph_ && TF_GraphOperationByName(graph_, name.c_str()) != nullptr);
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
|
|
330
|
+
{
|
|
331
|
+
Napi::Env env = info.Env();
|
|
332
|
+
if (!graph_ || !info[0].IsString())
|
|
333
|
+
return env.Null();
|
|
334
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
335
|
+
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
336
|
+
|
|
337
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
338
|
+
if (!op)
|
|
339
|
+
return env.Null();
|
|
340
|
+
|
|
341
|
+
TF_DataType dtype = TF_OperationOutputType({op, idx});
|
|
342
|
+
return Napi::Number::New(env, static_cast<double>(dtype));
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
|
|
346
|
+
{
|
|
347
|
+
Napi::Env env = info.Env();
|
|
348
|
+
if (!graph_ || !info[0].IsString())
|
|
349
|
+
return env.Null();
|
|
350
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
351
|
+
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
352
|
+
|
|
353
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
354
|
+
if (!op)
|
|
355
|
+
return env.Null();
|
|
356
|
+
|
|
357
|
+
TF_Output out{op, idx};
|
|
358
|
+
StatusGuard status;
|
|
359
|
+
int ndims = TF_GraphGetTensorNumDims(graph_, out, status.s);
|
|
360
|
+
if (!status.ok() || ndims < 0)
|
|
361
|
+
return env.Null();
|
|
362
|
+
|
|
363
|
+
std::vector<int64_t> dims(ndims, -1);
|
|
364
|
+
StatusGuard shape_status;
|
|
365
|
+
TF_GraphGetTensorShape(graph_, out, dims.data(), ndims, shape_status.s);
|
|
366
|
+
|
|
367
|
+
Napi::Array arr = Napi::Array::New(env, ndims);
|
|
368
|
+
for (int i = 0; i < ndims; ++i)
|
|
369
|
+
arr.Set(i, Napi::Number::New(env, static_cast<double>(dims[i])));
|
|
370
|
+
return arr;
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
Napi::Value GraphWrap::ToGraphDef(const Napi::CallbackInfo &info)
|
|
374
|
+
{
|
|
375
|
+
Napi::Env env = info.Env();
|
|
376
|
+
if (!graph_)
|
|
377
|
+
return env.Null();
|
|
378
|
+
|
|
379
|
+
StatusGuard status;
|
|
380
|
+
TF_Buffer *buf = TF_NewBuffer();
|
|
381
|
+
TF_GraphToGraphDef(graph_, buf, status.s);
|
|
382
|
+
|
|
383
|
+
if (!status.ok())
|
|
384
|
+
{
|
|
385
|
+
TF_DeleteBuffer(buf);
|
|
386
|
+
Napi::Error::New(env, "TF_GraphToGraphDef failed: " + status.message())
|
|
387
|
+
.ThrowAsJavaScriptException();
|
|
388
|
+
return env.Undefined();
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
auto node_buf = Napi::Buffer<uint8_t>::Copy(
|
|
392
|
+
env,
|
|
393
|
+
reinterpret_cast<const uint8_t *>(buf->data),
|
|
394
|
+
buf->length);
|
|
395
|
+
TF_DeleteBuffer(buf);
|
|
396
|
+
return node_buf;
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
Napi::Value GraphWrap::NumOps(const Napi::CallbackInfo &info)
|
|
400
|
+
{
|
|
401
|
+
if (!graph_)
|
|
402
|
+
return Napi::Number::New(info.Env(), 0);
|
|
403
|
+
size_t pos = 0;
|
|
404
|
+
int count = 0;
|
|
405
|
+
while (TF_GraphNextOperation(graph_, &pos))
|
|
406
|
+
++count;
|
|
407
|
+
return Napi::Number::New(info.Env(), count);
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
// ---------------------------------------------------------------
|
|
411
|
+
// importGraphDef(buffer: Buffer): void
|
|
412
|
+
//
|
|
413
|
+
// Deserialises a binary GraphDef proto into this graph.
|
|
414
|
+
// Intended use: load a frozen .pb model so the graph can be
|
|
415
|
+
// executed via the native Session (with ConfigProto + affinity).
|
|
416
|
+
//
|
|
417
|
+
// The graph must be empty before calling this — importing into a
|
|
418
|
+
// non-empty graph will produce op name collisions.
|
|
419
|
+
//
|
|
420
|
+
// TF_GraphImportGraphDef uses a prefix option (default "")
|
|
421
|
+
// so all imported op names are used as-is, matching the op
|
|
422
|
+
// names in the original frozen graph.
|
|
423
|
+
// ---------------------------------------------------------------
|
|
424
|
+
Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
|
|
425
|
+
{
|
|
426
|
+
Napi::Env env = info.Env();
|
|
427
|
+
|
|
428
|
+
if (!graph_)
|
|
429
|
+
{
|
|
430
|
+
Napi::Error::New(env, "Graph destroyed")
|
|
431
|
+
.ThrowAsJavaScriptException();
|
|
432
|
+
return env.Undefined();
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
if (info.Length() < 1 || !info[0].IsBuffer())
|
|
436
|
+
{
|
|
437
|
+
Napi::TypeError::New(env, "importGraphDef(buffer: Buffer)")
|
|
438
|
+
.ThrowAsJavaScriptException();
|
|
439
|
+
return env.Undefined();
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
auto buf = info[0].As<Napi::Buffer<uint8_t>>();
|
|
443
|
+
|
|
444
|
+
// TF_NewBufferFromString copies the bytes into TF-owned memory.
|
|
445
|
+
// The JS Buffer does not need to outlive this call.
|
|
446
|
+
TF_Buffer *graphdef =
|
|
447
|
+
TF_NewBufferFromString(buf.Data(), buf.ByteLength());
|
|
448
|
+
|
|
449
|
+
TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
|
|
450
|
+
// Default prefix is "" — op names in the imported graph are
|
|
451
|
+
// used verbatim, matching the frozen graph's checkpoint keys.
|
|
452
|
+
TF_ImportGraphDefOptionsSetPrefix(opts, "");
|
|
453
|
+
|
|
454
|
+
StatusGuard status;
|
|
455
|
+
TF_GraphImportGraphDef(graph_, graphdef, opts, status.s);
|
|
456
|
+
|
|
457
|
+
TF_DeleteImportGraphDefOptions(opts);
|
|
458
|
+
TF_DeleteBuffer(graphdef);
|
|
459
|
+
|
|
460
|
+
if (!status.ok())
|
|
461
|
+
{
|
|
462
|
+
Napi::Error::New(env,
|
|
463
|
+
"TF_GraphImportGraphDef failed: " + status.message())
|
|
464
|
+
.ThrowAsJavaScriptException();
|
|
465
|
+
}
|
|
466
|
+
return env.Undefined();
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
// ---------------------------------------------------------------------------
|
|
470
|
+
// addGradients — JS signature:
|
|
471
|
+
// addGradients(
|
|
472
|
+
// y: { opName: string; index: number }[], // loss outputs
|
|
473
|
+
// x: { opName: string; index: number }[], // inputs to diff w.r.t.
|
|
474
|
+
// dx?: { opName: string; index: number }[], // initial upstream grads
|
|
475
|
+
// ) -> { opName: string; index: number }[] // gradient tensors, len = |x|
|
|
476
|
+
//
|
|
477
|
+
// Wraps TF_AddGradients. TF injects gradient ops directly into the graph.
|
|
478
|
+
// The returned outputs are the dL/dx_i tensors, one per x_i.
|
|
479
|
+
//
|
|
480
|
+
// dx defaults to ones (i.e. dL/dy = 1 for scalar loss, standard convention).
|
|
481
|
+
// Pass explicit dx when you need to chain gradients or scale the loss.
|
|
482
|
+
//
|
|
483
|
+
// Throws if any op in the path is non-differentiable (e.g. ArgMax).
|
|
484
|
+
// ---------------------------------------------------------------------------
|
|
485
|
+
Napi::Value GraphWrap::AddGradients(const Napi::CallbackInfo &info)
|
|
486
|
+
{
|
|
487
|
+
Napi::Env env = info.Env();
|
|
488
|
+
|
|
489
|
+
if (!graph_)
|
|
490
|
+
{
|
|
491
|
+
Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
|
|
492
|
+
return env.Undefined();
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
if (info.Length() < 2 || !info[0].IsArray() || !info[1].IsArray())
|
|
496
|
+
{
|
|
497
|
+
Napi::TypeError::New(env, "addGradients(y, x, dx?)")
|
|
498
|
+
.ThrowAsJavaScriptException();
|
|
499
|
+
return env.Undefined();
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
auto resolve_outputs = [&](Napi::Array arr,
|
|
503
|
+
std::vector<TF_Output> &out,
|
|
504
|
+
const std::string &label) -> bool
|
|
505
|
+
{
|
|
506
|
+
for (uint32_t i = 0; i < arr.Length(); ++i)
|
|
507
|
+
{
|
|
508
|
+
auto obj = arr.Get(i).As<Napi::Object>();
|
|
509
|
+
std::string name = obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
510
|
+
int idx = obj.Get("index").As<Napi::Number>().Int32Value();
|
|
511
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
512
|
+
if (!op)
|
|
513
|
+
{
|
|
514
|
+
Napi::Error::New(env, label + " op not found: " + name)
|
|
515
|
+
.ThrowAsJavaScriptException();
|
|
516
|
+
return false;
|
|
517
|
+
}
|
|
518
|
+
out.push_back({op, idx});
|
|
519
|
+
}
|
|
520
|
+
return true;
|
|
521
|
+
};
|
|
522
|
+
|
|
523
|
+
std::vector<TF_Output> y_vec, x_vec, dx_vec;
|
|
524
|
+
|
|
525
|
+
if (!resolve_outputs(info[0].As<Napi::Array>(), y_vec, "y"))
|
|
526
|
+
return env.Undefined();
|
|
527
|
+
if (!resolve_outputs(info[1].As<Napi::Array>(), x_vec, "x"))
|
|
528
|
+
return env.Undefined();
|
|
529
|
+
|
|
530
|
+
// Optional initial gradients.
|
|
531
|
+
TF_Output *dx_ptr = nullptr;
|
|
532
|
+
if (info.Length() >= 3 && info[2].IsArray())
|
|
533
|
+
{
|
|
534
|
+
if (!resolve_outputs(info[2].As<Napi::Array>(), dx_vec, "dx"))
|
|
535
|
+
return env.Undefined();
|
|
536
|
+
if (dx_vec.size() != y_vec.size())
|
|
537
|
+
{
|
|
538
|
+
Napi::Error::New(env,
|
|
539
|
+
"addGradients: dx length must equal y length")
|
|
540
|
+
.ThrowAsJavaScriptException();
|
|
541
|
+
return env.Undefined();
|
|
542
|
+
}
|
|
543
|
+
dx_ptr = dx_vec.data();
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
// Allocate output array — TF_AddGradients writes one gradient per x.
|
|
547
|
+
std::vector<TF_Output> dy(x_vec.size());
|
|
548
|
+
|
|
549
|
+
StatusGuard status;
|
|
550
|
+
TF_AddGradients(
|
|
551
|
+
graph_,
|
|
552
|
+
y_vec.data(), static_cast<int>(y_vec.size()),
|
|
553
|
+
x_vec.data(), static_cast<int>(x_vec.size()),
|
|
554
|
+
dx_ptr,
|
|
555
|
+
status.s,
|
|
556
|
+
dy.data());
|
|
557
|
+
|
|
558
|
+
if (!status.ok())
|
|
559
|
+
{
|
|
560
|
+
Napi::Error::New(env,
|
|
561
|
+
"TF_AddGradients failed: " + status.message())
|
|
562
|
+
.ThrowAsJavaScriptException();
|
|
563
|
+
return env.Undefined();
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
// Return array of { opName, index } — same shape as x.
|
|
567
|
+
Napi::Array result = Napi::Array::New(env, dy.size());
|
|
568
|
+
for (size_t i = 0; i < dy.size(); ++i)
|
|
569
|
+
{
|
|
570
|
+
const char *raw_name = TF_OperationName(dy[i].oper);
|
|
571
|
+
Napi::Object obj = Napi::Object::New(env);
|
|
572
|
+
obj.Set("opName", Napi::String::New(env, raw_name ? raw_name : ""));
|
|
573
|
+
obj.Set("index", Napi::Number::New(env, dy[i].index));
|
|
574
|
+
result.Set(static_cast<uint32_t>(i), obj);
|
|
575
|
+
}
|
|
576
|
+
return result;
|
|
577
|
+
}
|