@isidorus/cpu 0.0.0-alpha.2 → 0.0.0-alpha.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +1 -1
- package/dist/graph.d.ts +19 -12
- package/dist/graph.d.ts.map +1 -1
- package/dist/graph.js +23 -16
- package/dist/graph.js.map +1 -1
- package/dist/inference-pool.d.ts +35 -12
- package/dist/inference-pool.d.ts.map +1 -1
- package/dist/inference-pool.js +226 -259
- package/dist/inference-pool.js.map +1 -1
- package/dist/tsconfig.tsbuildinfo +1 -1
- package/package.json +2 -3
- package/prebuilds/darwin-arm64/@isidorus+cpu.node +0 -0
- package/prebuilds/linux-x64/@isidorus+cpu.node +0 -0
- package/prebuilds/win32-x64/@isidorus+cpu.node +0 -0
- package/scripts/test-install.js +1 -1
- package/src/native/graph.cc +356 -255
- package/src/native/graph.h +2 -0
package/src/native/graph.cc
CHANGED
|
@@ -1,19 +1,7 @@
|
|
|
1
1
|
#include "graph.h"
|
|
2
2
|
#include <cstring>
|
|
3
|
-
#include <stdexcept>
|
|
4
3
|
#include <unordered_map>
|
|
5
|
-
|
|
6
|
-
// -------------------------------------------------------------------
|
|
7
|
-
// GraphWrap - N-API ObjectWrap around TF_Graph
|
|
8
|
-
//
|
|
9
|
-
// Exposes graph construction primitives:
|
|
10
|
-
// addOp(type, inputs, attrs) -> {opName, numOutputs}
|
|
11
|
-
// hasOp(name) -> boolean
|
|
12
|
-
// opOutputType(name, index) -> DType integer
|
|
13
|
-
// opOutputShape(name, index) -> number[] (-1 = unknown dim)
|
|
14
|
-
// toGraphDef() -> Buffer (serialized GraphDef proto)
|
|
15
|
-
// numOps() -> number
|
|
16
|
-
//-------------------------------------------------------------------
|
|
4
|
+
#include <unordered_set>
|
|
17
5
|
|
|
18
6
|
Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
19
7
|
{
|
|
@@ -26,6 +14,8 @@ Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
|
26
14
|
InstanceMethod<&GraphWrap::NumOps>("numOps"),
|
|
27
15
|
InstanceMethod<&GraphWrap::ImportGraphDef>("importGraphDef"),
|
|
28
16
|
InstanceMethod<&GraphWrap::AddGradients>("addGradients"),
|
|
17
|
+
InstanceMethod<&GraphWrap::ListOpsOfType>("listOpsOfType"),
|
|
18
|
+
InstanceMethod<&GraphWrap::ListSinkOps>("listSinkOps"),
|
|
29
19
|
});
|
|
30
20
|
auto *ctor = new Napi::FunctionReference(Napi::Persistent(func));
|
|
31
21
|
env.SetInstanceData<Napi::FunctionReference>(ctor);
|
|
@@ -33,11 +23,13 @@ Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
|
33
23
|
return exports;
|
|
34
24
|
}
|
|
35
25
|
|
|
36
|
-
GraphWrap::GraphWrap(const Napi::CallbackInfo &info)
|
|
26
|
+
GraphWrap::GraphWrap(const Napi::CallbackInfo &info)
|
|
27
|
+
: Napi::ObjectWrap<GraphWrap>(info)
|
|
37
28
|
{
|
|
38
29
|
graph_ = TF_NewGraph();
|
|
39
30
|
if (!graph_)
|
|
40
|
-
Napi::Error::New(info.Env(), "Failed to create TF_Graph")
|
|
31
|
+
Napi::Error::New(info.Env(), "Failed to create TF_Graph")
|
|
32
|
+
.ThrowAsJavaScriptException();
|
|
41
33
|
}
|
|
42
34
|
|
|
43
35
|
GraphWrap::~GraphWrap()
|
|
@@ -49,109 +41,138 @@ GraphWrap::~GraphWrap()
|
|
|
49
41
|
}
|
|
50
42
|
}
|
|
51
43
|
|
|
52
|
-
//
|
|
53
|
-
// addOp
|
|
54
|
-
//
|
|
44
|
+
// ---------------------------------------------------------------------------
|
|
45
|
+
// addOp — JS signature:
|
|
46
|
+
// addOp(
|
|
47
|
+
// type: string,
|
|
48
|
+
// inputs: { opName: string; index: number }[],
|
|
49
|
+
// attrs?: Record<string, AttrValue>,
|
|
50
|
+
// name?: string,
|
|
51
|
+
// controlInputs?: string[], // ← NEW: op names that must run first
|
|
52
|
+
// ) -> { opName: string; numOutputs: number }
|
|
55
53
|
//
|
|
56
|
-
//
|
|
57
|
-
//
|
|
58
|
-
//
|
|
59
|
-
//
|
|
60
|
-
//
|
|
61
|
-
//
|
|
62
|
-
// { kind: "shape", value: { dtype, shape, data: Buffer } }
|
|
63
|
-
// { kind: "list_type", value: number[] }
|
|
64
|
-
// { kind: "list_shape", value: number[] }
|
|
65
|
-
// ---------------------------------------------------------------
|
|
66
|
-
|
|
54
|
+
// Control inputs are wired via TF_AddControlInput. TF guarantees the
|
|
55
|
+
// described op will not execute until every control-input op has completed.
|
|
56
|
+
// This is the mechanism behind globalVariablesInitializer: a NoOp with
|
|
57
|
+
// control edges to all init AssignVariableOps — running the NoOp forces all
|
|
58
|
+
// assignments to execute first.
|
|
59
|
+
// ---------------------------------------------------------------------------
|
|
67
60
|
Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
68
61
|
{
|
|
69
62
|
Napi::Env env = info.Env();
|
|
70
63
|
|
|
71
64
|
if (!graph_)
|
|
72
65
|
{
|
|
73
|
-
Napi::Error::New(env, "Graph has been destroyed")
|
|
66
|
+
Napi::Error::New(env, "Graph has been destroyed")
|
|
67
|
+
.ThrowAsJavaScriptException();
|
|
74
68
|
return env.Undefined();
|
|
75
69
|
}
|
|
76
70
|
|
|
77
71
|
if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray())
|
|
78
72
|
{
|
|
79
|
-
Napi::TypeError::New(env,
|
|
73
|
+
Napi::TypeError::New(env,
|
|
74
|
+
"addOp(type, inputs, attrs?, name?, controlInputs?)")
|
|
75
|
+
.ThrowAsJavaScriptException();
|
|
80
76
|
return env.Undefined();
|
|
81
77
|
}
|
|
82
78
|
|
|
83
79
|
std::string op_type = info[0].As<Napi::String>().Utf8Value();
|
|
84
|
-
|
|
85
|
-
// Auto-generate a unique op name: type + "_" + counter.
|
|
86
80
|
std::string op_name = op_type + "_" + std::to_string(op_counter_++);
|
|
87
|
-
// Allow caller to override the name via optional 3rd arg.
|
|
88
81
|
if (info.Length() >= 4 && info[3].IsString())
|
|
89
82
|
op_name = info[3].As<Napi::String>().Utf8Value();
|
|
90
83
|
|
|
91
|
-
// Resolve inputs
|
|
92
|
-
// for list-input ops (e.g. IdentityN) when needed.
|
|
84
|
+
// ── Resolve data inputs ─────────────────────────────────────────────────
|
|
93
85
|
std::vector<TF_Output> resolved_inputs;
|
|
94
86
|
auto inputs_arr = info[1].As<Napi::Array>();
|
|
95
87
|
resolved_inputs.reserve(inputs_arr.Length());
|
|
96
88
|
for (uint32_t i = 0; i < inputs_arr.Length(); i++)
|
|
97
89
|
{
|
|
98
|
-
auto
|
|
99
|
-
std::string
|
|
100
|
-
int
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if (!input_op)
|
|
90
|
+
auto obj = inputs_arr.Get(i).As<Napi::Object>();
|
|
91
|
+
std::string dep_name = obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
92
|
+
int dep_idx = obj.Get("index").As<Napi::Number>().Int32Value();
|
|
93
|
+
TF_Operation *dep_op = TF_GraphOperationByName(graph_, dep_name.c_str());
|
|
94
|
+
if (!dep_op)
|
|
104
95
|
{
|
|
105
|
-
Napi::Error::New(env, "Input op not found: " +
|
|
96
|
+
Napi::Error::New(env, "Input op not found: " + dep_name)
|
|
97
|
+
.ThrowAsJavaScriptException();
|
|
106
98
|
return env.Undefined();
|
|
107
99
|
}
|
|
108
|
-
resolved_inputs.push_back({
|
|
100
|
+
resolved_inputs.push_back({dep_op, dep_idx});
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// ── Resolve control inputs (arg index 4) ───────────────────────────────
|
|
104
|
+
std::vector<TF_Operation *> ctrl_ops;
|
|
105
|
+
if (info.Length() >= 5 && info[4].IsArray())
|
|
106
|
+
{
|
|
107
|
+
auto ctrl_arr = info[4].As<Napi::Array>();
|
|
108
|
+
ctrl_ops.reserve(ctrl_arr.Length());
|
|
109
|
+
for (uint32_t i = 0; i < ctrl_arr.Length(); i++)
|
|
110
|
+
{
|
|
111
|
+
std::string ctrl_name =
|
|
112
|
+
ctrl_arr.Get(i).As<Napi::String>().Utf8Value();
|
|
113
|
+
TF_Operation *ctrl_op =
|
|
114
|
+
TF_GraphOperationByName(graph_, ctrl_name.c_str());
|
|
115
|
+
if (!ctrl_op)
|
|
116
|
+
{
|
|
117
|
+
Napi::Error::New(env,
|
|
118
|
+
"Control input op not found: " + ctrl_name)
|
|
119
|
+
.ThrowAsJavaScriptException();
|
|
120
|
+
return env.Undefined();
|
|
121
|
+
}
|
|
122
|
+
ctrl_ops.push_back(ctrl_op);
|
|
123
|
+
}
|
|
109
124
|
}
|
|
110
125
|
|
|
126
|
+
// ── Attr dispatch ───────────────────────────────────────────────────────
|
|
127
|
+
enum class AttrKind
|
|
128
|
+
{
|
|
129
|
+
Int,
|
|
130
|
+
Float,
|
|
131
|
+
Bool,
|
|
132
|
+
Type,
|
|
133
|
+
Shape,
|
|
134
|
+
ListType,
|
|
135
|
+
ListInt,
|
|
136
|
+
Tensor,
|
|
137
|
+
String,
|
|
138
|
+
Unknown
|
|
139
|
+
};
|
|
140
|
+
static const std::unordered_map<std::string, AttrKind> kind_map = {
|
|
141
|
+
{"int", AttrKind::Int},
|
|
142
|
+
{"float", AttrKind::Float},
|
|
143
|
+
{"bool", AttrKind::Bool},
|
|
144
|
+
{"type", AttrKind::Type},
|
|
145
|
+
{"shape", AttrKind::Shape},
|
|
146
|
+
{"list_type", AttrKind::ListType},
|
|
147
|
+
{"list_int", AttrKind::ListInt},
|
|
148
|
+
{"tensor", AttrKind::Tensor},
|
|
149
|
+
{"string", AttrKind::String},
|
|
150
|
+
};
|
|
151
|
+
|
|
111
152
|
auto apply_attrs = [&](TF_OperationDescription *desc) -> bool
|
|
112
153
|
{
|
|
113
154
|
if (!(info.Length() >= 3 && info[2].IsObject()))
|
|
114
155
|
return true;
|
|
115
156
|
|
|
116
|
-
enum class AttrKind
|
|
117
|
-
{
|
|
118
|
-
Int,
|
|
119
|
-
Float,
|
|
120
|
-
Bool,
|
|
121
|
-
Type,
|
|
122
|
-
Shape,
|
|
123
|
-
ListType,
|
|
124
|
-
ListInt,
|
|
125
|
-
Tensor,
|
|
126
|
-
String,
|
|
127
|
-
Unknown
|
|
128
|
-
};
|
|
129
|
-
|
|
130
|
-
static const std::unordered_map<std::string, AttrKind> kind_map = {
|
|
131
|
-
{"int", AttrKind::Int},
|
|
132
|
-
{"float", AttrKind::Float},
|
|
133
|
-
{"bool", AttrKind::Bool},
|
|
134
|
-
{"type", AttrKind::Type},
|
|
135
|
-
{"shape", AttrKind::Shape},
|
|
136
|
-
{"list_type", AttrKind::ListType},
|
|
137
|
-
{"list_int", AttrKind::ListInt},
|
|
138
|
-
{"tensor", AttrKind::Tensor},
|
|
139
|
-
{"string", AttrKind::String}};
|
|
140
|
-
|
|
141
157
|
auto attrs = info[2].As<Napi::Object>();
|
|
142
158
|
auto attrs_keys = attrs.GetPropertyNames();
|
|
143
159
|
for (uint32_t i = 0; i < attrs_keys.Length(); i++)
|
|
144
160
|
{
|
|
145
|
-
std::string attr_name =
|
|
161
|
+
std::string attr_name =
|
|
162
|
+
attrs_keys.Get(i).As<Napi::String>().Utf8Value();
|
|
146
163
|
auto attr_val = attrs.Get(attr_name).As<Napi::Object>();
|
|
147
|
-
std::string kind_str =
|
|
164
|
+
std::string kind_str =
|
|
165
|
+
attr_val.Get("kind").As<Napi::String>().Utf8Value();
|
|
166
|
+
|
|
148
167
|
auto it = kind_map.find(kind_str);
|
|
149
|
-
AttrKind kind = (it != kind_map.end()) ? it->second
|
|
168
|
+
AttrKind kind = (it != kind_map.end()) ? it->second
|
|
169
|
+
: AttrKind::Unknown;
|
|
150
170
|
switch (kind)
|
|
151
171
|
{
|
|
152
172
|
case AttrKind::Int:
|
|
153
173
|
{
|
|
154
|
-
int64_t v = static_cast<int64_t>(
|
|
174
|
+
int64_t v = static_cast<int64_t>(
|
|
175
|
+
attr_val.Get("value").As<Napi::Number>().Int64Value());
|
|
155
176
|
TF_SetAttrInt(desc, attr_name.c_str(), v);
|
|
156
177
|
break;
|
|
157
178
|
}
|
|
@@ -163,8 +184,8 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
163
184
|
}
|
|
164
185
|
case AttrKind::Bool:
|
|
165
186
|
{
|
|
166
|
-
unsigned char b =
|
|
167
|
-
|
|
187
|
+
unsigned char b =
|
|
188
|
+
attr_val.Get("value").As<Napi::Boolean>().Value() ? 1 : 0;
|
|
168
189
|
TF_SetAttrBool(desc, attr_name.c_str(), b);
|
|
169
190
|
break;
|
|
170
191
|
}
|
|
@@ -180,7 +201,8 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
180
201
|
auto dims_arr = attr_val.Get("value").As<Napi::Array>();
|
|
181
202
|
std::vector<int64_t> dims(dims_arr.Length());
|
|
182
203
|
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
183
|
-
dims[j] = static_cast<int64_t>(
|
|
204
|
+
dims[j] = static_cast<int64_t>(
|
|
205
|
+
dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
184
206
|
TF_SetAttrShape(desc, attr_name.c_str(),
|
|
185
207
|
dims.data(), static_cast<int>(dims.size()));
|
|
186
208
|
break;
|
|
@@ -190,8 +212,11 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
190
212
|
auto vals = attr_val.Get("value").As<Napi::Array>();
|
|
191
213
|
std::vector<TF_DataType> types(vals.Length());
|
|
192
214
|
for (uint32_t j = 0; j < vals.Length(); ++j)
|
|
193
|
-
types[j] = static_cast<TF_DataType>(
|
|
194
|
-
|
|
215
|
+
types[j] = static_cast<TF_DataType>(
|
|
216
|
+
vals.Get(j).As<Napi::Number>().Int32Value());
|
|
217
|
+
TF_SetAttrTypeList(desc, attr_name.c_str(),
|
|
218
|
+
types.data(),
|
|
219
|
+
static_cast<int>(types.size()));
|
|
195
220
|
break;
|
|
196
221
|
}
|
|
197
222
|
case AttrKind::ListInt:
|
|
@@ -199,83 +224,88 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
199
224
|
auto vals_int = attr_val.Get("value").As<Napi::Array>();
|
|
200
225
|
std::vector<int64_t> ints(vals_int.Length());
|
|
201
226
|
for (uint32_t j = 0; j < vals_int.Length(); ++j)
|
|
202
|
-
ints[j] = static_cast<int64_t>(
|
|
203
|
-
|
|
227
|
+
ints[j] = static_cast<int64_t>(
|
|
228
|
+
vals_int.Get(j).As<Napi::Number>().Int64Value());
|
|
229
|
+
TF_SetAttrIntList(desc, attr_name.c_str(),
|
|
230
|
+
ints.data(),
|
|
231
|
+
static_cast<int>(ints.size()));
|
|
204
232
|
break;
|
|
205
233
|
}
|
|
206
234
|
case AttrKind::Tensor:
|
|
207
235
|
{
|
|
208
|
-
// Inline constant tensor
|
|
209
236
|
auto tv = attr_val.Get("value").As<Napi::Object>();
|
|
210
|
-
TF_DataType dtype = static_cast<TF_DataType>(
|
|
237
|
+
TF_DataType dtype = static_cast<TF_DataType>(
|
|
238
|
+
tv.Get("dtype").As<Napi::Number>().Int32Value());
|
|
211
239
|
auto data_buf = tv.Get("data").As<Napi::Buffer<uint8_t>>();
|
|
212
240
|
auto dims_arr = tv.Get("shape").As<Napi::Array>();
|
|
213
|
-
|
|
214
241
|
std::vector<int64_t> dims(dims_arr.Length());
|
|
215
242
|
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
216
|
-
dims[j] = static_cast<int64_t>(
|
|
217
|
-
|
|
243
|
+
dims[j] = static_cast<int64_t>(
|
|
244
|
+
dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
218
245
|
StatusGuard ts;
|
|
219
|
-
TF_Tensor *tensor = TF_AllocateTensor(
|
|
246
|
+
TF_Tensor *tensor = TF_AllocateTensor(
|
|
247
|
+
dtype, dims.data(), static_cast<int>(dims.size()),
|
|
248
|
+
data_buf.Length());
|
|
220
249
|
if (tensor)
|
|
221
250
|
{
|
|
222
|
-
std::memcpy(TF_TensorData(tensor),
|
|
251
|
+
std::memcpy(TF_TensorData(tensor),
|
|
252
|
+
data_buf.Data(), data_buf.ByteLength());
|
|
223
253
|
TF_SetAttrTensor(desc, attr_name.c_str(), tensor, ts.s);
|
|
224
|
-
TF_DeleteTensor(tensor);
|
|
254
|
+
TF_DeleteTensor(tensor);
|
|
225
255
|
}
|
|
226
256
|
break;
|
|
227
257
|
}
|
|
228
258
|
case AttrKind::String:
|
|
229
259
|
{
|
|
230
|
-
|
|
231
|
-
|
|
260
|
+
// TF_SetAttrString expects raw bytes + length.
|
|
261
|
+
// The JS value is a JS string — convert via Utf8Value().
|
|
262
|
+
std::string sv;
|
|
263
|
+
auto vnode = attr_val.Get("value");
|
|
264
|
+
if (vnode.IsString())
|
|
265
|
+
sv = vnode.As<Napi::String>().Utf8Value();
|
|
266
|
+
TF_SetAttrString(desc, attr_name.c_str(),
|
|
267
|
+
sv.data(), sv.size());
|
|
232
268
|
break;
|
|
233
269
|
}
|
|
234
|
-
case AttrKind::Unknown:
|
|
235
270
|
default:
|
|
236
|
-
Napi::Error::New(env, "Unsupported attr kind: " + kind_str)
|
|
271
|
+
Napi::Error::New(env, "Unsupported attr kind: " + kind_str)
|
|
272
|
+
.ThrowAsJavaScriptException();
|
|
237
273
|
return false;
|
|
238
274
|
}
|
|
239
275
|
}
|
|
240
276
|
return true;
|
|
241
277
|
};
|
|
242
278
|
|
|
243
|
-
|
|
279
|
+
// ── Build the op description ────────────────────────────────────────────
|
|
280
|
+
auto finish_op = [&](bool list_input,
|
|
281
|
+
std::string &err) -> TF_Operation *
|
|
244
282
|
{
|
|
245
283
|
StatusGuard status;
|
|
246
|
-
TF_OperationDescription *desc =
|
|
284
|
+
TF_OperationDescription *desc =
|
|
285
|
+
TF_NewOperation(graph_, op_type.c_str(), op_name.c_str());
|
|
247
286
|
if (!desc)
|
|
248
287
|
{
|
|
249
|
-
|
|
288
|
+
err = "TF_NewOperation failed for " + op_type;
|
|
250
289
|
return nullptr;
|
|
251
290
|
}
|
|
252
291
|
|
|
253
|
-
if (
|
|
292
|
+
if (list_input)
|
|
254
293
|
{
|
|
255
294
|
if (!resolved_inputs.empty())
|
|
256
|
-
TF_AddInputList(desc, resolved_inputs.data(),
|
|
295
|
+
TF_AddInputList(desc, resolved_inputs.data(),
|
|
296
|
+
static_cast<int>(resolved_inputs.size()));
|
|
257
297
|
}
|
|
258
298
|
else
|
|
259
299
|
{
|
|
260
|
-
for (const auto &
|
|
261
|
-
TF_AddInput(desc,
|
|
300
|
+
for (const auto &inp : resolved_inputs)
|
|
301
|
+
TF_AddInput(desc, inp);
|
|
262
302
|
}
|
|
263
303
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
std::string ctrl_name = ctrl_arr.Get(i).As<Napi::String>().Utf8Value();
|
|
270
|
-
TF_Operation *ctrl_op = TF_GraphOperationByName(graph_, ctrl_name.c_str());
|
|
271
|
-
if (!ctrl_op)
|
|
272
|
-
{
|
|
273
|
-
Napi::Error::New(env, "Control input op not found: " + ctrl_name).ThrowAsJavaScriptException();
|
|
274
|
-
return nullptr;
|
|
275
|
-
}
|
|
276
|
-
TF_AddControlInput(desc, ctrl_op);
|
|
277
|
-
}
|
|
278
|
-
}
|
|
304
|
+
// Wire control dependencies — these guarantee ordering without
|
|
305
|
+
// passing tensor data. The described op will not run until all
|
|
306
|
+
// control-input ops have finished executing.
|
|
307
|
+
for (TF_Operation *ctrl : ctrl_ops)
|
|
308
|
+
TF_AddControlInput(desc, ctrl);
|
|
279
309
|
|
|
280
310
|
if (!apply_attrs(desc))
|
|
281
311
|
return nullptr;
|
|
@@ -283,7 +313,7 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
283
313
|
TF_Operation *op = TF_FinishOperation(desc, status.s);
|
|
284
314
|
if (!status.ok() || !op)
|
|
285
315
|
{
|
|
286
|
-
|
|
316
|
+
err = status.message();
|
|
287
317
|
return nullptr;
|
|
288
318
|
}
|
|
289
319
|
return op;
|
|
@@ -292,9 +322,9 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
292
322
|
std::string first_error;
|
|
293
323
|
TF_Operation *op = finish_op(false, first_error);
|
|
294
324
|
|
|
295
|
-
// Some ops have a single list-valued input arg
|
|
296
|
-
|
|
297
|
-
|
|
325
|
+
// Some ops have a single list-valued input arg — retry with AddInputList.
|
|
326
|
+
if (!op &&
|
|
327
|
+
first_error.find("expected list") != std::string::npos &&
|
|
298
328
|
!resolved_inputs.empty())
|
|
299
329
|
{
|
|
300
330
|
std::string retry_error;
|
|
@@ -306,24 +336,141 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
|
306
336
|
if (!op)
|
|
307
337
|
{
|
|
308
338
|
if (!env.IsExceptionPending())
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
339
|
+
Napi::Error::New(env,
|
|
340
|
+
"TF_FinishOperation failed for " + op_type +
|
|
341
|
+
": " + first_error)
|
|
342
|
+
.ThrowAsJavaScriptException();
|
|
312
343
|
return env.Undefined();
|
|
313
344
|
}
|
|
314
345
|
|
|
315
346
|
Napi::Object result = Napi::Object::New(env);
|
|
316
347
|
result.Set("opName", Napi::String::New(env, op_name));
|
|
317
|
-
result.Set("numOutputs", Napi::Number::New(env,
|
|
348
|
+
result.Set("numOutputs", Napi::Number::New(env,
|
|
349
|
+
static_cast<double>(TF_OperationNumOutputs(op))));
|
|
350
|
+
return result;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// ---------------------------------------------------------------------------
|
|
354
|
+
// addGradients — JS signature:
|
|
355
|
+
// addGradients(
|
|
356
|
+
// y: { opName: string; index: number }[], // loss outputs
|
|
357
|
+
// x: { opName: string; index: number }[], // inputs to diff w.r.t.
|
|
358
|
+
// dx?: { opName: string; index: number }[], // initial upstream grads
|
|
359
|
+
// ) -> { opName: string; index: number }[] // gradient tensors, len = |x|
|
|
360
|
+
//
|
|
361
|
+
// Wraps TF_AddGradients. TF injects gradient ops directly into the graph.
|
|
362
|
+
// The returned outputs are the dL/dx_i tensors, one per x_i.
|
|
363
|
+
//
|
|
364
|
+
// dx defaults to ones (i.e. dL/dy = 1 for scalar loss, standard convention).
|
|
365
|
+
// Pass explicit dx when you need to chain gradients or scale the loss.
|
|
366
|
+
//
|
|
367
|
+
// Throws if any op in the path is non-differentiable (e.g. ArgMax).
|
|
368
|
+
// ---------------------------------------------------------------------------
|
|
369
|
+
Napi::Value GraphWrap::AddGradients(const Napi::CallbackInfo &info)
|
|
370
|
+
{
|
|
371
|
+
Napi::Env env = info.Env();
|
|
372
|
+
|
|
373
|
+
if (!graph_)
|
|
374
|
+
{
|
|
375
|
+
Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
|
|
376
|
+
return env.Undefined();
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
if (info.Length() < 2 || !info[0].IsArray() || !info[1].IsArray())
|
|
380
|
+
{
|
|
381
|
+
Napi::TypeError::New(env, "addGradients(y, x, dx?)")
|
|
382
|
+
.ThrowAsJavaScriptException();
|
|
383
|
+
return env.Undefined();
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
auto resolve_outputs = [&](Napi::Array arr,
|
|
387
|
+
std::vector<TF_Output> &out,
|
|
388
|
+
const std::string &label) -> bool
|
|
389
|
+
{
|
|
390
|
+
for (uint32_t i = 0; i < arr.Length(); ++i)
|
|
391
|
+
{
|
|
392
|
+
auto obj = arr.Get(i).As<Napi::Object>();
|
|
393
|
+
std::string name = obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
394
|
+
int idx = obj.Get("index").As<Napi::Number>().Int32Value();
|
|
395
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
396
|
+
if (!op)
|
|
397
|
+
{
|
|
398
|
+
Napi::Error::New(env, label + " op not found: " + name)
|
|
399
|
+
.ThrowAsJavaScriptException();
|
|
400
|
+
return false;
|
|
401
|
+
}
|
|
402
|
+
out.push_back({op, idx});
|
|
403
|
+
}
|
|
404
|
+
return true;
|
|
405
|
+
};
|
|
406
|
+
|
|
407
|
+
std::vector<TF_Output> y_vec, x_vec, dx_vec;
|
|
408
|
+
|
|
409
|
+
if (!resolve_outputs(info[0].As<Napi::Array>(), y_vec, "y"))
|
|
410
|
+
return env.Undefined();
|
|
411
|
+
if (!resolve_outputs(info[1].As<Napi::Array>(), x_vec, "x"))
|
|
412
|
+
return env.Undefined();
|
|
413
|
+
|
|
414
|
+
// Optional initial gradients.
|
|
415
|
+
TF_Output *dx_ptr = nullptr;
|
|
416
|
+
if (info.Length() >= 3 && info[2].IsArray())
|
|
417
|
+
{
|
|
418
|
+
if (!resolve_outputs(info[2].As<Napi::Array>(), dx_vec, "dx"))
|
|
419
|
+
return env.Undefined();
|
|
420
|
+
if (dx_vec.size() != y_vec.size())
|
|
421
|
+
{
|
|
422
|
+
Napi::Error::New(env,
|
|
423
|
+
"addGradients: dx length must equal y length")
|
|
424
|
+
.ThrowAsJavaScriptException();
|
|
425
|
+
return env.Undefined();
|
|
426
|
+
}
|
|
427
|
+
dx_ptr = dx_vec.data();
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
// Allocate output array — TF_AddGradients writes one gradient per x.
|
|
431
|
+
std::vector<TF_Output> dy(x_vec.size());
|
|
432
|
+
|
|
433
|
+
StatusGuard status;
|
|
434
|
+
TF_AddGradients(
|
|
435
|
+
graph_,
|
|
436
|
+
y_vec.data(), static_cast<int>(y_vec.size()),
|
|
437
|
+
x_vec.data(), static_cast<int>(x_vec.size()),
|
|
438
|
+
dx_ptr,
|
|
439
|
+
status.s,
|
|
440
|
+
dy.data());
|
|
441
|
+
|
|
442
|
+
if (!status.ok())
|
|
443
|
+
{
|
|
444
|
+
Napi::Error::New(env,
|
|
445
|
+
"TF_AddGradients failed: " + status.message())
|
|
446
|
+
.ThrowAsJavaScriptException();
|
|
447
|
+
return env.Undefined();
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
// Return array of { opName, index } — same shape as x.
|
|
451
|
+
Napi::Array result = Napi::Array::New(env, dy.size());
|
|
452
|
+
for (size_t i = 0; i < dy.size(); ++i)
|
|
453
|
+
{
|
|
454
|
+
const char *raw_name = TF_OperationName(dy[i].oper);
|
|
455
|
+
Napi::Object obj = Napi::Object::New(env);
|
|
456
|
+
obj.Set("opName", Napi::String::New(env, raw_name ? raw_name : ""));
|
|
457
|
+
obj.Set("index", Napi::Number::New(env, dy[i].index));
|
|
458
|
+
result.Set(static_cast<uint32_t>(i), obj);
|
|
459
|
+
}
|
|
318
460
|
return result;
|
|
319
461
|
}
|
|
320
462
|
|
|
463
|
+
// ---------------------------------------------------------------------------
|
|
464
|
+
// Remaining methods — unchanged from previous iteration
|
|
465
|
+
// ---------------------------------------------------------------------------
|
|
466
|
+
|
|
321
467
|
Napi::Value GraphWrap::HasOp(const Napi::CallbackInfo &info)
|
|
322
468
|
{
|
|
323
469
|
if (!info[0].IsString())
|
|
324
470
|
return Napi::Boolean::New(info.Env(), false);
|
|
325
471
|
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
326
|
-
return Napi::Boolean::New(info.Env(),
|
|
472
|
+
return Napi::Boolean::New(info.Env(),
|
|
473
|
+
graph_ && TF_GraphOperationByName(graph_, name.c_str()) != nullptr);
|
|
327
474
|
}
|
|
328
475
|
|
|
329
476
|
Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
|
|
@@ -333,13 +480,11 @@ Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
|
|
|
333
480
|
return env.Null();
|
|
334
481
|
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
335
482
|
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
336
|
-
|
|
337
483
|
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
338
484
|
if (!op)
|
|
339
485
|
return env.Null();
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
return Napi::Number::New(env, static_cast<double>(dtype));
|
|
486
|
+
return Napi::Number::New(env,
|
|
487
|
+
static_cast<double>(TF_OperationOutputType({op, idx})));
|
|
343
488
|
}
|
|
344
489
|
|
|
345
490
|
Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
|
|
@@ -349,20 +494,19 @@ Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
|
|
|
349
494
|
return env.Null();
|
|
350
495
|
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
351
496
|
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
352
|
-
|
|
353
497
|
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
354
498
|
if (!op)
|
|
355
499
|
return env.Null();
|
|
356
500
|
|
|
357
501
|
TF_Output out{op, idx};
|
|
358
|
-
StatusGuard
|
|
359
|
-
int ndims = TF_GraphGetTensorNumDims(graph_, out,
|
|
360
|
-
if (!
|
|
502
|
+
StatusGuard s1;
|
|
503
|
+
int ndims = TF_GraphGetTensorNumDims(graph_, out, s1.s);
|
|
504
|
+
if (!s1.ok() || ndims < 0)
|
|
361
505
|
return env.Null();
|
|
362
506
|
|
|
363
507
|
std::vector<int64_t> dims(ndims, -1);
|
|
364
|
-
StatusGuard
|
|
365
|
-
TF_GraphGetTensorShape(graph_, out, dims.data(), ndims,
|
|
508
|
+
StatusGuard s2;
|
|
509
|
+
TF_GraphGetTensorShape(graph_, out, dims.data(), ndims, s2.s);
|
|
366
510
|
|
|
367
511
|
Napi::Array arr = Napi::Array::New(env, ndims);
|
|
368
512
|
for (int i = 0; i < ndims; ++i)
|
|
@@ -379,19 +523,17 @@ Napi::Value GraphWrap::ToGraphDef(const Napi::CallbackInfo &info)
|
|
|
379
523
|
StatusGuard status;
|
|
380
524
|
TF_Buffer *buf = TF_NewBuffer();
|
|
381
525
|
TF_GraphToGraphDef(graph_, buf, status.s);
|
|
382
|
-
|
|
383
526
|
if (!status.ok())
|
|
384
527
|
{
|
|
385
528
|
TF_DeleteBuffer(buf);
|
|
386
|
-
Napi::Error::New(env,
|
|
529
|
+
Napi::Error::New(env,
|
|
530
|
+
"TF_GraphToGraphDef failed: " + status.message())
|
|
387
531
|
.ThrowAsJavaScriptException();
|
|
388
532
|
return env.Undefined();
|
|
389
533
|
}
|
|
390
|
-
|
|
391
534
|
auto node_buf = Napi::Buffer<uint8_t>::Copy(
|
|
392
535
|
env,
|
|
393
|
-
reinterpret_cast<const uint8_t *>(buf->data),
|
|
394
|
-
buf->length);
|
|
536
|
+
reinterpret_cast<const uint8_t *>(buf->data), buf->length);
|
|
395
537
|
TF_DeleteBuffer(buf);
|
|
396
538
|
return node_buf;
|
|
397
539
|
}
|
|
@@ -407,31 +549,14 @@ Napi::Value GraphWrap::NumOps(const Napi::CallbackInfo &info)
|
|
|
407
549
|
return Napi::Number::New(info.Env(), count);
|
|
408
550
|
}
|
|
409
551
|
|
|
410
|
-
// ---------------------------------------------------------------
|
|
411
|
-
// importGraphDef(buffer: Buffer): void
|
|
412
|
-
//
|
|
413
|
-
// Deserialises a binary GraphDef proto into this graph.
|
|
414
|
-
// Intended use: load a frozen .pb model so the graph can be
|
|
415
|
-
// executed via the native Session (with ConfigProto + affinity).
|
|
416
|
-
//
|
|
417
|
-
// The graph must be empty before calling this — importing into a
|
|
418
|
-
// non-empty graph will produce op name collisions.
|
|
419
|
-
//
|
|
420
|
-
// TF_GraphImportGraphDef uses a prefix option (default "")
|
|
421
|
-
// so all imported op names are used as-is, matching the op
|
|
422
|
-
// names in the original frozen graph.
|
|
423
|
-
// ---------------------------------------------------------------
|
|
424
552
|
Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
|
|
425
553
|
{
|
|
426
554
|
Napi::Env env = info.Env();
|
|
427
|
-
|
|
428
555
|
if (!graph_)
|
|
429
556
|
{
|
|
430
|
-
Napi::Error::New(env, "Graph destroyed")
|
|
431
|
-
.ThrowAsJavaScriptException();
|
|
557
|
+
Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
|
|
432
558
|
return env.Undefined();
|
|
433
559
|
}
|
|
434
|
-
|
|
435
560
|
if (info.Length() < 1 || !info[0].IsBuffer())
|
|
436
561
|
{
|
|
437
562
|
Napi::TypeError::New(env, "importGraphDef(buffer: Buffer)")
|
|
@@ -440,138 +565,114 @@ Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
|
|
|
440
565
|
}
|
|
441
566
|
|
|
442
567
|
auto buf = info[0].As<Napi::Buffer<uint8_t>>();
|
|
443
|
-
|
|
444
|
-
// TF_NewBufferFromString copies the bytes into TF-owned memory.
|
|
445
|
-
// The JS Buffer does not need to outlive this call.
|
|
446
|
-
TF_Buffer *graphdef =
|
|
447
|
-
TF_NewBufferFromString(buf.Data(), buf.ByteLength());
|
|
448
|
-
|
|
568
|
+
TF_Buffer *graphdef = TF_NewBufferFromString(buf.Data(), buf.ByteLength());
|
|
449
569
|
TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
|
|
450
|
-
// Default prefix is "" — op names in the imported graph are
|
|
451
|
-
// used verbatim, matching the frozen graph's checkpoint keys.
|
|
452
570
|
TF_ImportGraphDefOptionsSetPrefix(opts, "");
|
|
453
571
|
|
|
454
572
|
StatusGuard status;
|
|
455
573
|
TF_GraphImportGraphDef(graph_, graphdef, opts, status.s);
|
|
456
|
-
|
|
457
574
|
TF_DeleteImportGraphDefOptions(opts);
|
|
458
575
|
TF_DeleteBuffer(graphdef);
|
|
459
576
|
|
|
460
577
|
if (!status.ok())
|
|
461
|
-
{
|
|
462
578
|
Napi::Error::New(env,
|
|
463
579
|
"TF_GraphImportGraphDef failed: " + status.message())
|
|
464
580
|
.ThrowAsJavaScriptException();
|
|
465
|
-
}
|
|
466
581
|
return env.Undefined();
|
|
467
582
|
}
|
|
468
583
|
|
|
469
|
-
|
|
470
|
-
// addGradients — JS signature:
|
|
471
|
-
// addGradients(
|
|
472
|
-
// y: { opName: string; index: number }[], // loss outputs
|
|
473
|
-
// x: { opName: string; index: number }[], // inputs to diff w.r.t.
|
|
474
|
-
// dx?: { opName: string; index: number }[], // initial upstream grads
|
|
475
|
-
// ) -> { opName: string; index: number }[] // gradient tensors, len = |x|
|
|
476
|
-
//
|
|
477
|
-
// Wraps TF_AddGradients. TF injects gradient ops directly into the graph.
|
|
478
|
-
// The returned outputs are the dL/dx_i tensors, one per x_i.
|
|
479
|
-
//
|
|
480
|
-
// dx defaults to ones (i.e. dL/dy = 1 for scalar loss, standard convention).
|
|
481
|
-
// Pass explicit dx when you need to chain gradients or scale the loss.
|
|
482
|
-
//
|
|
483
|
-
// Throws if any op in the path is non-differentiable (e.g. ArgMax).
|
|
484
|
-
// ---------------------------------------------------------------------------
|
|
485
|
-
Napi::Value GraphWrap::AddGradients(const Napi::CallbackInfo &info)
|
|
584
|
+
Napi::Value GraphWrap::ListOpsOfType(const Napi::CallbackInfo &info)
|
|
486
585
|
{
|
|
487
586
|
Napi::Env env = info.Env();
|
|
488
|
-
|
|
489
|
-
if (!graph_)
|
|
490
|
-
{
|
|
491
|
-
Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
|
|
492
|
-
return env.Undefined();
|
|
493
|
-
}
|
|
494
|
-
|
|
495
|
-
if (info.Length() < 2 || !info[0].IsArray() || !info[1].IsArray())
|
|
587
|
+
if (!graph_ || !info[0].IsString())
|
|
496
588
|
{
|
|
497
|
-
Napi::
|
|
498
|
-
.ThrowAsJavaScriptException();
|
|
499
|
-
return env.Undefined();
|
|
589
|
+
return Napi::Array::New(env, 0);
|
|
500
590
|
}
|
|
591
|
+
std::string target_type = info[0].As<Napi::String>().Utf8Value();
|
|
501
592
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
593
|
+
std::vector<std::string> matches;
|
|
594
|
+
size_t pos = 0;
|
|
595
|
+
TF_Operation *op = nullptr;
|
|
596
|
+
while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
|
|
505
597
|
{
|
|
506
|
-
|
|
598
|
+
const char *op_type = TF_OperationOpType(op);
|
|
599
|
+
if (op_type && target_type == op_type)
|
|
507
600
|
{
|
|
508
|
-
|
|
509
|
-
std::string name = obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
510
|
-
int idx = obj.Get("index").As<Napi::Number>().Int32Value();
|
|
511
|
-
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
512
|
-
if (!op)
|
|
513
|
-
{
|
|
514
|
-
Napi::Error::New(env, label + " op not found: " + name)
|
|
515
|
-
.ThrowAsJavaScriptException();
|
|
516
|
-
return false;
|
|
517
|
-
}
|
|
518
|
-
out.push_back({op, idx});
|
|
601
|
+
matches.push_back(TF_OperationName(op));
|
|
519
602
|
}
|
|
520
|
-
|
|
521
|
-
};
|
|
603
|
+
}
|
|
522
604
|
|
|
523
|
-
|
|
605
|
+
Napi::Array result = Napi::Array::New(env, matches.size());
|
|
606
|
+
for (size_t i = 0; i < matches.size(); ++i)
|
|
607
|
+
result.Set(static_cast<uint32_t>(i), Napi::String::New(env, matches[i]));
|
|
608
|
+
return result;
|
|
609
|
+
}
|
|
524
610
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
611
|
+
Napi::Value GraphWrap::ListSinkOps(const Napi::CallbackInfo &info)
|
|
612
|
+
{
|
|
613
|
+
Napi::Env env = info.Env();
|
|
614
|
+
if (!graph_)
|
|
615
|
+
return Napi::Array::New(env, 0);
|
|
616
|
+
|
|
617
|
+
// Op types that must never be returned as output ops.
|
|
618
|
+
// These are control/infrastructure nodes that appear in frozen SavedModel
|
|
619
|
+
// graphs but have zero outputs or are not tensor-producing ops.
|
|
620
|
+
// NoOp — zero outputs, used as a control dependency fence
|
|
621
|
+
// VarHandleOp — resource handle for a variable, not a model output
|
|
622
|
+
// Placeholder — graph inputs, explicitly not outputs
|
|
623
|
+
// Const — constant initializer values, not inference outputs
|
|
624
|
+
// AssignVariableOp / ReadVariableOp — variable lifecycle ops
|
|
625
|
+
static const std::unordered_set<std::string> excluded_types = {
|
|
626
|
+
"NoOp",
|
|
627
|
+
"VarHandleOp",
|
|
628
|
+
"Placeholder",
|
|
629
|
+
"Const",
|
|
630
|
+
"AssignVariableOp",
|
|
631
|
+
"ReadVariableOp",
|
|
632
|
+
"SaveV2",
|
|
633
|
+
"RestoreV2",
|
|
634
|
+
"MergeV2Checkpoints",
|
|
635
|
+
"StringJoin",
|
|
636
|
+
"ShardedFilename",
|
|
637
|
+
"_Arg",
|
|
638
|
+
"_Retval",
|
|
639
|
+
};
|
|
529
640
|
|
|
530
|
-
//
|
|
531
|
-
|
|
532
|
-
|
|
641
|
+
// Collect every op that is consumed as a data input by some other op.
|
|
642
|
+
std::unordered_set<TF_Operation *> consumed;
|
|
643
|
+
|
|
644
|
+
size_t pos = 0;
|
|
645
|
+
TF_Operation *op = nullptr;
|
|
646
|
+
while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
|
|
533
647
|
{
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
if (dx_vec.size() != y_vec.size())
|
|
648
|
+
int num_inputs = TF_OperationNumInputs(op);
|
|
649
|
+
for (int i = 0; i < num_inputs; ++i)
|
|
537
650
|
{
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
.ThrowAsJavaScriptException();
|
|
541
|
-
return env.Undefined();
|
|
651
|
+
TF_Output src = TF_OperationInput({op, i});
|
|
652
|
+
consumed.insert(src.oper);
|
|
542
653
|
}
|
|
543
|
-
dx_ptr = dx_vec.data();
|
|
544
654
|
}
|
|
545
655
|
|
|
546
|
-
//
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
x_vec.data(), static_cast<int>(x_vec.size()),
|
|
554
|
-
dx_ptr,
|
|
555
|
-
status.s,
|
|
556
|
-
dy.data());
|
|
557
|
-
|
|
558
|
-
if (!status.ok())
|
|
656
|
+
// A sink op is one that:
|
|
657
|
+
// 1. Is not consumed as an input by any other op
|
|
658
|
+
// 2. Has at least one output (excludes NoOp and other zero-output nodes)
|
|
659
|
+
// 3. Is not an excluded infrastructure op type
|
|
660
|
+
std::vector<std::string> sinks;
|
|
661
|
+
pos = 0;
|
|
662
|
+
while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
|
|
559
663
|
{
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
664
|
+
if (consumed.find(op) != consumed.end())
|
|
665
|
+
continue;
|
|
666
|
+
if (TF_OperationNumOutputs(op) < 1)
|
|
667
|
+
continue;
|
|
668
|
+
const char *op_type = TF_OperationOpType(op);
|
|
669
|
+
if (op_type && excluded_types.count(op_type))
|
|
670
|
+
continue;
|
|
671
|
+
sinks.push_back(TF_OperationName(op));
|
|
564
672
|
}
|
|
565
673
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
{
|
|
570
|
-
const char *raw_name = TF_OperationName(dy[i].oper);
|
|
571
|
-
Napi::Object obj = Napi::Object::New(env);
|
|
572
|
-
obj.Set("opName", Napi::String::New(env, raw_name ? raw_name : ""));
|
|
573
|
-
obj.Set("index", Napi::Number::New(env, dy[i].index));
|
|
574
|
-
result.Set(static_cast<uint32_t>(i), obj);
|
|
575
|
-
}
|
|
674
|
+
Napi::Array result = Napi::Array::New(env, sinks.size());
|
|
675
|
+
for (size_t i = 0; i < sinks.size(); ++i)
|
|
676
|
+
result.Set(static_cast<uint32_t>(i), Napi::String::New(env, sinks[i]));
|
|
576
677
|
return result;
|
|
577
678
|
}
|