@isidorus/cpu 0.0.0-alpha.2 → 0.0.0-alpha.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,7 @@
1
1
  #include "graph.h"
2
2
  #include <cstring>
3
- #include <stdexcept>
4
3
  #include <unordered_map>
5
-
6
- // -------------------------------------------------------------------
7
- // GraphWrap - N-API ObjectWrap around TF_Graph
8
- //
9
- // Exposes graph construction primitives:
10
- // addOp(type, inputs, attrs) -> {opName, numOutputs}
11
- // hasOp(name) -> boolean
12
- // opOutputType(name, index) -> DType integer
13
- // opOutputShape(name, index) -> number[] (-1 = unknown dim)
14
- // toGraphDef() -> Buffer (serialized GraphDef proto)
15
- // numOps() -> number
16
- //-------------------------------------------------------------------
4
+ #include <unordered_set>
17
5
 
18
6
  Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
19
7
  {
@@ -26,6 +14,8 @@ Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
26
14
  InstanceMethod<&GraphWrap::NumOps>("numOps"),
27
15
  InstanceMethod<&GraphWrap::ImportGraphDef>("importGraphDef"),
28
16
  InstanceMethod<&GraphWrap::AddGradients>("addGradients"),
17
+ InstanceMethod<&GraphWrap::ListOpsOfType>("listOpsOfType"),
18
+ InstanceMethod<&GraphWrap::ListSinkOps>("listSinkOps"),
29
19
  });
30
20
  auto *ctor = new Napi::FunctionReference(Napi::Persistent(func));
31
21
  env.SetInstanceData<Napi::FunctionReference>(ctor);
@@ -33,11 +23,13 @@ Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
33
23
  return exports;
34
24
  }
35
25
 
36
- GraphWrap::GraphWrap(const Napi::CallbackInfo &info) : Napi::ObjectWrap<GraphWrap>(info)
26
+ GraphWrap::GraphWrap(const Napi::CallbackInfo &info)
27
+ : Napi::ObjectWrap<GraphWrap>(info)
37
28
  {
38
29
  graph_ = TF_NewGraph();
39
30
  if (!graph_)
40
- Napi::Error::New(info.Env(), "Failed to create TF_Graph").ThrowAsJavaScriptException();
31
+ Napi::Error::New(info.Env(), "Failed to create TF_Graph")
32
+ .ThrowAsJavaScriptException();
41
33
  }
42
34
 
43
35
  GraphWrap::~GraphWrap()
@@ -49,109 +41,138 @@ GraphWrap::~GraphWrap()
49
41
  }
50
42
  }
51
43
 
52
- // ---------------------------------------------------------------
53
- // addOp(type: string, inputs: {opName: string, index: number}[],
54
- // attrs: Record<string, AttrValue>) -> { opName: string, numOutputs : number }
44
+ // ---------------------------------------------------------------------------
45
+ // addOp JS signature:
46
+ // addOp(
47
+ // type: string,
48
+ // inputs: { opName: string; index: number }[],
49
+ // attrs?: Record<string, AttrValue>,
50
+ // name?: string,
51
+ // controlInputs?: string[], // ← NEW: op names that must run first
52
+ // ) -> { opName: string; numOutputs: number }
55
53
  //
56
- // AttrValue is one of:
57
- // { kind: "int", value: number }
58
- // { kind: "float", value: number }
59
- // { kind: "bool", value: boolean }
60
- // { kind: "type", value: number } <- DType integer
61
- // { kind: "shape", value: number[] } <- -1 for unknown dim
62
- // { kind: "shape", value: { dtype, shape, data: Buffer } }
63
- // { kind: "list_type", value: number[] }
64
- // { kind: "list_shape", value: number[] }
65
- // ---------------------------------------------------------------
66
-
54
+ // Control inputs are wired via TF_AddControlInput. TF guarantees the
55
+ // described op will not execute until every control-input op has completed.
56
+ // This is the mechanism behind globalVariablesInitializer: a NoOp with
57
+ // control edges to all init AssignVariableOps — running the NoOp forces all
58
+ // assignments to execute first.
59
+ // ---------------------------------------------------------------------------
67
60
  Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
68
61
  {
69
62
  Napi::Env env = info.Env();
70
63
 
71
64
  if (!graph_)
72
65
  {
73
- Napi::Error::New(env, "Graph has been destroyed").ThrowAsJavaScriptException();
66
+ Napi::Error::New(env, "Graph has been destroyed")
67
+ .ThrowAsJavaScriptException();
74
68
  return env.Undefined();
75
69
  }
76
70
 
77
71
  if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray())
78
72
  {
79
- Napi::TypeError::New(env, "addOp(type: string, inputs: TFOutput[], attrs?)").ThrowAsJavaScriptException();
73
+ Napi::TypeError::New(env,
74
+ "addOp(type, inputs, attrs?, name?, controlInputs?)")
75
+ .ThrowAsJavaScriptException();
80
76
  return env.Undefined();
81
77
  }
82
78
 
83
79
  std::string op_type = info[0].As<Napi::String>().Utf8Value();
84
-
85
- // Auto-generate a unique op name: type + "_" + counter.
86
80
  std::string op_name = op_type + "_" + std::to_string(op_counter_++);
