@weavelogic/knowledge-graph-agent 0.6.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +70 -3
  3. package/dist/_virtual/__vite-browser-external.js +2 -2
  4. package/dist/_virtual/__vite-browser-external.js.map +1 -1
  5. package/dist/_virtual/index12.js +7 -0
  6. package/dist/_virtual/index12.js.map +1 -0
  7. package/dist/_virtual/ort-web.min.js +8 -0
  8. package/dist/_virtual/ort-web.min.js.map +1 -0
  9. package/dist/_virtual/ort-web.min2.js +5 -0
  10. package/dist/_virtual/ort-web.min2.js.map +1 -0
  11. package/dist/agents/base-agent.d.ts +63 -0
  12. package/dist/agents/base-agent.d.ts.map +1 -1
  13. package/dist/agents/base-agent.js +139 -0
  14. package/dist/agents/base-agent.js.map +1 -1
  15. package/dist/agents/coordinator-agent.d.ts +422 -0
  16. package/dist/agents/coordinator-agent.d.ts.map +1 -0
  17. package/dist/agents/documenter-agent.d.ts +298 -0
  18. package/dist/agents/documenter-agent.d.ts.map +1 -0
  19. package/dist/agents/index.d.ts +11 -1
  20. package/dist/agents/index.d.ts.map +1 -1
  21. package/dist/agents/index.js +4 -0
  22. package/dist/agents/index.js.map +1 -1
  23. package/dist/agents/mixins/index.d.ts +9 -0
  24. package/dist/agents/mixins/index.d.ts.map +1 -0
  25. package/dist/agents/mixins/trajectory-mixin.d.ts +112 -0
  26. package/dist/agents/mixins/trajectory-mixin.d.ts.map +1 -0
  27. package/dist/agents/optimizer-agent.d.ts +388 -0
  28. package/dist/agents/optimizer-agent.d.ts.map +1 -0
  29. package/dist/agents/planner-agent.d.ts +395 -0
  30. package/dist/agents/planner-agent.d.ts.map +1 -0
  31. package/dist/agents/registry.d.ts.map +1 -1
  32. package/dist/agents/registry.js +5 -0
  33. package/dist/agents/registry.js.map +1 -1
  34. package/dist/agents/reviewer-agent.d.ts +330 -0
  35. package/dist/agents/reviewer-agent.d.ts.map +1 -0
  36. package/dist/agents/types.d.ts +12 -1
  37. package/dist/agents/types.d.ts.map +1 -1
  38. package/dist/agents/types.js +1 -0
  39. package/dist/agents/types.js.map +1 -1
  40. package/dist/cli/commands/hive-mind/add-frontmatter.d.ts +102 -0
  41. package/dist/cli/commands/hive-mind/add-frontmatter.d.ts.map +1 -0
  42. package/dist/cli/commands/hive-mind/add-frontmatter.js +439 -0
  43. package/dist/cli/commands/hive-mind/add-frontmatter.js.map +1 -0
  44. package/dist/cli/commands/hive-mind/analyze-links.d.ts +80 -0
  45. package/dist/cli/commands/hive-mind/analyze-links.d.ts.map +1 -0
  46. package/dist/cli/commands/hive-mind/analyze-links.js +367 -0
  47. package/dist/cli/commands/hive-mind/analyze-links.js.map +1 -0
  48. package/dist/cli/commands/hive-mind/find-connections.d.ts +75 -0
  49. package/dist/cli/commands/hive-mind/find-connections.d.ts.map +1 -0
  50. package/dist/cli/commands/hive-mind/find-connections.js +347 -0
  51. package/dist/cli/commands/hive-mind/find-connections.js.map +1 -0
  52. package/dist/cli/commands/hive-mind/index.d.ts +37 -0
  53. package/dist/cli/commands/hive-mind/index.d.ts.map +1 -0
  54. package/dist/cli/commands/hive-mind/index.js +33 -0
  55. package/dist/cli/commands/hive-mind/index.js.map +1 -0
  56. package/dist/cli/commands/hive-mind/validate-names.d.ts +79 -0
  57. package/dist/cli/commands/hive-mind/validate-names.d.ts.map +1 -0
  58. package/dist/cli/commands/hive-mind/validate-names.js +353 -0
  59. package/dist/cli/commands/hive-mind/validate-names.js.map +1 -0
  60. package/dist/cli/commands/vector.js +2 -0
  61. package/dist/cli/commands/vector.js.map +1 -1
  62. package/dist/cli/index.d.ts.map +1 -1
  63. package/dist/cli/index.js +7 -0
  64. package/dist/cli/index.js.map +1 -1
  65. package/dist/equilibrium/agent-equilibrium.d.ts +194 -0
  66. package/dist/equilibrium/agent-equilibrium.d.ts.map +1 -0
  67. package/dist/equilibrium/agent-equilibrium.js +304 -0
  68. package/dist/equilibrium/agent-equilibrium.js.map +1 -0
  69. package/dist/equilibrium/graph-equilibrium.d.ts +177 -0
  70. package/dist/equilibrium/graph-equilibrium.d.ts.map +1 -0
  71. package/dist/equilibrium/index.d.ts +11 -0
  72. package/dist/equilibrium/index.d.ts.map +1 -0
  73. package/dist/equilibrium/memory-equilibrium.d.ts +153 -0
  74. package/dist/equilibrium/memory-equilibrium.d.ts.map +1 -0
  75. package/dist/graphql/resolvers/index.d.ts.map +1 -1
  76. package/dist/graphql/resolvers/queries.d.ts +11 -0
  77. package/dist/graphql/resolvers/queries.d.ts.map +1 -1
  78. package/dist/index.d.ts +2 -0
  79. package/dist/index.d.ts.map +1 -1
  80. package/dist/index.js +10 -4
  81. package/dist/index.js.map +1 -1
  82. package/dist/inference/index.d.ts +9 -0
  83. package/dist/inference/index.d.ts.map +1 -0
  84. package/dist/inference/model-selection.d.ts +131 -0
  85. package/dist/inference/model-selection.d.ts.map +1 -0
  86. package/dist/integrations/agentic-flow/adapters/agent-booster-adapter.d.ts +265 -0
  87. package/dist/integrations/agentic-flow/adapters/agent-booster-adapter.d.ts.map +1 -0
  88. package/dist/integrations/agentic-flow/adapters/agentdb-adapter.d.ts +197 -0
  89. package/dist/integrations/agentic-flow/adapters/agentdb-adapter.d.ts.map +1 -0
  90. package/dist/integrations/agentic-flow/adapters/agentdb-vector-store.d.ts +249 -0
  91. package/dist/integrations/agentic-flow/adapters/agentdb-vector-store.d.ts.map +1 -0
  92. package/dist/integrations/agentic-flow/adapters/base-adapter.d.ts +120 -0
  93. package/dist/integrations/agentic-flow/adapters/base-adapter.d.ts.map +1 -0
  94. package/dist/integrations/agentic-flow/adapters/federation-hub-adapter.d.ts +444 -0
  95. package/dist/integrations/agentic-flow/adapters/federation-hub-adapter.d.ts.map +1 -0
  96. package/dist/integrations/agentic-flow/adapters/index.d.ts +17 -0
  97. package/dist/integrations/agentic-flow/adapters/index.d.ts.map +1 -0
  98. package/dist/integrations/agentic-flow/adapters/model-router-adapter.d.ts +242 -0
  99. package/dist/integrations/agentic-flow/adapters/model-router-adapter.d.ts.map +1 -0
  100. package/dist/integrations/agentic-flow/adapters/quic-transport-adapter.d.ts +364 -0
  101. package/dist/integrations/agentic-flow/adapters/quic-transport-adapter.d.ts.map +1 -0
  102. package/dist/integrations/agentic-flow/adapters/reasoning-bank-adapter.d.ts +209 -0
  103. package/dist/integrations/agentic-flow/adapters/reasoning-bank-adapter.d.ts.map +1 -0
  104. package/dist/integrations/agentic-flow/benchmark/index.d.ts +9 -0
  105. package/dist/integrations/agentic-flow/benchmark/index.d.ts.map +1 -0
  106. package/dist/integrations/agentic-flow/benchmark/vector-benchmark.d.ts +253 -0
  107. package/dist/integrations/agentic-flow/benchmark/vector-benchmark.d.ts.map +1 -0
  108. package/dist/integrations/agentic-flow/config.d.ts +109 -0
  109. package/dist/integrations/agentic-flow/config.d.ts.map +1 -0
  110. package/dist/integrations/agentic-flow/feature-flags.d.ts +140 -0
  111. package/dist/integrations/agentic-flow/feature-flags.d.ts.map +1 -0
  112. package/dist/integrations/agentic-flow/index.d.ts +22 -0
  113. package/dist/integrations/agentic-flow/index.d.ts.map +1 -0
  114. package/dist/integrations/agentic-flow/migration/index.d.ts +9 -0
  115. package/dist/integrations/agentic-flow/migration/index.d.ts.map +1 -0
  116. package/dist/integrations/agentic-flow/migration/migrate-to-agentdb.d.ts +242 -0
  117. package/dist/integrations/agentic-flow/migration/migrate-to-agentdb.d.ts.map +1 -0
  118. package/dist/learning/index.d.ts +91 -0
  119. package/dist/learning/index.d.ts.map +1 -0
  120. package/dist/learning/learning-loop.d.ts +176 -0
  121. package/dist/learning/learning-loop.d.ts.map +1 -0
  122. package/dist/learning/services/ab-testing-framework.d.ts +135 -0
  123. package/dist/learning/services/ab-testing-framework.d.ts.map +1 -0
  124. package/dist/learning/services/agent-priming-service.d.ts +207 -0
  125. package/dist/learning/services/agent-priming-service.d.ts.map +1 -0
  126. package/dist/learning/services/daily-log-generator.d.ts +113 -0
  127. package/dist/learning/services/daily-log-generator.d.ts.map +1 -0
  128. package/dist/learning/services/index.d.ts +14 -0
  129. package/dist/learning/services/index.d.ts.map +1 -0
  130. package/dist/learning/services/memory-extraction-service.d.ts +87 -0
  131. package/dist/learning/services/memory-extraction-service.d.ts.map +1 -0
  132. package/dist/learning/services/task-completion-consumer.d.ts +162 -0
  133. package/dist/learning/services/task-completion-consumer.d.ts.map +1 -0
  134. package/dist/learning/services/trajectory-tracker.d.ts +174 -0
  135. package/dist/learning/services/trajectory-tracker.d.ts.map +1 -0
  136. package/dist/learning/types.d.ts +516 -0
  137. package/dist/learning/types.d.ts.map +1 -0
  138. package/dist/mcp/clients/claude-flow-memory-client.d.ts +259 -0
  139. package/dist/mcp/clients/claude-flow-memory-client.d.ts.map +1 -0
  140. package/dist/mcp/clients/claude-flow-memory-client.js +305 -0
  141. package/dist/mcp/clients/claude-flow-memory-client.js.map +1 -0
  142. package/dist/mcp/clients/index.d.ts +11 -0
  143. package/dist/mcp/clients/index.d.ts.map +1 -0
  144. package/dist/mcp/clients/mcp-client-adapter.d.ts +146 -0
  145. package/dist/mcp/clients/mcp-client-adapter.d.ts.map +1 -0
  146. package/dist/mcp/clients/mcp-client-adapter.js +372 -0
  147. package/dist/mcp/clients/mcp-client-adapter.js.map +1 -0
  148. package/dist/mcp/index.d.ts +10 -0
  149. package/dist/mcp/index.d.ts.map +1 -0
  150. package/dist/memory/vault-sync.d.ts +12 -0
  151. package/dist/memory/vault-sync.d.ts.map +1 -1
  152. package/dist/memory/vault-sync.js +94 -11
  153. package/dist/memory/vault-sync.js.map +1 -1
  154. package/dist/node_modules/@huggingface/jinja/dist/index.js +118 -0
  155. package/dist/node_modules/@huggingface/jinja/dist/index.js.map +1 -0
  156. package/dist/node_modules/@typescript-eslint/project-service/dist/index.js +1 -1
  157. package/dist/node_modules/@xenova/transformers/src/backends/onnx.js +24 -0
  158. package/dist/node_modules/@xenova/transformers/src/backends/onnx.js.map +1 -0
  159. package/dist/node_modules/@xenova/transformers/src/configs.js +52 -0
  160. package/dist/node_modules/@xenova/transformers/src/configs.js.map +1 -0
  161. package/dist/node_modules/@xenova/transformers/src/env.js +35 -0
  162. package/dist/node_modules/@xenova/transformers/src/env.js.map +1 -0
  163. package/dist/node_modules/@xenova/transformers/src/models.js +3852 -0
  164. package/dist/node_modules/@xenova/transformers/src/models.js.map +1 -0
  165. package/dist/node_modules/@xenova/transformers/src/tokenizers.js +144 -0
  166. package/dist/node_modules/@xenova/transformers/src/tokenizers.js.map +1 -0
  167. package/dist/node_modules/@xenova/transformers/src/utils/core.js +52 -0
  168. package/dist/node_modules/@xenova/transformers/src/utils/core.js.map +1 -0
  169. package/dist/node_modules/@xenova/transformers/src/utils/generation.js +623 -0
  170. package/dist/node_modules/@xenova/transformers/src/utils/generation.js.map +1 -0
  171. package/dist/node_modules/@xenova/transformers/src/utils/hub.js +395 -0
  172. package/dist/node_modules/@xenova/transformers/src/utils/hub.js.map +1 -0
  173. package/dist/node_modules/@xenova/transformers/src/utils/image.js +12 -0
  174. package/dist/node_modules/@xenova/transformers/src/utils/image.js.map +1 -0
  175. package/dist/node_modules/@xenova/transformers/src/utils/maths.js +89 -0
  176. package/dist/node_modules/@xenova/transformers/src/utils/maths.js.map +1 -0
  177. package/dist/node_modules/@xenova/transformers/src/utils/tensor.js +750 -0
  178. package/dist/node_modules/@xenova/transformers/src/utils/tensor.js.map +1 -0
  179. package/dist/node_modules/fdir/dist/index.js +13 -13
  180. package/dist/node_modules/fdir/dist/index.js.map +1 -1
  181. package/dist/node_modules/onnxruntime-common/dist/lib/backend-impl.js +67 -0
  182. package/dist/node_modules/onnxruntime-common/dist/lib/backend-impl.js.map +1 -0
  183. package/dist/node_modules/onnxruntime-common/dist/lib/env-impl.js +24 -0
  184. package/dist/node_modules/onnxruntime-common/dist/lib/env-impl.js.map +1 -0
  185. package/dist/node_modules/onnxruntime-common/dist/lib/env.js +6 -0
  186. package/dist/node_modules/onnxruntime-common/dist/lib/env.js.map +1 -0
  187. package/dist/node_modules/onnxruntime-common/dist/lib/index.js +11 -0
  188. package/dist/node_modules/onnxruntime-common/dist/lib/index.js.map +1 -0
  189. package/dist/node_modules/onnxruntime-common/dist/lib/inference-session-impl.js +162 -0
  190. package/dist/node_modules/onnxruntime-common/dist/lib/inference-session-impl.js.map +1 -0
  191. package/dist/node_modules/onnxruntime-common/dist/lib/inference-session.js +6 -0
  192. package/dist/node_modules/onnxruntime-common/dist/lib/inference-session.js.map +1 -0
  193. package/dist/node_modules/onnxruntime-common/dist/lib/tensor-impl.js +393 -0
  194. package/dist/node_modules/onnxruntime-common/dist/lib/tensor-impl.js.map +1 -0
  195. package/dist/node_modules/onnxruntime-common/dist/lib/tensor.js +6 -0
  196. package/dist/node_modules/onnxruntime-common/dist/lib/tensor.js.map +1 -0
  197. package/dist/node_modules/onnxruntime-web/dist/ort-web.min.js +12919 -0
  198. package/dist/node_modules/onnxruntime-web/dist/ort-web.min.js.map +1 -0
  199. package/dist/node_modules/tinyglobby/dist/index.js +14 -14
  200. package/dist/node_modules/tinyglobby/dist/index.js.map +1 -1
  201. package/dist/node_modules/typescript/lib/typescript.js +24 -24
  202. package/dist/node_modules/typescript/lib/typescript.js.map +1 -1
  203. package/dist/transport/agent-transport.d.ts +269 -0
  204. package/dist/transport/agent-transport.d.ts.map +1 -0
  205. package/dist/transport/index.d.ts +10 -0
  206. package/dist/transport/index.d.ts.map +1 -0
  207. package/dist/vector/index.d.ts +1 -1
  208. package/dist/vector/index.d.ts.map +1 -1
  209. package/dist/vector/services/embedding-service.d.ts +244 -0
  210. package/dist/vector/services/embedding-service.d.ts.map +1 -0
  211. package/dist/vector/services/embedding-service.js +10 -0
  212. package/dist/vector/services/embedding-service.js.map +1 -0
  213. package/dist/vector/services/hybrid-search.d.ts +320 -0
  214. package/dist/vector/services/hybrid-search.d.ts.map +1 -0
  215. package/dist/vector/services/hybrid-search.js +3 -0
  216. package/dist/vector/services/hybrid-search.js.map +1 -0
  217. package/dist/vector/services/index.d.ts +4 -0
  218. package/dist/vector/services/index.d.ts.map +1 -1
  219. package/package.json +10 -1
