@huggingface/inference 3.7.1 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/dist/index.cjs +247 -132
  2. package/dist/index.js +247 -132
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/makeRequestOptions.d.ts +5 -3
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/consts.d.ts +2 -3
  8. package/dist/src/providers/consts.d.ts.map +1 -1
  9. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  10. package/dist/src/providers/hf-inference.d.ts +1 -0
  11. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  12. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  13. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  14. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  15. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  16. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  17. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  18. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  19. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  20. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  21. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  22. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  23. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  24. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  25. package/dist/src/types.d.ts +2 -0
  26. package/dist/src/types.d.ts.map +1 -1
  27. package/dist/src/utils/request.d.ts +3 -2
  28. package/dist/src/utils/request.d.ts.map +1 -1
  29. package/package.json +3 -3
  30. package/src/lib/getInferenceProviderMapping.ts +96 -0
  31. package/src/lib/makeRequestOptions.ts +50 -12
  32. package/src/providers/consts.ts +5 -2
  33. package/src/providers/fal-ai.ts +31 -2
  34. package/src/providers/hf-inference.ts +8 -6
  35. package/src/snippets/getInferenceSnippets.ts +26 -8
  36. package/src/snippets/templates.exported.ts +25 -25
  37. package/src/tasks/audio/audioClassification.ts +1 -1
  38. package/src/tasks/audio/audioToAudio.ts +1 -1
  39. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  40. package/src/tasks/audio/textToSpeech.ts +1 -1
  41. package/src/tasks/custom/request.ts +3 -1
  42. package/src/tasks/custom/streamingRequest.ts +4 -1
  43. package/src/tasks/cv/imageClassification.ts +1 -1
  44. package/src/tasks/cv/imageSegmentation.ts +1 -1
  45. package/src/tasks/cv/imageToImage.ts +1 -1
  46. package/src/tasks/cv/imageToText.ts +1 -1
  47. package/src/tasks/cv/objectDetection.ts +1 -1
  48. package/src/tasks/cv/textToImage.ts +2 -2
  49. package/src/tasks/cv/textToVideo.ts +9 -5
  50. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  51. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -0
  52. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  53. package/src/tasks/nlp/chatCompletion.ts +1 -1
  54. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  55. package/src/tasks/nlp/featureExtraction.ts +1 -1
  56. package/src/tasks/nlp/fillMask.ts +1 -1
  57. package/src/tasks/nlp/questionAnswering.ts +8 -4
  58. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  59. package/src/tasks/nlp/summarization.ts +1 -1
  60. package/src/tasks/nlp/tableQuestionAnswering.ts +8 -4
  61. package/src/tasks/nlp/textClassification.ts +1 -1
  62. package/src/tasks/nlp/textGeneration.ts +2 -3
  63. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  64. package/src/tasks/nlp/tokenClassification.ts +8 -5
  65. package/src/tasks/nlp/translation.ts +1 -1
  66. package/src/tasks/nlp/zeroShotClassification.ts +8 -5
  67. package/src/tasks/tabular/tabularClassification.ts +1 -1
  68. package/src/tasks/tabular/tabularRegression.ts +1 -1
  69. package/src/types.ts +2 -0
  70. package/src/utils/request.ts +7 -4
  71. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  72. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  73. package/src/lib/getProviderModelId.ts +0 -74
package/dist/index.js CHANGED
@@ -41,15 +41,6 @@ __export(tasks_exports, {
41
41
  zeroShotImageClassification: () => zeroShotImageClassification
42
42
  });
43
43
 
44
- // package.json
45
- var name = "@huggingface/inference";
46
- var version = "3.7.1";
47
-
48
- // src/config.ts
49
- var HF_HUB_URL = "https://huggingface.co";
50
- var HF_ROUTER_URL = "https://router.huggingface.co";
51
- var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
52
-
53
44
  // src/lib/InferenceOutputError.ts