87
- // Allow caller to override the name via optional 3rd arg.
88
81
  if (info.Length() >= 4 && info[3].IsString())
89
82
  op_name = info[3].As<Napi::String>().Utf8Value();
90
83
 
91
- // Resolve inputs first so we can retry op construction with TF_AddInputList
92
- // for list-input ops (e.g. IdentityN) when needed.
84
+ // ── Resolve data inputs ─────────────────────────────────────────────────
93
85
  std::vector<TF_Output> resolved_inputs;
94
86
  auto inputs_arr = info[1].As<Napi::Array>();
95
87
  resolved_inputs.reserve(inputs_arr.Length());
96
88
  for (uint32_t i = 0; i < inputs_arr.Length(); i++)
97
89
  {
98
- auto input_obj = inputs_arr.Get(i).As<Napi::Object>();
99
- std::string input_op_name = input_obj.Get("opName").As<Napi::String>().Utf8Value();
100
- int input_idx = input_obj.Get("index").As<Napi::Number>().Int32Value();
101
-
102
- TF_Operation *input_op = TF_GraphOperationByName(graph_, input_op_name.c_str());
103
- if (!input_op)
90
+ auto obj = inputs_arr.Get(i).As<Napi::Object>();
91
+ std::string dep_name = obj.Get("opName").As<Napi::String>().Utf8Value();
92
+ int dep_idx = obj.Get("index").As<Napi::Number>().Int32Value();
93
+ TF_Operation *dep_op = TF_GraphOperationByName(graph_, dep_name.c_str());
94
+ if (!dep_op)
104
95
  {
105
- Napi::Error::New(env, "Input op not found: " + input_op_name).ThrowAsJavaScriptException();
96
+ Napi::Error::New(env, "Input op not found: " + dep_name)
97
+ .ThrowAsJavaScriptException();
106
98
  return env.Undefined();
107
99
  }
108
- resolved_inputs.push_back({input_op, input_idx});
100
+ resolved_inputs.push_back({dep_op, dep_idx});
101
+ }
102
+
103
+ // ── Resolve control inputs (arg index 4) ───────────────────────────────
104
+ std::vector<TF_Operation *> ctrl_ops;
105
+ if (info.Length() >= 5 && info[4].IsArray())
106
+ {
107
+ auto ctrl_arr = info[4].As<Napi::Array>();
108
+ ctrl_ops.reserve(ctrl_arr.Length());
109
+ for (uint32_t i = 0; i < ctrl_arr.Length(); i++)
110
+ {
111
+ std::string ctrl_name =
112
+ ctrl_arr.Get(i).As<Napi::String>().Utf8Value();
113
+ TF_Operation *ctrl_op =
114
+ TF_GraphOperationByName(graph_, ctrl_name.c_str());
115
+ if (!ctrl_op)
116
+ {
117
+ Napi::Error::New(env,
118
+ "Control input op not found: " + ctrl_name)
119
+ .ThrowAsJavaScriptException();
120
+ return env.Undefined();
121
+ }
122
+ ctrl_ops.push_back(ctrl_op);
123
+ }
109
124
  }
110
125
 
126
+ // ── Attr dispatch ───────────────────────────────────────────────────────
127
+ enum class AttrKind
128
+ {
129
+ Int,
130
+ Float,
131
+ Bool,
132
+ Type,
133
+ Shape,
134
+ ListType,
135
+ ListInt,
136
+ Tensor,
137
+ String,
138
+ Unknown
139
+ };
140
+ static const std::unordered_map<std::string, AttrKind> kind_map = {
141
+ {"int", AttrKind::Int},
142
+ {"float", AttrKind::Float},
143
+ {"bool", AttrKind::Bool},
144
+ {"type", AttrKind::Type},
145
+ {"shape", AttrKind::Shape},
146
+ {"list_type", AttrKind::ListType},
147
+ {"list_int", AttrKind::ListInt},
148
+ {"tensor", AttrKind::Tensor},
149
+ {"string", AttrKind::String},
150
+ };
151
+
111
152
  auto apply_attrs = [&](TF_OperationDescription *desc) -> bool
