@isidorus/cpu 0.0.0-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +47 -0
- package/binding.gyp +103 -0
- package/dist/ts/_native.d.ts +13 -0
- package/dist/ts/_native.d.ts.map +1 -0
- package/dist/ts/_native.js +22 -0
- package/dist/ts/_native.js.map +1 -0
- package/dist/ts/graph.d.ts +91 -0
- package/dist/ts/graph.d.ts.map +1 -0
- package/dist/ts/graph.js +95 -0
- package/dist/ts/graph.js.map +1 -0
- package/dist/ts/index.d.ts +47 -0
- package/dist/ts/index.d.ts.map +1 -0
- package/dist/ts/index.js +58 -0
- package/dist/ts/index.js.map +1 -0
- package/dist/ts/inference-pool.d.ts +84 -0
- package/dist/ts/inference-pool.d.ts.map +1 -0
- package/dist/ts/inference-pool.js +625 -0
- package/dist/ts/inference-pool.js.map +1 -0
- package/dist/ts/inference_pool.d.ts +99 -0
- package/dist/ts/inference_pool.d.ts.map +1 -0
- package/dist/ts/inference_pool.js +370 -0
- package/dist/ts/inference_pool.js.map +1 -0
- package/dist/ts/install-libtensorflow.d.ts +34 -0
- package/dist/ts/install-libtensorflow.d.ts.map +1 -0
- package/dist/ts/install-libtensorflow.js +254 -0
- package/dist/ts/install-libtensorflow.js.map +1 -0
- package/dist/ts/ops/array_ops.d.ts +29 -0
- package/dist/ts/ops/array_ops.d.ts.map +1 -0
- package/dist/ts/ops/array_ops.js +54 -0
- package/dist/ts/ops/array_ops.js.map +1 -0
- package/dist/ts/ops/index.d.ts +5 -0
- package/dist/ts/ops/index.d.ts.map +1 -0
- package/dist/ts/ops/index.js +5 -0
- package/dist/ts/ops/index.js.map +1 -0
- package/dist/ts/ops/math_ops.d.ts +96 -0
- package/dist/ts/ops/math_ops.d.ts.map +1 -0
- package/dist/ts/ops/math_ops.js +277 -0
- package/dist/ts/ops/math_ops.js.map +1 -0
- package/dist/ts/ops/nn_ops.d.ts +130 -0
- package/dist/ts/ops/nn_ops.d.ts.map +1 -0
- package/dist/ts/ops/nn_ops.js +340 -0
- package/dist/ts/ops/nn_ops.js.map +1 -0
- package/dist/ts/ops/variable_ops.d.ts +128 -0
- package/dist/ts/ops/variable_ops.d.ts.map +1 -0
- package/dist/ts/ops/variable_ops.js +267 -0
- package/dist/ts/ops/variable_ops.js.map +1 -0
- package/dist/ts/session.d.ts +83 -0
- package/dist/ts/session.d.ts.map +1 -0
- package/dist/ts/session.js +81 -0
- package/dist/ts/session.js.map +1 -0
- package/package.json +63 -0
- package/scripts/install.js +100 -0
- package/scripts/test-install.js +82 -0
- package/scripts/test.js +45 -0
- package/src/native/addon.cc +12 -0
- package/src/native/graph.cc +442 -0
- package/src/native/graph.h +52 -0
- package/src/native/platform_tf.h +8 -0
- package/src/native/session.cc +716 -0
- package/src/native/session.h +92 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Test installation behavior
|
|
5
|
+
* Usage: node scripts/test-install.js
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { arch, platform } from "os";
|
|
9
|
+
import { existsSync, rmSync } from "fs";
|
|
10
|
+
import { execSync } from "child_process";
|
|
11
|
+
import { join, dirname } from "path";
|
|
12
|
+
import { fileURLToPath } from "url";
|
|
13
|
+
|
|
14
|
+
const __filename = fileURLToPath(import.meta.url);
|
|
15
|
+
const __dirname = dirname(__filename);
|
|
16
|
+
const projectRoot = join(__dirname, "..");
|
|
17
|
+
|
|
18
|
+
console.log("\nTesting jude-map installation behavior\n");
|
|
19
|
+
console.log(`Platform: ${platform()}`);
|
|
20
|
+
console.log(`Architecture: ${arch()}`);
|
|
21
|
+
|
|
22
|
+
const isArm = arch().includes("arm") || arch().includes("aarch");
|
|
23
|
+
const buildDir = join(projectRoot, "build");
|
|
24
|
+
const prebuildsDir = join(projectRoot, "prebuilds");
|
|
25
|
+
|
|
26
|
+
console.log(
|
|
27
|
+
`\nBuild directory: ${existsSync(buildDir) ? " exists" : " missing"}`,
|
|
28
|
+
);
|
|
29
|
+
console.log(
|
|
30
|
+
`Prebuilds directory: ${existsSync(prebuildsDir) ? " exists" : " missing"}`,
|
|
31
|
+
);
|
|
32
|
+
|
|
33
|
+
console.log("\n Test Options:");
|
|
34
|
+
console.log("1. Test with prebuilds (simulate x64 user)");
|
|
35
|
+
console.log("2. Test without prebuilds (force compilation)");
|
|
36
|
+
console.log("3. Test current state");
|
|
37
|
+
console.log("4. Clean all build artifacts");
|
|
38
|
+
|
|
39
|
+
// For simplicity, run test 3 (current state)
|
|
40
|
+
console.log("\n Running test: Current state\n");
|
|
41
|
+
|
|
42
|
+
try {
|
|
43
|
+
// Try to load the module
|
|
44
|
+
const startTime = Date.now();
|
|
45
|
+
|
|
46
|
+
console.log("Attempting to load native addon...");
|
|
47
|
+
|
|
48
|
+
execSync("node scripts/postinstall-verify.js", {
|
|
49
|
+
cwd: projectRoot,
|
|
50
|
+
stdio: "inherit",
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
const elapsed = Date.now() - startTime;
|
|
54
|
+
console.log(`\n Success! Loaded in ${elapsed}ms`);
|
|
55
|
+
|
|
56
|
+
if (isArm) {
|
|
57
|
+
console.log("\n ARM detected: Module was compiled locally");
|
|
58
|
+
} else {
|
|
59
|
+
if (existsSync(prebuildsDir)) {
|
|
60
|
+
console.log("\n x64 detected: Module loaded from prebuilds");
|
|
61
|
+
} else {
|
|
62
|
+
console.log(
|
|
63
|
+
"\n💡 x64 detected: Module was compiled (no prebuilds available)",
|
|
64
|
+
);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
} catch (error) {
|
|
68
|
+
console.error("\nFailed to load module");
|
|
69
|
+
console.error(error.message);
|
|
70
|
+
process.exit(1);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
console.log("\nInstallation test complete!\n");
|
|
74
|
+
|
|
75
|
+
// Show next steps
|
|
76
|
+
console.log("Next steps:");
|
|
77
|
+
console.log(' - Run "npm test" to verify functionality');
|
|
78
|
+
console.log(" - Check PREBUILDS.md for prebuild generation");
|
|
79
|
+
if (!existsSync(prebuildsDir)) {
|
|
80
|
+
console.log(' - Run "npm run prebuildify" to generate prebuilds');
|
|
81
|
+
}
|
|
82
|
+
console.log("");
|
package/scripts/test.js
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { readdirSync } from "fs";
|
|
2
|
+
import { join } from "path";
|
|
3
|
+
import { spawn } from "child_process";
|
|
4
|
+
import { fileURLToPath } from "url";
|
|
5
|
+
import { dirname } from "path";
|
|
6
|
+
|
|
7
|
+
const __dirname = dirname(fileURLToPath(import.meta.url));
|
|
8
|
+
const testDir = join(__dirname, "../src/ts/__tests__");
|
|
9
|
+
|
|
10
|
+
// Get all test files
|
|
11
|
+
const testFiles = readdirSync(testDir)
|
|
12
|
+
.filter((file) => file.endsWith(".test.ts"))
|
|
13
|
+
.map((file) => join(testDir, file));
|
|
14
|
+
|
|
15
|
+
// Run node test runner with all test files serially to avoid cross-file
|
|
16
|
+
// native teardown races while still preserving all test coverage.
|
|
17
|
+
const args = [
|
|
18
|
+
"--import",
|
|
19
|
+
"tsx",
|
|
20
|
+
"--test",
|
|
21
|
+
"--test-concurrency=1",
|
|
22
|
+
...testFiles,
|
|
23
|
+
];
|
|
24
|
+
|
|
25
|
+
const env = { ...process.env };
|
|
26
|
+
if (process.platform === "win32") {
|
|
27
|
+
const libtfPath = env.LIBTENSORFLOW_PATH || "C:\\libtensorflow";
|
|
28
|
+
const dllDir = libtfPath.toLowerCase().endsWith("\\lib")
|
|
29
|
+
? libtfPath
|
|
30
|
+
: join(libtfPath, "lib");
|
|
31
|
+
const currentPath = env.Path || env.PATH || "";
|
|
32
|
+
const nextPath = `${dllDir};${currentPath}`;
|
|
33
|
+
env.Path = nextPath;
|
|
34
|
+
env.PATH = nextPath;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
const child = spawn("node", args, {
|
|
38
|
+
stdio: "inherit",
|
|
39
|
+
shell: true,
|
|
40
|
+
env,
|
|
41
|
+
});
|
|
42
|
+
|
|
43
|
+
child.on("exit", (code) => {
|
|
44
|
+
process.exit(code);
|
|
45
|
+
});
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#include <napi.h>
|
|
2
|
+
#include "graph.h"
|
|
3
|
+
#include "session.h"
|
|
4
|
+
|
|
5
|
+
Napi::Object InitAll(Napi::Env env, Napi::Object exports)
|
|
6
|
+
{
|
|
7
|
+
GraphWrap::Init(env, exports);
|
|
8
|
+
SessionWrap::Init(env, exports);
|
|
9
|
+
return exports;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
NODE_API_MODULE(isidorus_cpu, InitAll)
|
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
#include "graph.h"
|
|
2
|
+
#include <cstring>
|
|
3
|
+
#include <stdexcept>
|
|
4
|
+
#include <unordered_map>
|
|
5
|
+
|
|
6
|
+
// -------------------------------------------------------------------
|
|
7
|
+
// GraphWrap - N-API ObjectWrap around TF_Graph
|
|
8
|
+
//
|
|
9
|
+
// Exposes graph construction primitives:
|
|
10
|
+
// addOp(type, inputs, attrs) -> {opName, numOutputs}
|
|
11
|
+
// hasOp(name) -> boolean
|
|
12
|
+
// opOutputType(name, index) -> DType integer
|
|
13
|
+
// opOutputShape(name, index) -> number[] (-1 = unknown dim)
|
|
14
|
+
// toGraphDef() -> Buffer (serialized GraphDef proto)
|
|
15
|
+
// numOps() -> number
|
|
16
|
+
//-------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
Napi::Object GraphWrap::Init(Napi::Env env, Napi::Object exports)
|
|
19
|
+
{
|
|
20
|
+
Napi::Function func = DefineClass(env, "Graph", {
|
|
21
|
+
InstanceMethod<&GraphWrap::AddOp>("addOp"),
|
|
22
|
+
InstanceMethod<&GraphWrap::HasOp>("hasOp"),
|
|
23
|
+
InstanceMethod<&GraphWrap::OpOutputType>("opOutputType"),
|
|
24
|
+
InstanceMethod<&GraphWrap::OpOutputShape>("opOutputShape"),
|
|
25
|
+
InstanceMethod<&GraphWrap::ToGraphDef>("toGraphDef"),
|
|
26
|
+
InstanceMethod<&GraphWrap::NumOps>("numOps"),
|
|
27
|
+
InstanceMethod<&GraphWrap::ImportGraphDef>("importGraphDef"),
|
|
28
|
+
});
|
|
29
|
+
auto *ctor = new Napi::FunctionReference(Napi::Persistent(func));
|
|
30
|
+
env.SetInstanceData<Napi::FunctionReference>(ctor);
|
|
31
|
+
exports.Set("Graph", func);
|
|
32
|
+
return exports;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
GraphWrap::GraphWrap(const Napi::CallbackInfo &info) : Napi::ObjectWrap<GraphWrap>(info)
|
|
36
|
+
{
|
|
37
|
+
graph_ = TF_NewGraph();
|
|
38
|
+
if (!graph_)
|
|
39
|
+
Napi::Error::New(info.Env(), "Failed to create TF_Graph").ThrowAsJavaScriptException();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
GraphWrap::~GraphWrap()
|
|
43
|
+
{
|
|
44
|
+
if (graph_)
|
|
45
|
+
{
|
|
46
|
+
TF_DeleteGraph(graph_);
|
|
47
|
+
graph_ = nullptr;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// ---------------------------------------------------------------
|
|
52
|
+
// addOp(type: string, inputs: {opName: string, index: number}[],
|
|
53
|
+
// attrs: Record<string, AttrValue>) -> { opName: string, numOutputs : number }
|
|
54
|
+
//
|
|
55
|
+
// AttrValue is one of:
|
|
56
|
+
// { kind: "int", value: number }
|
|
57
|
+
// { kind: "float", value: number }
|
|
58
|
+
// { kind: "bool", value: boolean }
|
|
59
|
+
// { kind: "type", value: number } <- DType integer
|
|
60
|
+
// { kind: "shape", value: number[] } <- -1 for unknown dim
|
|
61
|
+
// { kind: "shape", value: { dtype, shape, data: Buffer } }
|
|
62
|
+
// { kind: "list_type", value: number[] }
|
|
63
|
+
// { kind: "list_shape", value: number[] }
|
|
64
|
+
// ---------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
Napi::Value GraphWrap::AddOp(const Napi::CallbackInfo &info)
|
|
67
|
+
{
|
|
68
|
+
Napi::Env env = info.Env();
|
|
69
|
+
|
|
70
|
+
if (!graph_)
|
|
71
|
+
{
|
|
72
|
+
Napi::Error::New(env, "Graph has been destroyed").ThrowAsJavaScriptException();
|
|
73
|
+
return env.Undefined();
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray())
|
|
77
|
+
{
|
|
78
|
+
Napi::TypeError::New(env, "addOp(type: string, inputs: TFOutput[], attrs?)").ThrowAsJavaScriptException();
|
|
79
|
+
return env.Undefined();
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
std::string op_type = info[0].As<Napi::String>().Utf8Value();
|
|
83
|
+
|
|
84
|
+
// Auto-generate a unique op name: type + "_" + counter.
|
|
85
|
+
std::string op_name = op_type + "_" + std::to_string(op_counter_++);
|
|
86
|
+
// Allow caller to override the name via optional 3rd arg.
|
|
87
|
+
if (info.Length() >= 4 && info[3].IsString())
|
|
88
|
+
op_name = info[3].As<Napi::String>().Utf8Value();
|
|
89
|
+
|
|
90
|
+
// Resolve inputs first so we can retry op construction with TF_AddInputList
|
|
91
|
+
// for list-input ops (e.g. IdentityN) when needed.
|
|
92
|
+
std::vector<TF_Output> resolved_inputs;
|
|
93
|
+
auto inputs_arr = info[1].As<Napi::Array>();
|
|
94
|
+
resolved_inputs.reserve(inputs_arr.Length());
|
|
95
|
+
for (uint32_t i = 0; i < inputs_arr.Length(); i++)
|
|
96
|
+
{
|
|
97
|
+
auto input_obj = inputs_arr.Get(i).As<Napi::Object>();
|
|
98
|
+
std::string input_op_name = input_obj.Get("opName").As<Napi::String>().Utf8Value();
|
|
99
|
+
int input_idx = input_obj.Get("index").As<Napi::Number>().Int32Value();
|
|
100
|
+
|
|
101
|
+
TF_Operation *input_op = TF_GraphOperationByName(graph_, input_op_name.c_str());
|
|
102
|
+
if (!input_op)
|
|
103
|
+
{
|
|
104
|
+
Napi::Error::New(env, "Input op not found: " + input_op_name).ThrowAsJavaScriptException();
|
|
105
|
+
return env.Undefined();
|
|
106
|
+
}
|
|
107
|
+
resolved_inputs.push_back({input_op, input_idx});
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
auto apply_attrs = [&](TF_OperationDescription *desc) -> bool
|
|
111
|
+
{
|
|
112
|
+
if (!(info.Length() >= 3 && info[2].IsObject()))
|
|
113
|
+
return true;
|
|
114
|
+
|
|
115
|
+
enum class AttrKind
|
|
116
|
+
{
|
|
117
|
+
Int,
|
|
118
|
+
Float,
|
|
119
|
+
Bool,
|
|
120
|
+
Type,
|
|
121
|
+
Shape,
|
|
122
|
+
ListType,
|
|
123
|
+
ListInt,
|
|
124
|
+
Tensor,
|
|
125
|
+
Unknown
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
static const std::unordered_map<std::string, AttrKind> kind_map = {
|
|
129
|
+
{"int", AttrKind::Int},
|
|
130
|
+
{"float", AttrKind::Float},
|
|
131
|
+
{"bool", AttrKind::Bool},
|
|
132
|
+
{"type", AttrKind::Type},
|
|
133
|
+
{"shape", AttrKind::Shape},
|
|
134
|
+
{"list_type", AttrKind::ListType},
|
|
135
|
+
{"list_int", AttrKind::ListInt},
|
|
136
|
+
{"tensor", AttrKind::Tensor}};
|
|
137
|
+
|
|
138
|
+
auto attrs = info[2].As<Napi::Object>();
|
|
139
|
+
auto attrs_keys = attrs.GetPropertyNames();
|
|
140
|
+
for (uint32_t i = 0; i < attrs_keys.Length(); i++)
|
|
141
|
+
{
|
|
142
|
+
std::string attr_name = attrs_keys.Get(i).As<Napi::String>().Utf8Value();
|
|
143
|
+
auto attr_val = attrs.Get(attr_name).As<Napi::Object>();
|
|
144
|
+
std::string kind_str = attr_val.Get("kind").As<Napi::String>().Utf8Value();
|
|
145
|
+
auto it = kind_map.find(kind_str);
|
|
146
|
+
AttrKind kind = (it != kind_map.end()) ? it->second : AttrKind::Unknown;
|
|
147
|
+
switch (kind)
|
|
148
|
+
{
|
|
149
|
+
case AttrKind::Int:
|
|
150
|
+
{
|
|
151
|
+
int64_t v = static_cast<int64_t>(attr_val.Get("value").As<Napi::Number>().Int64Value());
|
|
152
|
+
TF_SetAttrInt(desc, attr_name.c_str(), v);
|
|
153
|
+
break;
|
|
154
|
+
}
|
|
155
|
+
case AttrKind::Float:
|
|
156
|
+
{
|
|
157
|
+
float f = attr_val.Get("value").As<Napi::Number>().FloatValue();
|
|
158
|
+
TF_SetAttrFloat(desc, attr_name.c_str(), f);
|
|
159
|
+
break;
|
|
160
|
+
}
|
|
161
|
+
case AttrKind::Bool:
|
|
162
|
+
{
|
|
163
|
+
unsigned char b = attr_val.Get("value").As<Napi::Boolean>().Value() ? 1
|
|
164
|
+
: 0;
|
|
165
|
+
TF_SetAttrBool(desc, attr_name.c_str(), b);
|
|
166
|
+
break;
|
|
167
|
+
}
|
|
168
|
+
case AttrKind::Type:
|
|
169
|
+
{
|
|
170
|
+
TF_DataType v = static_cast<TF_DataType>(
|
|
171
|
+
attr_val.Get("value").As<Napi::Number>().Int32Value());
|
|
172
|
+
TF_SetAttrType(desc, attr_name.c_str(), v);
|
|
173
|
+
break;
|
|
174
|
+
}
|
|
175
|
+
case AttrKind::Shape:
|
|
176
|
+
{
|
|
177
|
+
auto dims_arr = attr_val.Get("value").As<Napi::Array>();
|
|
178
|
+
std::vector<int64_t> dims(dims_arr.Length());
|
|
179
|
+
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
180
|
+
dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
181
|
+
TF_SetAttrShape(desc, attr_name.c_str(),
|
|
182
|
+
dims.data(), static_cast<int>(dims.size()));
|
|
183
|
+
break;
|
|
184
|
+
}
|
|
185
|
+
case AttrKind::ListType:
|
|
186
|
+
{
|
|
187
|
+
auto vals = attr_val.Get("value").As<Napi::Array>();
|
|
188
|
+
std::vector<TF_DataType> types(vals.Length());
|
|
189
|
+
for (uint32_t j = 0; j < vals.Length(); ++j)
|
|
190
|
+
types[j] = static_cast<TF_DataType>(vals.Get(j).As<Napi::Number>().Int32Value());
|
|
191
|
+
TF_SetAttrTypeList(desc, attr_name.c_str(), types.data(), static_cast<int>(types.size()));
|
|
192
|
+
break;
|
|
193
|
+
}
|
|
194
|
+
case AttrKind::ListInt:
|
|
195
|
+
{
|
|
196
|
+
auto vals_int = attr_val.Get("value").As<Napi::Array>();
|
|
197
|
+
std::vector<int64_t> ints(vals_int.Length());
|
|
198
|
+
for (uint32_t j = 0; j < vals_int.Length(); ++j)
|
|
199
|
+
ints[j] = static_cast<int64_t>(vals_int.Get(j).As<Napi::Number>().Int64Value());
|
|
200
|
+
TF_SetAttrIntList(desc, attr_name.c_str(), ints.data(), static_cast<int>(ints.size()));
|
|
201
|
+
break;
|
|
202
|
+
}
|
|
203
|
+
case AttrKind::Tensor:
|
|
204
|
+
{
|
|
205
|
+
// Inline constant tensor
|
|
206
|
+
auto tv = attr_val.Get("value").As<Napi::Object>();
|
|
207
|
+
TF_DataType dtype = static_cast<TF_DataType>(tv.Get("dtype").As<Napi::Number>().Int32Value());
|
|
208
|
+
auto data_buf = tv.Get("data").As<Napi::Buffer<uint8_t>>();
|
|
209
|
+
auto dims_arr = tv.Get("shape").As<Napi::Array>();
|
|
210
|
+
|
|
211
|
+
std::vector<int64_t> dims(dims_arr.Length());
|
|
212
|
+
for (uint32_t j = 0; j < dims_arr.Length(); ++j)
|
|
213
|
+
dims[j] = static_cast<int64_t>(dims_arr.Get(j).As<Napi::Number>().Int64Value());
|
|
214
|
+
|
|
215
|
+
StatusGuard ts;
|
|
216
|
+
TF_Tensor *tensor = TF_AllocateTensor(dtype, dims.data(), static_cast<int>(dims.size()), data_buf.Length());
|
|
217
|
+
if (tensor)
|
|
218
|
+
{
|
|
219
|
+
std::memcpy(TF_TensorData(tensor), data_buf.Data(), data_buf.ByteLength());
|
|
220
|
+
TF_SetAttrTensor(desc, attr_name.c_str(), tensor, ts.s);
|
|
221
|
+
TF_DeleteTensor(tensor); // Graph takes ownership, safe to delete here.
|
|
222
|
+
}
|
|
223
|
+
break;
|
|
224
|
+
}
|
|
225
|
+
case AttrKind::Unknown:
|
|
226
|
+
default:
|
|
227
|
+
Napi::Error::New(env, "Unsupported attr kind: " + kind_str).ThrowAsJavaScriptException();
|
|
228
|
+
return false;
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
return true;
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
auto finish_op = [&](bool add_as_input_list, std::string &error_out) -> TF_Operation *
|
|
235
|
+
{
|
|
236
|
+
StatusGuard status;
|
|
237
|
+
TF_OperationDescription *desc = TF_NewOperation(graph_, op_type.c_str(), op_name.c_str());
|
|
238
|
+
if (!desc)
|
|
239
|
+
{
|
|
240
|
+
error_out = "TF_NewOperation failed for type " + op_type;
|
|
241
|
+
return nullptr;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
if (add_as_input_list)
|
|
245
|
+
{
|
|
246
|
+
if (!resolved_inputs.empty())
|
|
247
|
+
TF_AddInputList(desc, resolved_inputs.data(), static_cast<int>(resolved_inputs.size()));
|
|
248
|
+
}
|
|
249
|
+
else
|
|
250
|
+
{
|
|
251
|
+
for (const auto &input : resolved_inputs)
|
|
252
|
+
TF_AddInput(desc, input);
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
if (!apply_attrs(desc))
|
|
256
|
+
return nullptr;
|
|
257
|
+
|
|
258
|
+
TF_Operation *op = TF_FinishOperation(desc, status.s);
|
|
259
|
+
if (!status.ok() || !op)
|
|
260
|
+
{
|
|
261
|
+
error_out = status.message();
|
|
262
|
+
return nullptr;
|
|
263
|
+
}
|
|
264
|
+
return op;
|
|
265
|
+
};
|
|
266
|
+
|
|
267
|
+
std::string first_error;
|
|
268
|
+
TF_Operation *op = finish_op(false, first_error);
|
|
269
|
+
|
|
270
|
+
// Some ops have a single list-valued input arg and reject repeated
|
|
271
|
+
// TF_AddInput calls even for one element; retry with TF_AddInputList.
|
|
272
|
+
if (!op && first_error.find("expected list") != std::string::npos &&
|
|
273
|
+
!resolved_inputs.empty())
|
|
274
|
+
{
|
|
275
|
+
std::string retry_error;
|
|
276
|
+
op = finish_op(true, retry_error);
|
|
277
|
+
if (!op)
|
|
278
|
+
first_error = retry_error;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
if (!op)
|
|
282
|
+
{
|
|
283
|
+
if (!env.IsExceptionPending())
|
|
284
|
+
{
|
|
285
|
+
Napi::Error::New(env, "TF_FinishOperation failed for " + op_type + ": " + first_error).ThrowAsJavaScriptException();
|
|
286
|
+
}
|
|
287
|
+
return env.Undefined();
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
Napi::Object result = Napi::Object::New(env);
|
|
291
|
+
result.Set("opName", Napi::String::New(env, op_name));
|
|
292
|
+
result.Set("numOutputs", Napi::Number::New(env, static_cast<double>(TF_OperationNumOutputs(op))));
|
|
293
|
+
return result;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
Napi::Value GraphWrap::HasOp(const Napi::CallbackInfo &info)
|
|
297
|
+
{
|
|
298
|
+
if (!info[0].IsString())
|
|
299
|
+
return Napi::Boolean::New(info.Env(), false);
|
|
300
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
301
|
+
return Napi::Boolean::New(info.Env(), graph_ && TF_GraphOperationByName(graph_, name.c_str()) != nullptr);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
Napi::Value GraphWrap::OpOutputType(const Napi::CallbackInfo &info)
|
|
305
|
+
{
|
|
306
|
+
Napi::Env env = info.Env();
|
|
307
|
+
if (!graph_ || !info[0].IsString())
|
|
308
|
+
return env.Null();
|
|
309
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
310
|
+
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
311
|
+
|
|
312
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
313
|
+
if (!op)
|
|
314
|
+
return env.Null();
|
|
315
|
+
|
|
316
|
+
TF_DataType dtype = TF_OperationOutputType({op, idx});
|
|
317
|
+
return Napi::Number::New(env, static_cast<double>(dtype));
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
Napi::Value GraphWrap::OpOutputShape(const Napi::CallbackInfo &info)
|
|
321
|
+
{
|
|
322
|
+
Napi::Env env = info.Env();
|
|
323
|
+
if (!graph_ || !info[0].IsString())
|
|
324
|
+
return env.Null();
|
|
325
|
+
std::string name = info[0].As<Napi::String>().Utf8Value();
|
|
326
|
+
int idx = info.Length() >= 2 ? info[1].As<Napi::Number>().Int32Value() : 0;
|
|
327
|
+
|
|
328
|
+
TF_Operation *op = TF_GraphOperationByName(graph_, name.c_str());
|
|
329
|
+
if (!op)
|
|
330
|
+
return env.Null();
|
|
331
|
+
|
|
332
|
+
TF_Output out{op, idx};
|
|
333
|
+
StatusGuard status;
|
|
334
|
+
int ndims = TF_GraphGetTensorNumDims(graph_, out, status.s);
|
|
335
|
+
if (!status.ok() || ndims < 0)
|
|
336
|
+
return env.Null();
|
|
337
|
+
|
|
338
|
+
std::vector<int64_t> dims(ndims, -1);
|
|
339
|
+
StatusGuard shape_status;
|
|
340
|
+
TF_GraphGetTensorShape(graph_, out, dims.data(), ndims, shape_status.s);
|
|
341
|
+
|
|
342
|
+
Napi::Array arr = Napi::Array::New(env, ndims);
|
|
343
|
+
for (int i = 0; i < ndims; ++i)
|
|
344
|
+
arr.Set(i, Napi::Number::New(env, static_cast<double>(dims[i])));
|
|
345
|
+
return arr;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
Napi::Value GraphWrap::ToGraphDef(const Napi::CallbackInfo &info)
|
|
349
|
+
{
|
|
350
|
+
Napi::Env env = info.Env();
|
|
351
|
+
if (!graph_)
|
|
352
|
+
return env.Null();
|
|
353
|
+
|
|
354
|
+
StatusGuard status;
|
|
355
|
+
TF_Buffer *buf = TF_NewBuffer();
|
|
356
|
+
TF_GraphToGraphDef(graph_, buf, status.s);
|
|
357
|
+
|
|
358
|
+
if (!status.ok())
|
|
359
|
+
{
|
|
360
|
+
TF_DeleteBuffer(buf);
|
|
361
|
+
Napi::Error::New(env, "TF_GraphToGraphDef failed: " + status.message())
|
|
362
|
+
.ThrowAsJavaScriptException();
|
|
363
|
+
return env.Undefined();
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
auto node_buf = Napi::Buffer<uint8_t>::Copy(
|
|
367
|
+
env,
|
|
368
|
+
reinterpret_cast<const uint8_t *>(buf->data),
|
|
369
|
+
buf->length);
|
|
370
|
+
TF_DeleteBuffer(buf);
|
|
371
|
+
return node_buf;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
Napi::Value GraphWrap::NumOps(const Napi::CallbackInfo &info)
|
|
375
|
+
{
|
|
376
|
+
if (!graph_)
|
|
377
|
+
return Napi::Number::New(info.Env(), 0);
|
|
378
|
+
size_t pos = 0;
|
|
379
|
+
int count = 0;
|
|
380
|
+
while (TF_GraphNextOperation(graph_, &pos))
|
|
381
|
+
++count;
|
|
382
|
+
return Napi::Number::New(info.Env(), count);
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
// ---------------------------------------------------------------
|
|
386
|
+
// importGraphDef(buffer: Buffer): void
|
|
387
|
+
//
|
|
388
|
+
// Deserialises a binary GraphDef proto into this graph.
|
|
389
|
+
// Intended use: load a frozen .pb model so the graph can be
|
|
390
|
+
// executed via the native Session (with ConfigProto + affinity).
|
|
391
|
+
//
|
|
392
|
+
// The graph must be empty before calling this — importing into a
|
|
393
|
+
// non-empty graph will produce op name collisions.
|
|
394
|
+
//
|
|
395
|
+
// TF_GraphImportGraphDef uses a prefix option (default "")
|
|
396
|
+
// so all imported op names are used as-is, matching the op
|
|
397
|
+
// names in the original frozen graph.
|
|
398
|
+
// ---------------------------------------------------------------
|
|
399
|
+
Napi::Value GraphWrap::ImportGraphDef(const Napi::CallbackInfo &info)
|
|
400
|
+
{
|
|
401
|
+
Napi::Env env = info.Env();
|
|
402
|
+
|
|
403
|
+
if (!graph_)
|
|
404
|
+
{
|
|
405
|
+
Napi::Error::New(env, "Graph destroyed")
|
|
406
|
+
.ThrowAsJavaScriptException();
|
|
407
|
+
return env.Undefined();
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
if (info.Length() < 1 || !info[0].IsBuffer())
|
|
411
|
+
{
|
|
412
|
+
Napi::TypeError::New(env, "importGraphDef(buffer: Buffer)")
|
|
413
|
+
.ThrowAsJavaScriptException();
|
|
414
|
+
return env.Undefined();
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
auto buf = info[0].As<Napi::Buffer<uint8_t>>();
|
|
418
|
+
|
|
419
|
+
// TF_NewBufferFromString copies the bytes into TF-owned memory.
|
|
420
|
+
// The JS Buffer does not need to outlive this call.
|
|
421
|
+
TF_Buffer *graphdef =
|
|
422
|
+
TF_NewBufferFromString(buf.Data(), buf.ByteLength());
|
|
423
|
+
|
|
424
|
+
TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
|
|
425
|
+
// Default prefix is "" — op names in the imported graph are
|
|
426
|
+
// used verbatim, matching the frozen graph's checkpoint keys.
|
|
427
|
+
TF_ImportGraphDefOptionsSetPrefix(opts, "");
|
|
428
|
+
|
|
429
|
+
StatusGuard status;
|
|
430
|
+
TF_GraphImportGraphDef(graph_, graphdef, opts, status.s);
|
|
431
|
+
|
|
432
|
+
TF_DeleteImportGraphDefOptions(opts);
|
|
433
|
+
TF_DeleteBuffer(graphdef);
|
|
434
|
+
|
|
435
|
+
if (!status.ok())
|
|
436
|
+
{
|
|
437
|
+
Napi::Error::New(env,
|
|
438
|
+
"TF_GraphImportGraphDef failed: " + status.message())
|
|
439
|
+
.ThrowAsJavaScriptException();
|
|
440
|
+
}
|
|
441
|
+
return env.Undefined();
|
|
442
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include <napi.h>
|
|
3
|
+
#include <string>
|
|
4
|
+
#include <vector>
|
|
5
|
+
#include "platform_tf.h"
|
|
6
|
+
|
|
7
|
+
#ifndef ISIDORUS_STATUS_GUARD_DEFINED
|
|
8
|
+
#define ISIDORUS_STATUS_GUARD_DEFINED
|
|
9
|
+
struct StatusGuard
|
|
10
|
+
{
|
|
11
|
+
TF_Status *s;
|
|
12
|
+
StatusGuard() : s(TF_NewStatus()) {}
|
|
13
|
+
~StatusGuard() { TF_DeleteStatus(s); }
|
|
14
|
+
bool ok() const { return TF_GetCode(s) == TF_OK; }
|
|
15
|
+
std::string message() const { return TF_Message(s); }
|
|
16
|
+
};
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
class GraphWrap : public Napi::ObjectWrap<GraphWrap>
|
|
20
|
+
{
|
|
21
|
+
public:
|
|
22
|
+
static Napi::Object Init(Napi::Env env, Napi::Object exports);
|
|
23
|
+
|
|
24
|
+
explicit GraphWrap(const Napi::CallbackInfo &info);
|
|
25
|
+
~GraphWrap();
|
|
26
|
+
|
|
27
|
+
// Expose the raw graph pointer to Session and op helpers.
|
|
28
|
+
TF_Graph *GetGraph() const { return graph_; }
|
|
29
|
+
|
|
30
|
+
private:
|
|
31
|
+
TF_Graph *graph_ = nullptr;
|
|
32
|
+
int op_counter_ = 0;
|
|
33
|
+
|
|
34
|
+
Napi::Value AddOp(const Napi::CallbackInfo &info);
|
|
35
|
+
Napi::Value HasOp(const Napi::CallbackInfo &info);
|
|
36
|
+
Napi::Value OpOutputType(const Napi::CallbackInfo &info);
|
|
37
|
+
Napi::Value OpOutputShape(const Napi::CallbackInfo &info);
|
|
38
|
+
Napi::Value ToGraphDef(const Napi::CallbackInfo &info);
|
|
39
|
+
Napi::Value NumOps(const Napi::CallbackInfo &info);
|
|
40
|
+
|
|
41
|
+
// ── Frozen graph import ─────────────────────────────────────────────────
|
|
42
|
+
// importGraphDef(buffer: Buffer): void
|
|
43
|
+
//
|
|
44
|
+
// Deserialises a binary GraphDef proto (e.g. a frozen .pb file) into
|
|
45
|
+
// this graph. The graph must be empty — importing into a non-empty graph
|
|
46
|
+
// causes op name conflicts.
|
|
47
|
+
//
|
|
48
|
+
// Used by InferencePool's tf-parallel path to load frozen models via the
|
|
49
|
+
// native Session (with ConfigProto thread config + affinity fence) rather
|
|
50
|
+
// than through jude-tf.
|
|
51
|
+
Napi::Value ImportGraphDef(const Napi::CallbackInfo &info);
|
|
52
|
+
};
|