54
45
  var InferenceOutputError = class extends TypeError {
55
46
  constructor(message) {
@@ -91,6 +82,11 @@ function omit(o, props) {
91
82
  return pick(o, letsKeep);
92
83
  }
93
84
 
85
+ // src/config.ts
86
+ var HF_HUB_URL = "https://huggingface.co";
87
+ var HF_ROUTER_URL = "https://router.huggingface.co";
88
+ var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
89
+
94
90
  // src/utils/toArray.ts
95
91
  function toArray(obj) {
96
92
  if (Array.isArray(obj)) {
@@ -280,14 +276,37 @@ var FalAITask = class extends TaskProviderHelper {
280
276
  return headers;
281
277
  }
282
278
  };
279
+ function buildLoraPath(modelId, adapterWeightsPath) {
280
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
281
+ }
283
282
  var FalAITextToImageTask = class extends FalAITask {
284
283
  preparePayload(params) {
285
- return {
284
+ const payload = {
286
285
  ...omit(params.args, ["inputs", "parameters"]),
287
286
  ...params.args.parameters,
288
287
  sync_mode: true,
289
- prompt: params.args.inputs
288
+ prompt: params.args.inputs,
289
+ ...params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? {
290
+ loras: [
291
+ {
292
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
293
+ scale: 1
294
+ }
295
+ ]
296
+ } : void 0
290
297
  };
298
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
299
+ payload.loras = [
300
+ {
301
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
302
+ scale: 1
303
+ }
304
+ ];
305
+ if (params.mapping.providerId === "fal-ai/lora") {
306
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
307
+ }
308
+ }
309
+ return payload;
291
310
  }
292
311
  async getResponse(response, outputType) {
293
312
  if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
@@ -417,6 +436,7 @@ var FireworksConversationalTask = class extends BaseConversationalTask {
417
436
  };
418
437
 
419
438
  // src/providers/hf-inference.ts
439
+ var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
420
440
  var HFInferenceTask = class extends TaskProviderHelper {
421
441
  constructor() {
422
442
  super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
@@ -1124,8 +1144,12 @@ function getProviderHelper(provider, task) {
1124
1144
  return providerTasks[task];
1125
1145
  }
1126
1146
 
1147
+ // package.json
1148
+ var name = "@huggingface/inference";
1149
+ var version = "3.8.0";
1150
+
1127
1151
  // src/providers/consts.ts
1128
- var HARDCODED_MODEL_ID_MAPPING = {
1152
+ var HARDCODED_MODEL_INFERENCE_MAPPING = {
1129
1153
  /**
1130
1154
  * "HF model ID" => "Model ID on Inference Provider's side"
1131
1155
  *
@@ -1147,53 +1171,67 @@ var HARDCODED_MODEL_ID_MAPPING = {
1147
1171
  together: {}
1148
1172
  };
1149
1173
 
1150
- // src/lib/getProviderModelId.ts
1174
+ // src/lib/getInferenceProviderMapping.ts
1151
1175
  var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
1152
- async function getProviderModelId(params, args, options = {}) {
1153
- if (params.provider === "hf-inference") {
1154
- return params.model;
1155
- }
1156
- if (!options.task) {
1157
- throw new Error("task must be specified when using a third-party provider");
1158
- }
1159
- const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
1160
- if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
1161
- return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
1176
+ async function getInferenceProviderMapping(params, options) {
1177
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
1178
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
1162
1179
  }
1163
1180
  let inferenceProviderMapping;
1164
- if (inferenceProviderMappingCache.has(params.model)) {
1165
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.model);
1181
+ if (inferenceProviderMappingCache.has(params.modelId)) {
1182
+ inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
1166
1183
  } else {
1167
- inferenceProviderMapping = await (options?.fetch ?? fetch)(
1168
- `${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
1184
+ const resp = await (options?.fetch ?? fetch)(
1185
+ `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
1169
1186
  {
1170
- headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {}
1187
+ headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
1171
1188
  }
1172
- ).then((resp) => resp.json()).then((json) => json.inferenceProviderMapping).catch(() => null);
1189
+ );
1190
+ if (resp.status === 404) {
1191
+ throw new Error(`Model ${params.modelId} does not exist`);
1192
+ }
1193
+ inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
1173
1194
  }
1174
1195
  if (!inferenceProviderMapping) {
1175
- throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
1196
+ throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
1176
1197
  }
1177
1198
  const providerMapping = inferenceProviderMapping[params.provider];
1178
1199
  if (providerMapping) {
1179
- if (providerMapping.task !== task) {
1200
+ const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
1201
+ if (!typedInclude(equivalentTasks, providerMapping.task)) {
1180
1202
  throw new Error(
1181
- `Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1203
+ `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1182
1204
  );
1183
1205
  }
1184
1206
  if (providerMapping.status === "staging") {
1185
1207
  console.warn(
1186
- `Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1208
+ `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1187
1209
  );
1188
1210
  }
1189
- return providerMapping.providerId;
1211
+ if (providerMapping.adapter === "lora") {
1212
+ const treeResp = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${params.modelId}/tree/main`);
1213
+ if (!treeResp.ok) {
1214
+ throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
1215
+ }
1216
+ const tree = await treeResp.json();
1217
+ const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
1218
+ if (!adapterWeightsPath) {
1219
+ throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
1220
+ }
1221
+ return {
1222
+ ...providerMapping,
1223
+ hfModelId: params.modelId,
1224
+ adapterWeightsPath
1225
+ };
1226
+ }
1227
+ return { ...providerMapping, hfModelId: params.modelId };
1190
1228
  }
1191
- throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
1229
+ return null;
1192
1230
  }
1193
1231
 
1194
1232
  // src/lib/makeRequestOptions.ts
1195
1233
  var tasks = null;
1196
- async function makeRequestOptions(args, options) {
1234
+ async function makeRequestOptions(args, providerHelper, options) {
1197
1235
  const { provider: maybeProvider, model: maybeModel } = args;
1198
1236
  const provider = maybeProvider ?? "hf-inference";
1199
1237
  const { task } = options ?? {};
@@ -1203,28 +1241,55 @@ async function makeRequestOptions(args, options) {
1203
1241
  if (maybeModel && isUrl(maybeModel)) {
1204
1242
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
1205
1243
  }
1244
+ if (args.endpointUrl) {
1245
+ return makeRequestOptionsFromResolvedModel(
1246
+ maybeModel ?? args.endpointUrl,
1247
+ providerHelper,
1248
+ args,
1249
+ void 0,
1250
+ options
1251
+ );
1252
+ }
1206
1253
  if (!maybeModel && !task) {
1207
1254
  throw new Error("No model provided, and no task has been specified.");
1208
1255
  }
1209
1256
  const hfModel = maybeModel ?? await loadDefaultModel(task);
1210
- const providerHelper = getProviderHelper(provider, task);
1211
1257
  if (providerHelper.clientSideRoutingOnly && !maybeModel) {
1212
1258
  throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
1213
1259
  }
1214
- const resolvedModel = providerHelper.clientSideRoutingOnly ? (
1260
+ const inferenceProviderMapping = providerHelper.clientSideRoutingOnly ? {
1215
1261
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1216
- removeProviderPrefix(maybeModel, provider)
1217
- ) : await getProviderModelId({ model: hfModel, provider }, args, {
1218
- task,
1219
- fetch: options?.fetch
1220
- });
1221
- return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
1262
+ providerId: removeProviderPrefix(maybeModel, provider),
1263
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1264
+ hfModelId: maybeModel,
1265
+ status: "live",
1266
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1267
+ task
1268
+ } : await getInferenceProviderMapping(
1269
+ {
1270
+ modelId: hfModel,
1271
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1272
+ task,
1273
+ provider,
1274
+ accessToken: args.accessToken
1275
+ },
1276
+ { fetch: options?.fetch }
1277
+ );
1278
+ if (!inferenceProviderMapping) {
1279
+ throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
1280
+ }
1281
+ return makeRequestOptionsFromResolvedModel(
1282
+ inferenceProviderMapping.providerId,
1283
+ providerHelper,
1284
+ args,
1285
+ inferenceProviderMapping,
1286
+ options
1287
+ );
1222
1288
  }
1223
- function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
1289
+ function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
1224
1290
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
1225
1291
  const provider = maybeProvider ?? "hf-inference";
1226
1292
  const { includeCredentials, task, signal, billTo } = options ?? {};
1227
- const providerHelper = getProviderHelper(provider, task);
1228
1293
  const authMethod = (() => {
1229
1294
  if (providerHelper.clientSideRoutingOnly) {
1230
1295
  if (accessToken && accessToken.startsWith("hf_")) {
@@ -1262,7 +1327,8 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
1262
1327
  const body = providerHelper.makeBody({
1263
1328
  args: remainingArgs,
1264
1329
  model: resolvedModel,
1265
- task
1330
+ task,
1331
+ mapping
1266
1332
  });
1267
1333
  let credentials;
1268
1334
  if (typeof includeCredentials === "string") {
@@ -1403,12 +1469,12 @@ function newMessage() {
1403
1469
  }
1404
1470
 
1405
1471
  // src/utils/request.ts
1406
- async function innerRequest(args, options) {
1407
- const { url, info } = await makeRequestOptions(args, options);
1472
+ async function innerRequest(args, providerHelper, options) {
1473
+ const { url, info } = await makeRequestOptions(args, providerHelper, options);
1408
1474
  const response = await (options?.fetch ?? fetch)(url, info);
1409
1475
  const requestContext = { url, info };
1410
1476
  if (options?.retry_on_error !== false && response.status === 503) {
1411
- return innerRequest(args, options);
1477
+ return innerRequest(args, providerHelper, options);
1412
1478
  }
1413
1479
  if (!response.ok) {
1414
1480
  const contentType = response.headers.get("Content-Type");
@@ -1435,11 +1501,11 @@ async function innerRequest(args, options) {
1435
1501
  const blob = await response.blob();
1436
1502
  return { data: blob, requestContext };
1437
1503
  }
1438
- async function* innerStreamingRequest(args, options) {
1439
- const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
1504
+ async function* innerStreamingRequest(args, providerHelper, options) {
1505
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
1440
1506
  const response = await (options?.fetch ?? fetch)(url, info);
1441
1507
  if (options?.retry_on_error !== false && response.status === 503) {
1442
- return yield* innerStreamingRequest(args, options);
1508
+ return yield* innerStreamingRequest(args, providerHelper, options);
1443
1509
  }
1444
1510
  if (!response.ok) {
1445
1511
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -1513,7 +1579,8 @@ async function request(args, options) {
1513
1579
  console.warn(
1514
1580
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1515
1581
  );
1516
- const result = await innerRequest(args, options);
1582
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1583
+ const result = await innerRequest(args, providerHelper, options);
1517
1584
  return result.data;
1518
1585
  }
1519
1586
 
@@ -1522,7 +1589,8 @@ async function* streamingRequest(args, options) {
1522
1589
  console.warn(
1523
1590
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1524
1591
  );
1525
- yield* innerStreamingRequest(args, options);
1592
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1593
+ yield* innerStreamingRequest(args, providerHelper, options);
1526
1594
  }
1527
1595
 
1528
1596
  // src/tasks/audio/utils.ts
@@ -1537,7 +1605,7 @@ function preparePayload(args) {
1537
1605
  async function audioClassification(args, options) {
1538
1606
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1539
1607
  const payload = preparePayload(args);
1540
- const { data: res } = await innerRequest(payload, {
1608
+ const { data: res } = await innerRequest(payload, providerHelper, {
1541
1609
  ...options,
1542
1610
  task: "audio-classification"
1543
1611
  });
@@ -1548,7 +1616,7 @@ async function audioClassification(args, options) {
1548
1616
  async function audioToAudio(args, options) {
1549
1617
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1550
1618
  const payload = preparePayload(args);
1551
- const { data: res } = await innerRequest(payload, {
1619
+ const { data: res } = await innerRequest(payload, providerHelper, {
1552
1620
  ...options,
1553
1621
  task: "audio-to-audio"
1554
1622
  });
@@ -1572,7 +1640,7 @@ function base64FromBytes(arr) {
1572
1640
  async function automaticSpeechRecognition(args, options) {
1573
1641
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
1574
1642
  const payload = await buildPayload(args);
1575
- const { data: res } = await innerRequest(payload, {
1643
+ const { data: res } = await innerRequest(payload, providerHelper, {
1576
1644
  ...options,
1577
1645
  task: "automatic-speech-recognition"
1578
1646
  });
@@ -1612,7 +1680,7 @@ async function buildPayload(args) {
1612
1680
  async function textToSpeech(args, options) {
1613
1681
  const provider = args.provider ?? "hf-inference";
1614
1682
  const providerHelper = getProviderHelper(provider, "text-to-speech");
1615
- const { data: res } = await innerRequest(args, {
1683
+ const { data: res } = await innerRequest(args, providerHelper, {
1616
1684
  ...options,
1617
1685
  task: "text-to-speech"
1618
1686
  });
@@ -1628,7 +1696,7 @@ function preparePayload2(args) {
1628
1696
  async function imageClassification(args, options) {
1629
1697
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1630
1698
  const payload = preparePayload2(args);
1631
- const { data: res } = await innerRequest(payload, {
1699
+ const { data: res } = await innerRequest(payload, providerHelper, {
1632
1700
  ...options,
1633
1701
  task: "image-classification"
1634
1702
  });
@@ -1639,7 +1707,7 @@ async function imageClassification(args, options) {
1639
1707
  async function imageSegmentation(args, options) {
1640
1708
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1641
1709
  const payload = preparePayload2(args);
1642
- const { data: res } = await innerRequest(payload, {
1710
+ const { data: res } = await innerRequest(payload, providerHelper, {
1643
1711
  ...options,
1644
1712
  task: "image-segmentation"
1645
1713
  });
@@ -1664,7 +1732,7 @@ async function imageToImage(args, options) {
1664
1732
  )
1665
1733
  };
1666
1734
  }
1667
- const { data: res } = await innerRequest(reqArgs, {
1735
+ const { data: res } = await innerRequest(reqArgs, providerHelper, {
1668
1736
  ...options,
1669
1737
  task: "image-to-image"
1670
1738
  });
@@ -1675,7 +1743,7 @@ async function imageToImage(args, options) {
1675
1743
  async function imageToText(args, options) {
1676
1744
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1677
1745
  const payload = preparePayload2(args);
1678
- const { data: res } = await innerRequest(payload, {
1746
+ const { data: res } = await innerRequest(payload, providerHelper, {
1679
1747
  ...options,
1680
1748
  task: "image-to-text"
1681
1749
  });
@@ -1686,7 +1754,7 @@ async function imageToText(args, options) {
1686
1754
  async function objectDetection(args, options) {
1687
1755
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1688
1756
  const payload = preparePayload2(args);
1689
- const { data: res } = await innerRequest(payload, {
1757
+ const { data: res } = await innerRequest(payload, providerHelper, {
1690
1758
  ...options,
1691
1759
  task: "object-detection"
1692
1760
  });
@@ -1697,11 +1765,11 @@ async function objectDetection(args, options) {
1697
1765
  async function textToImage(args, options) {
1698
1766
  const provider = args.provider ?? "hf-inference";
1699
1767
  const providerHelper = getProviderHelper(provider, "text-to-image");
1700
- const { data: res } = await innerRequest(args, {
1768
+ const { data: res } = await innerRequest(args, providerHelper, {
1701
1769
  ...options,
1702
1770
  task: "text-to-image"
1703
1771
  });
1704
- const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
1772
+ const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" });
1705
1773
  return providerHelper.getResponse(res, url, info.headers, options?.outputType);
1706
1774
  }
1707
1775
 
@@ -1709,11 +1777,15 @@ async function textToImage(args, options) {
1709
1777
  async function textToVideo(args, options) {
1710
1778
  const provider = args.provider ?? "hf-inference";
1711
1779
  const providerHelper = getProviderHelper(provider, "text-to-video");
1712
- const { data: response } = await innerRequest(args, {
1713
- ...options,
1714
- task: "text-to-video"
1715
- });
1716
- const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
1780
+ const { data: response } = await innerRequest(
1781
+ args,
1782
+ providerHelper,
1783
+ {
1784
+ ...options,
1785
+ task: "text-to-video"
1786
+ }
1787
+ );
1788
+ const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
1717
1789
  return providerHelper.getResponse(response, url, info.headers);
1718
1790
  }
1719
1791
 
@@ -1742,7 +1814,7 @@ async function preparePayload3(args) {
1742
1814
  async function zeroShotImageClassification(args, options) {
1743
1815
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1744
1816
  const payload = await preparePayload3(args);
1745
- const { data: res } = await innerRequest(payload, {
1817
+ const { data: res } = await innerRequest(payload, providerHelper, {
1746
1818
  ...options,
1747
1819
  task: "zero-shot-image-classification"
1748
1820
  });
@@ -1752,7 +1824,7 @@ async function zeroShotImageClassification(args, options) {
1752
1824
  // src/tasks/nlp/chatCompletion.ts
1753
1825
  async function chatCompletion(args, options) {
1754
1826
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1755
- const { data: response } = await innerRequest(args, {
1827
+ const { data: response } = await innerRequest(args, providerHelper, {
1756
1828
  ...options,
1757
1829
  task: "conversational"
1758
1830
  });
@@ -1761,7 +1833,8 @@ async function chatCompletion(args, options) {
1761
1833
 
1762
1834
  // src/tasks/nlp/chatCompletionStream.ts
1763
1835
  async function* chatCompletionStream(args, options) {
1764
- yield* innerStreamingRequest(args, {
1836
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1837
+ yield* innerStreamingRequest(args, providerHelper, {
1765
1838
  ...options,
1766
1839
  task: "conversational"
1767
1840
  });
@@ -1770,7 +1843,7 @@ async function* chatCompletionStream(args, options) {
1770
1843
  // src/tasks/nlp/featureExtraction.ts
1771
1844
  async function featureExtraction(args, options) {
1772
1845
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
1773
- const { data: res } = await innerRequest(args, {
1846
+ const { data: res } = await innerRequest(args, providerHelper, {
1774
1847
  ...options,
1775
1848
  task: "feature-extraction"
1776
1849
  });
@@ -1780,7 +1853,7 @@ async function featureExtraction(args, options) {
1780
1853
  // src/tasks/nlp/fillMask.ts
1781
1854
  async function fillMask(args, options) {
1782
1855
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
1783
- const { data: res } = await innerRequest(args, {
1856
+ const { data: res } = await innerRequest(args, providerHelper, {
1784
1857
  ...options,
1785
1858
  task: "fill-mask"
1786
1859
  });
@@ -1790,17 +1863,21 @@ async function fillMask(args, options) {
1790
1863
  // src/tasks/nlp/questionAnswering.ts
1791
1864
  async function questionAnswering(args, options) {
1792
1865
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
1793
- const { data: res } = await innerRequest(args, {
1794
- ...options,
1795
- task: "question-answering"
1796
- });
1866
+ const { data: res } = await innerRequest(
1867
+ args,
1868
+ providerHelper,
1869
+ {
1870
+ ...options,
1871
+ task: "question-answering"
1872
+ }
1873
+ );
1797
1874
  return providerHelper.getResponse(res);
1798
1875
  }
1799
1876
 
1800
1877
  // src/tasks/nlp/sentenceSimilarity.ts
1801
1878
  async function sentenceSimilarity(args, options) {
1802
1879
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
1803
- const { data: res } = await innerRequest(args, {
1880
+ const { data: res } = await innerRequest(args, providerHelper, {
1804
1881
  ...options,
1805
1882
  task: "sentence-similarity"
1806
1883
  });
@@ -1810,7 +1887,7 @@ async function sentenceSimilarity(args, options) {
1810
1887
  // src/tasks/nlp/summarization.ts
1811
1888
  async function summarization(args, options) {
1812
1889
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
1813
- const { data: res } = await innerRequest(args, {
1890
+ const { data: res } = await innerRequest(args, providerHelper, {
1814
1891
  ...options,
1815
1892
  task: "summarization"
1816
1893
  });
@@ -1820,17 +1897,21 @@ async function summarization(args, options) {
1820
1897
  // src/tasks/nlp/tableQuestionAnswering.ts
1821
1898
  async function tableQuestionAnswering(args, options) {
1822
1899
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
1823
- const { data: res } = await innerRequest(args, {
1824
- ...options,
1825
- task: "table-question-answering"
1826
- });
1900
+ const { data: res } = await innerRequest(
1901
+ args,
1902
+ providerHelper,
1903
+ {
1904
+ ...options,
1905
+ task: "table-question-answering"
1906
+ }
1907
+ );
1827
1908
  return providerHelper.getResponse(res);
1828
1909
  }
1829
1910
 
1830
1911
  // src/tasks/nlp/textClassification.ts
1831
1912
  async function textClassification(args, options) {
1832
1913
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
1833
- const { data: res } = await innerRequest(args, {
1914
+ const { data: res } = await innerRequest(args, providerHelper, {
1834
1915
  ...options,
1835
1916
  task: "text-classification"
1836
1917
  });
@@ -1839,9 +1920,8 @@ async function textClassification(args, options) {
1839
1920
 
1840
1921
  // src/tasks/nlp/textGeneration.ts
1841
1922
  async function textGeneration(args, options) {
1842
- const provider = args.provider ?? "hf-inference";
1843
- const providerHelper = getProviderHelper(provider, "text-generation");
1844
- const { data: response } = await innerRequest(args, {
1923
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
1924
+ const { data: response } = await innerRequest(args, providerHelper, {
1845
1925
  ...options,
1846
1926
  task: "text-generation"
1847
1927
  });
@@ -1850,7 +1930,8 @@ async function textGeneration(args, options) {
1850
1930
 
1851
1931
  // src/tasks/nlp/textGenerationStream.ts
1852
1932
  async function* textGenerationStream(args, options) {
1853
- yield* innerStreamingRequest(args, {
1933
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
1934
+ yield* innerStreamingRequest(args, providerHelper, {
1854
1935
  ...options,
1855
1936
  task: "text-generation"
1856
1937
  });
@@ -1859,17 +1940,21 @@ async function* textGenerationStream(args, options) {
1859
1940
  // src/tasks/nlp/tokenClassification.ts
1860
1941
  async function tokenClassification(args, options) {
1861
1942
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
1862
- const { data: res } = await innerRequest(args, {
1863
- ...options,
1864
- task: "token-classification"
1865
- });
1943
+ const { data: res } = await innerRequest(
1944
+ args,
1945
+ providerHelper,
1946
+ {
1947
+ ...options,
1948
+ task: "token-classification"
1949
+ }
1950
+ );
1866
1951
  return providerHelper.getResponse(res);
1867
1952
  }
1868
1953
 
1869
1954
  // src/tasks/nlp/translation.ts
1870
1955
  async function translation(args, options) {
1871
1956
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
1872
- const { data: res } = await innerRequest(args, {
1957
+ const { data: res } = await innerRequest(args, providerHelper, {
1873
1958
  ...options,
1874
1959
  task: "translation"
1875
1960
  });
@@ -1879,10 +1964,14 @@ async function translation(args, options) {
1879
1964
  // src/tasks/nlp/zeroShotClassification.ts
1880
1965
  async function zeroShotClassification(args, options) {
1881
1966
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
1882
- const { data: res } = await innerRequest(args, {
1883
- ...options,
1884
- task: "zero-shot-classification"
1885
- });
1967
+ const { data: res } = await innerRequest(
1968
+ args,
1969
+ providerHelper,
1970
+ {
1971
+ ...options,
1972
+ task: "zero-shot-classification"
1973
+ }
1974
+ );
1886
1975
  return providerHelper.getResponse(res);
1887
1976
  }
1888
1977
 
@@ -1899,6 +1988,7 @@ async function documentQuestionAnswering(args, options) {
1899
1988
  };
1900
1989
  const { data: res } = await innerRequest(
1901
1990
  reqArgs,
1991
+ providerHelper,
1902
1992
  {
1903
1993
  ...options,
1904
1994
  task: "document-question-answering"
@@ -1918,7 +2008,7 @@ async function visualQuestionAnswering(args, options) {
1918
2008
  image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1919
2009
  }
1920
2010
  };
1921
- const { data: res } = await innerRequest(reqArgs, {
2011
+ const { data: res } = await innerRequest(reqArgs, providerHelper, {
1922
2012
  ...options,
1923
2013
  task: "visual-question-answering"
1924
2014
  });
@@ -1928,7 +2018,7 @@ async function visualQuestionAnswering(args, options) {
1928
2018
  // src/tasks/tabular/tabularClassification.ts
1929
2019
  async function tabularClassification(args, options) {
1930
2020
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
1931
- const { data: res } = await innerRequest(args, {
2021
+ const { data: res } = await innerRequest(args, providerHelper, {
1932
2022
  ...options,
1933
2023
  task: "tabular-classification"
1934
2024
  });
@@ -1938,7 +2028,7 @@ async function tabularClassification(args, options) {
1938
2028
  // src/tasks/tabular/tabularRegression.ts
1939
2029
  async function tabularRegression(args, options) {
1940
2030
  const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
1941
- const { data: res } = await innerRequest(args, {
2031
+ const { data: res } = await innerRequest(args, providerHelper, {
1942
2032
  ...options,
1943
2033
  task: "tabular-regression"
1944
2034
  });
@@ -2021,19 +2111,19 @@ import {
2021
2111
  var templates = {
2022
2112
  "js": {
2023
2113
  "fetch": {
2024
- "basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2025
- "basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2026
- "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2027
- "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
2028
- "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
2029
- "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
2114
+ "basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2115
+ "basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2116
+ "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2117
+ "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
2118
+ "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
2119
+ "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
2030
2120
  },
2031
2121
  "huggingface.js": {
2032
- "basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
2033
- "basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
2034
- "basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
2035
- "conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
2036
- "conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
2122
+ "basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2123
+ "basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2124
+ "basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2125
+ "conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(chatCompletion.choices[0].message);',
2126
+ "conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
2037
2127
  "textToImage": `import { InferenceClient } from "@huggingface/inference";
2038
2128
 
2039
2129
  const client = new InferenceClient("{{ accessToken }}");
@@ -2043,7 +2133,9 @@ const image = await client.textToImage({
2043
2133
  model: "{{ model.id }}",
2044
2134
  inputs: {{ inputs.asObj.inputs }},
2045
2135
  parameters: { num_inference_steps: 5 },
2046
- });
2136
+ }{% if billTo %}, {
2137
+ billTo: "{{ billTo }}",
2138
+ }{% endif %});
2047
2139
  /// Use the generated image (it's a Blob)`,
2048
2140
  "textToVideo": `import { InferenceClient } from "@huggingface/inference";
2049
2141
 
@@ -2053,12 +2145,14 @@ const image = await client.textToVideo({
2053
2145
  provider: "{{ provider }}",
2054
2146
  model: "{{ model.id }}",
2055
2147
  inputs: {{ inputs.asObj.inputs }},
2056
- });
2148
+ }{% if billTo %}, {
2149
+ billTo: "{{ billTo }}",
2150
+ }{% endif %});
2057
2151
  // Use the generated video (it's a Blob)`
2058
2152
  },
2059
2153
  "openai": {
2060
- "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
2061
- "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
2154
+ "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
2155
+ "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
2062
2156
  }
2063
2157
  },
2064
2158
  "python": {
@@ -2073,13 +2167,13 @@ const image = await client.textToVideo({
2073
2167
  "conversationalStream": 'stream = client.chat.completions.create(\n model="{{ model.id }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="") ',
2074
2168
  "documentQuestionAnswering": 'output = client.document_question_answering(\n "{{ inputs.asObj.image }}",\n question="{{ inputs.asObj.question }}",\n model="{{ model.id }}",\n) ',
2075
2169
  "imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
2076
- "importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n)',
2170
+ "importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
2077
2171
  "textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
2078
2172
  "textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
2079
2173
  },
2080
2174
  "openai": {
2081
- "conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
2082
- "conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
2175
+ "conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
2176
+ "conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
2083
2177
  },
2084
2178
  "requests": {
2085
2179
  "basic": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n}) ',
@@ -2089,7 +2183,7 @@ const image = await client.textToVideo({
2089
2183
  "conversationalStream": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload, stream=True)\n for line in response.iter_lines():\n if not line.startswith(b"data:"):\n continue\n if line.strip() == b"data: [DONE]":\n return\n yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))\n\nchunks = query({\n{{ providerInputs.asJsonString }},\n "stream": True,\n})\n\nfor chunk in chunks:\n print(chunk["choices"][0]["delta"]["content"], end="")',
2090
2184
  "documentQuestionAnswering": 'def query(payload):\n with open(payload["image"], "rb") as f:\n img = f.read()\n payload["image"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {\n "image": "{{ inputs.asObj.image }}",\n "question": "{{ inputs.asObj.question }}",\n },\n}) ',
2091
2185
  "imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
2092
- "importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {"Authorization": "{{ authorizationHeader }}"}',
2186
+ "importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
2093
2187
  "tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
2094
2188
  "textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
2095
2189
  "textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
@@ -2099,12 +2193,15 @@ const image = await client.textToVideo({
2099
2193
  },
2100
2194
  "sh": {
2101
2195
  "curl": {
2102
- "basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
2103
- "basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
2104
- "basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
2196
+ "basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
2197
+ "basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
2198
+ "basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
2105
2199
  "conversational": `curl {{ fullUrl }} \\
2106
2200
  -H 'Authorization: {{ authorizationHeader }}' \\
2107
2201
  -H 'Content-Type: application/json' \\
2202
+ {% if billTo %}
2203
+ -H 'X-HF-Bill-To: {{ billTo }}' \\
2204
+ {% endif %}
2108
2205
  -d '{
2109
2206
  {{ providerInputs.asCurlString }},
2110
2207
  "stream": false
@@ -2112,6 +2209,9 @@ const image = await client.textToVideo({
2112
2209
  "conversationalStream": `curl {{ fullUrl }} \\
2113
2210
  -H 'Authorization: {{ authorizationHeader }}' \\
2114
2211
  -H 'Content-Type: application/json' \\
2212
+ {% if billTo %}
2213
+ -H 'X-HF-Bill-To: {{ billTo }}' \\
2214
+ {% endif %}
2115
2215
  -d '{
2116
2216
  {{ providerInputs.asCurlString }},
2117
2217
  "stream": true
@@ -2120,7 +2220,10 @@ const image = await client.textToVideo({
2120
2220
  -X POST \\
2121
2221
  -d '{"inputs": {{ providerInputs.asObj.inputs }}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
2122
2222
  -H 'Content-Type: application/json' \\
2123
- -H 'Authorization: {{ authorizationHeader }}'`
2223
+ -H 'Authorization: {{ authorizationHeader }}'
2224
+ {% if billTo %} \\
2225
+ -H 'X-HF-Bill-To: {{ billTo }}'
2226
+ {% endif %}`
2124
2227
  }
2125
2228
  }
2126
2229
  };
@@ -2189,23 +2292,34 @@ var HF_JS_METHODS = {
2189
2292
  translation: "translation"
2190
2293
  };
2191
2294
  var snippetGenerator = (templateName, inputPreparationFn) => {
2192
- return (model, accessToken, provider, providerModelId, opts) => {
2295
+ return (model, accessToken, provider, inferenceProviderMapping, billTo, opts) => {
2296
+ const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
2193
2297
  let task = model.pipeline_tag;
2194
2298
  if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
2195
2299
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
2196
2300
  inputPreparationFn = prepareConversationalInput;
2197
2301
  task = "conversational";
2198
2302
  }
2303
+ let providerHelper;
2304
+ try {
2305
+ providerHelper = getProviderHelper(provider, task);
2306
+ } catch (e) {
2307
+ console.error(`Failed to get provider helper for ${provider} (${task})`, e);
2308
+ return [];
2309
+ }
2199
2310
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
2200
2311
  const request2 = makeRequestOptionsFromResolvedModel(
2201
- providerModelId ?? model.id,
2312
+ providerModelId,
2313
+ providerHelper,
2202
2314
  {
2203
2315
  accessToken,
2204
2316
  provider,
2205
2317
  ...inputs
2206
2318
  },
2319
+ inferenceProviderMapping,
2207
2320
  {
2208
- task
2321
+ task,
2322
+ billTo
2209
2323
  }
2210
2324
  );
2211
2325
  let providerInputs = inputs;
@@ -2238,7 +2352,8 @@ var snippetGenerator = (templateName, inputPreparationFn) => {
2238
2352
  },
2239
2353
  model,
2240
2354
  provider,
2241
- providerModelId: providerModelId ?? model.id
2355
+ providerModelId: providerModelId ?? model.id,
2356
+ billTo
2242
2357
  };
2243
2358
  return inferenceSnippetLanguages.map((language) => {
2244
2359
  return CLIENTS[language].map((client) => {
@@ -2328,8 +2443,8 @@ var snippets = {
2328
2443
  "zero-shot-classification": snippetGenerator("zeroShotClassification"),
2329
2444
  "zero-shot-image-classification": snippetGenerator("zeroShotImageClassification")
2330
2445
  };
2331
- function getInferenceSnippets(model, accessToken, provider, providerModelId, opts) {
2332
- return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? [] : [];
2446
+ function getInferenceSnippets(model, accessToken, provider, inferenceProviderMapping, billTo, opts) {
2447
+ return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, billTo, opts) ?? [] : [];
2333
2448
  }
2334
2449
  function formatBody(obj, format) {
2335
2450
  switch (format) {