112
153
  {
113
154
  if (!(info.Length() >= 3 && info[2].IsObject()))
114
155
  return true;
115
156
 
116
- enum class AttrKind
117
- {
118
- Int,
119
- Float,
120
- Bool,
121
- Type,
122
- Shape,
123
- ListType,
124
- ListInt,
125
- Tensor,
126
- String,
127
- Unknown
128
- };
129
-
130
- static const std::unordered_map<std::string, AttrKind> kind_map = {
131
- {"int", AttrKind::Int},
132
- {"float", AttrKind::Float},
133
- {"bool", AttrKind::Bool},
134
- {"type", AttrKind::Type},
135
- {"shape", AttrKind::Shape},
136
- {"list_type", AttrKind::ListType},
137
- {"list_int", AttrKind::ListInt},
138
- {"tensor", AttrKind::Tensor},
139
- {"string", AttrKind::String}};
140
-
141
157
  auto attrs = info[2].As<Napi::Object>();
142
158
  auto attrs_keys = attrs.GetPropertyNames();
143
159
  for (uint32_t i = 0; i < attrs_keys.Length(); i++)
144
160
  {
145
- std::string attr_name = attrs_keys.Get(i).As<Napi::String>().Utf8Value();
161
+ std::string attr_name =
162
+ attrs_keys.Get(i).As<Napi::String>().Utf8Value();
146
163
  auto attr_val = attrs.Get(attr_name).As<Napi::Object>();
147
- std::string kind_str = attr_val.Get("kind").As<Napi::String>().Utf8Value();
164
+ std::string kind_str =
165
+ attr_val.Get("kind").As<Napi::String>().Utf8Value();
166
+
148
167
  auto it = kind_map.find(kind_str);
149
- AttrKind kind = (it != kind_map.end()) ? it->second : AttrKind::Unknown;
168
+ AttrKind kind = (it != kind_map.end()) ? it->second
169
+ : AttrKind::Unknown;
150
170
  switch (kind)
151
171
  {
152
172
  case AttrKind::Int:
153
173
  {
154
- int64_t v = static_cast<int64_t>(attr_val.Get("value").As<Napi::Number>().Int64Value());
174
+ int64_t v = static_cast<int64_t>(
175
+ attr_val.Get("value").As<Napi::Number>().Int64Value());
155
176
  TF_SetAttrInt(desc, attr_name.c_str(), v);
156
177
  break;
157
178
  }
@@ -163,8 +184,8 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
163
184
  }
164
185
  case AttrKind::Bool:
165
186
  {
166
- unsigned char b = attr_val.Get("value").As<Napi::Boolean>().Value() ? 1
167
- : 0;
187
+ unsigned char b =
188
+ attr_val.Get("value").As<Napi::Boolean>().Value() ? 1 : 0;
168
189
  TF_SetAttrBool(desc, attr_name.c_str(), b);
169
190
  break;
170
191
  }
@@ -180,7 +201,8 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
180
201
  auto dims_arr = attr_val.Get("value").As<Napi::Array>();
181
202
  std::vector<int64_t> dims(dims_arr.Length());
182
203
  for (uint32_t j = 0; j < dims_arr.Length(); ++j)
183
- dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
204
+ dims[j] = static_cast<int64_t>(
205
+ dims_arr.Get(j).As<Napi::Number>().Int64Value());
184
206
  TF_SetAttrShape(desc, attr_name.c_str(),
185
207
  dims.data(), static_cast<int>(dims.size()));
186
208
  break;
@@ -190,8 +212,11 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
190
212
  auto vals = attr_val.Get("value").As<Napi::Array>();
191
213
  std::vector<TF_DataType> types(vals.Length());
192
214
  for (uint32_t j = 0; j < vals.Length(); ++j)
193
- types[j] = static_cast<TF_DataType>(vals.Get(j).As<Napi::Number>().Int32Value());
194
- TF_SetAttrTypeList(desc, attr_name.c_str(), types.data(), static_cast<int>(types.size()));
215
+ types[j] = static_cast<TF_DataType>(
216
+ vals.Get(j).As<Napi::Number>().Int32Value());
217
+ TF_SetAttrTypeList(desc, attr_name.c_str(),
218
+ types.data(),
219
+ static_cast<int>(types.size()));
195
220
  break;
196
221
  }
197
222
  case AttrKind::ListInt:
@@ -199,83 +224,88 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
199
224
  auto vals_int = attr_val.Get("value").As<Napi::Array>();
200
225
  std::vector<int64_t> ints(vals_int.Length());
201
226
  for (uint32_t j = 0; j < vals_int.Length(); ++j)
202
- ints[j] = static_cast<int64_t>(vals_int.Get(j).As<Napi::Number>().Int64Value());
203
- TF_SetAttrIntList(desc, attr_name.c_str(), ints.data(), static_cast<int>(ints.size()));
227
+ ints[j] = static_cast<int64_t>(
228
+ vals_int.Get(j).As<Napi::Number>().Int64Value());
229
+ TF_SetAttrIntList(desc, attr_name.c_str(),
230
+ ints.data(),
231
+ static_cast<int>(ints.size()));
204
232
  break;
205
233
  }
206
234
  case AttrKind::Tensor:
207
235
  {
208
- // Inline constant tensor
209
236
  auto tv = attr_val.Get("value").As<Napi::Object>();
210
- TF_DataType dtype = static_cast<TF_DataType>(tv.Get("dtype").As<Napi::Number>().Int32Value());
237
+ TF_DataType dtype = static_cast<TF_DataType>(
238
+ tv.Get("dtype").As<Napi::Number>().Int32Value());
211
239
  auto data_buf = tv.Get("data").As<Napi::Buffer<uint8_t>>();
212
240
  auto dims_arr = tv.Get("shape").As<Napi::Array>();
213
-
214
241
  std::vector<int64_t> dims(dims_arr.Length());
215
242
  for (uint32_t j = 0; j < dims_arr.Length(); ++j)
216
- dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
217
-
243
+ dims[j] = static_cast<int64_t>(
244
+ dims_arr.Get(j).As<Napi::Number>().Int64Value());
218
245
  StatusGuard ts;
219
- TF_Tensor *tensor = TF_AllocateTensor(dtype, dims.data(), static_cast<int>(dims.size()), data_buf.Length());
246
+ TF_Tensor *tensor = TF_AllocateTensor(
247
+ dtype, dims.data(), static_cast<int>(dims.size()),
248
+ data_buf.Length());
220
249
  if (tensor)
221
250
  {
222
- std::memcpy(TF_TensorData(tensor), data_buf.Data(), data_buf.ByteLength());
251
+ std::memcpy(TF_TensorData(tensor),
252
+ data_buf.Data(), data_buf.ByteLength());
223
253
  TF_SetAttrTensor(desc, attr_name.c_str(), tensor, ts.s);
224
- TF_DeleteTensor(tensor); // Graph takes ownership, safe to delete here.
254
+ TF_DeleteTensor(tensor);
225
255
  }
226
256
  break;
227
257
  }