@@ -0,0 +1,3852 @@
1
+ import { AutoConfig } from "./configs.js";
2
+ import { mergeArrays, Callable, isTypedArray, isIntegralNumber } from "./utils/core.js";
3
+ import { getModelJSON, getModelFile } from "./utils/hub.js";
4
+ import { WhisperTimeStampLogitsProcessor, LogitsProcessorList, RepetitionPenaltyLogitsProcessor, NoRepeatNGramLogitsProcessor, NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, ForceTokensLogitsProcessor, GenerationConfig, Sampler } from "./utils/generation.js";
5
+ import { cat, stack, std_mean, mean, Tensor, dynamicTimeWarping, ones_like } from "./utils/tensor.js";
6
+ import { ONNX, executionProviders } from "./backends/onnx.js";
7
+ import "./tokenizers.js";
8
+ import { medianFilter } from "./utils/maths.js";
9
+ import "./utils/image.js";
10
+ import "./env.js";
11
+ const { InferenceSession, Tensor: ONNXTensor, env } = ONNX;
12
+ const MODEL_TYPES = {
13
+ EncoderOnly: 0,
14
+ EncoderDecoder: 1,
15
+ Seq2Seq: 2,
16
+ Vision2Seq: 3,
17
+ DecoderOnly: 4,
18
+ MaskGeneration: 5
19
+ };
20
+ const MODEL_TYPE_MAPPING = /* @__PURE__ */ new Map();
21
+ const MODEL_NAME_TO_CLASS_MAPPING = /* @__PURE__ */ new Map();
22
+ const MODEL_CLASS_TO_NAME_MAPPING = /* @__PURE__ */ new Map();
23
+ async function constructSession(pretrained_model_name_or_path, fileName, options) {
24
+ let modelFileName = `onnx/${fileName}${options.quantized ? "_quantized" : ""}.onnx`;
25
+ let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
26
+ try {
27
+ return await InferenceSession.create(buffer, {
28
+ executionProviders
29
+ });
30
+ } catch (err) {
31
+ if (executionProviders.length === 1 && executionProviders[0] === "wasm") {
32
+ throw err;
33
+ }
34
+ console.warn(err);
35
+ console.warn(
36
+ "Something went wrong during model construction (most likely a missing operation). Using `wasm` as a fallback. "
37
+ );
38
+ return await InferenceSession.create(buffer, {
39
+ executionProviders: ["wasm"]
40
+ });
41
+ }
42
+ }
43
+ function validateInputs(session, inputs) {
44
+ const checkedInputs = /* @__PURE__ */ Object.create(null);
45
+ const missingInputs = [];
46
+ for (const inputName of session.inputNames) {
47
+ const tensor = inputs[inputName];
48
+ if (!(tensor instanceof Tensor)) {
49
+ missingInputs.push(inputName);
50
+ continue;
51
+ }
52
+ checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor;
53
+ }
54
+ if (missingInputs.length > 0) {
55
+ throw new Error(
56
+ `An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(", ")}.`
57
+ );
58
+ }
59
+ const numInputsProvided = Object.keys(inputs).length;
60
+ const numInputsNeeded = session.inputNames.length;
61
+ if (numInputsProvided > numInputsNeeded) {
62
+ let ignored = Object.keys(inputs).filter((inputName) => !session.inputNames.includes(inputName));
63
+ console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(", ")}".`);
64
+ }
65
+ return checkedInputs;
66
+ }
67
+ async function sessionRun(session, inputs) {
68
+ const checkedInputs = validateInputs(session, inputs);
69
+ try {
70
+ let output = await session.run(checkedInputs);
71
+ output = replaceTensors(output);
72
+ return output;
73
+ } catch (e) {
74
+ console.error(`An error occurred during model execution: "${e}".`);
75
+ console.error("Inputs given to model:", checkedInputs);
76
+ throw e;
77
+ }
78
+ }
79
+ function replaceTensors(obj) {
80
+ for (let prop in obj) {
81
+ if (obj[prop] instanceof ONNXTensor) {
82
+ obj[prop] = new Tensor(obj[prop]);
83
+ } else if (typeof obj[prop] === "object") {
84
+ replaceTensors(obj[prop]);
85
+ }
86
+ }
87
+ return obj;
88
+ }
89
+ function toI64Tensor(items) {
90
+ if (items instanceof Tensor) {
91
+ return items;
92
+ }
93
+ if (items.length === 0) {
94
+ throw Error("items must be non-empty");
95
+ }
96
+ if (Array.isArray(items[0])) {
97
+ if (items.some((x) => x.length !== items[0].length)) {
98
+ throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.");
99
+ }
100
+ return new Tensor(
101
+ "int64",
102
+ BigInt64Array.from(items.flat().map((x) => BigInt(x))),
103
+ [items.length, items[0].length]
104
+ );
105
+ } else {
106
+ return new Tensor(
107
+ "int64",
108
+ BigInt64Array.from(items.map((x) => BigInt(x))),
109
+ [1, items.length]
110
+ );
111
+ }
112
+ }
113
+ function prepareAttentionMask(self, tokens) {
114
+ let pad_token_id = self.config.pad_token_id ?? null;
115
+ let eos_token_id = self.config.eos_token_id ?? null;
116
+ if (isIntegralNumber(eos_token_id)) {
117
+ eos_token_id = [eos_token_id];
118
+ }
119
+ let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
120
+ let is_pad_token_not_equal_to_eos_token_id = eos_token_id === null || !eos_token_id.includes(pad_token_id);
121
+ if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
122
+ let data = BigInt64Array.from(
123
+ // Note: != so that int matches bigint
124
+ // @ts-ignore
125
+ tokens.data.map((x) => x != pad_token_id)
126
+ );
127
+ return new Tensor("int64", data, tokens.dims);
128
+ } else {
129
+ return ones_like(tokens);
130
+ }
131
+ }
132
+ function preparePositionIds(session, feeds, use_cache_branch) {
133
+ if (!session.inputNames.includes("position_ids")) return;
134
+ const data = new BigInt64Array(feeds.attention_mask.data.length);
135
+ for (let i = 0; i < feeds.attention_mask.dims[0]; ++i) {
136
+ let start = i * feeds.attention_mask.dims[1];
137
+ let sum = BigInt(0);
138
+ for (let j = 0; j < feeds.attention_mask.dims[1]; ++j) {
139
+ const index = start + j;
140
+ if (feeds.attention_mask.data[index] === 0n) {
141
+ data[index] = BigInt(1);
142
+ } else {
143
+ data[index] = sum;
144
+ sum += feeds.attention_mask.data[index];
145
+ }
146
+ }
147
+ }
148
+ feeds.position_ids = new Tensor("int64", data, feeds.attention_mask.dims);
149
+ if (use_cache_branch) {
150
+ feeds.position_ids = feeds.position_ids.slice(null, -1).unsqueeze_(-1);
151
+ }
152
+ }
153
+ function boolTensor(value) {
154
+ return new Tensor("bool", [value], [1]);
155
+ }
156
+ async function seq2seqForward(self, model_inputs) {
157
+ let { encoder_outputs, past_key_values } = model_inputs;
158
+ if (!encoder_outputs) {
159
+ encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state;
160
+ }
161
+ let decoderFeeds = {
162
+ input_ids: model_inputs.decoder_input_ids,
163
+ encoder_hidden_states: encoder_outputs
164
+ };
165
+ const use_cache_branch = !!past_key_values;
166
+ if (self.decoder_merged_session.inputNames.includes("use_cache_branch")) {
167
+ decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
168
+ }
169
+ if (self.decoder_merged_session.inputNames.includes("encoder_attention_mask")) {
170
+ decoderFeeds.encoder_attention_mask = model_inputs.attention_mask;
171
+ }
172
+ preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch);
173
+ self.addPastKeyValues(decoderFeeds, past_key_values);
174
+ const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
175
+ let logits = decoderResults.logits;
176
+ past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
177
+ const attns = self.getAttentions(decoderResults);
178
+ return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs, ...attns });
179
+ }
180
+ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputTokens) {
181
+ let beams = [];
182
+ let beamId = 0;
183
+ const requires_attention_mask = self.requires_attention_mask ?? true;
184
+ let decoder_input_ids = generation_config.decoder_input_ids ?? generation_config.decoder_start_token_id ?? generation_config.bos_token_id ?? generation_config.eos_token_id;
185
+ if (decoder_input_ids instanceof Tensor) {
186
+ decoder_input_ids = decoder_input_ids.tolist().flat();
187
+ } else if (!Array.isArray(decoder_input_ids)) {
188
+ decoder_input_ids = [decoder_input_ids];
189
+ }
190
+ for (let tokens of inputTokenIds) {
191
+ tokens.dims = [1, ...tokens.dims];
192
+ let start = {
193
+ inputs: tokens,
194
+ encoder_outputs: null,
195
+ prev_model_outputs: null,
196
+ output_token_ids: decoder_input_ids,
197
+ done: false,
198
+ score: 0,
199
+ id: beamId++
200
+ // assign unique id to beams
201
+ };
202
+ if (requires_attention_mask) {
203
+ start.attention_mask = prepareAttentionMask(self, tokens);
204
+ }
205
+ beams.push(start);
206
+ }
207
+ return beams;
208
+ }
209
+ async function seq2seqRunBeam(self, beam) {
210
+ const input_name = self.main_input_name;
211
+ let decoder_input_ids = beam.output_token_ids;
212
+ if (beam.prev_model_outputs) {
213
+ decoder_input_ids = decoder_input_ids.slice(-1);
214
+ }
215
+ let model_inputs = {
216
+ [input_name]: beam.inputs,
217
+ decoder_input_ids: toI64Tensor(decoder_input_ids),
218
+ encoder_outputs: beam.encoder_outputs,
219
+ past_key_values: beam.prev_model_outputs?.past_key_values
220
+ };
221
+ if (beam.attention_mask) {
222
+ model_inputs.attention_mask = beam.attention_mask;
223
+ }
224
+ let output = await self.forward(model_inputs);
225
+ beam.prev_model_outputs = output;
226
+ beam.encoder_outputs = output.encoder_outputs;
227
+ return output;
228
+ }
229
+ function seq2seqUpdatebeam(beam, newTokenId) {
230
+ beam.output_token_ids = [...beam.output_token_ids, newTokenId];
231
+ }
232
+ async function encoderForward(self, model_inputs) {
233
+ const encoderFeeds = /* @__PURE__ */ Object.create(null);
234
+ for (const key of self.session.inputNames) {
235
+ encoderFeeds[key] = model_inputs[key];
236
+ }
237
+ if (self.session.inputNames.includes("token_type_ids") && !encoderFeeds.token_type_ids) {
238
+ encoderFeeds.token_type_ids = new Tensor(
239
+ "int64",
240
+ new BigInt64Array(encoderFeeds.input_ids.data.length),
241
+ encoderFeeds.input_ids.dims
242
+ );
243
+ }
244
+ return await sessionRun(self.session, encoderFeeds);
245
+ }
246
+ async function decoderForward(self, model_inputs) {
247
+ let { input_ids, past_key_values, attention_mask } = model_inputs;
248
+ let decoderFeeds = {
249
+ input_ids,
250
+ attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids)
251
+ };
252
+ const use_cache_branch = !!past_key_values;
253
+ if (self.session.inputNames.includes("use_cache_branch")) {
254
+ decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
255
+ }
256
+ preparePositionIds(self.session, decoderFeeds, use_cache_branch);
257
+ self.addPastKeyValues(decoderFeeds, past_key_values);
258
+ let decoderResults = await sessionRun(self.session, decoderFeeds);
259
+ let logits = decoderResults.logits;
260
+ past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
261
+ return { logits, past_key_values };
262
+ }
263
+ function decoderStartBeams(self, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) {
264
+ let beams = [];
265
+ let beamId = 0;
266
+ for (let tokens of inputTokenIds) {
267
+ let output_token_ids = tokens.tolist().map(Number);
268
+ tokens.dims = [1, ...tokens.dims];
269
+ let attn_mask;
270
+ if (inputs_attention_mask) {
271
+ attn_mask = inputs_attention_mask[beamId];
272
+ attn_mask.dims = [1, ...attn_mask.dims];
273
+ } else {
274
+ attn_mask = prepareAttentionMask(self, tokens);
275
+ }
276
+ let start = {
277
+ input: tokens,
278
+ model_input_ids: tokens,
279
+ attention_mask: attn_mask,
280
+ prev_model_outputs: null,
281
+ output_token_ids,
282
+ num_output_tokens: numOutputTokens,
283
+ done: false,
284
+ score: 0,
285
+ id: beamId++
286
+ // assign unique id to beams
287
+ };
288
+ beams.push(start);
289
+ }
290
+ return beams;
291
+ }
292
+ async function decoderRunBeam(self, beam) {
293
+ let attnMaskData = new BigInt64Array(beam.output_token_ids.length).fill(1n);
294
+ let model_inputs = {
295
+ input_ids: beam.model_input_ids,
296
+ attention_mask: new Tensor(
297
+ "int64",
298
+ attnMaskData,
299
+ [1, attnMaskData.length]
300
+ ),
301
+ past_key_values: beam.prev_model_outputs?.past_key_values
302
+ };
303
+ let output = await self.forward(model_inputs);
304
+ beam.prev_model_outputs = output;
305
+ return output;
306
+ }
307
+ function decoderUpdatebeam(beam, newTokenId) {
308
+ beam.output_token_ids = [...beam.output_token_ids, newTokenId];
309
+ beam.model_input_ids = new Tensor("int64", [BigInt(newTokenId)], [1, 1]);
310
+ }
311
+ class PreTrainedModel extends Callable {
312
+ main_input_name = "input_ids";
313
+ /**
314
+ * Creates a new instance of the `PreTrainedModel` class.
315
+ * @param {Object} config The model configuration.
316
+ * @param {any} session session for the model.
317
+ */
318
+ constructor(config, session) {
319
+ super();
320
+ this.config = config;
321
+ this.session = session;
322
+ const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
323
+ const modelType = MODEL_TYPE_MAPPING.get(modelName);
324
+ this.can_generate = false;
325
+ this._runBeam = null;
326
+ this._getStartBeams = null;
327
+ this._updateBeam = null;
328
+ this._forward = null;
329
+ if (modelType === MODEL_TYPES.DecoderOnly) {
330
+ this.can_generate = true;
331
+ this._runBeam = decoderRunBeam;
332
+ this._getStartBeams = decoderStartBeams;
333
+ this._updateBeam = decoderUpdatebeam;
334
+ this._forward = decoderForward;
335
+ } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
336
+ this.can_generate = true;
337
+ this._runBeam = seq2seqRunBeam;
338
+ this._getStartBeams = seq2seqStartBeams;
339
+ this._updateBeam = seq2seqUpdatebeam;
340
+ this._forward = seq2seqForward;
341
+ } else if (modelType === MODEL_TYPES.EncoderDecoder) {
342
+ this._forward = encoderForward;
343
+ } else {
344
+ this._forward = encoderForward;
345
+ }
346
+ }
347
+ /**
348
+ * Disposes of all the ONNX sessions that were created during inference.
349
+ * @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
350
+ * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
351
+ */
352
+ async dispose() {
353
+ const promises = [];
354
+ for (let key of Object.keys(this)) {
355
+ const item = this[key];
356
+ if (item instanceof InferenceSession) {
357
+ promises.push(item.handler.dispose());
358
+ }
359
+ }
360
+ return await Promise.all(promises);
361
+ }
362
+ /**
363
+ * Instantiate one of the model classes of the library from a pretrained model.
364
+ *
365
+ * The model class to instantiate is selected based on the `model_type` property of the config object
366
+ * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
367
+ *
368
+ * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
369
+ * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
370
+ * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
371
+ * user or organization name, like `dbmdz/bert-base-german-cased`.
372
+ * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
373
+ * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
374
+ *
375
+ * @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
376
+ */
377
+ static async from_pretrained(pretrained_model_name_or_path, {
378
+ quantized = true,
379
+ progress_callback = null,
380
+ config = null,
381
+ cache_dir = null,
382
+ local_files_only = false,
383
+ revision = "main",
384
+ model_file_name = null
385
+ } = {}) {
386
+ let options = {
387
+ quantized,
388
+ progress_callback,
389
+ config,
390
+ cache_dir,
391
+ local_files_only,
392
+ revision,
393
+ model_file_name
394
+ };
395
+ const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
396
+ const modelType = MODEL_TYPE_MAPPING.get(modelName);
397
+ let info;
398
+ if (modelType === MODEL_TYPES.DecoderOnly) {
399
+ info = await Promise.all([
400
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
401
+ constructSession(pretrained_model_name_or_path, options.model_file_name ?? "decoder_model_merged", options),
402
+ getModelJSON(pretrained_model_name_or_path, "generation_config.json", false, options)
403
+ ]);
404
+ } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
405
+ info = await Promise.all([
406
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
407
+ constructSession(pretrained_model_name_or_path, "encoder_model", options),
408
+ constructSession(pretrained_model_name_or_path, "decoder_model_merged", options),
409
+ getModelJSON(pretrained_model_name_or_path, "generation_config.json", false, options)
410
+ ]);
411
+ } else if (modelType === MODEL_TYPES.MaskGeneration) {
412
+ info = await Promise.all([
413
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
414
+ constructSession(pretrained_model_name_or_path, "vision_encoder", options),
415
+ constructSession(pretrained_model_name_or_path, "prompt_encoder_mask_decoder", options)
416
+ ]);
417
+ } else if (modelType === MODEL_TYPES.EncoderDecoder) {
418
+ info = await Promise.all([
419
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
420
+ constructSession(pretrained_model_name_or_path, "encoder_model", options),
421
+ constructSession(pretrained_model_name_or_path, "decoder_model_merged", options)
422
+ ]);
423
+ } else {
424
+ if (modelType !== MODEL_TYPES.EncoderOnly) {
425
+ console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`);
426
+ }
427
+ info = await Promise.all([
428
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
429
+ constructSession(pretrained_model_name_or_path, options.model_file_name ?? "model", options)
430
+ ]);
431
+ }
432
+ return new this(...info);
433
+ }
434
+ /**
435
+ * Runs the model with the provided inputs
436
+ * @param {Object} model_inputs Object containing input tensors
437
+ * @returns {Promise<Object>} Object containing output tensors
438
+ */
439
+ async _call(model_inputs) {
440
+ return await this.forward(model_inputs);
441
+ }
442
+ /**
443
+ * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
444
+ * will be chosen based on the model type.
445
+ * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
446
+ * @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model.
447
+ * @throws {Error} This method must be implemented in subclasses.
448
+ */
449
+ async forward(model_inputs) {
450
+ return await this._forward(this, model_inputs);
451
+ }
452
+ /**
453
+ * @param {import('./utils/generation.js').GenerationConfigType} generation_config
454
+ * @param {number} input_ids_seq_length The starting sequence length for the input ids.
455
+ * @returns {LogitsProcessorList}
456
+ * @private
457
+ */
458
+ _get_logits_processor(generation_config, input_ids_seq_length, logits_processor = null) {
459
+ const processors = new LogitsProcessorList();
460
+ if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1) {
461
+ processors.push(new RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty));
462
+ }
463
+ if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) {
464
+ processors.push(new NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size));
465
+ }
466
+ if (generation_config.bad_words_ids !== null) {
467
+ processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
468
+ }
469
+ if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
470
+ processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
471
+ }
472
+ if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
473
+ processors.push(new MinNewTokensLengthLogitsProcessor(
474
+ input_ids_seq_length,
475
+ generation_config.min_new_tokens,
476
+ generation_config.eos_token_id
477
+ ));
478
+ }
479
+ if (generation_config.forced_bos_token_id !== null) {
480
+ processors.push(new ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id));
481
+ }
482
+ if (generation_config.forced_eos_token_id !== null) {
483
+ processors.push(new ForcedEOSTokenLogitsProcessor(
484
+ generation_config.max_length,
485
+ generation_config.forced_eos_token_id
486
+ ));
487
+ }
488
+ if (generation_config.begin_suppress_tokens !== null) {
489
+ let begin_index = input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null ? input_ids_seq_length : input_ids_seq_length + 1;
490
+ if (generation_config.forced_decoder_ids !== null) {
491
+ begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0];
492
+ }
493
+ processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
494
+ }
495
+ if (generation_config.forced_decoder_ids !== null) {
496
+ processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
497
+ }
498
+ if (logits_processor !== null) {
499
+ processors.extend(logits_processor);
500
+ }
501
+ return processors;
502
+ }
503
+ /**
504
+ * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation.
505
+ * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object.
506
+ * @param {import('./utils/generation.js').GenerationConfigType} generation_config A `GenerationConfig` object containing generation parameters.
507
+ * @returns {import('./utils/generation.js').GenerationConfigType} The final generation config object to be used by the model for text generation.
508
+ */
509
+ _get_generation_config(generation_config) {
510
+ let gen_config = new GenerationConfig(this.config);
511
+ if ("generation_config" in this) {
512
+ Object.assign(gen_config, this.generation_config);
513
+ }
514
+ if (generation_config !== null) {
515
+ Object.assign(gen_config, generation_config);
516
+ }
517
+ return gen_config;
518
+ }
519
+ /**
520
+ * @typedef {import('./utils/maths.js').TypedArray} TypedArray
521
+ */
522
+ /**
523
+ * @typedef {{ sequences: Tensor, decoder_attentions: Tensor, cross_attentions: Tensor }} EncoderDecoderOutput
524
+ * @typedef {Object} DecoderOutput
525
+ *
526
+ * Generates text based on the given inputs and generation configuration using the model.
527
+ * @param {Tensor|Array|TypedArray} inputs An array of input token IDs.
528
+ * @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used.
529
+ * @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
530
+ * @param {Object} options options
531
+ * @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs.
532
+ * @returns {Promise<number[][]|EncoderDecoderOutput|DecoderOutput>} An array of generated output sequences, where each sequence is an array of token IDs.
533
+ * @throws {Error} Throws an error if the inputs array is empty.
534
+ */
535
+ async generate(inputs, generation_config = null, logits_processor = null, {
536
+ inputs_attention_mask = null
537
+ } = {}) {
538
+ if (!this.can_generate) {
539
+ const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
540
+ let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`;
541
+ const modelType = this.config.model_type;
542
+ const possibleInfo = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType);
543
+ if (possibleInfo) {
544
+ errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`;
545
+ }
546
+ throw Error(errorMessage);
547
+ }
548
+ if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) {
549
+ throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`);
550
+ }
551
+ let input_ids_seq_length;
552
+ if (this.config.is_encoder_decoder) {
553
+ input_ids_seq_length = 0;
554
+ } else {
555
+ input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;
556
+ if (input_ids_seq_length === 0) {
557
+ throw Error("Must supply a non-empty array of input token ids.");
558
+ }
559
+ }
560
+ generation_config = this._get_generation_config(generation_config);
561
+ logits_processor = logits_processor ?? new LogitsProcessorList();
562
+ logits_processor = this._get_logits_processor(
563
+ generation_config,
564
+ input_ids_seq_length,
565
+ logits_processor
566
+ );
567
+ let eos_token_ids = generation_config.eos_token_id;
568
+ if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
569
+ eos_token_ids = [eos_token_ids];
570
+ }
571
+ let numOutputTokens = 1;
572
+ const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity);
573
+ const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null;
574
+ let sampler = Sampler.getSampler(generation_config);
575
+ let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask);
576
+ while (beams.some((x) => !x.done) && numOutputTokens < maxOutputTokens) {
577
+ let newest_beams = [];
578
+ for (let beam of beams) {
579
+ if (beam.done) {
580
+ newest_beams.push(beam);
581
+ continue;
582
+ }
583
+ if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) {
584
+ beam.done = true;
585
+ newest_beams.push(beam);
586
+ continue;
587
+ }
588
+ let output = await this.runBeam(beam);
589
+ if (generation_config.output_attentions) {
590
+ this.addAttentionsToBeam(beam, output);
591
+ }
592
+ if (generation_config.output_scores) ;
593
+ let logits = output.logits.slice(null, -1, null);
594
+ logits_processor(beam.output_token_ids, logits);
595
+ let sampledTokens = sampler(logits);
596
+ for (let [newTokenId, logProb] of sampledTokens) {
597
+ let newBeam = { ...beam };
598
+ this.updateBeam(newBeam, newTokenId);
599
+ newBeam.score += logProb;
600
+ if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
601
+ newBeam.done = true;
602
+ }
603
+ newest_beams.push(newBeam);
604
+ }
605
+ }
606
+ ++numOutputTokens;
607
+ newest_beams = this.groupBeams(newest_beams).map(
608
+ (group) => group.sort((a, b) => b.score - a.score).slice(0, generation_config.num_beams)
609
+ // remove outside beam width
610
+ );
611
+ beams = newest_beams.flat();
612
+ if (generation_config.callback_function) {
613
+ generation_config.callback_function(beams);
614
+ }
615
+ }
616
+ const groupedBeams = this.groupBeams(beams);
617
+ const getFlattened = (key) => groupedBeams.map(
618
+ (batch) => {
619
+ if (generation_config.num_return_sequences > 1) {
620
+ return batch.slice(0, generation_config.num_return_sequences).map((x) => x[key]);
621
+ } else {
622
+ return [batch[0][key]];
623
+ }
624
+ }
625
+ ).flat();
626
+ const sequences = getFlattened("output_token_ids");
627
+ if (generation_config.return_dict_in_generate) {
628
+ const decoder_attentions = getFlattened("decoder_attentions");
629
+ const cross_attentions = getFlattened("cross_attentions");
630
+ return {
631
+ sequences,
632
+ decoder_attentions,
633
+ cross_attentions
634
+ };
635
+ } else {
636
+ return sequences;
637
+ }
638
+ }
639
+ /**
640
+ * Helper function to add attentions to beam
641
+ * @param {Object} beam
642
+ * @param {Object} output
643
+ * @private
644
+ */
645
+ addAttentionsToBeam(beam, output) {
646
+ if (this.config.is_encoder_decoder) {
647
+ if (!output.cross_attentions || output.cross_attentions.length === 0) {
648
+ throw Error(
649
+ "`output_attentions` is true, but the model did not produce cross-attentions. This is most likely because the model was not exported with `output_attentions=True`."
650
+ );
651
+ }
652
+ if (!beam.cross_attentions) {
653
+ beam.cross_attentions = [];
654
+ }
655
+ beam.cross_attentions.push(output.cross_attentions);
656
+ }
657
+ if (!output.decoder_attentions || output.decoder_attentions.length === 0) {
658
+ throw Error(
659
+ "`output_attentions` is true, but the model did not produce decoder-attentions. This is most likely because the model was not exported with `output_attentions=True`."
660
+ );
661
+ }
662
+ if (!beam.decoder_attentions) {
663
+ beam.decoder_attentions = [];
664
+ }
665
+ beam.decoder_attentions.push(output.decoder_attentions);
666
+ }
667
+ /**
668
+ * Groups an array of beam objects by their ids.
669
+ *
670
+ * @param {Array} beams The array of beam objects to group.
671
+ * @returns {Array} An array of arrays, where each inner array contains beam objects with the same id.
672
+ */
673
+ groupBeams(beams) {
674
+ const groups = /* @__PURE__ */ Object.create(null);
675
+ for (const obj of beams) {
676
+ if (groups[obj.id] === void 0) {
677
+ groups[obj.id] = [obj];
678
+ } else {
679
+ groups[obj.id].push(obj);
680
+ }
681
+ }
682
+ return Object.values(groups);
683
+ }
684
+ /**
685
+ * Returns an object containing past key values from the given decoder results object.
686
+ *
687
+ * @param {Object} decoderResults The decoder results object.
688
+ * @param {Object} pastKeyValues The previous past key values.
689
+ * @returns {Object} An object containing past key values.
690
+ */
691
+ getPastKeyValues(decoderResults, pastKeyValues) {
692
+ const pkvs = /* @__PURE__ */ Object.create(null);
693
+ for (const name in decoderResults) {
694
+ if (name.startsWith("present")) {
695
+ let newName = name.replace("present", "past_key_values");
696
+ if (pastKeyValues && name.includes("encoder")) {
697
+ pkvs[newName] = pastKeyValues[newName];
698
+ } else {
699
+ pkvs[newName] = decoderResults[name];
700
+ }
701
+ }
702
+ }
703
+ return pkvs;
704
+ }
705
+ /**
706
+ * Returns an object containing attentions from the given decoder results object.
707
+ *
708
+ * @param {Object} decoderResults The decoder results object.
709
+ * @returns {Object} An object containing attentions.
710
+ */
711
+ getAttentions(decoderResults) {
712
+ const attns = /* @__PURE__ */ Object.create(null);
713
+ for (const attnName of ["cross_attentions", "decoder_attentions"]) {
714
+ const result = [];
715
+ for (const name in decoderResults) {
716
+ if (name.startsWith(attnName)) {
717
+ const index = name.split(".").pop();
718
+ result[index] = decoderResults[name];
719
+ }
720
+ }
721
+ attns[attnName] = result;
722
+ }
723
+ return attns;
724
+ }
725
+ /**
726
+ * Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values.
727
+ *
728
+ * @param {Object} decoderFeeds The decoder feeds object to add past key values to.
729
+ * @param {Object} pastKeyValues An object containing past key values.
730
+ */
731
+ addPastKeyValues(decoderFeeds, pastKeyValues) {
732
+ if (pastKeyValues) {
733
+ Object.assign(decoderFeeds, pastKeyValues);
734
+ } else {
735
+ const batch_size = 1;
736
+ if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
737
+ let encoder_dims = [batch_size, this.num_encoder_heads, 0, this.encoder_dim_kv];
738
+ let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv];
739
+ for (let i = 0; i < this.num_decoder_layers; ++i) {
740
+ decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor("float32", [], encoder_dims);
741
+ decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor("float32", [], encoder_dims);
742
+ decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor("float32", [], decoder_dims);
743
+ decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor("float32", [], decoder_dims);
744
+ }
745
+ } else if (this.config.model_type === "falcon") {
746
+ let dims = [batch_size * this.num_heads, 0, this.dim_kv];
747
+ for (let i = 0; i < this.num_layers; ++i) {
748
+ decoderFeeds[`past_key_values.${i}.key`] = new Tensor("float32", [], dims);
749
+ decoderFeeds[`past_key_values.${i}.value`] = new Tensor("float32", [], dims);
750
+ }
751
+ } else if (this.config.multi_query) {
752
+ let dims = [batch_size * this.num_heads, 0, 2 * this.dim_kv];
753
+ for (let i = 0; i < this.num_layers; ++i) {
754
+ decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor("float32", [], dims);
755
+ }
756
+ } else if (this.config.model_type === "bloom") {
757
+ let keyDims = [batch_size * this.num_heads, this.dim_kv, 0];
758
+ let valueDims = [batch_size * this.num_heads, 0, this.dim_kv];
759
+ for (let i = 0; i < this.num_layers; ++i) {
760
+ decoderFeeds[`past_key_values.${i}.key`] = new Tensor("float32", [], keyDims);
761
+ decoderFeeds[`past_key_values.${i}.value`] = new Tensor("float32", [], valueDims);
762
+ }
763
+ } else {
764
+ let dims = [batch_size, this.num_heads, 0, this.dim_kv];
765
+ for (let i = 0; i < this.num_layers; ++i) {
766
+ decoderFeeds[`past_key_values.${i}.key`] = new Tensor("float32", [], dims);
767
+ decoderFeeds[`past_key_values.${i}.value`] = new Tensor("float32", [], dims);
768
+ }
769
+ }
770
+ }
771
+ }
772
+ /**
773
+ * Initializes and returns the beam for text generation task
774
+ * @param {Tensor} inputTokenIds The input token ids.
775
+ * @param {Object} generation_config The generation config.
776
+ * @param {number} numOutputTokens The number of tokens to be generated.
777
+ * @param {Tensor} inputs_attention_mask Optional input attention mask.
778
+ * @returns {any} A Beam object representing the initialized beam.
779
+ * @private
780
+ */
781
+ getStartBeams(inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) {
782
+ return this._getStartBeams(this, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask);
783
+ }
784
+ /**
785
+ * Runs a single step of the beam search generation algorithm.
786
+ * @param {any} beam The current beam being generated.
787
+ * @returns {Promise<any>} The updated beam after a single generation step.
788
+ * @private
789
+ */
790
+ async runBeam(beam) {
791
+ return await this._runBeam(this, beam);
792
+ }
793
+ /**
794
+ * Update a beam with a new token ID.
795
+ * @param {Object} beam The beam to update.
796
+ * @param {number} newTokenId The new token ID to add to the beam's output.
797
+ * @private
798
+ */
799
+ updateBeam(beam, newTokenId) {
800
+ return this._updateBeam(beam, newTokenId);
801
+ }
802
+ }
803
+ class ModelOutput {
804
+ }
805
+ class BertPreTrainedModel extends PreTrainedModel {
806
+ }
807
+ class BertModel extends BertPreTrainedModel {
808
+ }
809
+ class BertForMaskedLM extends BertPreTrainedModel {
810
+ /**
811
+ * Calls the model on new inputs.
812
+ *
813
+ * @param {Object} model_inputs The inputs to the model.
814
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
815
+ */
816
+ async _call(model_inputs) {
817
+ return new MaskedLMOutput(await super._call(model_inputs));
818
+ }
819
+ }
820
+ class BertForSequenceClassification extends BertPreTrainedModel {
821
+ /**
822
+ * Calls the model on new inputs.
823
+ *
824
+ * @param {Object} model_inputs The inputs to the model.
825
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
826
+ */
827
+ async _call(model_inputs) {
828
+ return new SequenceClassifierOutput(await super._call(model_inputs));
829
+ }
830
+ }
831
+ class BertForTokenClassification extends BertPreTrainedModel {
832
+ /**
833
+ * Calls the model on new inputs.
834
+ *
835
+ * @param {Object} model_inputs The inputs to the model.
836
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
837
+ */
838
+ async _call(model_inputs) {
839
+ return new TokenClassifierOutput(await super._call(model_inputs));
840
+ }
841
+ }
842
+ class BertForQuestionAnswering extends BertPreTrainedModel {
843
+ /**
844
+ * Calls the model on new inputs.
845
+ *
846
+ * @param {Object} model_inputs The inputs to the model.
847
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
848
+ */
849
+ async _call(model_inputs) {
850
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
851
+ }
852
+ }
853
+ class NomicBertPreTrainedModel extends PreTrainedModel {
854
+ }
855
+ class NomicBertModel extends NomicBertPreTrainedModel {
856
+ }
857
+ class RoFormerPreTrainedModel extends PreTrainedModel {
858
+ }
859
+ class RoFormerModel extends RoFormerPreTrainedModel {
860
+ }
861
+ class RoFormerForMaskedLM extends RoFormerPreTrainedModel {
862
+ /**
863
+ * Calls the model on new inputs.
864
+ *
865
+ * @param {Object} model_inputs The inputs to the model.
866
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
867
+ */
868
+ async _call(model_inputs) {
869
+ return new MaskedLMOutput(await super._call(model_inputs));
870
+ }
871
+ }
872
+ class RoFormerForSequenceClassification extends RoFormerPreTrainedModel {
873
+ /**
874
+ * Calls the model on new inputs.
875
+ *
876
+ * @param {Object} model_inputs The inputs to the model.
877
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
878
+ */
879
+ async _call(model_inputs) {
880
+ return new SequenceClassifierOutput(await super._call(model_inputs));
881
+ }
882
+ }
883
+ class RoFormerForTokenClassification extends RoFormerPreTrainedModel {
884
+ /**
885
+ * Calls the model on new inputs.
886
+ *
887
+ * @param {Object} model_inputs The inputs to the model.
888
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
889
+ */
890
+ async _call(model_inputs) {
891
+ return new TokenClassifierOutput(await super._call(model_inputs));
892
+ }
893
+ }
894
+ class RoFormerForQuestionAnswering extends RoFormerPreTrainedModel {
895
+ /**
896
+ * Calls the model on new inputs.
897
+ *
898
+ * @param {Object} model_inputs The inputs to the model.
899
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
900
+ */
901
+ async _call(model_inputs) {
902
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
903
+ }
904
+ }
905
+ class ConvBertPreTrainedModel extends PreTrainedModel {
906
+ }
907
+ class ConvBertModel extends ConvBertPreTrainedModel {
908
+ }
909
+ class ConvBertForMaskedLM extends ConvBertPreTrainedModel {
910
+ /**
911
+ * Calls the model on new inputs.
912
+ *
913
+ * @param {Object} model_inputs The inputs to the model.
914
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
915
+ */
916
+ async _call(model_inputs) {
917
+ return new MaskedLMOutput(await super._call(model_inputs));
918
+ }
919
+ }
920
+ class ConvBertForSequenceClassification extends ConvBertPreTrainedModel {
921
+ /**
922
+ * Calls the model on new inputs.
923
+ *
924
+ * @param {Object} model_inputs The inputs to the model.
925
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
926
+ */
927
+ async _call(model_inputs) {
928
+ return new SequenceClassifierOutput(await super._call(model_inputs));
929
+ }
930
+ }
931
+ class ConvBertForTokenClassification extends ConvBertPreTrainedModel {
932
+ /**
933
+ * Calls the model on new inputs.
934
+ *
935
+ * @param {Object} model_inputs The inputs to the model.
936
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
937
+ */
938
+ async _call(model_inputs) {
939
+ return new TokenClassifierOutput(await super._call(model_inputs));
940
+ }
941
+ }
942
+ class ConvBertForQuestionAnswering extends ConvBertPreTrainedModel {
943
+ /**
944
+ * Calls the model on new inputs.
945
+ *
946
+ * @param {Object} model_inputs The inputs to the model.
947
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
948
+ */
949
+ async _call(model_inputs) {
950
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
951
+ }
952
+ }
953
+ class ElectraPreTrainedModel extends PreTrainedModel {
954
+ }
955
+ class ElectraModel extends ElectraPreTrainedModel {
956
+ }
957
+ class ElectraForMaskedLM extends ElectraPreTrainedModel {
958
+ /**
959
+ * Calls the model on new inputs.
960
+ *
961
+ * @param {Object} model_inputs The inputs to the model.
962
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
963
+ */
964
+ async _call(model_inputs) {
965
+ return new MaskedLMOutput(await super._call(model_inputs));
966
+ }
967
+ }
968
+ class ElectraForSequenceClassification extends ElectraPreTrainedModel {
969
+ /**
970
+ * Calls the model on new inputs.
971
+ *
972
+ * @param {Object} model_inputs The inputs to the model.
973
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
974
+ */
975
+ async _call(model_inputs) {
976
+ return new SequenceClassifierOutput(await super._call(model_inputs));
977
+ }
978
+ }
979
+ class ElectraForTokenClassification extends ElectraPreTrainedModel {
980
+ /**
981
+ * Calls the model on new inputs.
982
+ *
983
+ * @param {Object} model_inputs The inputs to the model.
984
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
985
+ */
986
+ async _call(model_inputs) {
987
+ return new TokenClassifierOutput(await super._call(model_inputs));
988
+ }
989
+ }
990
+ class ElectraForQuestionAnswering extends ElectraPreTrainedModel {
991
+ /**
992
+ * Calls the model on new inputs.
993
+ *
994
+ * @param {Object} model_inputs The inputs to the model.
995
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
996
+ */
997
+ async _call(model_inputs) {
998
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
999
+ }
1000
+ }
1001
+ class CamembertPreTrainedModel extends PreTrainedModel {
1002
+ }
1003
+ class CamembertModel extends CamembertPreTrainedModel {
1004
+ }
1005
+ class CamembertForMaskedLM extends CamembertPreTrainedModel {
1006
+ /**
1007
+ * Calls the model on new inputs.
1008
+ *
1009
+ * @param {Object} model_inputs The inputs to the model.
1010
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1011
+ */
1012
+ async _call(model_inputs) {
1013
+ return new MaskedLMOutput(await super._call(model_inputs));
1014
+ }
1015
+ }
1016
+ class CamembertForSequenceClassification extends CamembertPreTrainedModel {
1017
+ /**
1018
+ * Calls the model on new inputs.
1019
+ *
1020
+ * @param {Object} model_inputs The inputs to the model.
1021
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1022
+ */
1023
+ async _call(model_inputs) {
1024
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1025
+ }
1026
+ }
1027
+ class CamembertForTokenClassification extends CamembertPreTrainedModel {
1028
+ /**
1029
+ * Calls the model on new inputs.
1030
+ *
1031
+ * @param {Object} model_inputs The inputs to the model.
1032
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1033
+ */
1034
+ async _call(model_inputs) {
1035
+ return new TokenClassifierOutput(await super._call(model_inputs));
1036
+ }
1037
+ }
1038
+ class CamembertForQuestionAnswering extends CamembertPreTrainedModel {
1039
+ /**
1040
+ * Calls the model on new inputs.
1041
+ *
1042
+ * @param {Object} model_inputs The inputs to the model.
1043
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1044
+ */
1045
+ async _call(model_inputs) {
1046
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1047
+ }
1048
+ }
1049
+ class DebertaPreTrainedModel extends PreTrainedModel {
1050
+ }
1051
+ class DebertaModel extends DebertaPreTrainedModel {
1052
+ }
1053
+ class DebertaForMaskedLM extends DebertaPreTrainedModel {
1054
+ /**
1055
+ * Calls the model on new inputs.
1056
+ *
1057
+ * @param {Object} model_inputs The inputs to the model.
1058
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1059
+ */
1060
+ async _call(model_inputs) {
1061
+ return new MaskedLMOutput(await super._call(model_inputs));
1062
+ }
1063
+ }
1064
+ class DebertaForSequenceClassification extends DebertaPreTrainedModel {
1065
+ /**
1066
+ * Calls the model on new inputs.
1067
+ *
1068
+ * @param {Object} model_inputs The inputs to the model.
1069
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1070
+ */
1071
+ async _call(model_inputs) {
1072
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1073
+ }
1074
+ }
1075
+ class DebertaForTokenClassification extends DebertaPreTrainedModel {
1076
+ /**
1077
+ * Calls the model on new inputs.
1078
+ *
1079
+ * @param {Object} model_inputs The inputs to the model.
1080
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1081
+ */
1082
+ async _call(model_inputs) {
1083
+ return new TokenClassifierOutput(await super._call(model_inputs));
1084
+ }
1085
+ }
1086
+ class DebertaForQuestionAnswering extends DebertaPreTrainedModel {
1087
+ /**
1088
+ * Calls the model on new inputs.
1089
+ *
1090
+ * @param {Object} model_inputs The inputs to the model.
1091
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1092
+ */
1093
+ async _call(model_inputs) {
1094
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1095
+ }
1096
+ }
1097
+ class DebertaV2PreTrainedModel extends PreTrainedModel {
1098
+ }
1099
+ class DebertaV2Model extends DebertaV2PreTrainedModel {
1100
+ }
1101
+ class DebertaV2ForMaskedLM extends DebertaV2PreTrainedModel {
1102
+ /**
1103
+ * Calls the model on new inputs.
1104
+ *
1105
+ * @param {Object} model_inputs The inputs to the model.
1106
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1107
+ */
1108
+ async _call(model_inputs) {
1109
+ return new MaskedLMOutput(await super._call(model_inputs));
1110
+ }
1111
+ }
1112
+ class DebertaV2ForSequenceClassification extends DebertaV2PreTrainedModel {
1113
+ /**
1114
+ * Calls the model on new inputs.
1115
+ *
1116
+ * @param {Object} model_inputs The inputs to the model.
1117
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1118
+ */
1119
+ async _call(model_inputs) {
1120
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1121
+ }
1122
+ }
1123
+ class DebertaV2ForTokenClassification extends DebertaV2PreTrainedModel {
1124
+ /**
1125
+ * Calls the model on new inputs.
1126
+ *
1127
+ * @param {Object} model_inputs The inputs to the model.
1128
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1129
+ */
1130
+ async _call(model_inputs) {
1131
+ return new TokenClassifierOutput(await super._call(model_inputs));
1132
+ }
1133
+ }
1134
+ class DebertaV2ForQuestionAnswering extends DebertaV2PreTrainedModel {
1135
+ /**
1136
+ * Calls the model on new inputs.
1137
+ *
1138
+ * @param {Object} model_inputs The inputs to the model.
1139
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1140
+ */
1141
+ async _call(model_inputs) {
1142
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1143
+ }
1144
+ }
1145
+ class DistilBertPreTrainedModel extends PreTrainedModel {
1146
+ }
1147
+ class DistilBertModel extends DistilBertPreTrainedModel {
1148
+ }
1149
+ class DistilBertForSequenceClassification extends DistilBertPreTrainedModel {
1150
+ /**
1151
+ * Calls the model on new inputs.
1152
+ *
1153
+ * @param {Object} model_inputs The inputs to the model.
1154
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1155
+ */
1156
+ async _call(model_inputs) {
1157
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1158
+ }
1159
+ }
1160
+ class DistilBertForTokenClassification extends DistilBertPreTrainedModel {
1161
+ /**
1162
+ * Calls the model on new inputs.
1163
+ *
1164
+ * @param {Object} model_inputs The inputs to the model.
1165
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1166
+ */
1167
+ async _call(model_inputs) {
1168
+ return new TokenClassifierOutput(await super._call(model_inputs));
1169
+ }
1170
+ }
1171
+ class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel {
1172
+ /**
1173
+ * Calls the model on new inputs.
1174
+ *
1175
+ * @param {Object} model_inputs The inputs to the model.
1176
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1177
+ */
1178
+ async _call(model_inputs) {
1179
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1180
+ }
1181
+ }
1182
+ class DistilBertForMaskedLM extends DistilBertPreTrainedModel {
1183
+ /**
1184
+ * Calls the model on new inputs.
1185
+ *
1186
+ * @param {Object} model_inputs The inputs to the model.
1187
+ * @returns {Promise<MaskedLMOutput>} returned object
1188
+ */
1189
+ async _call(model_inputs) {
1190
+ return new MaskedLMOutput(await super._call(model_inputs));
1191
+ }
1192
+ }
1193
+ class EsmPreTrainedModel extends PreTrainedModel {
1194
+ }
1195
+ class EsmModel extends EsmPreTrainedModel {
1196
+ }
1197
+ class EsmForMaskedLM extends EsmPreTrainedModel {
1198
+ /**
1199
+ * Calls the model on new inputs.
1200
+ *
1201
+ * @param {Object} model_inputs The inputs to the model.
1202
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1203
+ */
1204
+ async _call(model_inputs) {
1205
+ return new MaskedLMOutput(await super._call(model_inputs));
1206
+ }
1207
+ }
1208
+ class EsmForSequenceClassification extends EsmPreTrainedModel {
1209
+ /**
1210
+ * Calls the model on new inputs.
1211
+ *
1212
+ * @param {Object} model_inputs The inputs to the model.
1213
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1214
+ */
1215
+ async _call(model_inputs) {
1216
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1217
+ }
1218
+ }
1219
+ class EsmForTokenClassification extends EsmPreTrainedModel {
1220
+ /**
1221
+ * Calls the model on new inputs.
1222
+ *
1223
+ * @param {Object} model_inputs The inputs to the model.
1224
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1225
+ */
1226
+ async _call(model_inputs) {
1227
+ return new TokenClassifierOutput(await super._call(model_inputs));
1228
+ }
1229
+ }
1230
+ class MobileBertPreTrainedModel extends PreTrainedModel {
1231
+ }
1232
+ class MobileBertModel extends MobileBertPreTrainedModel {
1233
+ }
1234
+ class MobileBertForMaskedLM extends MobileBertPreTrainedModel {
1235
+ /**
1236
+ * Calls the model on new inputs.
1237
+ *
1238
+ * @param {Object} model_inputs The inputs to the model.
1239
+ * @returns {Promise<MaskedLMOutput>} returned object
1240
+ */
1241
+ async _call(model_inputs) {
1242
+ return new MaskedLMOutput(await super._call(model_inputs));
1243
+ }
1244
+ }
1245
+ class MobileBertForSequenceClassification extends MobileBertPreTrainedModel {
1246
+ /**
1247
+ * Calls the model on new inputs.
1248
+ *
1249
+ * @param {Object} model_inputs The inputs to the model.
1250
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1251
+ */
1252
+ async _call(model_inputs) {
1253
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1254
+ }
1255
+ }
1256
+ class MobileBertForQuestionAnswering extends MobileBertPreTrainedModel {
1257
+ /**
1258
+ * Calls the model on new inputs.
1259
+ *
1260
+ * @param {Object} model_inputs The inputs to the model.
1261
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1262
+ */
1263
+ async _call(model_inputs) {
1264
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1265
+ }
1266
+ }
1267
+ class MPNetPreTrainedModel extends PreTrainedModel {
1268
+ }
1269
+ class MPNetModel extends MPNetPreTrainedModel {
1270
+ }
1271
+ class MPNetForMaskedLM extends MPNetPreTrainedModel {
1272
+ /**
1273
+ * Calls the model on new inputs.
1274
+ *
1275
+ * @param {Object} model_inputs The inputs to the model.
1276
+ * @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1277
+ */
1278
+ async _call(model_inputs) {
1279
+ return new MaskedLMOutput(await super._call(model_inputs));
1280
+ }
1281
+ }
1282
+ class MPNetForSequenceClassification extends MPNetPreTrainedModel {
1283
+ /**
1284
+ * Calls the model on new inputs.
1285
+ *
1286
+ * @param {Object} model_inputs The inputs to the model.
1287
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1288
+ */
1289
+ async _call(model_inputs) {
1290
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1291
+ }
1292
+ }
1293
+ class MPNetForTokenClassification extends MPNetPreTrainedModel {
1294
+ /**
1295
+ * Calls the model on new inputs.
1296
+ *
1297
+ * @param {Object} model_inputs The inputs to the model.
1298
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1299
+ */
1300
+ async _call(model_inputs) {
1301
+ return new TokenClassifierOutput(await super._call(model_inputs));
1302
+ }
1303
+ }
1304
+ class MPNetForQuestionAnswering extends MPNetPreTrainedModel {
1305
+ /**
1306
+ * Calls the model on new inputs.
1307
+ *
1308
+ * @param {Object} model_inputs The inputs to the model.
1309
+ * @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1310
+ */
1311
+ async _call(model_inputs) {
1312
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1313
+ }
1314
+ }
1315
+ class SqueezeBertPreTrainedModel extends PreTrainedModel {
1316
+ }
1317
+ class SqueezeBertModel extends SqueezeBertPreTrainedModel {
1318
+ }
1319
+ class SqueezeBertForMaskedLM extends SqueezeBertPreTrainedModel {
1320
+ /**
1321
+ * Calls the model on new inputs.
1322
+ *
1323
+ * @param {Object} model_inputs The inputs to the model.
1324
+ * @returns {Promise<MaskedLMOutput>} returned object
1325
+ */
1326
+ async _call(model_inputs) {
1327
+ return new MaskedLMOutput(await super._call(model_inputs));
1328
+ }
1329
+ }
1330
+ class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedModel {
1331
+ /**
1332
+ * Calls the model on new inputs.
1333
+ *
1334
+ * @param {Object} model_inputs The inputs to the model.
1335
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1336
+ */
1337
+ async _call(model_inputs) {
1338
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1339
+ }
1340
+ }
1341
+ class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel {
1342
+ /**
1343
+ * Calls the model on new inputs.
1344
+ *
1345
+ * @param {Object} model_inputs The inputs to the model.
1346
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1347
+ */
1348
+ async _call(model_inputs) {
1349
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1350
+ }
1351
+ }
1352
+ class AlbertPreTrainedModel extends PreTrainedModel {
1353
+ }
1354
+ class AlbertModel extends AlbertPreTrainedModel {
1355
+ }
1356
+ class AlbertForSequenceClassification extends AlbertPreTrainedModel {
1357
+ /**
1358
+ * Calls the model on new inputs.
1359
+ *
1360
+ * @param {Object} model_inputs The inputs to the model.
1361
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1362
+ */
1363
+ async _call(model_inputs) {
1364
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1365
+ }
1366
+ }
1367
+ class AlbertForQuestionAnswering extends AlbertPreTrainedModel {
1368
+ /**
1369
+ * Calls the model on new inputs.
1370
+ *
1371
+ * @param {Object} model_inputs The inputs to the model.
1372
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1373
+ */
1374
+ async _call(model_inputs) {
1375
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1376
+ }
1377
+ }
1378
+ class AlbertForMaskedLM extends AlbertPreTrainedModel {
1379
+ /**
1380
+ * Calls the model on new inputs.
1381
+ *
1382
+ * @param {Object} model_inputs The inputs to the model.
1383
+ * @returns {Promise<MaskedLMOutput>} returned object
1384
+ */
1385
+ async _call(model_inputs) {
1386
+ return new MaskedLMOutput(await super._call(model_inputs));
1387
+ }
1388
+ }
1389
+ class T5PreTrainedModel extends PreTrainedModel {
1390
+ }
1391
+ class T5Model extends T5PreTrainedModel {
1392
+ }
1393
+ class T5ForConditionalGeneration extends T5PreTrainedModel {
1394
+ /**
1395
+ * Creates a new instance of the `T5ForConditionalGeneration` class.
1396
+ * @param {Object} config The model configuration.
1397
+ * @param {any} session session for the model.
1398
+ * @param {any} decoder_merged_session session for the decoder.
1399
+ * @param {GenerationConfig} generation_config The generation configuration.
1400
+ */
1401
+ constructor(config, session, decoder_merged_session, generation_config) {
1402
+ super(config, session);
1403
+ this.decoder_merged_session = decoder_merged_session;
1404
+ this.generation_config = generation_config;
1405
+ this.num_decoder_layers = this.config.num_decoder_layers;
1406
+ this.num_decoder_heads = this.config.num_heads;
1407
+ this.decoder_dim_kv = this.config.d_kv;
1408
+ this.num_encoder_layers = this.config.num_layers;
1409
+ this.num_encoder_heads = this.config.num_heads;
1410
+ this.encoder_dim_kv = this.config.d_kv;
1411
+ }
1412
+ }
1413
+ class LongT5PreTrainedModel extends PreTrainedModel {
1414
+ }
1415
+ class LongT5Model extends LongT5PreTrainedModel {
1416
+ }
1417
+ class LongT5ForConditionalGeneration extends LongT5PreTrainedModel {
1418
+ /**
1419
+ * Creates a new instance of the `LongT5ForConditionalGeneration` class.
1420
+ * @param {Object} config The model configuration.
1421
+ * @param {any} session session for the model.
1422
+ * @param {any} decoder_merged_session session for the decoder.
1423
+ * @param {GenerationConfig} generation_config The generation configuration.
1424
+ */
1425
+ constructor(config, session, decoder_merged_session, generation_config) {
1426
+ super(config, session);
1427
+ this.decoder_merged_session = decoder_merged_session;
1428
+ this.generation_config = generation_config;
1429
+ this.num_decoder_layers = this.config.num_decoder_layers;
1430
+ this.num_decoder_heads = this.config.num_heads;
1431
+ this.decoder_dim_kv = this.config.d_kv;
1432
+ this.num_encoder_layers = this.config.num_layers;
1433
+ this.num_encoder_heads = this.config.num_heads;
1434
+ this.encoder_dim_kv = this.config.d_kv;
1435
+ }
1436
+ }
1437
+ class MT5PreTrainedModel extends PreTrainedModel {
1438
+ }
1439
+ class MT5Model extends MT5PreTrainedModel {
1440
+ }
1441
+ class MT5ForConditionalGeneration extends MT5PreTrainedModel {
1442
+ /**
1443
+ * Creates a new instance of the `MT5ForConditionalGeneration` class.
1444
+ * @param {any} config The model configuration.
1445
+ * @param {any} session The ONNX session containing the encoder weights.
1446
+ * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
1447
+ * @param {GenerationConfig} generation_config The generation configuration.
1448
+ */
1449
+ constructor(config, session, decoder_merged_session, generation_config) {
1450
+ super(config, session);
1451
+ this.decoder_merged_session = decoder_merged_session;
1452
+ this.generation_config = generation_config;
1453
+ this.num_decoder_layers = this.config.num_decoder_layers;
1454
+ this.num_decoder_heads = this.config.num_heads;
1455
+ this.decoder_dim_kv = this.config.d_kv;
1456
+ this.num_encoder_layers = this.config.num_layers;
1457
+ this.num_encoder_heads = this.config.num_heads;
1458
+ this.encoder_dim_kv = this.config.d_kv;
1459
+ }
1460
+ }
1461
+ class BartPretrainedModel extends PreTrainedModel {
1462
+ }
1463
+ class BartModel extends BartPretrainedModel {
1464
+ }
1465
+ class BartForConditionalGeneration extends BartPretrainedModel {
1466
+ /**
1467
+ * Creates a new instance of the `BartForConditionalGeneration` class.
1468
+ * @param {Object} config The configuration object for the Bart model.
1469
+ * @param {Object} session The ONNX session used to execute the model.
1470
+ * @param {Object} decoder_merged_session The ONNX session used to execute the decoder.
1471
+ * @param {Object} generation_config The generation configuration object.
1472
+ */
1473
+ constructor(config, session, decoder_merged_session, generation_config) {
1474
+ super(config, session);
1475
+ this.decoder_merged_session = decoder_merged_session;
1476
+ this.generation_config = generation_config;
1477
+ this.num_decoder_layers = this.config.decoder_layers;
1478
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1479
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1480
+ this.num_encoder_layers = this.config.encoder_layers;
1481
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1482
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1483
+ }
1484
+ }
1485
+ class BartForSequenceClassification extends BartPretrainedModel {
1486
+ /**
1487
+ * Calls the model on new inputs.
1488
+ *
1489
+ * @param {Object} model_inputs The inputs to the model.
1490
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1491
+ */
1492
+ async _call(model_inputs) {
1493
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1494
+ }
1495
+ }
1496
+ class MBartPreTrainedModel extends PreTrainedModel {
1497
+ }
1498
+ class MBartModel extends MBartPreTrainedModel {
1499
+ }
1500
+ class MBartForConditionalGeneration extends MBartPreTrainedModel {
1501
+ /**
1502
+ * Creates a new instance of the `MBartForConditionalGeneration` class.
1503
+ * @param {Object} config The configuration object for the Bart model.
1504
+ * @param {Object} session The ONNX session used to execute the model.
1505
+ * @param {Object} decoder_merged_session The ONNX session used to execute the decoder.
1506
+ * @param {Object} generation_config The generation configuration object.
1507
+ */
1508
+ constructor(config, session, decoder_merged_session, generation_config) {
1509
+ super(config, session);
1510
+ this.decoder_merged_session = decoder_merged_session;
1511
+ this.generation_config = generation_config;
1512
+ this.num_decoder_layers = this.config.decoder_layers;
1513
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1514
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1515
+ this.num_encoder_layers = this.config.encoder_layers;
1516
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1517
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1518
+ }
1519
+ }
1520
+ class MBartForSequenceClassification extends MBartPreTrainedModel {
1521
+ /**
1522
+ * Calls the model on new inputs.
1523
+ *
1524
+ * @param {Object} model_inputs The inputs to the model.
1525
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1526
+ */
1527
+ async _call(model_inputs) {
1528
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1529
+ }
1530
+ }
1531
+ class MBartForCausalLM extends MBartPreTrainedModel {
1532
+ /**
1533
+ * Creates a new instance of the `MBartForCausalLM` class.
1534
+ * @param {Object} config Configuration object for the model.
1535
+ * @param {Object} decoder_merged_session ONNX Session object for the decoder.
1536
+ * @param {Object} generation_config Configuration object for the generation process.
1537
+ */
1538
+ constructor(config, decoder_merged_session, generation_config) {
1539
+ super(config, decoder_merged_session);
1540
+ this.generation_config = generation_config;
1541
+ this.num_decoder_layers = this.config.decoder_layers;
1542
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1543
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1544
+ this.num_encoder_layers = this.config.encoder_layers;
1545
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1546
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1547
+ }
1548
+ }
1549
+ class BlenderbotPreTrainedModel extends PreTrainedModel {
1550
+ }
1551
+ class BlenderbotModel extends BlenderbotPreTrainedModel {
1552
+ }
1553
+ class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel {
1554
+ /**
1555
+ * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
1556
+ * @param {any} config The model configuration.
1557
+ * @param {any} session The ONNX session containing the encoder weights.
1558
+ * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
1559
+ * @param {GenerationConfig} generation_config The generation configuration.
1560
+ */
1561
+ constructor(config, session, decoder_merged_session, generation_config) {
1562
+ super(config, session);
1563
+ this.decoder_merged_session = decoder_merged_session;
1564
+ this.generation_config = generation_config;
1565
+ this.num_decoder_layers = this.config.decoder_layers;
1566
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1567
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1568
+ this.num_encoder_layers = this.config.encoder_layers;
1569
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1570
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1571
+ }
1572
+ }
1573
+ class BlenderbotSmallPreTrainedModel extends PreTrainedModel {
1574
+ }
1575
+ class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel {
1576
+ }
1577
+ class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel {
1578
+ /**
1579
+ * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
1580
+ * @param {any} config The model configuration.
1581
+ * @param {any} session The ONNX session containing the encoder weights.
1582
+ * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
1583
+ * @param {GenerationConfig} generation_config The generation configuration.
1584
+ */
1585
+ constructor(config, session, decoder_merged_session, generation_config) {
1586
+ super(config, session);
1587
+ this.decoder_merged_session = decoder_merged_session;
1588
+ this.generation_config = generation_config;
1589
+ this.num_decoder_layers = this.config.decoder_layers;
1590
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1591
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1592
+ this.num_encoder_layers = this.config.encoder_layers;
1593
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1594
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1595
+ }
1596
+ }
1597
+ class RobertaPreTrainedModel extends PreTrainedModel {
1598
+ }
1599
+ class RobertaModel extends RobertaPreTrainedModel {
1600
+ }
1601
+ class RobertaForMaskedLM extends RobertaPreTrainedModel {
1602
+ /**
1603
+ * Calls the model on new inputs.
1604
+ *
1605
+ * @param {Object} model_inputs The inputs to the model.
1606
+ * @returns {Promise<MaskedLMOutput>} returned object
1607
+ */
1608
+ async _call(model_inputs) {
1609
+ return new MaskedLMOutput(await super._call(model_inputs));
1610
+ }
1611
+ }
1612
+ class RobertaForSequenceClassification extends RobertaPreTrainedModel {
1613
+ /**
1614
+ * Calls the model on new inputs.
1615
+ *
1616
+ * @param {Object} model_inputs The inputs to the model.
1617
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1618
+ */
1619
+ async _call(model_inputs) {
1620
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1621
+ }
1622
+ }
1623
+ class RobertaForTokenClassification extends RobertaPreTrainedModel {
1624
+ /**
1625
+ * Calls the model on new inputs.
1626
+ *
1627
+ * @param {Object} model_inputs The inputs to the model.
1628
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1629
+ */
1630
+ async _call(model_inputs) {
1631
+ return new TokenClassifierOutput(await super._call(model_inputs));
1632
+ }
1633
+ }
1634
+ class RobertaForQuestionAnswering extends RobertaPreTrainedModel {
1635
+ /**
1636
+ * Calls the model on new inputs.
1637
+ *
1638
+ * @param {Object} model_inputs The inputs to the model.
1639
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1640
+ */
1641
+ async _call(model_inputs) {
1642
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1643
+ }
1644
+ }
1645
+ class XLMPreTrainedModel extends PreTrainedModel {
1646
+ }
1647
+ class XLMModel extends XLMPreTrainedModel {
1648
+ }
1649
+ class XLMWithLMHeadModel extends XLMPreTrainedModel {
1650
+ /**
1651
+ * Calls the model on new inputs.
1652
+ *
1653
+ * @param {Object} model_inputs The inputs to the model.
1654
+ * @returns {Promise<MaskedLMOutput>} returned object
1655
+ */
1656
+ async _call(model_inputs) {
1657
+ return new MaskedLMOutput(await super._call(model_inputs));
1658
+ }
1659
+ }
1660
+ class XLMForSequenceClassification extends XLMPreTrainedModel {
1661
+ /**
1662
+ * Calls the model on new inputs.
1663
+ *
1664
+ * @param {Object} model_inputs The inputs to the model.
1665
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1666
+ */
1667
+ async _call(model_inputs) {
1668
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1669
+ }
1670
+ }
1671
+ class XLMForTokenClassification extends XLMPreTrainedModel {
1672
+ /**
1673
+ * Calls the model on new inputs.
1674
+ *
1675
+ * @param {Object} model_inputs The inputs to the model.
1676
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1677
+ */
1678
+ async _call(model_inputs) {
1679
+ return new TokenClassifierOutput(await super._call(model_inputs));
1680
+ }
1681
+ }
1682
+ class XLMForQuestionAnswering extends XLMPreTrainedModel {
1683
+ /**
1684
+ * Calls the model on new inputs.
1685
+ *
1686
+ * @param {Object} model_inputs The inputs to the model.
1687
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1688
+ */
1689
+ async _call(model_inputs) {
1690
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1691
+ }
1692
+ }
1693
+ class XLMRobertaPreTrainedModel extends PreTrainedModel {
1694
+ }
1695
+ class XLMRobertaModel extends XLMRobertaPreTrainedModel {
1696
+ }
1697
+ class XLMRobertaForMaskedLM extends XLMRobertaPreTrainedModel {
1698
+ /**
1699
+ * Calls the model on new inputs.
1700
+ *
1701
+ * @param {Object} model_inputs The inputs to the model.
1702
+ * @returns {Promise<MaskedLMOutput>} returned object
1703
+ */
1704
+ async _call(model_inputs) {
1705
+ return new MaskedLMOutput(await super._call(model_inputs));
1706
+ }
1707
+ }
1708
+ class XLMRobertaForSequenceClassification extends XLMRobertaPreTrainedModel {
1709
+ /**
1710
+ * Calls the model on new inputs.
1711
+ *
1712
+ * @param {Object} model_inputs The inputs to the model.
1713
+ * @returns {Promise<SequenceClassifierOutput>} returned object
1714
+ */
1715
+ async _call(model_inputs) {
1716
+ return new SequenceClassifierOutput(await super._call(model_inputs));
1717
+ }
1718
+ }
1719
+ class XLMRobertaForTokenClassification extends XLMRobertaPreTrainedModel {
1720
+ /**
1721
+ * Calls the model on new inputs.
1722
+ *
1723
+ * @param {Object} model_inputs The inputs to the model.
1724
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1725
+ */
1726
+ async _call(model_inputs) {
1727
+ return new TokenClassifierOutput(await super._call(model_inputs));
1728
+ }
1729
+ }
1730
+ class XLMRobertaForQuestionAnswering extends XLMRobertaPreTrainedModel {
1731
+ /**
1732
+ * Calls the model on new inputs.
1733
+ *
1734
+ * @param {Object} model_inputs The inputs to the model.
1735
+ * @returns {Promise<QuestionAnsweringModelOutput>} returned object
1736
+ */
1737
+ async _call(model_inputs) {
1738
+ return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1739
+ }
1740
+ }
1741
+ class ASTPreTrainedModel extends PreTrainedModel {
1742
+ }
1743
+ class ASTModel extends ASTPreTrainedModel {
1744
+ }
1745
+ class ASTForAudioClassification extends ASTPreTrainedModel {
1746
+ }
1747
+ class WhisperPreTrainedModel extends PreTrainedModel {
1748
+ }
1749
+ class WhisperModel extends WhisperPreTrainedModel {
1750
+ }
1751
+ class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
1752
+ requires_attention_mask = false;
1753
+ main_input_name = "input_features";
1754
+ /**
1755
+ * Creates a new instance of the `WhisperForConditionalGeneration` class.
1756
+ * @param {Object} config Configuration object for the model.
1757
+ * @param {Object} session ONNX Session object for the model.
1758
+ * @param {Object} decoder_merged_session ONNX Session object for the decoder.
1759
+ * @param {Object} generation_config Configuration object for the generation process.
1760
+ */
1761
+ constructor(config, session, decoder_merged_session, generation_config) {
1762
+ super(config, session);
1763
+ this.decoder_merged_session = decoder_merged_session;
1764
+ this.generation_config = generation_config;
1765
+ this.num_decoder_layers = this.config.decoder_layers;
1766
+ this.num_decoder_heads = this.config.decoder_attention_heads;
1767
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
1768
+ this.num_encoder_layers = this.config.encoder_layers;
1769
+ this.num_encoder_heads = this.config.encoder_attention_heads;
1770
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
1771
+ }
1772
+ /**
1773
+ * @typedef {Object} WhisperGenerationConfig
1774
+ * @extends GenerationConfig
1775
+ * @property {boolean} [return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
1776
+ * @property {boolean} [return_token_timestamps=null] Whether to return token-level timestamps
1777
+ * with the text. This can be used with or without the `return_timestamps` option. To get word-level
1778
+ * timestamps, use the tokenizer to group the tokens into words.
1779
+ * @property {number} [num_frames=null] The number of audio frames available in this chunk. This is only used generating word-level timestamps.
1780
+ */
1781
+ /**
1782
+ * Generates outputs based on input and generation configuration.
1783
+ * @param {Object} inputs Input data for the model.
1784
+ * @param {WhisperGenerationConfig} generation_config Configuration object for the generation process.
1785
+ * @param {Object} logits_processor Optional logits processor object.
1786
+ * @returns {Promise<Object>} Promise object represents the generated outputs.
1787
+ */
1788
+ async generate(inputs, generation_config = null, logits_processor = null) {
1789
+ generation_config = this._get_generation_config(generation_config);
1790
+ generation_config.return_timestamps ??= false;
1791
+ if (generation_config.return_timestamps) {
1792
+ logits_processor = [new WhisperTimeStampLogitsProcessor(generation_config)];
1793
+ }
1794
+ if (generation_config.return_token_timestamps) {
1795
+ generation_config.output_attentions = true;
1796
+ generation_config.return_dict_in_generate = true;
1797
+ if (generation_config.task === "translate") {
1798
+ console.warn("Token-level timestamps may not be reliable for task 'translate'.");
1799
+ }
1800
+ if (!generation_config.alignment_heads) {
1801
+ throw new Error(
1802
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1803
+ );
1804
+ }
1805
+ }
1806
+ const outputs = await super.generate(inputs, generation_config, logits_processor);
1807
+ if (generation_config.return_token_timestamps && generation_config.alignment_heads) {
1808
+ outputs["token_timestamps"] = this._extract_token_timestamps(
1809
+ outputs,
1810
+ generation_config.alignment_heads,
1811
+ generation_config.num_frames
1812
+ );
1813
+ }
1814
+ return outputs;
1815
+ }
1816
+ /**
1817
+ * Calculates token-level timestamps using the encoder-decoder cross-attentions and
1818
+ * dynamic time-warping (DTW) to map each output token to a position in the input audio.
1819
+ * @param {Object} generate_outputs Outputs generated by the model
1820
+ * @param {Tensor[][][]} generate_outputs.cross_attentions The cross attentions output by the model
1821
+ * @param {Tensor[][][]} generate_outputs.decoder_attentions The decoder attentions output by the model
1822
+ * @param {number[][]} generate_outputs.sequences The sequences output by the model
1823
+ * @param {number[][]} alignment_heads Alignment heads of the model
1824
+ * @param {number} [num_frames=null] Number of frames in the input audio.
1825
+ * @param {number} [time_precision=0.02] Precision of the timestamps in seconds
1826
+ * @returns {Tensor} tensor containing the timestamps in seconds for each predicted token
1827
+ */
1828
+ _extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
1829
+ if (!generate_outputs.cross_attentions) {
1830
+ throw new Error(
1831
+ "Model outputs must contain cross attentions to extract timestamps. This is most likely because the model was not exported with `output_attentions=True`."
1832
+ );
1833
+ }
1834
+ let median_filter_width = this.config.median_filter_width;
1835
+ if (median_filter_width === void 0) {
1836
+ console.warn("Model config has no `median_filter_width`, using default value of 7.");
1837
+ median_filter_width = 7;
1838
+ }
1839
+ const batchedMatrices = generate_outputs.cross_attentions.map((batch) => {
1840
+ let cross_attentions = Array.from(
1841
+ { length: this.config.decoder_layers },
1842
+ (_, i) => cat(batch.map((x) => x[i]), 2)
1843
+ );
1844
+ let weights = stack(alignment_heads.map(([l, h]) => {
1845
+ return num_frames ? cross_attentions[l].slice(null, h, null, [0, num_frames]) : cross_attentions[l].slice(null, h);
1846
+ }));
1847
+ weights = weights.transpose(1, 0, 2, 3);
1848
+ let [std, calculatedMean] = std_mean(weights, -2, 0, true);
1849
+ let smoothedWeights = weights.clone();
1850
+ for (let a = 0; a < smoothedWeights.dims[0]; ++a) {
1851
+ let aTensor = smoothedWeights[a];
1852
+ for (let b = 0; b < aTensor.dims[0]; ++b) {
1853
+ let bTensor = aTensor[b];
1854
+ const stdTensor = std[a][b][0];
1855
+ const meanTensor = calculatedMean[a][b][0];
1856
+ for (let c = 0; c < bTensor.dims[0]; ++c) {
1857
+ let cTensor = bTensor[c];
1858
+ for (let d = 0; d < cTensor.data.length; ++d) {
1859
+ cTensor.data[d] = (cTensor.data[d] - meanTensor.data[d]) / stdTensor.data[d];
1860
+ }
1861
+ cTensor.data.set(medianFilter(cTensor.data, median_filter_width));
1862
+ }
1863
+ }
1864
+ }
1865
+ const matrix = mean(smoothedWeights, 1);
1866
+ return matrix;
1867
+ });
1868
+ const timestampsShape = [generate_outputs.sequences.length, generate_outputs.sequences[0].length];
1869
+ const timestamps = new Tensor(
1870
+ "float32",
1871
+ new Float32Array(timestampsShape[0] * timestampsShape[1]),
1872
+ timestampsShape
1873
+ );
1874
+ for (let batch_idx = 0; batch_idx < timestampsShape[0]; ++batch_idx) {
1875
+ const matrix = batchedMatrices[batch_idx].neg().squeeze_(0);
1876
+ let [text_indices, time_indices] = dynamicTimeWarping(matrix);
1877
+ let diffs = Array.from({ length: text_indices.length - 1 }, (v, i) => text_indices[i + 1] - text_indices[i]);
1878
+ let jumps = mergeArrays([1], diffs).map((x) => !!x);
1879
+ let jump_times = [];
1880
+ for (let i = 0; i < jumps.length; ++i) {
1881
+ if (jumps[i]) {
1882
+ jump_times.push(time_indices[i] * time_precision);
1883
+ }
1884
+ }
1885
+ timestamps[batch_idx].data.set(jump_times, 1);
1886
+ }
1887
+ return timestamps;
1888
+ }
1889
+ }
1890
+ class VisionEncoderDecoderModel extends PreTrainedModel {
1891
+ main_input_name = "pixel_values";
1892
+ /**
1893
+ * Creates a new instance of the `VisionEncoderDecoderModel` class.
1894
+ * @param {Object} config The configuration object specifying the hyperparameters and other model settings.
1895
+ * @param {Object} session The ONNX session containing the encoder model.
1896
+ * @param {any} decoder_merged_session The ONNX session containing the merged decoder model.
1897
+ * @param {Object} generation_config Configuration object for the generation process.
1898
+ */
1899
+ constructor(config, session, decoder_merged_session, generation_config) {
1900
+ super(config, session);
1901
+ this.decoder_merged_session = decoder_merged_session;
1902
+ this.generation_config = generation_config;
1903
+ const encoderConfig = this.config.encoder;
1904
+ const decoderConfig = this.config.decoder;
1905
+ const encoderModelType = encoderConfig.model_type;
1906
+ const encoderModel = MODEL_MAPPING_NAMES_ENCODER_ONLY.get(encoderModelType) ?? MODEL_MAPPING_NAMES_ENCODER_DECODER.get(encoderModelType);
1907
+ if (!encoderModel) {
1908
+ console.warn(`Model type for encoder '${encoderModelType}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`);
1909
+ }
1910
+ const decoderModel = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(decoderConfig.model_type);
1911
+ if (!decoderModel) {
1912
+ throw new Error(`Unable to construct \`VisionEncoderDecoder\` due to unsupported decoder: "${this.config.decoder.model_type}"`);
1913
+ }
1914
+ const decoderModelClass = decoderModel[1];
1915
+ const decoder = new decoderModelClass(decoderConfig, decoder_merged_session, generation_config);
1916
+ this.add_encoder_pkv = "num_decoder_layers" in decoder;
1917
+ if (this.add_encoder_pkv) {
1918
+ this.num_decoder_layers = decoder.num_decoder_layers;
1919
+ this.num_decoder_heads = decoder.num_decoder_heads;
1920
+ this.decoder_dim_kv = decoder.decoder_dim_kv;
1921
+ this.num_encoder_layers = decoder.num_encoder_layers;
1922
+ this.num_encoder_heads = decoder.num_encoder_heads;
1923
+ this.encoder_dim_kv = decoder.encoder_dim_kv;
1924
+ } else {
1925
+ this.num_layers = decoder.num_layers;
1926
+ this.num_heads = decoder.num_heads;
1927
+ this.dim_kv = decoder.dim_kv;
1928
+ }
1929
+ }
1930
+ }
1931
+ class CLIPPreTrainedModel extends PreTrainedModel {
1932
+ }
1933
+ class CLIPModel extends CLIPPreTrainedModel {
1934
+ }
1935
+ class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
1936
+ /** @type {PreTrainedModel.from_pretrained} */
1937
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
1938
+ options.model_file_name ??= "text_model";
1939
+ return super.from_pretrained(pretrained_model_name_or_path, options);
1940
+ }
1941
+ }
1942
+ class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
1943
+ /** @type {PreTrainedModel.from_pretrained} */
1944
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
1945
+ options.model_file_name ??= "vision_model";
1946
+ return super.from_pretrained(pretrained_model_name_or_path, options);
1947
+ }
1948
+ }
1949
+ class SiglipPreTrainedModel extends PreTrainedModel {
1950
+ }
1951
+ class SiglipModel extends SiglipPreTrainedModel {
1952
+ }
1953
+ class SiglipTextModel extends SiglipPreTrainedModel {
1954
+ /** @type {PreTrainedModel.from_pretrained} */
1955
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
1956
+ options.model_file_name ??= "text_model";
1957
+ return super.from_pretrained(pretrained_model_name_or_path, options);
1958
+ }
1959
+ }
1960
+ class SiglipVisionModel extends CLIPPreTrainedModel {
1961
+ /** @type {PreTrainedModel.from_pretrained} */
1962
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
1963
+ options.model_file_name ??= "vision_model";
1964
+ return super.from_pretrained(pretrained_model_name_or_path, options);
1965
+ }
1966
+ }
1967
+ class ChineseCLIPPreTrainedModel extends PreTrainedModel {
1968
+ }
1969
+ class ChineseCLIPModel extends ChineseCLIPPreTrainedModel {
1970
+ }
1971
+ class CLIPSegPreTrainedModel extends PreTrainedModel {
1972
+ }
1973
+ class CLIPSegModel extends CLIPSegPreTrainedModel {
1974
+ }
1975
+ class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel {
1976
+ }
1977
+ class GPT2PreTrainedModel extends PreTrainedModel {
1978
+ /**
1979
+ * Creates a new instance of the `GPT2PreTrainedModel` class.
1980
+ * @param {Object} config The configuration of the model.
1981
+ * @param {any} session The ONNX session containing the model weights.
1982
+ * @param {GenerationConfig} generation_config The generation configuration.
1983
+ */
1984
+ constructor(config, session, generation_config) {
1985
+ super(config, session);
1986
+ this.generation_config = generation_config;
1987
+ this.config.pad_token_id = this.config.eos_token_id;
1988
+ this.num_heads = this.config.n_head;
1989
+ this.num_layers = this.config.n_layer;
1990
+ this.dim_kv = this.config.n_embd / this.num_heads;
1991
+ }
1992
+ }
1993
+ class GPT2Model extends GPT2PreTrainedModel {
1994
+ }
1995
+ class GPT2LMHeadModel extends GPT2PreTrainedModel {
1996
+ }
1997
+ class GPTNeoPreTrainedModel extends PreTrainedModel {
1998
+ /**
1999
+ * Creates a new instance of the `GPTNeoPreTrainedModel` class.
2000
+ * @param {Object} config The configuration of the model.
2001
+ * @param {any} session The ONNX session containing the model weights.
2002
+ * @param {GenerationConfig} generation_config The generation configuration.
2003
+ */
2004
+ constructor(config, session, generation_config) {
2005
+ super(config, session);
2006
+ this.generation_config = generation_config;
2007
+ this.config.pad_token_id = this.config.eos_token_id;
2008
+ this.num_heads = this.config.num_heads;
2009
+ this.num_layers = this.config.num_layers;
2010
+ this.dim_kv = this.config.hidden_size / this.num_heads;
2011
+ }
2012
+ }
2013
+ class GPTNeoModel extends GPTNeoPreTrainedModel {
2014
+ }
2015
+ class GPTNeoForCausalLM extends GPTNeoPreTrainedModel {
2016
+ }
2017
+ class GPTNeoXPreTrainedModel extends PreTrainedModel {
2018
+ /**
2019
+ * Creates a new instance of the `GPTNeoXPreTrainedModel` class.
2020
+ * @param {Object} config The configuration of the model.
2021
+ * @param {any} session The ONNX session containing the model weights.
2022
+ * @param {GenerationConfig} generation_config The generation configuration.
2023
+ */
2024
+ constructor(config, session, generation_config) {
2025
+ super(config, session);
2026
+ this.generation_config = generation_config;
2027
+ this.config.pad_token_id = this.config.eos_token_id;
2028
+ this.num_heads = this.config.num_attention_heads;
2029
+ this.num_layers = this.config.num_hidden_layers;
2030
+ this.dim_kv = this.config.hidden_size / this.num_heads;
2031
+ }
2032
+ }
2033
+ class GPTNeoXModel extends GPTNeoXPreTrainedModel {
2034
+ }
2035
+ class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel {
2036
+ }
2037
+ class GPTJPreTrainedModel extends PreTrainedModel {
2038
+ /**
2039
+ * Creates a new instance of the `GPTJPreTrainedModel` class.
2040
+ * @param {Object} config The configuration of the model.
2041
+ * @param {any} session The ONNX session containing the model weights.
2042
+ * @param {GenerationConfig} generation_config The generation configuration.
2043
+ */
2044
+ constructor(config, session, generation_config) {
2045
+ super(config, session);
2046
+ this.generation_config = generation_config;
2047
+ this.config.pad_token_id = this.config.eos_token_id;
2048
+ this.num_heads = this.config.n_head;
2049
+ this.num_layers = this.config.n_layer;
2050
+ this.dim_kv = this.config.n_embd / this.num_heads;
2051
+ }
2052
+ }
2053
+ class GPTJModel extends GPTJPreTrainedModel {
2054
+ }
2055
+ class GPTJForCausalLM extends GPTJPreTrainedModel {
2056
+ }
2057
+ class GPTBigCodePreTrainedModel extends PreTrainedModel {
2058
+ /**
2059
+ * Creates a new instance of the `GPTBigCodePreTrainedModel` class.
2060
+ * @param {Object} config The configuration of the model.
2061
+ * @param {any} session The ONNX session containing the model weights.
2062
+ * @param {GenerationConfig} generation_config The generation configuration.
2063
+ */
2064
+ constructor(config, session, generation_config) {
2065
+ super(config, session);
2066
+ this.generation_config = generation_config;
2067
+ this.config.pad_token_id = this.config.eos_token_id;
2068
+ this.num_heads = this.config.n_head;
2069
+ this.num_layers = this.config.n_layer;
2070
+ this.dim_kv = this.config.n_embd / this.num_heads;
2071
+ }
2072
+ }
2073
+ class GPTBigCodeModel extends GPTBigCodePreTrainedModel {
2074
+ }
2075
+ class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel {
2076
+ }
2077
+ class CodeGenPreTrainedModel extends PreTrainedModel {
2078
+ /**
2079
+ * Creates a new instance of the `CodeGenPreTrainedModel` class.
2080
+ * @param {Object} config The model configuration object.
2081
+ * @param {Object} session The ONNX session object.
2082
+ * @param {GenerationConfig} generation_config The generation configuration.
2083
+ */
2084
+ constructor(config, session, generation_config) {
2085
+ super(config, session);
2086
+ this.generation_config = generation_config;
2087
+ this.config.pad_token_id = this.config.eos_token_id;
2088
+ this.num_heads = this.config.n_head;
2089
+ this.num_layers = this.config.n_layer;
2090
+ this.dim_kv = this.config.n_embd / this.num_heads;
2091
+ }
2092
+ }
2093
+ class CodeGenModel extends CodeGenPreTrainedModel {
2094
+ }
2095
+ class CodeGenForCausalLM extends CodeGenPreTrainedModel {
2096
+ }
2097
+ class LlamaPreTrainedModel extends PreTrainedModel {
2098
+ /**
2099
+ * Creates a new instance of the `LlamaPreTrainedModel` class.
2100
+ * @param {Object} config The model configuration object.
2101
+ * @param {Object} session The ONNX session object.
2102
+ * @param {GenerationConfig} generation_config The generation configuration.
2103
+ */
2104
+ constructor(config, session, generation_config) {
2105
+ super(config, session);
2106
+ this.generation_config = generation_config;
2107
+ this.config.pad_token_id = this.config.eos_token_id;
2108
+ this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads;
2109
+ this.num_layers = this.config.num_hidden_layers;
2110
+ this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
2111
+ }
2112
+ }
2113
+ class LlamaModel extends LlamaPreTrainedModel {
2114
+ }
2115
+ class LlamaForCausalLM extends LlamaPreTrainedModel {
2116
+ }
2117
+ class Qwen2PreTrainedModel extends PreTrainedModel {
2118
+ /**
2119
+ * Creates a new instance of the `Qwen2PreTrainedModel` class.
2120
+ * @param {Object} config The model configuration object.
2121
+ * @param {Object} session The ONNX session object.
2122
+ * @param {GenerationConfig} generation_config The generation configuration.
2123
+ */
2124
+ constructor(config, session, generation_config) {
2125
+ super(config, session);
2126
+ this.generation_config = generation_config;
2127
+ this.config.pad_token_id = this.config.eos_token_id;
2128
+ this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads;
2129
+ this.num_layers = this.config.num_hidden_layers;
2130
+ this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
2131
+ }
2132
+ }
2133
+ class Qwen2Model extends Qwen2PreTrainedModel {
2134
+ }
2135
+ class Qwen2ForCausalLM extends Qwen2PreTrainedModel {
2136
+ }
2137
+ class PhiPreTrainedModel extends PreTrainedModel {
2138
+ /**
2139
+ * Creates a new instance of the `PhiPreTrainedModel` class.
2140
+ * @param {Object} config The model configuration object.
2141
+ * @param {Object} session The ONNX session object.
2142
+ * @param {GenerationConfig} generation_config The generation configuration.
2143
+ */
2144
+ constructor(config, session, generation_config) {
2145
+ super(config, session);
2146
+ this.generation_config = generation_config;
2147
+ this.config.pad_token_id = this.config.eos_token_id;
2148
+ this.num_heads = this.config.num_attention_heads;
2149
+ this.num_layers = this.config.num_hidden_layers;
2150
+ this.dim_kv = this.config.hidden_size / this.num_heads;
2151
+ }
2152
+ }
2153
+ class PhiModel extends PhiPreTrainedModel {
2154
+ }
2155
+ class PhiForCausalLM extends PhiPreTrainedModel {
2156
+ }
2157
+ class BloomPreTrainedModel extends PreTrainedModel {
2158
+ /**
2159
+ * Creates a new instance of the `BloomPreTrainedModel` class.
2160
+ * @param {Object} config The configuration of the model.
2161
+ * @param {any} session The ONNX session containing the model weights.
2162
+ * @param {GenerationConfig} generation_config The generation configuration.
2163
+ */
2164
+ constructor(config, session, generation_config) {
2165
+ super(config, session);
2166
+ this.generation_config = generation_config;
2167
+ this.config.pad_token_id = this.config.eos_token_id;
2168
+ this.num_heads = this.config.n_head;
2169
+ this.num_layers = this.config.n_layer;
2170
+ this.dim_kv = this.config.hidden_size / this.num_heads;
2171
+ }
2172
+ }
2173
+ class BloomModel extends BloomPreTrainedModel {
2174
+ }
2175
+ class BloomForCausalLM extends BloomPreTrainedModel {
2176
+ }
2177
+ class MptPreTrainedModel extends PreTrainedModel {
2178
+ /**
2179
+ * Creates a new instance of the `MptPreTrainedModel` class.
2180
+ * @param {Object} config The model configuration object.
2181
+ * @param {Object} session The ONNX session object.
2182
+ * @param {GenerationConfig} generation_config The generation configuration.
2183
+ */
2184
+ constructor(config, session, generation_config) {
2185
+ super(config, session);
2186
+ this.generation_config = generation_config;
2187
+ this.config.pad_token_id = this.config.eos_token_id;
2188
+ this.num_heads = this.config.n_heads;
2189
+ this.num_layers = this.config.n_layers;
2190
+ this.dim_kv = this.config.d_model / this.num_heads;
2191
+ }
2192
+ }
2193
+ class MptModel extends MptPreTrainedModel {
2194
+ }
2195
+ class MptForCausalLM extends MptPreTrainedModel {
2196
+ }
2197
+ class OPTPreTrainedModel extends PreTrainedModel {
2198
+ /**
2199
+ * Creates a new instance of the `OPTPreTrainedModel` class.
2200
+ * @param {Object} config The model configuration object.
2201
+ * @param {Object} session The ONNX session object.
2202
+ * @param {GenerationConfig} generation_config The generation configuration.
2203
+ */
2204
+ constructor(config, session, generation_config) {
2205
+ super(config, session);
2206
+ this.generation_config = generation_config;
2207
+ this.config.pad_token_id = this.config.eos_token_id;
2208
+ this.num_heads = this.config.num_attention_heads;
2209
+ this.num_layers = this.config.num_hidden_layers;
2210
+ this.dim_kv = this.config.hidden_size / this.num_heads;
2211
+ }
2212
+ }
2213
+ class OPTModel extends OPTPreTrainedModel {
2214
+ }
2215
+ class OPTForCausalLM extends OPTPreTrainedModel {
2216
+ }
2217
+ class ViTPreTrainedModel extends PreTrainedModel {
2218
+ }
2219
+ class ViTModel extends ViTPreTrainedModel {
2220
+ }
2221
+ class ViTForImageClassification extends ViTPreTrainedModel {
2222
+ /**
2223
+ * @param {any} model_inputs
2224
+ */
2225
+ async _call(model_inputs) {
2226
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2227
+ }
2228
+ }
2229
+ class FastViTPreTrainedModel extends PreTrainedModel {
2230
+ }
2231
+ class FastViTModel extends FastViTPreTrainedModel {
2232
+ }
2233
+ class FastViTForImageClassification extends FastViTPreTrainedModel {
2234
+ /**
2235
+ * @param {any} model_inputs
2236
+ */
2237
+ async _call(model_inputs) {
2238
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2239
+ }
2240
+ }
2241
+ class VitMattePreTrainedModel extends PreTrainedModel {
2242
+ }
2243
+ class VitMatteForImageMatting extends VitMattePreTrainedModel {
2244
+ /**
2245
+ * @param {any} model_inputs
2246
+ */
2247
+ async _call(model_inputs) {
2248
+ return new ImageMattingOutput(await super._call(model_inputs));
2249
+ }
2250
+ }
2251
+ class MobileViTPreTrainedModel extends PreTrainedModel {
2252
+ }
2253
+ class MobileViTModel extends MobileViTPreTrainedModel {
2254
+ }
2255
+ class MobileViTForImageClassification extends MobileViTPreTrainedModel {
2256
+ /**
2257
+ * @param {any} model_inputs
2258
+ */
2259
+ async _call(model_inputs) {
2260
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2261
+ }
2262
+ }
2263
+ class MobileViTV2PreTrainedModel extends PreTrainedModel {
2264
+ }
2265
+ class MobileViTV2Model extends MobileViTV2PreTrainedModel {
2266
+ }
2267
+ class MobileViTV2ForImageClassification extends MobileViTV2PreTrainedModel {
2268
+ /**
2269
+ * @param {any} model_inputs
2270
+ */
2271
+ async _call(model_inputs) {
2272
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2273
+ }
2274
+ }
2275
+ class OwlViTPreTrainedModel extends PreTrainedModel {
2276
+ }
2277
+ class OwlViTModel extends OwlViTPreTrainedModel {
2278
+ }
2279
+ class OwlViTForObjectDetection extends OwlViTPreTrainedModel {
2280
+ }
2281
+ class Owlv2PreTrainedModel extends PreTrainedModel {
2282
+ }
2283
+ class Owlv2Model extends Owlv2PreTrainedModel {
2284
+ }
2285
+ class Owlv2ForObjectDetection extends Owlv2PreTrainedModel {
2286
+ }
2287
+ class BeitPreTrainedModel extends PreTrainedModel {
2288
+ }
2289
+ class BeitModel extends BeitPreTrainedModel {
2290
+ }
2291
+ class BeitForImageClassification extends BeitPreTrainedModel {
2292
+ /**
2293
+ * @param {any} model_inputs
2294
+ */
2295
+ async _call(model_inputs) {
2296
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2297
+ }
2298
+ }
2299
+ class DetrPreTrainedModel extends PreTrainedModel {
2300
+ }
2301
+ class DetrModel extends DetrPreTrainedModel {
2302
+ }
2303
+ class DetrForObjectDetection extends DetrPreTrainedModel {
2304
+ /**
2305
+ * @param {any} model_inputs
2306
+ */
2307
+ async _call(model_inputs) {
2308
+ return new DetrObjectDetectionOutput(await super._call(model_inputs));
2309
+ }
2310
+ }
2311
+ class DetrForSegmentation extends DetrPreTrainedModel {
2312
+ /**
2313
+ * Runs the model with the provided inputs
2314
+ * @param {Object} model_inputs Model inputs
2315
+ * @returns {Promise<DetrSegmentationOutput>} Object containing segmentation outputs
2316
+ */
2317
+ async _call(model_inputs) {
2318
+ return new DetrSegmentationOutput(await super._call(model_inputs));
2319
+ }
2320
+ }
2321
+ class DetrObjectDetectionOutput extends ModelOutput {
2322
+ /**
2323
+ * @param {Object} output The output of the model.
2324
+ * @param {Tensor} output.logits Classification logits (including no-object) for all queries.
2325
+ * @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height).
2326
+ * These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding).
2327
+ */
2328
+ constructor({ logits, pred_boxes }) {
2329
+ super();
2330
+ this.logits = logits;
2331
+ this.pred_boxes = pred_boxes;
2332
+ }
2333
+ }
2334
+ class DetrSegmentationOutput extends ModelOutput {
2335
+ /**
2336
+ * @param {Object} output The output of the model.
2337
+ * @param {Tensor} output.logits The output logits of the model.
2338
+ * @param {Tensor} output.pred_boxes Predicted boxes.
2339
+ * @param {Tensor} output.pred_masks Predicted masks.
2340
+ */
2341
+ constructor({ logits, pred_boxes, pred_masks }) {
2342
+ super();
2343
+ this.logits = logits;
2344
+ this.pred_boxes = pred_boxes;
2345
+ this.pred_masks = pred_masks;
2346
+ }
2347
+ }
2348
+ class TableTransformerPreTrainedModel extends PreTrainedModel {
2349
+ }
2350
+ class TableTransformerModel extends TableTransformerPreTrainedModel {
2351
+ }
2352
+ class TableTransformerForObjectDetection extends TableTransformerPreTrainedModel {
2353
+ /**
2354
+ * @param {any} model_inputs
2355
+ */
2356
+ async _call(model_inputs) {
2357
+ return new TableTransformerObjectDetectionOutput(await super._call(model_inputs));
2358
+ }
2359
+ }
2360
+ class TableTransformerObjectDetectionOutput extends DetrObjectDetectionOutput {
2361
+ }
2362
+ class DeiTPreTrainedModel extends PreTrainedModel {
2363
+ }
2364
+ class DeiTModel extends DeiTPreTrainedModel {
2365
+ }
2366
+ class DeiTForImageClassification extends DeiTPreTrainedModel {
2367
+ /**
2368
+ * @param {any} model_inputs
2369
+ */
2370
+ async _call(model_inputs) {
2371
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2372
+ }
2373
+ }
2374
+ class ResNetPreTrainedModel extends PreTrainedModel {
2375
+ }
2376
+ class ResNetModel extends ResNetPreTrainedModel {
2377
+ }
2378
+ class ResNetForImageClassification extends ResNetPreTrainedModel {
2379
+ /**
2380
+ * @param {any} model_inputs
2381
+ */
2382
+ async _call(model_inputs) {
2383
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2384
+ }
2385
+ }
2386
+ class SwinPreTrainedModel extends PreTrainedModel {
2387
+ }
2388
+ class SwinModel extends SwinPreTrainedModel {
2389
+ }
2390
+ class SwinForImageClassification extends SwinPreTrainedModel {
2391
+ /**
2392
+ * @param {any} model_inputs
2393
+ */
2394
+ async _call(model_inputs) {
2395
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2396
+ }
2397
+ }
2398
+ class Swin2SRPreTrainedModel extends PreTrainedModel {
2399
+ }
2400
+ class Swin2SRModel extends Swin2SRPreTrainedModel {
2401
+ }
2402
+ class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel {
2403
+ }
2404
+ class DPTPreTrainedModel extends PreTrainedModel {
2405
+ }
2406
+ class DPTModel extends DPTPreTrainedModel {
2407
+ }
2408
+ class DPTForDepthEstimation extends DPTPreTrainedModel {
2409
+ }
2410
+ class DepthAnythingPreTrainedModel extends PreTrainedModel {
2411
+ }
2412
+ class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedModel {
2413
+ }
2414
+ class GLPNPreTrainedModel extends PreTrainedModel {
2415
+ }
2416
+ class GLPNModel extends GLPNPreTrainedModel {
2417
+ }
2418
+ class GLPNForDepthEstimation extends GLPNPreTrainedModel {
2419
+ }
2420
+ class DonutSwinPreTrainedModel extends PreTrainedModel {
2421
+ }
2422
+ class DonutSwinModel extends DonutSwinPreTrainedModel {
2423
+ }
2424
+ class ConvNextPreTrainedModel extends PreTrainedModel {
2425
+ }
2426
+ class ConvNextModel extends ConvNextPreTrainedModel {
2427
+ }
2428
+ class ConvNextForImageClassification extends ConvNextPreTrainedModel {
2429
+ /**
2430
+ * @param {any} model_inputs
2431
+ */
2432
+ async _call(model_inputs) {
2433
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2434
+ }
2435
+ }
2436
+ class ConvNextV2PreTrainedModel extends PreTrainedModel {
2437
+ }
2438
+ class ConvNextV2Model extends ConvNextV2PreTrainedModel {
2439
+ }
2440
+ class ConvNextV2ForImageClassification extends ConvNextV2PreTrainedModel {
2441
+ /**
2442
+ * @param {any} model_inputs
2443
+ */
2444
+ async _call(model_inputs) {
2445
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2446
+ }
2447
+ }
2448
+ class Dinov2PreTrainedModel extends PreTrainedModel {
2449
+ }
2450
+ class Dinov2Model extends Dinov2PreTrainedModel {
2451
+ }
2452
+ class Dinov2ForImageClassification extends Dinov2PreTrainedModel {
2453
+ /**
2454
+ * @param {any} model_inputs
2455
+ */
2456
+ async _call(model_inputs) {
2457
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2458
+ }
2459
+ }
2460
+ class YolosPreTrainedModel extends PreTrainedModel {
2461
+ }
2462
+ class YolosModel extends YolosPreTrainedModel {
2463
+ }
2464
+ class YolosForObjectDetection extends YolosPreTrainedModel {
2465
+ /**
2466
+ * @param {any} model_inputs
2467
+ */
2468
+ async _call(model_inputs) {
2469
+ return new YolosObjectDetectionOutput(await super._call(model_inputs));
2470
+ }
2471
+ }
2472
+ class YolosObjectDetectionOutput extends ModelOutput {
2473
+ /**
2474
+ * @param {Object} output The output of the model.
2475
+ * @param {Tensor} output.logits Classification logits (including no-object) for all queries.
2476
+ * @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height).
2477
+ * These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding).
2478
+ */
2479
+ constructor({ logits, pred_boxes }) {
2480
+ super();
2481
+ this.logits = logits;
2482
+ this.pred_boxes = pred_boxes;
2483
+ }
2484
+ }
2485
+ class SamPreTrainedModel extends PreTrainedModel {
2486
+ }
2487
+ class SamModel extends SamPreTrainedModel {
2488
+ /**
2489
+ * Creates a new instance of the `SamModel` class.
2490
+ * @param {Object} config The configuration object specifying the hyperparameters and other model settings.
2491
+ * @param {Object} vision_encoder The ONNX session containing the vision encoder model.
2492
+ * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
2493
+ */
2494
+ constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
2495
+ super(config, vision_encoder);
2496
+ this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
2497
+ }
2498
+ /**
2499
+ * Compute image embeddings and positional image embeddings, given the pixel values of an image.
2500
+ * @param {Object} model_inputs Object containing the model inputs.
2501
+ * @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`.
2502
+ * @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings.
2503
+ */
2504
+ async get_image_embeddings({ pixel_values }) {
2505
+ return await encoderForward(this, { pixel_values });
2506
+ }
2507
+ /**
2508
+ * @typedef {Object} SamModelInputs Object containing the model inputs.
2509
+ * @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
2510
+ * These can be obtained using a `SamProcessor`.
2511
+ * @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
2512
+ * This is used by the prompt encoder to encode the prompt.
2513
+ * @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
2514
+ * This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
2515
+ * - `1`: the point is a point that contains the object of interest
2516
+ * - `0`: the point is a point that does not contain the object of interest
2517
+ * - `-1`: the point corresponds to the background
2518
+ * - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
2519
+ * @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
2520
+ * @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
2521
+ */
2522
+ /**
2523
+ * @param {SamModelInputs} model_inputs Object containing the model inputs.
2524
+ * @returns {Promise<Object>} The output of the model.
2525
+ */
2526
+ async forward(model_inputs) {
2527
+ if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) {
2528
+ model_inputs = {
2529
+ ...model_inputs,
2530
+ ...await this.get_image_embeddings(model_inputs)
2531
+ };
2532
+ }
2533
+ if (!model_inputs.input_labels) {
2534
+ const shape = model_inputs.input_points.dims.slice(0, -1);
2535
+ const numElements = shape.reduce((a, b) => a * b, 1);
2536
+ model_inputs.input_labels = new Tensor(
2537
+ "int64",
2538
+ new BigInt64Array(numElements).fill(1n),
2539
+ shape
2540
+ );
2541
+ }
2542
+ return await sessionRun(this.prompt_encoder_mask_decoder, {
2543
+ input_points: model_inputs.input_points,
2544
+ input_labels: model_inputs.input_labels,
2545
+ image_embeddings: model_inputs.image_embeddings,
2546
+ image_positional_embeddings: model_inputs.image_positional_embeddings
2547
+ });
2548
+ }
2549
+ /**
2550
+ * Runs the model with the provided inputs
2551
+ * @param {Object} model_inputs Model inputs
2552
+ * @returns {Promise<SamImageSegmentationOutput>} Object containing segmentation outputs
2553
+ */
2554
+ async _call(model_inputs) {
2555
+ return new SamImageSegmentationOutput(await super._call(model_inputs));
2556
+ }
2557
+ }
2558
+ class SamImageSegmentationOutput extends ModelOutput {
2559
+ /**
2560
+ * @param {Object} output The output of the model.
2561
+ * @param {Tensor} output.iou_scores The output logits of the model.
2562
+ * @param {Tensor} output.pred_masks Predicted boxes.
2563
+ */
2564
+ constructor({ iou_scores, pred_masks }) {
2565
+ super();
2566
+ this.iou_scores = iou_scores;
2567
+ this.pred_masks = pred_masks;
2568
+ }
2569
+ }
2570
+ class MarianPreTrainedModel extends PreTrainedModel {
2571
+ }
2572
+ class MarianModel extends MarianPreTrainedModel {
2573
+ }
2574
+ class MarianMTModel extends MarianPreTrainedModel {
2575
+ /**
2576
+ * Creates a new instance of the `MarianMTModel` class.
2577
+ * @param {Object} config The model configuration object.
2578
+ * @param {Object} session The ONNX session object.
2579
+ * @param {any} decoder_merged_session
2580
+ * @param {any} generation_config
2581
+ */
2582
+ constructor(config, session, decoder_merged_session, generation_config) {
2583
+ super(config, session);
2584
+ this.decoder_merged_session = decoder_merged_session;
2585
+ this.generation_config = generation_config;
2586
+ this.num_decoder_layers = this.config.decoder_layers;
2587
+ this.num_decoder_heads = this.config.decoder_attention_heads;
2588
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
2589
+ this.num_encoder_layers = this.config.encoder_layers;
2590
+ this.num_encoder_heads = this.config.encoder_attention_heads;
2591
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
2592
+ }
2593
+ }
2594
+ class M2M100PreTrainedModel extends PreTrainedModel {
2595
+ }
2596
+ class M2M100Model extends M2M100PreTrainedModel {
2597
+ }
2598
+ class M2M100ForConditionalGeneration extends M2M100PreTrainedModel {
2599
+ /**
2600
+ * Creates a new instance of the `M2M100ForConditionalGeneration` class.
2601
+ * @param {Object} config The model configuration object.
2602
+ * @param {Object} session The ONNX session object.
2603
+ * @param {any} decoder_merged_session
2604
+ * @param {any} generation_config
2605
+ */
2606
+ constructor(config, session, decoder_merged_session, generation_config) {
2607
+ super(config, session);
2608
+ this.decoder_merged_session = decoder_merged_session;
2609
+ this.generation_config = generation_config;
2610
+ this.num_decoder_layers = this.config.decoder_layers;
2611
+ this.num_decoder_heads = this.config.decoder_attention_heads;
2612
+ this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
2613
+ this.num_encoder_layers = this.config.encoder_layers;
2614
+ this.num_encoder_heads = this.config.encoder_attention_heads;
2615
+ this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
2616
+ }
2617
+ }
2618
+ class Wav2Vec2PreTrainedModel extends PreTrainedModel {
2619
+ }
2620
+ class Wav2Vec2Model extends Wav2Vec2PreTrainedModel {
2621
+ }
2622
+ class Wav2Vec2ForCTC extends Wav2Vec2PreTrainedModel {
2623
+ /**
2624
+ * @param {Object} model_inputs
2625
+ * @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
2626
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2627
+ */
2628
+ async _call(model_inputs) {
2629
+ return new CausalLMOutput(await super._call(model_inputs));
2630
+ }
2631
+ }
2632
+ class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel {
2633
+ /**
2634
+ * Calls the model on new inputs.
2635
+ * @param {Object} model_inputs The inputs to the model.
2636
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2637
+ */
2638
+ async _call(model_inputs) {
2639
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2640
+ }
2641
+ }
2642
+ class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel {
2643
+ /**
2644
+ * Calls the model on new inputs.
2645
+ * @param {Object} model_inputs The inputs to the model.
2646
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
2647
+ */
2648
+ async _call(model_inputs) {
2649
+ return new TokenClassifierOutput(await super._call(model_inputs));
2650
+ }
2651
+ }
2652
+ class UniSpeechPreTrainedModel extends PreTrainedModel {
2653
+ }
2654
+ class UniSpeechModel extends UniSpeechPreTrainedModel {
2655
+ }
2656
+ class UniSpeechForCTC extends UniSpeechPreTrainedModel {
2657
+ /**
2658
+ * @param {Object} model_inputs
2659
+ * @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
2660
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2661
+ */
2662
+ async _call(model_inputs) {
2663
+ return new CausalLMOutput(await super._call(model_inputs));
2664
+ }
2665
+ }
2666
+ class UniSpeechForSequenceClassification extends UniSpeechPreTrainedModel {
2667
+ /**
2668
+ * Calls the model on new inputs.
2669
+ * @param {Object} model_inputs The inputs to the model.
2670
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2671
+ */
2672
+ async _call(model_inputs) {
2673
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2674
+ }
2675
+ }
2676
+ class UniSpeechSatPreTrainedModel extends PreTrainedModel {
2677
+ }
2678
+ class UniSpeechSatModel extends UniSpeechSatPreTrainedModel {
2679
+ }
2680
+ class UniSpeechSatForCTC extends UniSpeechSatPreTrainedModel {
2681
+ /**
2682
+ * @param {Object} model_inputs
2683
+ * @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
2684
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2685
+ */
2686
+ async _call(model_inputs) {
2687
+ return new CausalLMOutput(await super._call(model_inputs));
2688
+ }
2689
+ }
2690
+ class UniSpeechSatForSequenceClassification extends UniSpeechSatPreTrainedModel {
2691
+ /**
2692
+ * Calls the model on new inputs.
2693
+ * @param {Object} model_inputs The inputs to the model.
2694
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2695
+ */
2696
+ async _call(model_inputs) {
2697
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2698
+ }
2699
+ }
2700
+ class UniSpeechSatForAudioFrameClassification extends UniSpeechSatPreTrainedModel {
2701
+ /**
2702
+ * Calls the model on new inputs.
2703
+ * @param {Object} model_inputs The inputs to the model.
2704
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
2705
+ */
2706
+ async _call(model_inputs) {
2707
+ return new TokenClassifierOutput(await super._call(model_inputs));
2708
+ }
2709
+ }
2710
+ class Wav2Vec2BertPreTrainedModel extends PreTrainedModel {
2711
+ }
2712
+ class Wav2Vec2BertModel extends Wav2Vec2BertPreTrainedModel {
2713
+ }
2714
+ class Wav2Vec2BertForCTC extends Wav2Vec2BertPreTrainedModel {
2715
+ /**
2716
+ * @param {Object} model_inputs
2717
+ * @param {Tensor} model_inputs.input_features Float values of input mel-spectrogram.
2718
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2719
+ */
2720
+ async _call(model_inputs) {
2721
+ return new CausalLMOutput(await super._call(model_inputs));
2722
+ }
2723
+ }
2724
+ class Wav2Vec2BertForSequenceClassification extends Wav2Vec2BertPreTrainedModel {
2725
+ /**
2726
+ * Calls the model on new inputs.
2727
+ * @param {Object} model_inputs The inputs to the model.
2728
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2729
+ */
2730
+ async _call(model_inputs) {
2731
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2732
+ }
2733
+ }
2734
+ class HubertModel extends Wav2Vec2PreTrainedModel {
2735
+ }
2736
+ class HubertForCTC extends Wav2Vec2PreTrainedModel {
2737
+ /**
2738
+ * @param {Object} model_inputs
2739
+ * @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
2740
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2741
+ */
2742
+ async _call(model_inputs) {
2743
+ return new CausalLMOutput(await super._call(model_inputs));
2744
+ }
2745
+ }
2746
+ class HubertForSequenceClassification extends Wav2Vec2PreTrainedModel {
2747
+ /**
2748
+ * Calls the model on new inputs.
2749
+ * @param {Object} model_inputs The inputs to the model.
2750
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2751
+ */
2752
+ async _call(model_inputs) {
2753
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2754
+ }
2755
+ }
2756
+ class WavLMPreTrainedModel extends PreTrainedModel {
2757
+ }
2758
+ class WavLMModel extends WavLMPreTrainedModel {
2759
+ }
2760
+ class WavLMForCTC extends WavLMPreTrainedModel {
2761
+ /**
2762
+ * @param {Object} model_inputs
2763
+ * @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
2764
+ * @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
2765
+ */
2766
+ async _call(model_inputs) {
2767
+ return new CausalLMOutput(await super._call(model_inputs));
2768
+ }
2769
+ }
2770
+ class WavLMForSequenceClassification extends WavLMPreTrainedModel {
2771
+ /**
2772
+ * Calls the model on new inputs.
2773
+ * @param {Object} model_inputs The inputs to the model.
2774
+ * @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
2775
+ */
2776
+ async _call(model_inputs) {
2777
+ return new SequenceClassifierOutput(await super._call(model_inputs));
2778
+ }
2779
+ }
2780
+ class WavLMForXVector extends WavLMPreTrainedModel {
2781
+ /**
2782
+ * Calls the model on new inputs.
2783
+ * @param {Object} model_inputs The inputs to the model.
2784
+ * @returns {Promise<XVectorOutput>} An object containing the model's output logits and speaker embeddings.
2785
+ */
2786
+ async _call(model_inputs) {
2787
+ return new XVectorOutput(await super._call(model_inputs));
2788
+ }
2789
+ }
2790
+ class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
2791
+ /**
2792
+ * Calls the model on new inputs.
2793
+ * @param {Object} model_inputs The inputs to the model.
2794
+ * @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
2795
+ */
2796
+ async _call(model_inputs) {
2797
+ return new TokenClassifierOutput(await super._call(model_inputs));
2798
+ }
2799
+ }
2800
+ class SpeechT5PreTrainedModel extends PreTrainedModel {
2801
+ }
2802
+ class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel {
2803
+ }
2804
+ class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel {
2805
+ /**
2806
+ * Creates a new instance of the `SpeechT5ForTextToSpeech` class.
2807
+ * @param {Object} config The model configuration.
2808
+ * @param {any} session session for the model.
2809
+ * @param {any} decoder_merged_session session for the decoder.
2810
+ * @param {GenerationConfig} generation_config The generation configuration.
2811
+ */
2812
+ constructor(config, session, decoder_merged_session, generation_config) {
2813
+ super(config, session);
2814
+ this.decoder_merged_session = decoder_merged_session;
2815
+ this.generation_config = generation_config;
2816
+ this.num_decoder_layers = this.config.decoder_layers;
2817
+ this.num_decoder_heads = this.config.decoder_attention_heads;
2818
+ this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads;
2819
+ this.num_encoder_layers = this.config.encoder_layers;
2820
+ this.num_encoder_heads = this.config.encoder_attention_heads;
2821
+ this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads;
2822
+ }
2823
+ /**
2824
+ * @typedef {Object} SpeechOutput
2825
+ * @property {Tensor} [spectrogram] The predicted log-mel spectrogram of shape
2826
+ * `(output_sequence_length, config.num_mel_bins)`. Returned when no `vocoder` is provided
2827
+ * @property {Tensor} [waveform] The predicted waveform of shape `(num_frames,)`. Returned when a `vocoder` is provided.
2828
+ * @property {Tensor} [cross_attentions] The outputs of the decoder's cross-attention layers of shape
2829
+ * `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)`. returned when `output_cross_attentions` is `true`.
2830
+ */
2831
+ /**
2832
+ * Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a speech waveform using a vocoder.
2833
+ * @param {Tensor} input_values Indices of input sequence tokens in the vocabulary.
2834
+ * @param {Tensor} speaker_embeddings Tensor containing the speaker embeddings.
2835
+ * @param {Object} options Optional parameters for generating speech.
2836
+ * @param {number} [options.threshold=0.5] The generated sequence ends when the predicted stop token probability exceeds this value.
2837
+ * @param {number} [options.minlenratio=0.0] Used to calculate the minimum required length for the output sequence.
2838
+ * @param {number} [options.maxlenratio=20.0] Used to calculate the maximum allowed length for the output sequence.
2839
+ * @param {Object} [options.vocoder=null] The vocoder that converts the mel spectrogram into a speech waveform. If `null`, the output is the mel spectrogram.
2840
+ * @param {boolean} [options.output_cross_attentions=false] Whether or not to return the attentions tensors of the decoder's cross-attention layers.
2841
+ * @returns {Promise<SpeechOutput>} A promise which resolves to an object containing the spectrogram, waveform, and cross-attention tensors.
2842
+ */
2843
+ async generate_speech(input_values, speaker_embeddings, {
2844
+ threshold = 0.5,
2845
+ minlenratio = 0,
2846
+ maxlenratio = 20,
2847
+ vocoder = null
2848
+ // output_cross_attentions = false, // TODO add
2849
+ } = {}) {
2850
+ const model_inputs = {
2851
+ input_ids: input_values
2852
+ };
2853
+ const { encoder_outputs, encoder_attention_mask } = await encoderForward(this, model_inputs);
2854
+ const r = encoder_outputs.dims[1] / this.config.reduction_factor;
2855
+ const maxlen = Math.floor(r * maxlenratio);
2856
+ const minlen = Math.floor(r * minlenratio);
2857
+ const num_mel_bins = this.config.num_mel_bins;
2858
+ let spectrogramParts = [];
2859
+ let past_key_values = null;
2860
+ let decoder_outputs = null;
2861
+ let idx = 0;
2862
+ while (true) {
2863
+ ++idx;
2864
+ const use_cache_branch = boolTensor(!!decoder_outputs);
2865
+ let output_sequence;
2866
+ if (decoder_outputs) {
2867
+ output_sequence = decoder_outputs.output_sequence_out;
2868
+ } else {
2869
+ output_sequence = new Tensor(
2870
+ "float32",
2871
+ new Float32Array(num_mel_bins),
2872
+ [1, 1, num_mel_bins]
2873
+ );
2874
+ }
2875
+ let decoderFeeds = {
2876
+ use_cache_branch,
2877
+ output_sequence,
2878
+ encoder_attention_mask,
2879
+ speaker_embeddings,
2880
+ encoder_hidden_states: encoder_outputs
2881
+ };
2882
+ this.addPastKeyValues(decoderFeeds, past_key_values);
2883
+ decoder_outputs = await sessionRun(this.decoder_merged_session, decoderFeeds);
2884
+ past_key_values = this.getPastKeyValues(decoder_outputs, past_key_values);
2885
+ const { prob, spectrum } = decoder_outputs;
2886
+ spectrogramParts.push(spectrum);
2887
+ if (idx >= minlen && // Finished when stop token or maximum length is reached.
2888
+ (Array.from(prob.data).filter((p) => p >= threshold).length > 0 || idx >= maxlen)) {
2889
+ break;
2890
+ }
2891
+ }
2892
+ const spectrogram = cat(spectrogramParts);
2893
+ const { waveform } = await sessionRun(vocoder.session, { spectrogram });
2894
+ return {
2895
+ spectrogram,
2896
+ waveform
2897
+ // cross_attentions: null, // TODO add
2898
+ };
2899
+ }
2900
+ }
2901
+ class SpeechT5HifiGan extends PreTrainedModel {
2902
+ main_input_name = "spectrogram";
2903
+ }
2904
+ class TrOCRPreTrainedModel extends PreTrainedModel {
2905
+ /**
2906
+ * Creates a new instance of the `TrOCRPreTrainedModel` class.
2907
+ * @param {Object} config The configuration of the model.
2908
+ * @param {any} session The ONNX session containing the model weights.
2909
+ * @param {GenerationConfig} generation_config The generation configuration.
2910
+ */
2911
+ constructor(config, session, generation_config) {
2912
+ super(config, session);
2913
+ this.generation_config = generation_config;
2914
+ this.config.pad_token_id = this.config.eos_token_id;
2915
+ this.num_encoder_layers = this.num_decoder_layers = this.config.decoder_layers;
2916
+ this.num_encoder_heads = this.num_decoder_heads = this.config.decoder_attention_heads;
2917
+ this.encoder_dim_kv = this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
2918
+ }
2919
+ }
2920
+ class TrOCRForCausalLM extends TrOCRPreTrainedModel {
2921
+ }
2922
+ class MistralPreTrainedModel extends PreTrainedModel {
2923
+ /**
2924
+ * Creates a new instance of the `MistralPreTrainedModel` class.
2925
+ * @param {Object} config The configuration of the model.
2926
+ * @param {any} session The ONNX session containing the model weights.
2927
+ * @param {GenerationConfig} generation_config The generation configuration.
2928
+ */
2929
+ constructor(config, session, generation_config) {
2930
+ super(config, session);
2931
+ this.generation_config = generation_config;
2932
+ this.config.pad_token_id = this.config.eos_token_id;
2933
+ this.num_heads = this.config.num_key_value_heads;
2934
+ this.num_layers = this.config.num_hidden_layers;
2935
+ this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
2936
+ }
2937
+ }
2938
+ class MistralModel extends MistralPreTrainedModel {
2939
+ }
2940
+ class MistralForCausalLM extends MistralPreTrainedModel {
2941
+ }
2942
+ class Starcoder2PreTrainedModel extends PreTrainedModel {
2943
+ /**
2944
+ * Creates a new instance of the `Starcoder2PreTrainedModel` class.
2945
+ * @param {Object} config The configuration of the model.
2946
+ * @param {any} session The ONNX session containing the model weights.
2947
+ * @param {GenerationConfig} generation_config The generation configuration.
2948
+ */
2949
+ constructor(config, session, generation_config) {
2950
+ super(config, session);
2951
+ this.generation_config = generation_config;
2952
+ this.config.pad_token_id = this.config.eos_token_id;
2953
+ this.num_heads = this.config.num_key_value_heads;
2954
+ this.num_layers = this.config.num_hidden_layers;
2955
+ this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
2956
+ }
2957
+ }
2958
+ class Starcoder2Model extends Starcoder2PreTrainedModel {
2959
+ }
2960
+ class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel {
2961
+ }
2962
+ class FalconPreTrainedModel extends PreTrainedModel {
2963
+ /**
2964
+ * Creates a new instance of the `FalconPreTrainedModel` class.
2965
+ * @param {Object} config The configuration of the model.
2966
+ * @param {any} session The ONNX session containing the model weights.
2967
+ * @param {GenerationConfig} generation_config The generation configuration.
2968
+ */
2969
+ constructor(config, session, generation_config) {
2970
+ super(config, session);
2971
+ this.generation_config = generation_config;
2972
+ this.config.pad_token_id = this.config.eos_token_id;
2973
+ this.num_heads = this.config.num_attention_heads;
2974
+ this.num_layers = this.config.num_hidden_layers;
2975
+ this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
2976
+ }
2977
+ }
2978
+ class FalconModel extends FalconPreTrainedModel {
2979
+ }
2980
+ class FalconForCausalLM extends FalconPreTrainedModel {
2981
+ }
2982
+ class ClapPreTrainedModel extends PreTrainedModel {
2983
+ }
2984
+ class ClapModel extends ClapPreTrainedModel {
2985
+ }
2986
+ class ClapTextModelWithProjection extends ClapPreTrainedModel {
2987
+ /** @type {PreTrainedModel.from_pretrained} */
2988
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
2989
+ options.model_file_name ??= "text_model";
2990
+ return super.from_pretrained(pretrained_model_name_or_path, options);
2991
+ }
2992
+ }
2993
+ class ClapAudioModelWithProjection extends ClapPreTrainedModel {
2994
+ /** @type {PreTrainedModel.from_pretrained} */
2995
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
2996
+ options.model_file_name ??= "audio_model";
2997
+ return super.from_pretrained(pretrained_model_name_or_path, options);
2998
+ }
2999
+ }
3000
+ class VitsPreTrainedModel extends PreTrainedModel {
3001
+ }
3002
+ class VitsModel extends VitsPreTrainedModel {
3003
+ /**
3004
+ * Calls the model on new inputs.
3005
+ * @param {Object} model_inputs The inputs to the model.
3006
+ * @returns {Promise<VitsModelOutput>} The outputs for the VITS model.
3007
+ */
3008
+ async _call(model_inputs) {
3009
+ return new VitsModelOutput(await super._call(model_inputs));
3010
+ }
3011
+ }
3012
+ class SegformerPreTrainedModel extends PreTrainedModel {
3013
+ }
3014
+ class SegformerForImageClassification extends SegformerPreTrainedModel {
3015
+ }
3016
+ class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
3017
+ }
3018
+ class StableLmPreTrainedModel extends PreTrainedModel {
3019
+ /**
3020
+ * Creates a new instance of the `StableLmPreTrainedModel` class.
3021
+ * @param {Object} config The configuration of the model.
3022
+ * @param {any} session The ONNX session containing the model weights.
3023
+ * @param {GenerationConfig} generation_config The generation configuration.
3024
+ */
3025
+ constructor(config, session, generation_config) {
3026
+ super(config, session);
3027
+ this.generation_config = generation_config;
3028
+ this.config.pad_token_id = this.config.eos_token_id;
3029
+ this.num_heads = this.config.num_attention_heads;
3030
+ this.num_layers = this.config.num_hidden_layers;
3031
+ this.dim_kv = this.config.hidden_size / this.num_heads;
3032
+ }
3033
+ }
3034
+ class StableLmForCausalLM extends StableLmPreTrainedModel {
3035
+ }
3036
+ class EfficientNetPreTrainedModel extends PreTrainedModel {
3037
+ }
3038
+ class EfficientNetModel extends EfficientNetPreTrainedModel {
3039
+ }
3040
+ class EfficientNetForImageClassification extends EfficientNetPreTrainedModel {
3041
+ /**
3042
+ * @param {any} model_inputs
3043
+ */
3044
+ async _call(model_inputs) {
3045
+ return new SequenceClassifierOutput(await super._call(model_inputs));
3046
+ }
3047
+ }
3048
+ class PretrainedMixin {
3049
+ /**
3050
+ * Mapping from model type to model class.
3051
+ * @type {Map<string, Object>[]}
3052
+ */
3053
+ static MODEL_CLASS_MAPPINGS = null;
3054
+ /**
3055
+ * Whether to attempt to instantiate the base class (`PretrainedModel`) if
3056
+ * the model type is not found in the mapping.
3057
+ */
3058
+ static BASE_IF_FAIL = false;
3059
+ /** @type {PreTrainedModel.from_pretrained} */
3060
+ static async from_pretrained(pretrained_model_name_or_path, {
3061
+ quantized = true,
3062
+ progress_callback = null,
3063
+ config = null,
3064
+ cache_dir = null,
3065
+ local_files_only = false,
3066
+ revision = "main",
3067
+ model_file_name = null
3068
+ } = {}) {
3069
+ let options = {
3070
+ quantized,
3071
+ progress_callback,
3072
+ config,
3073
+ cache_dir,
3074
+ local_files_only,
3075
+ revision,
3076
+ model_file_name
3077
+ };
3078
+ config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
3079
+ if (!options.config) {
3080
+ options.config = config;
3081
+ }
3082
+ if (!this.MODEL_CLASS_MAPPINGS) {
3083
+ throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
3084
+ }
3085
+ for (let MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
3086
+ const modelInfo = MODEL_CLASS_MAPPING.get(config.model_type);
3087
+ if (!modelInfo) {
3088
+ continue;
3089
+ }
3090
+ return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
3091
+ }
3092
+ if (this.BASE_IF_FAIL) {
3093
+ console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`);
3094
+ return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
3095
+ } else {
3096
+ throw Error(`Unsupported model type: ${config.model_type}`);
3097
+ }
3098
+ }
3099
+ }
3100
+ const MODEL_MAPPING_NAMES_ENCODER_ONLY = /* @__PURE__ */ new Map([
3101
+ ["bert", ["BertModel", BertModel]],
3102
+ ["nomic_bert", ["NomicBertModel", NomicBertModel]],
3103
+ ["roformer", ["RoFormerModel", RoFormerModel]],
3104
+ ["electra", ["ElectraModel", ElectraModel]],
3105
+ ["esm", ["EsmModel", EsmModel]],
3106
+ ["convbert", ["ConvBertModel", ConvBertModel]],
3107
+ ["camembert", ["CamembertModel", CamembertModel]],
3108
+ ["deberta", ["DebertaModel", DebertaModel]],
3109
+ ["deberta-v2", ["DebertaV2Model", DebertaV2Model]],
3110
+ ["mpnet", ["MPNetModel", MPNetModel]],
3111
+ ["albert", ["AlbertModel", AlbertModel]],
3112
+ ["distilbert", ["DistilBertModel", DistilBertModel]],
3113
+ ["roberta", ["RobertaModel", RobertaModel]],
3114
+ ["xlm", ["XLMModel", XLMModel]],
3115
+ ["xlm-roberta", ["XLMRobertaModel", XLMRobertaModel]],
3116
+ ["clap", ["ClapModel", ClapModel]],
3117
+ ["clip", ["CLIPModel", CLIPModel]],
3118
+ ["clipseg", ["CLIPSegModel", CLIPSegModel]],
3119
+ ["chinese_clip", ["ChineseCLIPModel", ChineseCLIPModel]],
3120
+ ["siglip", ["SiglipModel", SiglipModel]],
3121
+ ["mobilebert", ["MobileBertModel", MobileBertModel]],
3122
+ ["squeezebert", ["SqueezeBertModel", SqueezeBertModel]],
3123
+ ["wav2vec2", ["Wav2Vec2Model", Wav2Vec2Model]],
3124
+ ["wav2vec2-bert", ["Wav2Vec2BertModel", Wav2Vec2BertModel]],
3125
+ ["unispeech", ["UniSpeechModel", UniSpeechModel]],
3126
+ ["unispeech-sat", ["UniSpeechSatModel", UniSpeechSatModel]],
3127
+ ["hubert", ["HubertModel", HubertModel]],
3128
+ ["wavlm", ["WavLMModel", WavLMModel]],
3129
+ ["audio-spectrogram-transformer", ["ASTModel", ASTModel]],
3130
+ ["vits", ["VitsModel", VitsModel]],
3131
+ ["detr", ["DetrModel", DetrModel]],
3132
+ ["table-transformer", ["TableTransformerModel", TableTransformerModel]],
3133
+ ["vit", ["ViTModel", ViTModel]],
3134
+ ["fastvit", ["FastViTModel", FastViTModel]],
3135
+ ["mobilevit", ["MobileViTModel", MobileViTModel]],
3136
+ ["mobilevitv2", ["MobileViTV2Model", MobileViTV2Model]],
3137
+ ["owlvit", ["OwlViTModel", OwlViTModel]],
3138
+ ["owlv2", ["Owlv2Model", Owlv2Model]],
3139
+ ["beit", ["BeitModel", BeitModel]],
3140
+ ["deit", ["DeiTModel", DeiTModel]],
3141
+ ["convnext", ["ConvNextModel", ConvNextModel]],
3142
+ ["convnextv2", ["ConvNextV2Model", ConvNextV2Model]],
3143
+ ["dinov2", ["Dinov2Model", Dinov2Model]],
3144
+ ["resnet", ["ResNetModel", ResNetModel]],
3145
+ ["swin", ["SwinModel", SwinModel]],
3146
+ ["swin2sr", ["Swin2SRModel", Swin2SRModel]],
3147
+ ["donut-swin", ["DonutSwinModel", DonutSwinModel]],
3148
+ ["yolos", ["YolosModel", YolosModel]],
3149
+ ["dpt", ["DPTModel", DPTModel]],
3150
+ ["glpn", ["GLPNModel", GLPNModel]],
3151
+ ["hifigan", ["SpeechT5HifiGan", SpeechT5HifiGan]],
3152
+ ["efficientnet", ["EfficientNetModel", EfficientNetModel]]
3153
+ ]);
3154
+ const MODEL_MAPPING_NAMES_ENCODER_DECODER = /* @__PURE__ */ new Map([
3155
+ ["t5", ["T5Model", T5Model]],
3156
+ ["longt5", ["LongT5Model", LongT5Model]],
3157
+ ["mt5", ["MT5Model", MT5Model]],
3158
+ ["bart", ["BartModel", BartModel]],
3159
+ ["mbart", ["MBartModel", MBartModel]],
3160
+ ["marian", ["MarianModel", MarianModel]],
3161
+ ["whisper", ["WhisperModel", WhisperModel]],
3162
+ ["m2m_100", ["M2M100Model", M2M100Model]],
3163
+ ["blenderbot", ["BlenderbotModel", BlenderbotModel]],
3164
+ ["blenderbot-small", ["BlenderbotSmallModel", BlenderbotSmallModel]]
3165
+ ]);
3166
+ const MODEL_MAPPING_NAMES_DECODER_ONLY = /* @__PURE__ */ new Map([
3167
+ ["bloom", ["BloomModel", BloomModel]],
3168
+ ["gpt2", ["GPT2Model", GPT2Model]],
3169
+ ["gptj", ["GPTJModel", GPTJModel]],
3170
+ ["gpt_bigcode", ["GPTBigCodeModel", GPTBigCodeModel]],
3171
+ ["gpt_neo", ["GPTNeoModel", GPTNeoModel]],
3172
+ ["gpt_neox", ["GPTNeoXModel", GPTNeoXModel]],
3173
+ ["codegen", ["CodeGenModel", CodeGenModel]],
3174
+ ["llama", ["LlamaModel", LlamaModel]],
3175
+ ["qwen2", ["Qwen2Model", Qwen2Model]],
3176
+ ["phi", ["PhiModel", PhiModel]],
3177
+ ["mpt", ["MptModel", MptModel]],
3178
+ ["opt", ["OPTModel", OPTModel]],
3179
+ ["mistral", ["MistralModel", MistralModel]],
3180
+ ["starcoder2", ["Starcoder2Model", Starcoder2Model]],
3181
+ ["falcon", ["FalconModel", FalconModel]]
3182
+ ]);
3183
+ const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = /* @__PURE__ */ new Map([
3184
+ ["speecht5", ["SpeechT5ForSpeechToText", SpeechT5ForSpeechToText]],
3185
+ ["whisper", ["WhisperForConditionalGeneration", WhisperForConditionalGeneration]]
3186
+ ]);
3187
+ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = /* @__PURE__ */ new Map([
3188
+ ["speecht5", ["SpeechT5ForTextToSpeech", SpeechT5ForTextToSpeech]]
3189
+ ]);
3190
+ const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = /* @__PURE__ */ new Map([
3191
+ ["vits", ["VitsModel", VitsModel]]
3192
+ ]);
3193
+ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3194
+ ["bert", ["BertForSequenceClassification", BertForSequenceClassification]],
3195
+ ["roformer", ["RoFormerForSequenceClassification", RoFormerForSequenceClassification]],
3196
+ ["electra", ["ElectraForSequenceClassification", ElectraForSequenceClassification]],
3197
+ ["esm", ["EsmForSequenceClassification", EsmForSequenceClassification]],
3198
+ ["convbert", ["ConvBertForSequenceClassification", ConvBertForSequenceClassification]],
3199
+ ["camembert", ["CamembertForSequenceClassification", CamembertForSequenceClassification]],
3200
+ ["deberta", ["DebertaForSequenceClassification", DebertaForSequenceClassification]],
3201
+ ["deberta-v2", ["DebertaV2ForSequenceClassification", DebertaV2ForSequenceClassification]],
3202
+ ["mpnet", ["MPNetForSequenceClassification", MPNetForSequenceClassification]],
3203
+ ["albert", ["AlbertForSequenceClassification", AlbertForSequenceClassification]],
3204
+ ["distilbert", ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]],
3205
+ ["roberta", ["RobertaForSequenceClassification", RobertaForSequenceClassification]],
3206
+ ["xlm", ["XLMForSequenceClassification", XLMForSequenceClassification]],
3207
+ ["xlm-roberta", ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]],
3208
+ ["bart", ["BartForSequenceClassification", BartForSequenceClassification]],
3209
+ ["mbart", ["MBartForSequenceClassification", MBartForSequenceClassification]],
3210
+ ["mobilebert", ["MobileBertForSequenceClassification", MobileBertForSequenceClassification]],
3211
+ ["squeezebert", ["SqueezeBertForSequenceClassification", SqueezeBertForSequenceClassification]]
3212
+ ]);
3213
+ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3214
+ ["bert", ["BertForTokenClassification", BertForTokenClassification]],
3215
+ ["roformer", ["RoFormerForTokenClassification", RoFormerForTokenClassification]],
3216
+ ["electra", ["ElectraForTokenClassification", ElectraForTokenClassification]],
3217
+ ["esm", ["EsmForTokenClassification", EsmForTokenClassification]],
3218
+ ["convbert", ["ConvBertForTokenClassification", ConvBertForTokenClassification]],
3219
+ ["camembert", ["CamembertForTokenClassification", CamembertForTokenClassification]],
3220
+ ["deberta", ["DebertaForTokenClassification", DebertaForTokenClassification]],
3221
+ ["deberta-v2", ["DebertaV2ForTokenClassification", DebertaV2ForTokenClassification]],
3222
+ ["mpnet", ["MPNetForTokenClassification", MPNetForTokenClassification]],
3223
+ ["distilbert", ["DistilBertForTokenClassification", DistilBertForTokenClassification]],
3224
+ ["roberta", ["RobertaForTokenClassification", RobertaForTokenClassification]],
3225
+ ["xlm", ["XLMForTokenClassification", XLMForTokenClassification]],
3226
+ ["xlm-roberta", ["XLMRobertaForTokenClassification", XLMRobertaForTokenClassification]]
3227
+ ]);
3228
+ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = /* @__PURE__ */ new Map([
3229
+ ["t5", ["T5ForConditionalGeneration", T5ForConditionalGeneration]],
3230
+ ["longt5", ["LongT5ForConditionalGeneration", LongT5ForConditionalGeneration]],
3231
+ ["mt5", ["MT5ForConditionalGeneration", MT5ForConditionalGeneration]],
3232
+ ["bart", ["BartForConditionalGeneration", BartForConditionalGeneration]],
3233
+ ["mbart", ["MBartForConditionalGeneration", MBartForConditionalGeneration]],
3234
+ ["marian", ["MarianMTModel", MarianMTModel]],
3235
+ ["m2m_100", ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]],
3236
+ ["blenderbot", ["BlenderbotForConditionalGeneration", BlenderbotForConditionalGeneration]],
3237
+ ["blenderbot-small", ["BlenderbotSmallForConditionalGeneration", BlenderbotSmallForConditionalGeneration]]
3238
+ ]);
3239
+ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = /* @__PURE__ */ new Map([
3240
+ ["bloom", ["BloomForCausalLM", BloomForCausalLM]],
3241
+ ["gpt2", ["GPT2LMHeadModel", GPT2LMHeadModel]],
3242
+ ["gptj", ["GPTJForCausalLM", GPTJForCausalLM]],
3243
+ ["gpt_bigcode", ["GPTBigCodeForCausalLM", GPTBigCodeForCausalLM]],
3244
+ ["gpt_neo", ["GPTNeoForCausalLM", GPTNeoForCausalLM]],
3245
+ ["gpt_neox", ["GPTNeoXForCausalLM", GPTNeoXForCausalLM]],
3246
+ ["codegen", ["CodeGenForCausalLM", CodeGenForCausalLM]],
3247
+ ["llama", ["LlamaForCausalLM", LlamaForCausalLM]],
3248
+ ["qwen2", ["Qwen2ForCausalLM", Qwen2ForCausalLM]],
3249
+ ["phi", ["PhiForCausalLM", PhiForCausalLM]],
3250
+ ["mpt", ["MptForCausalLM", MptForCausalLM]],
3251
+ ["opt", ["OPTForCausalLM", OPTForCausalLM]],
3252
+ ["mbart", ["MBartForCausalLM", MBartForCausalLM]],
3253
+ ["mistral", ["MistralForCausalLM", MistralForCausalLM]],
3254
+ ["starcoder2", ["Starcoder2ForCausalLM", Starcoder2ForCausalLM]],
3255
+ ["falcon", ["FalconForCausalLM", FalconForCausalLM]],
3256
+ ["trocr", ["TrOCRForCausalLM", TrOCRForCausalLM]],
3257
+ ["stablelm", ["StableLmForCausalLM", StableLmForCausalLM]]
3258
+ ]);
3259
+ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = /* @__PURE__ */ new Map([
3260
+ ["bert", ["BertForMaskedLM", BertForMaskedLM]],
3261
+ ["roformer", ["RoFormerForMaskedLM", RoFormerForMaskedLM]],
3262
+ ["electra", ["ElectraForMaskedLM", ElectraForMaskedLM]],
3263
+ ["esm", ["EsmForMaskedLM", EsmForMaskedLM]],
3264
+ ["convbert", ["ConvBertForMaskedLM", ConvBertForMaskedLM]],
3265
+ ["camembert", ["CamembertForMaskedLM", CamembertForMaskedLM]],
3266
+ ["deberta", ["DebertaForMaskedLM", DebertaForMaskedLM]],
3267
+ ["deberta-v2", ["DebertaV2ForMaskedLM", DebertaV2ForMaskedLM]],
3268
+ ["mpnet", ["MPNetForMaskedLM", MPNetForMaskedLM]],
3269
+ ["albert", ["AlbertForMaskedLM", AlbertForMaskedLM]],
3270
+ ["distilbert", ["DistilBertForMaskedLM", DistilBertForMaskedLM]],
3271
+ ["roberta", ["RobertaForMaskedLM", RobertaForMaskedLM]],
3272
+ ["xlm", ["XLMWithLMHeadModel", XLMWithLMHeadModel]],
3273
+ ["xlm-roberta", ["XLMRobertaForMaskedLM", XLMRobertaForMaskedLM]],
3274
+ ["mobilebert", ["MobileBertForMaskedLM", MobileBertForMaskedLM]],
3275
+ ["squeezebert", ["SqueezeBertForMaskedLM", SqueezeBertForMaskedLM]]
3276
+ ]);
3277
+ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = /* @__PURE__ */ new Map([
3278
+ ["bert", ["BertForQuestionAnswering", BertForQuestionAnswering]],
3279
+ ["roformer", ["RoFormerForQuestionAnswering", RoFormerForQuestionAnswering]],
3280
+ ["electra", ["ElectraForQuestionAnswering", ElectraForQuestionAnswering]],
3281
+ ["convbert", ["ConvBertForQuestionAnswering", ConvBertForQuestionAnswering]],
3282
+ ["camembert", ["CamembertForQuestionAnswering", CamembertForQuestionAnswering]],
3283
+ ["deberta", ["DebertaForQuestionAnswering", DebertaForQuestionAnswering]],
3284
+ ["deberta-v2", ["DebertaV2ForQuestionAnswering", DebertaV2ForQuestionAnswering]],
3285
+ ["mpnet", ["MPNetForQuestionAnswering", MPNetForQuestionAnswering]],
3286
+ ["albert", ["AlbertForQuestionAnswering", AlbertForQuestionAnswering]],
3287
+ ["distilbert", ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]],
3288
+ ["roberta", ["RobertaForQuestionAnswering", RobertaForQuestionAnswering]],
3289
+ ["xlm", ["XLMForQuestionAnswering", XLMForQuestionAnswering]],
3290
+ ["xlm-roberta", ["XLMRobertaForQuestionAnswering", XLMRobertaForQuestionAnswering]],
3291
+ ["mobilebert", ["MobileBertForQuestionAnswering", MobileBertForQuestionAnswering]],
3292
+ ["squeezebert", ["SqueezeBertForQuestionAnswering", SqueezeBertForQuestionAnswering]]
3293
+ ]);
3294
+ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = /* @__PURE__ */ new Map([
3295
+ ["vision-encoder-decoder", ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]]
3296
+ ]);
3297
+ const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3298
+ ["vit", ["ViTForImageClassification", ViTForImageClassification]],
3299
+ ["fastvit", ["FastViTForImageClassification", FastViTForImageClassification]],
3300
+ ["mobilevit", ["MobileViTForImageClassification", MobileViTForImageClassification]],
3301
+ ["mobilevitv2", ["MobileViTV2ForImageClassification", MobileViTV2ForImageClassification]],
3302
+ ["beit", ["BeitForImageClassification", BeitForImageClassification]],
3303
+ ["deit", ["DeiTForImageClassification", DeiTForImageClassification]],
3304
+ ["convnext", ["ConvNextForImageClassification", ConvNextForImageClassification]],
3305
+ ["convnextv2", ["ConvNextV2ForImageClassification", ConvNextV2ForImageClassification]],
3306
+ ["dinov2", ["Dinov2ForImageClassification", Dinov2ForImageClassification]],
3307
+ ["resnet", ["ResNetForImageClassification", ResNetForImageClassification]],
3308
+ ["swin", ["SwinForImageClassification", SwinForImageClassification]],
3309
+ ["segformer", ["SegformerForImageClassification", SegformerForImageClassification]],
3310
+ ["efficientnet", ["EfficientNetForImageClassification", EfficientNetForImageClassification]]
3311
+ ]);
3312
+ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3313
+ ["detr", ["DetrForObjectDetection", DetrForObjectDetection]],
3314
+ ["table-transformer", ["TableTransformerForObjectDetection", TableTransformerForObjectDetection]],
3315
+ ["yolos", ["YolosForObjectDetection", YolosForObjectDetection]]
3316
+ ]);
3317
+ const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3318
+ ["owlvit", ["OwlViTForObjectDetection", OwlViTForObjectDetection]],
3319
+ ["owlv2", ["Owlv2ForObjectDetection", Owlv2ForObjectDetection]]
3320
+ ]);
3321
+ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3322
+ ["detr", ["DetrForSegmentation", DetrForSegmentation]],
3323
+ ["clipseg", ["CLIPSegForImageSegmentation", CLIPSegForImageSegmentation]]
3324
+ ]);
3325
+ const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3326
+ ["segformer", ["SegformerForSemanticSegmentation", SegformerForSemanticSegmentation]]
3327
+ ]);
3328
+ const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3329
+ ["sam", ["SamModel", SamModel]]
3330
+ ]);
3331
+ const MODEL_FOR_CTC_MAPPING_NAMES = /* @__PURE__ */ new Map([
3332
+ ["wav2vec2", ["Wav2Vec2ForCTC", Wav2Vec2ForCTC]],
3333
+ ["wav2vec2-bert", ["Wav2Vec2BertForCTC", Wav2Vec2BertForCTC]],
3334
+ ["unispeech", ["UniSpeechForCTC", UniSpeechForCTC]],
3335
+ ["unispeech-sat", ["UniSpeechSatForCTC", UniSpeechSatForCTC]],
3336
+ ["wavlm", ["WavLMForCTC", WavLMForCTC]],
3337
+ ["hubert", ["HubertForCTC", HubertForCTC]]
3338
+ ]);
3339
+ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3340
+ ["wav2vec2", ["Wav2Vec2ForSequenceClassification", Wav2Vec2ForSequenceClassification]],
3341
+ ["wav2vec2-bert", ["Wav2Vec2BertForSequenceClassification", Wav2Vec2BertForSequenceClassification]],
3342
+ ["unispeech", ["UniSpeechForSequenceClassification", UniSpeechForSequenceClassification]],
3343
+ ["unispeech-sat", ["UniSpeechSatForSequenceClassification", UniSpeechSatForSequenceClassification]],
3344
+ ["wavlm", ["WavLMForSequenceClassification", WavLMForSequenceClassification]],
3345
+ ["hubert", ["HubertForSequenceClassification", HubertForSequenceClassification]],
3346
+ ["audio-spectrogram-transformer", ["ASTForAudioClassification", ASTForAudioClassification]]
3347
+ ]);
3348
+ const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = /* @__PURE__ */ new Map([
3349
+ ["wavlm", ["WavLMForXVector", WavLMForXVector]]
3350
+ ]);
3351
+ const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3352
+ ["unispeech-sat", ["UniSpeechSatForAudioFrameClassification", UniSpeechSatForAudioFrameClassification]],
3353
+ ["wavlm", ["WavLMForAudioFrameClassification", WavLMForAudioFrameClassification]],
3354
+ ["wav2vec2", ["Wav2Vec2ForAudioFrameClassification", Wav2Vec2ForAudioFrameClassification]]
3355
+ ]);
3356
+ const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = /* @__PURE__ */ new Map([
3357
+ ["vitmatte", ["VitMatteForImageMatting", VitMatteForImageMatting]]
3358
+ ]);
3359
+ const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = /* @__PURE__ */ new Map([
3360
+ ["swin2sr", ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]]
3361
+ ]);
3362
+ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3363
+ ["dpt", ["DPTForDepthEstimation", DPTForDepthEstimation]],
3364
+ ["depth_anything", ["DepthAnythingForDepthEstimation", DepthAnythingForDepthEstimation]],
3365
+ ["glpn", ["GLPNForDepthEstimation", GLPNForDepthEstimation]]
3366
+ ]);
3367
+ const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = /* @__PURE__ */ new Map([
3368
+ ["clip", ["CLIPVisionModelWithProjection", CLIPVisionModelWithProjection]],
3369
+ ["siglip", ["SiglipVisionModel", SiglipVisionModel]]
3370
+ ]);
3371
+ const MODEL_CLASS_TYPE_MAPPING = [
3372
+ [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
3373
+ [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES.EncoderDecoder],
3374
+ [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly],
3375
+ [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3376
+ [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3377
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
3378
+ [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
3379
+ [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
3380
+ [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3381
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3382
+ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
3383
+ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3384
+ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3385
+ [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3386
+ [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3387
+ [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3388
+ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3389
+ [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3390
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3391
+ [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
3392
+ [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3393
+ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3394
+ [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
3395
+ [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3396
+ [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3397
+ [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
3398
+ // Custom:
3399
+ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly]
3400
+ ];
3401
+ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
3402
+ for (const [name, model] of mappings.values()) {
3403
+ MODEL_TYPE_MAPPING.set(name, type);
3404
+ MODEL_CLASS_TO_NAME_MAPPING.set(model, name);
3405
+ MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
3406
+ }
3407
+ }
3408
+ const CUSTOM_MAPPING = [
3409
+ ["CLIPTextModelWithProjection", CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
3410
+ ["SiglipTextModel", SiglipTextModel, MODEL_TYPES.EncoderOnly],
3411
+ ["ClapTextModelWithProjection", ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
3412
+ ["ClapAudioModelWithProjection", ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly]
3413
+ ];
3414
+ for (const [name, model, type] of CUSTOM_MAPPING) {
3415
+ MODEL_TYPE_MAPPING.set(name, type);
3416
+ MODEL_CLASS_TO_NAME_MAPPING.set(model, name);
3417
+ MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
3418
+ }
3419
+ class AutoModel extends PretrainedMixin {
3420
+ /** @type {Map<string, Object>[]} */
3421
+ // @ts-ignore
3422
+ static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map((x) => x[0]);
3423
+ static BASE_IF_FAIL = true;
3424
+ }
3425
+ class Seq2SeqLMOutput extends ModelOutput {
3426
+ /**
3427
+ * @param {Object} output The output of the model.
3428
+ * @param {Tensor} output.logits The output logits of the model.
3429
+ * @param {Tensor} output.past_key_values An tensor of key/value pairs that represent the previous state of the model.
3430
+ * @param {Tensor} output.encoder_outputs The output of the encoder in a sequence-to-sequence model.
3431
+ * @param {Tensor} [output.decoder_attentions] Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.
3432
+ * @param {Tensor} [output.cross_attentions] Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
3433
+ */
3434
+ constructor({ logits, past_key_values, encoder_outputs, decoder_attentions = null, cross_attentions = null }) {
3435
+ super();
3436
+ this.logits = logits;
3437
+ this.past_key_values = past_key_values;
3438
+ this.encoder_outputs = encoder_outputs;
3439
+ this.decoder_attentions = decoder_attentions;
3440
+ this.cross_attentions = cross_attentions;
3441
+ }
3442
+ }
3443
+ class SequenceClassifierOutput extends ModelOutput {
3444
+ /**
3445
+ * @param {Object} output The output of the model.
3446
+ * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax).
3447
+ */
3448
+ constructor({ logits }) {
3449
+ super();
3450
+ this.logits = logits;
3451
+ }
3452
+ }
3453
+ class XVectorOutput extends ModelOutput {
3454
+ /**
3455
+ * @param {Object} output The output of the model.
3456
+ * @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`.
3457
+ * @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`.
3458
+ */
3459
+ constructor({ logits, embeddings }) {
3460
+ super();
3461
+ this.logits = logits;
3462
+ this.embeddings = embeddings;
3463
+ }
3464
+ }
3465
+ class TokenClassifierOutput extends ModelOutput {
3466
+ /**
3467
+ * @param {Object} output The output of the model.
3468
+ * @param {Tensor} output.logits Classification scores (before SoftMax).
3469
+ */
3470
+ constructor({ logits }) {
3471
+ super();
3472
+ this.logits = logits;
3473
+ }
3474
+ }
3475
+ class MaskedLMOutput extends ModelOutput {
3476
+ /**
3477
+ * @param {Object} output The output of the model.
3478
+ * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
3479
+ */
3480
+ constructor({ logits }) {
3481
+ super();
3482
+ this.logits = logits;
3483
+ }
3484
+ }
3485
+ class QuestionAnsweringModelOutput extends ModelOutput {
3486
+ /**
3487
+ * @param {Object} output The output of the model.
3488
+ * @param {Tensor} output.start_logits Span-start scores (before SoftMax).
3489
+ * @param {Tensor} output.end_logits Span-end scores (before SoftMax).
3490
+ */
3491
+ constructor({ start_logits, end_logits }) {
3492
+ super();
3493
+ this.start_logits = start_logits;
3494
+ this.end_logits = end_logits;
3495
+ }
3496
+ }
3497
+ class CausalLMOutput extends ModelOutput {
3498
+ /**
3499
+ * @param {Object} output The output of the model.
3500
+ * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before softmax).
3501
+ */
3502
+ constructor({ logits }) {
3503
+ super();
3504
+ this.logits = logits;
3505
+ }
3506
+ }
3507
+ class ImageMattingOutput extends ModelOutput {
3508
+ /**
3509
+ * @param {Object} output The output of the model.
3510
+ * @param {Tensor} output.alphas Estimated alpha values, of shape `(batch_size, num_channels, height, width)`.
3511
+ */
3512
+ constructor({ alphas }) {
3513
+ super();
3514
+ this.alphas = alphas;
3515
+ }
3516
+ }
3517
+ class VitsModelOutput extends ModelOutput {
3518
+ /**
3519
+ * @param {Object} output The output of the model.
3520
+ * @param {Tensor} output.waveform The final audio waveform predicted by the model, of shape `(batch_size, sequence_length)`.
3521
+ * @param {Tensor} output.spectrogram The log-mel spectrogram predicted at the output of the flow model.
3522
+ * This spectrogram is passed to the Hi-Fi GAN decoder model to obtain the final audio waveform.
3523
+ */
3524
+ constructor({ waveform, spectrogram }) {
3525
+ super();
3526
+ this.waveform = waveform;
3527
+ this.spectrogram = spectrogram;
3528
+ }
3529
+ }
3530
+ export {
3531
+ ASTForAudioClassification,
3532
+ ASTModel,
3533
+ ASTPreTrainedModel,
3534
+ AlbertForMaskedLM,
3535
+ AlbertForQuestionAnswering,
3536
+ AlbertForSequenceClassification,
3537
+ AlbertModel,
3538
+ AlbertPreTrainedModel,
3539
+ AutoModel,
3540
+ BartForConditionalGeneration,
3541
+ BartForSequenceClassification,
3542
+ BartModel,
3543
+ BartPretrainedModel,
3544
+ BeitForImageClassification,
3545
+ BeitModel,
3546
+ BeitPreTrainedModel,
3547
+ BertForMaskedLM,
3548
+ BertForQuestionAnswering,
3549
+ BertForSequenceClassification,
3550
+ BertForTokenClassification,
3551
+ BertModel,
3552
+ BertPreTrainedModel,
3553
+ BlenderbotForConditionalGeneration,
3554
+ BlenderbotModel,
3555
+ BlenderbotPreTrainedModel,
3556
+ BlenderbotSmallForConditionalGeneration,
3557
+ BlenderbotSmallModel,
3558
+ BlenderbotSmallPreTrainedModel,
3559
+ BloomForCausalLM,
3560
+ BloomModel,
3561
+ BloomPreTrainedModel,
3562
+ CLIPModel,
3563
+ CLIPPreTrainedModel,
3564
+ CLIPSegForImageSegmentation,
3565
+ CLIPSegModel,
3566
+ CLIPSegPreTrainedModel,
3567
+ CLIPTextModelWithProjection,
3568
+ CLIPVisionModelWithProjection,
3569
+ CamembertForMaskedLM,
3570
+ CamembertForQuestionAnswering,
3571
+ CamembertForSequenceClassification,
3572
+ CamembertForTokenClassification,
3573
+ CamembertModel,
3574
+ CamembertPreTrainedModel,
3575
+ CausalLMOutput,
3576
+ ChineseCLIPModel,
3577
+ ChineseCLIPPreTrainedModel,
3578
+ ClapAudioModelWithProjection,
3579
+ ClapModel,
3580
+ ClapPreTrainedModel,
3581
+ ClapTextModelWithProjection,
3582
+ CodeGenForCausalLM,
3583
+ CodeGenModel,
3584
+ CodeGenPreTrainedModel,
3585
+ ConvBertForMaskedLM,
3586
+ ConvBertForQuestionAnswering,
3587
+ ConvBertForSequenceClassification,
3588
+ ConvBertForTokenClassification,
3589
+ ConvBertModel,
3590
+ ConvBertPreTrainedModel,
3591
+ ConvNextForImageClassification,
3592
+ ConvNextModel,
3593
+ ConvNextPreTrainedModel,
3594
+ ConvNextV2ForImageClassification,
3595
+ ConvNextV2Model,
3596
+ ConvNextV2PreTrainedModel,
3597
+ DPTForDepthEstimation,
3598
+ DPTModel,
3599
+ DPTPreTrainedModel,
3600
+ DebertaForMaskedLM,
3601
+ DebertaForQuestionAnswering,
3602
+ DebertaForSequenceClassification,
3603
+ DebertaForTokenClassification,
3604
+ DebertaModel,
3605
+ DebertaPreTrainedModel,
3606
+ DebertaV2ForMaskedLM,
3607
+ DebertaV2ForQuestionAnswering,
3608
+ DebertaV2ForSequenceClassification,
3609
+ DebertaV2ForTokenClassification,
3610
+ DebertaV2Model,
3611
+ DebertaV2PreTrainedModel,
3612
+ DeiTForImageClassification,
3613
+ DeiTModel,
3614
+ DeiTPreTrainedModel,
3615
+ DepthAnythingForDepthEstimation,
3616
+ DepthAnythingPreTrainedModel,
3617
+ DetrForObjectDetection,
3618
+ DetrForSegmentation,
3619
+ DetrModel,
3620
+ DetrObjectDetectionOutput,
3621
+ DetrPreTrainedModel,
3622
+ DetrSegmentationOutput,
3623
+ Dinov2ForImageClassification,
3624
+ Dinov2Model,
3625
+ Dinov2PreTrainedModel,
3626
+ DistilBertForMaskedLM,
3627
+ DistilBertForQuestionAnswering,
3628
+ DistilBertForSequenceClassification,
3629
+ DistilBertForTokenClassification,
3630
+ DistilBertModel,
3631
+ DistilBertPreTrainedModel,
3632
+ DonutSwinModel,
3633
+ DonutSwinPreTrainedModel,
3634
+ EfficientNetForImageClassification,
3635
+ EfficientNetModel,
3636
+ EfficientNetPreTrainedModel,
3637
+ ElectraForMaskedLM,
3638
+ ElectraForQuestionAnswering,
3639
+ ElectraForSequenceClassification,
3640
+ ElectraForTokenClassification,
3641
+ ElectraModel,
3642
+ ElectraPreTrainedModel,
3643
+ EsmForMaskedLM,
3644
+ EsmForSequenceClassification,
3645
+ EsmForTokenClassification,
3646
+ EsmModel,
3647
+ EsmPreTrainedModel,
3648
+ FalconForCausalLM,
3649
+ FalconModel,
3650
+ FalconPreTrainedModel,
3651
+ FastViTForImageClassification,
3652
+ FastViTModel,
3653
+ FastViTPreTrainedModel,
3654
+ GLPNForDepthEstimation,
3655
+ GLPNModel,
3656
+ GLPNPreTrainedModel,
3657
+ GPT2LMHeadModel,
3658
+ GPT2Model,
3659
+ GPT2PreTrainedModel,
3660
+ GPTBigCodeForCausalLM,
3661
+ GPTBigCodeModel,
3662
+ GPTBigCodePreTrainedModel,
3663
+ GPTJForCausalLM,
3664
+ GPTJModel,
3665
+ GPTJPreTrainedModel,
3666
+ GPTNeoForCausalLM,
3667
+ GPTNeoModel,
3668
+ GPTNeoPreTrainedModel,
3669
+ GPTNeoXForCausalLM,
3670
+ GPTNeoXModel,
3671
+ GPTNeoXPreTrainedModel,
3672
+ HubertForCTC,
3673
+ HubertForSequenceClassification,
3674
+ HubertModel,
3675
+ ImageMattingOutput,
3676
+ LlamaForCausalLM,
3677
+ LlamaModel,
3678
+ LlamaPreTrainedModel,
3679
+ LongT5ForConditionalGeneration,
3680
+ LongT5Model,
3681
+ LongT5PreTrainedModel,
3682
+ M2M100ForConditionalGeneration,
3683
+ M2M100Model,
3684
+ M2M100PreTrainedModel,
3685
+ MBartForCausalLM,
3686
+ MBartForConditionalGeneration,
3687
+ MBartForSequenceClassification,
3688
+ MBartModel,
3689
+ MBartPreTrainedModel,
3690
+ MPNetForMaskedLM,
3691
+ MPNetForQuestionAnswering,
3692
+ MPNetForSequenceClassification,
3693
+ MPNetForTokenClassification,
3694
+ MPNetModel,
3695
+ MPNetPreTrainedModel,
3696
+ MT5ForConditionalGeneration,
3697
+ MT5Model,
3698
+ MT5PreTrainedModel,
3699
+ MarianMTModel,
3700
+ MarianModel,
3701
+ MarianPreTrainedModel,
3702
+ MaskedLMOutput,
3703
+ MistralForCausalLM,
3704
+ MistralModel,
3705
+ MistralPreTrainedModel,
3706
+ MobileBertForMaskedLM,
3707
+ MobileBertForQuestionAnswering,
3708
+ MobileBertForSequenceClassification,
3709
+ MobileBertModel,
3710
+ MobileBertPreTrainedModel,
3711
+ MobileViTForImageClassification,
3712
+ MobileViTModel,
3713
+ MobileViTPreTrainedModel,
3714
+ MobileViTV2ForImageClassification,
3715
+ MobileViTV2Model,
3716
+ MobileViTV2PreTrainedModel,
3717
+ ModelOutput,
3718
+ MptForCausalLM,
3719
+ MptModel,
3720
+ MptPreTrainedModel,
3721
+ NomicBertModel,
3722
+ NomicBertPreTrainedModel,
3723
+ OPTForCausalLM,
3724
+ OPTModel,
3725
+ OPTPreTrainedModel,
3726
+ OwlViTForObjectDetection,
3727
+ OwlViTModel,
3728
+ OwlViTPreTrainedModel,
3729
+ Owlv2ForObjectDetection,
3730
+ Owlv2Model,
3731
+ Owlv2PreTrainedModel,
3732
+ PhiForCausalLM,
3733
+ PhiModel,
3734
+ PhiPreTrainedModel,
3735
+ PreTrainedModel,
3736
+ PretrainedMixin,
3737
+ QuestionAnsweringModelOutput,
3738
+ Qwen2ForCausalLM,
3739
+ Qwen2Model,
3740
+ Qwen2PreTrainedModel,
3741
+ ResNetForImageClassification,
3742
+ ResNetModel,
3743
+ ResNetPreTrainedModel,
3744
+ RoFormerForMaskedLM,
3745
+ RoFormerForQuestionAnswering,
3746
+ RoFormerForSequenceClassification,
3747
+ RoFormerForTokenClassification,
3748
+ RoFormerModel,
3749
+ RoFormerPreTrainedModel,
3750
+ RobertaForMaskedLM,
3751
+ RobertaForQuestionAnswering,
3752
+ RobertaForSequenceClassification,
3753
+ RobertaForTokenClassification,
3754
+ RobertaModel,
3755
+ RobertaPreTrainedModel,
3756
+ SamImageSegmentationOutput,
3757
+ SamModel,
3758
+ SamPreTrainedModel,
3759
+ SegformerForImageClassification,
3760
+ SegformerForSemanticSegmentation,
3761
+ SegformerPreTrainedModel,
3762
+ Seq2SeqLMOutput,
3763
+ SequenceClassifierOutput,
3764
+ SiglipModel,
3765
+ SiglipPreTrainedModel,
3766
+ SiglipTextModel,
3767
+ SiglipVisionModel,
3768
+ SpeechT5ForSpeechToText,
3769
+ SpeechT5ForTextToSpeech,
3770
+ SpeechT5HifiGan,
3771
+ SpeechT5PreTrainedModel,
3772
+ SqueezeBertForMaskedLM,
3773
+ SqueezeBertForQuestionAnswering,
3774
+ SqueezeBertForSequenceClassification,
3775
+ SqueezeBertModel,
3776
+ SqueezeBertPreTrainedModel,
3777
+ StableLmForCausalLM,
3778
+ StableLmPreTrainedModel,
3779
+ Starcoder2ForCausalLM,
3780
+ Starcoder2Model,
3781
+ Starcoder2PreTrainedModel,
3782
+ Swin2SRForImageSuperResolution,
3783
+ Swin2SRModel,
3784
+ Swin2SRPreTrainedModel,
3785
+ SwinForImageClassification,
3786
+ SwinModel,
3787
+ SwinPreTrainedModel,
3788
+ T5ForConditionalGeneration,
3789
+ T5Model,
3790
+ T5PreTrainedModel,
3791
+ TableTransformerForObjectDetection,
3792
+ TableTransformerModel,
3793
+ TableTransformerObjectDetectionOutput,
3794
+ TableTransformerPreTrainedModel,
3795
+ TokenClassifierOutput,
3796
+ TrOCRForCausalLM,
3797
+ TrOCRPreTrainedModel,
3798
+ UniSpeechForCTC,
3799
+ UniSpeechForSequenceClassification,
3800
+ UniSpeechModel,
3801
+ UniSpeechPreTrainedModel,
3802
+ UniSpeechSatForAudioFrameClassification,
3803
+ UniSpeechSatForCTC,
3804
+ UniSpeechSatForSequenceClassification,
3805
+ UniSpeechSatModel,
3806
+ UniSpeechSatPreTrainedModel,
3807
+ ViTForImageClassification,
3808
+ ViTModel,
3809
+ ViTPreTrainedModel,
3810
+ VisionEncoderDecoderModel,
3811
+ VitMatteForImageMatting,
3812
+ VitMattePreTrainedModel,
3813
+ VitsModel,
3814
+ VitsModelOutput,
3815
+ VitsPreTrainedModel,
3816
+ Wav2Vec2BertForCTC,
3817
+ Wav2Vec2BertForSequenceClassification,
3818
+ Wav2Vec2BertModel,
3819
+ Wav2Vec2BertPreTrainedModel,
3820
+ Wav2Vec2ForAudioFrameClassification,
3821
+ Wav2Vec2ForCTC,
3822
+ Wav2Vec2ForSequenceClassification,
3823
+ Wav2Vec2Model,
3824
+ Wav2Vec2PreTrainedModel,
3825
+ WavLMForAudioFrameClassification,
3826
+ WavLMForCTC,
3827
+ WavLMForSequenceClassification,
3828
+ WavLMForXVector,
3829
+ WavLMModel,
3830
+ WavLMPreTrainedModel,
3831
+ WhisperForConditionalGeneration,
3832
+ WhisperModel,
3833
+ WhisperPreTrainedModel,
3834
+ XLMForQuestionAnswering,
3835
+ XLMForSequenceClassification,
3836
+ XLMForTokenClassification,
3837
+ XLMModel,
3838
+ XLMPreTrainedModel,
3839
+ XLMRobertaForMaskedLM,
3840
+ XLMRobertaForQuestionAnswering,
3841
+ XLMRobertaForSequenceClassification,
3842
+ XLMRobertaForTokenClassification,
3843
+ XLMRobertaModel,
3844
+ XLMRobertaPreTrainedModel,
3845
+ XLMWithLMHeadModel,
3846
+ XVectorOutput,
3847
+ YolosForObjectDetection,
3848
+ YolosModel,
3849
+ YolosObjectDetectionOutput,
3850
+ YolosPreTrainedModel
3851
+ };
3852
+ //# sourceMappingURL=models.js.map