228
258
  case AttrKind::String:
229
259
  {
230
- std::string s = attr_val.Get("value").As<Napi::String>().Utf8Value();
231
- TF_SetAttrString(desc, attr_name.c_str(), s.data(), s.length());
260
+ // TF_SetAttrString expects raw bytes + length.
261
+ // The JS value is a JS string — convert via Utf8Value().
262
+ std::string sv;
263
+ auto vnode = attr_val.Get("value");
264
+ if (vnode.IsString())
265
+ sv = vnode.As<Napi::String>().Utf8Value();
266
+ TF_SetAttrString(desc, attr_name.c_str(),
267
+ sv.data(), sv.size());
232
268
  break;
233
269
  }
234
- case AttrKind::Unknown:
235
270
  default:
236
- Napi::Error::New(env, "Unsupported attr kind: " + kind_str).ThrowAsJavaScriptException();
271
+ Napi::Error::New(env, "Unsupported attr kind: " + kind_str)
272
+ .ThrowAsJavaScriptException();
237
273
  return false;
238
274
  }
239
275
  }
240
276
  return true;
241
277
  };
242
278
 
243
- auto finish_op = [&](bool add_as_input_list, std::string &error_out) -> TF_Operation *
279
+ // ── Build the op description ────────────────────────────────────────────
280
+ auto finish_op = [&](bool list_input,
281
+ std::string &err) -> TF_Operation *
244
282
  {
245
283
  StatusGuard status;
246
- TF_OperationDescription *desc = TF_NewOperation(graph_, op_type.c_str(), op_name.c_str());
284
+ TF_OperationDescription *desc =
285
+ TF_NewOperation(graph_, op_type.c_str(), op_name.c_str());
247
286
  if (!desc)
248
287
  {
249
- error_out = "TF_NewOperation failed for type " + op_type;
288
+ err = "TF_NewOperation failed for " + op_type;
250
289
  return nullptr;
251
290
  }
252
291
 
253
- if (add_as_input_list)
292
+ if (list_input)
254
293
  {
255
294
  if (!resolved_inputs.empty())
256
- TF_AddInputList(desc, resolved_inputs.data(), static_cast<int>(resolved_inputs.size()));
295
+ TF_AddInputList(desc, resolved_inputs.data(),
296
+ static_cast<int>(resolved_inputs.size()));
257
297
  }
258
298
  else
259
299
  {
260
- for (const auto &input : resolved_inputs)
261
- TF_AddInput(desc, input);
300
+ for (const auto &inp : resolved_inputs)
301
+ TF_AddInput(desc, inp);
262
302
  }
263
303
 
264
- if (info.Length() >= 5 && info[4].IsArray())
265
- {
266
- auto ctrl_arr = info[4].As<Napi::Array>();
267
- for (uint32_t i = 0; i < ctrl_arr.Length(); i++)
268
- {
269
- std::string ctrl_name = ctrl_arr.Get(i).As<Napi::String>().Utf8Value();
270
- TF_Operation *ctrl_op = TF_GraphOperationByName(graph_, ctrl_name.c_str());
271
- if (!ctrl_op)
272
- {
273
- Napi::Error::New(env, "Control input op not found: " + ctrl_name).ThrowAsJavaScriptException();
274
- return nullptr;
275
- }
276
- TF_AddControlInput(desc, ctrl_op);
277
- }
278
- }
304
+ // Wire control dependencies these guarantee ordering without
305
+ // passing tensor data. The described op will not run until all
306
+ // control-input ops have finished executing.
307
+ for (TF_Operation *ctrl : ctrl_ops)
308
+ TF_AddControlInput(desc, ctrl);
279
309
 
280
310
  if (!apply_attrs(desc))
281
311
  return nullptr;
@@ -283,7 +313,7 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
283
313
  TF_Operation *op = TF_FinishOperation(desc, status.s);
284
314
  if (!status.ok() || !op)
285
315
  {
286
- error_out = status.message();
316
+ err = status.message();
287
317
  return nullptr;
288
318
  }
289
319
  return op;
@@ -292,9 +322,9 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
292
322
  std::string first_error;
293
323
  TF_Operation *op = finish_op(false, first_error);
294
324
 
295
- // Some ops have a single list-valued input arg and reject repeated
296
- // TF_AddInput calls even for one element; retry with TF_AddInputList.
297
- if (!op && first_error.find("expected list") != std::string::npos &&
325
+ // Some ops have a single list-valued input arg retry with AddInputList.
326
+ if (!op &&
327
+ first_error.find("expected list") != std::string::npos &&
298
328
  !resolved_inputs.empty())
299
329
  {
300
330
  std::string retry_error;
@@ -306,24 +336,141 @@ Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
306
336
  if (!op)
307
337
  {
308
338
  if (!env.IsExceptionPending())
309
- {
310
- Napi::Error::New(env, "TF_FinishOperation failed for " + op_type + ": " + first_error).ThrowAsJavaScriptException();
311
- }
339
+ Napi::Error::New(env,
340
+ "TF_FinishOperation failed for " + op_type +
341
+ ": " + first_error)
342
+ .ThrowAsJavaScriptException();
312
343
  return env.Undefined();
313
344
  }
314
345
 
315
346
  Napi::Object result = Napi::Object::New(env);
316
347
  result.Set("opName", Napi::String::New(env, op_name));
317
- result.Set("numOutputs", Napi::Number::New(env, static_cast<double>(TF_OperationNumOutputs(op))));
348
+ result.Set("numOutputs", Napi::Number::New(env,
349
+ static_cast<double>(TF_OperationNumOutputs(op))));
350
+ return result;
351
+ }
352
+
353
+ // ---------------------------------------------------------------------------
354
+ // addGradients — JS signature:
355
+ // addGradients(
356
+ // y: { opName: string; index: number }[], // loss outputs
357
+ // x: { opName: string; index: number }[], // inputs to diff w.r.t.
358
+ // dx?: { opName: string; index: number }[], // initial upstream grads
359
+ // ) -> { opName: string; index: number }[] // gradient tensors, len = |x|
360
+ //
361
+ // Wraps TF_AddGradients. TF injects gradient ops directly into the graph.
362
+ // The returned outputs are the dL/dx_i tensors, one per x_i.
363
+ //
364
+ // dx defaults to ones (i.e. dL/dy = 1 for scalar loss, standard convention).
365
+ // Pass explicit dx when you need to chain gradients or scale the loss.
366
+ //
367
+ // Throws if any op in the path is non-differentiable (e.g. ArgMax).
368
+ // ---------------------------------------------------------------------------
369
+ Napi::Value GraphWrap::AddGradients(const Napi::CallbackInfo &info)
370
+ {
371
+ Napi::Env env = info.Env();
372
+
373
+ if (!graph_)
374
+ {
375
+ Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
376
+ return env.Undefined();
377
+ }
378
+
379
+ if (info.Length() < 2 || !info[0].IsArray() || !info[1].IsArray())
380
+ {
381
+ Napi::TypeError::New(env, "addGradients(y, x, dx?)")
382
+ .ThrowAsJavaScriptException();
383
+ return env.Undefined();
384
+ }
385
+
386
+ auto resolve_outputs = [&](Napi::Array arr,
387
+ std::vector<TF_Output> &out,
388
+ const std::string &label) -> bool
389
+ {
390
+ for (uint32_t i = 0; i < arr.Length(); ++i)
391
+ {
392
+ auto obj = arr.Get(i).As<Napi::Object>();
393
+ std::string name = obj.Get("opName").As<Napi::String>().Utf8Value();
394
+ int idx = obj.Get("index").As<Napi::Number>().Int32Value();
395
+ TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
396
+ if (!op)
397
+ {
398
+ Napi::Error::New(env, label + " op not found: " + name)
399
+ .ThrowAsJavaScriptException();
400
+ return false;
401
+ }
402
+ out.push_back({op, idx});
403
+ }
404
+ return true;
405
+ };
406
+
407
+ std::vector<TF_Output> y_vec, x_vec, dx_vec;
408
+
409
+ if (!resolve_outputs(info[0].As<Napi::Array>(), y_vec, "y"))
410
+ return env.Undefined();
411
+ if (!resolve_outputs(info[1].As<Napi::Array>(), x_vec, "x"))
412
+ return env.Undefined();
413
+
414
+ // Optional initial gradients.
415
+ TF_Output *dx_ptr = nullptr;
416
+ if (info.Length() >= 3 && info[2].IsArray())
417
+ {
418
+ if (!resolve_outputs(info[2].As<Napi::Array>(), dx_vec, "dx"))
419
+ return env.Undefined();
420
+ if (dx_vec.size() != y_vec.size())
421
+ {
422
+ Napi::Error::New(env,
423
+ "addGradients: dx length must equal y length")
424
+ .ThrowAsJavaScriptException();
425
+ return env.Undefined();
426
+ }
427
+ dx_ptr = dx_vec.data();
428
+ }
429
+
430
+ // Allocate output array — TF_AddGradients writes one gradient per x.
431
+ std::vector<TF_Output> dy(x_vec.size());
432
+
433
+ StatusGuard status;
434
+ TF_AddGradients(
435
+ graph_,
436
+ y_vec.data(), static_cast<int>(y_vec.size()),
437
+ x_vec.data(), static_cast<int>(x_vec.size()),
438
+ dx_ptr,
439
+ status.s,
440
+ dy.data());
441
+
442
+ if (!status.ok())
443
+ {
444
+ Napi::Error::New(env,
445
+ "TF_AddGradients failed: " + status.message())
446
+ .ThrowAsJavaScriptException();
447
+ return env.Undefined();
448
+ }
449
+
450
+ // Return array of { opName, index } — same shape as x.
451
+ Napi::Array result = Napi::Array::New(env, dy.size());
452
+ for (size_t i = 0; i < dy.size(); ++i)
453
+ {
454
+ const char *raw_name = TF_OperationName(dy[i].oper);
455
+ Napi::Object obj = Napi::Object::New(env);
456
+ obj.Set("opName", Napi::String::New(env, raw_name ? raw_name : ""));
457
+ obj.Set("index", Napi::Number::New(env, dy[i].index));
458
+ result.Set(static_cast<uint32_t>(i), obj);
459
+ }
318
460
  return result;
319
461
  }
320
462
 
463
+ // ---------------------------------------------------------------------------
464
+ // Remaining methods — unchanged from previous iteration
465
+ // ---------------------------------------------------------------------------
466
+
321
467
  Napi::Value GraphWrap::HasOp(const Napi::CallbackInfo &info)
322
468
  {
323
469
  if (!info[0].IsString())
324
470
  return Napi::Boolean::New(info.Env(), false);
325
471
  std::string name = info[0].As<Napi::String>().Utf8Value();
326
- return Napi::Boolean::New(info.Env(), graph_ && TF_GraphOperationByName(graph_, name.c_str()) != nullptr);
472
+ return Napi::Boolean::New(info.Env(),
473
+ graph_ && TF_GraphOperationByName(graph_, name.c_str()) != nullptr);
327
474
  }
328
475
 
329
476
  Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
@@ -333,13 +480,11 @@ Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
333
480
  return env.Null();
334
481
  std::string name = info[0].As<Napi::String>().Utf8Value();
335
482
  int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
336
-
337
483
  TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
338
484
  if (!op)
339
485
  return env.Null();
340
-
341
- TF_DataType dtype = TF_OperationOutputType({op, idx});
342
- return Napi::Number::New(env, static_cast<double>(dtype));
486
+ return Napi::Number::New(env,
487
+ static_cast<double>(TF_OperationOutputType({op, idx})));
343
488
  }
344
489
 
345
490
  Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
@@ -349,20 +494,19 @@ Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
349
494
  return env.Null();
350
495
  std::string name = info[0].As<Napi::String>().Utf8Value();
351
496
  int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
352
-
353
497
  TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
354
498
  if (!op)
355
499
  return env.Null();
356
500
 
357
501
  TF_Output out{op, idx};
358
- StatusGuard status;
359
- int ndims = TF_GraphGetTensorNumDims(graph_, out, status.s);
360
- if (!status.ok() || ndims < 0)
502
+ StatusGuard s1;
503
+ int ndims = TF_GraphGetTensorNumDims(graph_, out, s1.s);
504
+ if (!s1.ok() || ndims < 0)
361
505
  return env.Null();
362
506
 
363
507
  std::vector<int64_t> dims(ndims, -1);
364
- StatusGuard shape_status;
365
- TF_GraphGetTensorShape(graph_, out, dims.data(), ndims, shape_status.s);
508
+ StatusGuard s2;
509
+ TF_GraphGetTensorShape(graph_, out, dims.data(), ndims, s2.s);
366
510
 
367
511
  Napi::Array arr = Napi::Array::New(env, ndims);
368
512
  for (int i = 0; i < ndims; ++i)
@@ -379,19 +523,17 @@ Napi::Value GraphWrap::ToGraphDef(const Napi::CallbackInfo &info)
379
523
  StatusGuard status;
380
524
  TF_Buffer *buf = TF_NewBuffer();
381
525
  TF_GraphToGraphDef(graph_, buf, status.s);
382
-
383
526
  if (!status.ok())
384
527
  {
385
528
  TF_DeleteBuffer(buf);
386
- Napi::Error::New(env, "TF_GraphToGraphDef failed: " + status.message())
529
+ Napi::Error::New(env,
530
+ "TF_GraphToGraphDef failed: " + status.message())
387
531
  .ThrowAsJavaScriptException();
388
532
  return env.Undefined();
389
533
  }
390
-
391
534
  auto node_buf = Napi::Buffer<uint8_t>::Copy(
392
535
  env,
393
- reinterpret_cast<const uint8_t *>(buf->data),
394
- buf->length);
536
+ reinterpret_cast<const uint8_t *>(buf->data), buf->length);
395
537
  TF_DeleteBuffer(buf);
396
538
  return node_buf;
397
539
  }
@@ -407,31 +549,14 @@ Napi::Value GraphWrap::NumOps(const Napi::CallbackInfo &info)
407
549
  return Napi::Number::New(info.Env(), count);
408
550
  }
409
551
 
410
- // ---------------------------------------------------------------
411
- // importGraphDef(buffer: Buffer): void
412
- //
413
- // Deserialises a binary GraphDef proto into this graph.
414
- // Intended use: load a frozen .pb model so the graph can be
415
- // executed via the native Session (with ConfigProto + affinity).
416
- //
417
- // The graph must be empty before calling this — importing into a
418
- // non-empty graph will produce op name collisions.
419
- //
420
- // TF_GraphImportGraphDef uses a prefix option (default "")
421
- // so all imported op names are used as-is, matching the op
422
- // names in the original frozen graph.
423
- // ---------------------------------------------------------------
424
552
  Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
425
553
  {
426
554
  Napi::Env env = info.Env();
427
-
428
555
  if (!graph_)
429
556
  {
430
- Napi::Error::New(env, "Graph destroyed")
431
- .ThrowAsJavaScriptException();
557
+ Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
432
558
  return env.Undefined();
433
559
  }
434
-
435
560
  if (info.Length() < 1 || !info[0].IsBuffer())
436
561
  {
437
562
  Napi::TypeError::New(env, "importGraphDef(buffer: Buffer)")
@@ -440,138 +565,114 @@ Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
440
565
  }
441
566
 
442
567
  auto buf = info[0].As<Napi::Buffer<uint8_t>>();
443
-
444
- // TF_NewBufferFromString copies the bytes into TF-owned memory.
445
- // The JS Buffer does not need to outlive this call.
446
- TF_Buffer *graphdef =
447
- TF_NewBufferFromString(buf.Data(), buf.ByteLength());
448
-
568
+ TF_Buffer *graphdef = TF_NewBufferFromString(buf.Data(), buf.ByteLength());
449
569
  TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
450
- // Default prefix is "" — op names in the imported graph are
451
- // used verbatim, matching the frozen graph's checkpoint keys.
452
570
  TF_ImportGraphDefOptionsSetPrefix(opts, "");
453
571
 
454
572
  StatusGuard status;
455
573
  TF_GraphImportGraphDef(graph_, graphdef, opts, status.s);
456
-
457
574
  TF_DeleteImportGraphDefOptions(opts);
458
575
  TF_DeleteBuffer(graphdef);
459
576
 
460
577
  if (!status.ok())
461
- {
462
578
  Napi::Error::New(env,
463
579
  "TF_GraphImportGraphDef failed: " + status.message())
464
580
  .ThrowAsJavaScriptException();
465
- }
466
581
  return env.Undefined();
467
582
  }
468
583
 
469
- // ---------------------------------------------------------------------------
470
- // addGradients — JS signature:
471
- // addGradients(
472
- // y: { opName: string; index: number }[], // loss outputs
473
- // x: { opName: string; index: number }[], // inputs to diff w.r.t.
474
- // dx?: { opName: string; index: number }[], // initial upstream grads
475
- // ) -> { opName: string; index: number }[] // gradient tensors, len = |x|
476
- //
477
- // Wraps TF_AddGradients. TF injects gradient ops directly into the graph.
478
- // The returned outputs are the dL/dx_i tensors, one per x_i.
479
- //
480
- // dx defaults to ones (i.e. dL/dy = 1 for scalar loss, standard convention).
481
- // Pass explicit dx when you need to chain gradients or scale the loss.
482
- //
483
- // Throws if any op in the path is non-differentiable (e.g. ArgMax).
484
- // ---------------------------------------------------------------------------
485
- Napi::Value GraphWrap::AddGradients(const Napi::CallbackInfo &info)
584
+ Napi::Value GraphWrap::ListOpsOfType(const Napi::CallbackInfo &info)
486
585
  {
487
586
  Napi::Env env = info.Env();
488
-
489
- if (!graph_)
490
- {
491
- Napi::Error::New(env, "Graph destroyed").ThrowAsJavaScriptException();
492
- return env.Undefined();
493
- }
494
-
495
- if (info.Length() < 2 || !info[0].IsArray() || !info[1].IsArray())
587
+ if (!graph_ || !info[0].IsString())
496
588
  {
497
- Napi::TypeError::New(env, "addGradients(y, x, dx?)")
498
- .ThrowAsJavaScriptException();
499
- return env.Undefined();
589
+ return Napi::Array::New(env, 0);
500
590
  }
591
+ std::string target_type = info[0].As<Napi::String>().Utf8Value();
501
592
 
502
- auto resolve_outputs = [&](Napi::Array arr,
503
- std::vector<TF_Output> &out,
504
- const std::string &label) -> bool
593
+ std::vector<std::string> matches;
594
+ size_t pos = 0;
595
+ TF_Operation *op = nullptr;
596
+ while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
505
597
  {
506
- for (uint32_t i = 0; i < arr.Length(); ++i)
598
+ const char *op_type = TF_OperationOpType(op);
599
+ if (op_type && target_type == op_type)
507
600
  {
508
- auto obj = arr.Get(i).As<Napi::Object>();
509
- std::string name = obj.Get("opName").As<Napi::String>().Utf8Value();
510
- int idx = obj.Get("index").As<Napi::Number>().Int32Value();
511
- TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
512
- if (!op)
513
- {
514
- Napi::Error::New(env, label + " op not found: " + name)
515
- .ThrowAsJavaScriptException();
516
- return false;
517
- }
518
- out.push_back({op, idx});
601
+ matches.push_back(TF_OperationName(op));
519
602
  }
520
- return true;
521
- };
603
+ }
522
604
 
523
- std::vector<TF_Output> y_vec, x_vec, dx_vec;
605
+ Napi::Array result = Napi::Array::New(env, matches.size());
606
+ for (size_t i = 0; i < matches.size(); ++i)
607
+ result.Set(static_cast<uint32_t>(i), Napi::String::New(env, matches[i]));
608
+ return result;
609
+ }
524
610
 
525
- if (!resolve_outputs(info[0].As<Napi::Array>(), y_vec, "y"))
526
- return env.Undefined();
527
- if (!resolve_outputs(info[1].As<Napi::Array>(), x_vec, "x"))
528
- return env.Undefined();
611
+ Napi::Value GraphWrap::ListSinkOps(const Napi::CallbackInfo &info)
612
+ {
613
+ Napi::Env env = info.Env();
614
+ if (!graph_)
615
+ return Napi::Array::New(env, 0);
616
+
617
+ // Op types that must never be returned as output ops.
618
+ // These are control/infrastructure nodes that appear in frozen SavedModel
619
+ // graphs but have zero outputs or are not tensor-producing ops.
620
+ // NoOp — zero outputs, used as a control dependency fence
621
+ // VarHandleOp — resource handle for a variable, not a model output
622
+ // Placeholder — graph inputs, explicitly not outputs
623
+ // Const — constant initializer values, not inference outputs
624
+ // AssignVariableOp / ReadVariableOp — variable lifecycle ops
625
+ static const std::unordered_set<std::string> excluded_types = {
626
+ "NoOp",
627
+ "VarHandleOp",
628
+ "Placeholder",
629
+ "Const",
630
+ "AssignVariableOp",
631
+ "ReadVariableOp",
632
+ "SaveV2",
633
+ "RestoreV2",
634
+ "MergeV2Checkpoints",
635
+ "StringJoin",
636
+ "ShardedFilename",
637
+ "_Arg",
638
+ "_Retval",
639
+ };
529
640
 
530
- // Optional initial gradients.
531
- TF_Output *dx_ptr = nullptr;
532
- if (info.Length() >= 3 && info[2].IsArray())
641
+ // Collect every op that is consumed as a data input by some other op.
642
+ std::unordered_set<TF_Operation *> consumed;
643
+
644
+ size_t pos = 0;
645
+ TF_Operation *op = nullptr;
646
+ while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
533
647
  {
534
- if (!resolve_outputs(info[2].As<Napi::Array>(), dx_vec, "dx"))
535
- return env.Undefined();
536
- if (dx_vec.size() != y_vec.size())
648
+ int num_inputs = TF_OperationNumInputs(op);
649
+ for (int i = 0; i < num_inputs; ++i)
537
650
  {
538
- Napi::Error::New(env,
539
- "addGradients: dx length must equal y length")
540
- .ThrowAsJavaScriptException();
541
- return env.Undefined();
651
+ TF_Output src = TF_OperationInput({op, i});
652
+ consumed.insert(src.oper);
542
653
  }
543
- dx_ptr = dx_vec.data();
544
654
  }
545
655
 
546
- // Allocate output array TF_AddGradients writes one gradient per x.
547
- std::vector<TF_Output> dy(x_vec.size());
548
-
549
- StatusGuard status;
550
- TF_AddGradients(
551
- graph_,
552
- y_vec.data(), static_cast<int>(y_vec.size()),
553
- x_vec.data(), static_cast<int>(x_vec.size()),
554
- dx_ptr,
555
- status.s,
556
- dy.data());
557
-
558
- if (!status.ok())
656
+ // A sink op is one that:
657
+ // 1. Is not consumed as an input by any other op
658
+ // 2. Has at least one output (excludes NoOp and other zero-output nodes)
659
+ // 3. Is not an excluded infrastructure op type
660
+ std::vector<std::string> sinks;
661
+ pos = 0;
662
+ while ((op = TF_GraphNextOperation(graph_, &pos)) != nullptr)
559
663
  {
560
- Napi::Error::New(env,
561
- "TF_AddGradients failed: " + status.message())
562
- .ThrowAsJavaScriptException();
563
- return env.Undefined();
664
+ if (consumed.find(op) != consumed.end())
665
+ continue;
666
+ if (TF_OperationNumOutputs(op) < 1)
667
+ continue;
668
+ const char *op_type = TF_OperationOpType(op);
669
+ if (op_type && excluded_types.count(op_type))
670
+ continue;
671
+ sinks.push_back(TF_OperationName(op));
564
672
  }
565
673
 
566
- // Return array of { opName, index } — same shape as x.
567
- Napi::Array result = Napi::Array::New(env, dy.size());
568
- for (size_t i = 0; i < dy.size(); ++i)
569
- {
570
- const char *raw_name = TF_OperationName(dy[i].oper);
571
- Napi::Object obj = Napi::Object::New(env);
572
- obj.Set("opName", Napi::String::New(env, raw_name ? raw_name : ""));
573
- obj.Set("index", Napi::Number::New(env, dy[i].index));
574
- result.Set(static_cast<uint32_t>(i), obj);
575
- }
674
+ Napi::Array result = Napi::Array::New(env, sinks.size());
675
+ for (size_t i = 0; i < sinks.size(); ++i)
676
+ result.Set(static_cast<uint32_t>(i), Napi::String::New(env, sinks[i]));
576
677
  return result;
577
678
  }