@huggingface/inference 2.6.6 → 2.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. package/README.md +126 -27
  2. package/dist/index.cjs +81 -17
  3. package/dist/index.js +79 -17
  4. package/dist/src/HfInference.d.ts +28 -0
  5. package/dist/src/HfInference.d.ts.map +1 -0
  6. package/dist/src/index.d.ts +5 -0
  7. package/dist/src/index.d.ts.map +1 -0
  8. package/dist/src/lib/InferenceOutputError.d.ts +4 -0
  9. package/dist/src/lib/InferenceOutputError.d.ts.map +1 -0
  10. package/dist/src/lib/getDefaultTask.d.ts +12 -0
  11. package/dist/src/lib/getDefaultTask.d.ts.map +1 -0
  12. package/dist/src/lib/isUrl.d.ts +2 -0
  13. package/dist/src/lib/isUrl.d.ts.map +1 -0
  14. package/dist/src/lib/makeRequestOptions.d.ts +18 -0
  15. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -0
  16. package/dist/src/tasks/audio/audioClassification.d.ts +24 -0
  17. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -0
  18. package/dist/src/tasks/audio/audioToAudio.d.ts +28 -0
  19. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -0
  20. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +19 -0
  21. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/textToSpeech.d.ts +14 -0
  23. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -0
  24. package/dist/src/tasks/custom/request.d.ts +13 -0
  25. package/dist/src/tasks/custom/request.d.ts.map +1 -0
  26. package/dist/src/tasks/custom/streamingRequest.d.ts +13 -0
  27. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -0
  28. package/dist/src/tasks/cv/imageClassification.d.ts +24 -0
  29. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -0
  30. package/dist/src/tasks/cv/imageSegmentation.d.ts +28 -0
  31. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -0
  32. package/dist/src/tasks/cv/imageToImage.d.ts +55 -0
  33. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -0
  34. package/dist/src/tasks/cv/imageToText.d.ts +18 -0
  35. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -0
  36. package/dist/src/tasks/cv/objectDetection.d.ts +33 -0
  37. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -0
  38. package/dist/src/tasks/cv/textToImage.d.ts +36 -0
  39. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -0
  40. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +26 -0
  41. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -0
  42. package/dist/src/tasks/index.d.ts +32 -0
  43. package/dist/src/tasks/index.d.ts.map +1 -0
  44. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +35 -0
  45. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -0
  46. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +27 -0
  47. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -0
  48. package/dist/src/tasks/nlp/chatCompletion.d.ts +7 -0
  49. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -0
  50. package/dist/src/tasks/nlp/chatCompletionStream.d.ts +7 -0
  51. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -0
  52. package/dist/src/tasks/nlp/featureExtraction.d.ts +19 -0
  53. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -0
  54. package/dist/src/tasks/nlp/fillMask.d.ts +27 -0
  55. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -0
  56. package/dist/src/tasks/nlp/questionAnswering.d.ts +30 -0
  57. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -0
  58. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +19 -0
  59. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -0
  60. package/dist/src/tasks/nlp/summarization.d.ts +48 -0
  61. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -0
  62. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +36 -0
  63. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -0
  64. package/dist/src/tasks/nlp/textClassification.d.ts +22 -0
  65. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -0
  66. package/dist/src/tasks/nlp/textGeneration.d.ts +8 -0
  67. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -0
  68. package/dist/src/tasks/nlp/textGenerationStream.d.ts +81 -0
  69. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -0
  70. package/dist/src/tasks/nlp/tokenClassification.d.ts +51 -0
  71. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -0
  72. package/dist/src/tasks/nlp/translation.d.ts +19 -0
  73. package/dist/src/tasks/nlp/translation.d.ts.map +1 -0
  74. package/dist/src/tasks/nlp/zeroShotClassification.d.ts +28 -0
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -0
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts +20 -0
  77. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -0
  78. package/dist/src/tasks/tabular/tabularRegression.d.ts +20 -0
  79. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -0
  80. package/dist/src/types.d.ts +69 -0
  81. package/dist/src/types.d.ts.map +1 -0
  82. package/dist/src/utils/base64FromBytes.d.ts +2 -0
  83. package/dist/src/utils/base64FromBytes.d.ts.map +1 -0
  84. package/dist/src/utils/distributive-omit.d.ts +9 -0
  85. package/dist/src/utils/distributive-omit.d.ts.map +1 -0
  86. package/dist/src/utils/isBackend.d.ts +2 -0
  87. package/dist/src/utils/isBackend.d.ts.map +1 -0
  88. package/dist/src/utils/isFrontend.d.ts +2 -0
  89. package/dist/src/utils/isFrontend.d.ts.map +1 -0
  90. package/dist/src/utils/omit.d.ts +5 -0
  91. package/dist/src/utils/omit.d.ts.map +1 -0
  92. package/dist/src/utils/pick.d.ts +5 -0
  93. package/dist/src/utils/pick.d.ts.map +1 -0
  94. package/dist/src/utils/toArray.d.ts +2 -0
  95. package/dist/src/utils/toArray.d.ts.map +1 -0
  96. package/dist/src/utils/typedInclude.d.ts +2 -0
  97. package/dist/src/utils/typedInclude.d.ts.map +1 -0
  98. package/dist/src/vendor/fetch-event-source/parse.d.ts +69 -0
  99. package/dist/src/vendor/fetch-event-source/parse.d.ts.map +1 -0
  100. package/dist/src/vendor/fetch-event-source/parse.spec.d.ts +2 -0
  101. package/dist/src/vendor/fetch-event-source/parse.spec.d.ts.map +1 -0
  102. package/dist/test/HfInference.spec.d.ts +2 -0
  103. package/dist/test/HfInference.spec.d.ts.map +1 -0
  104. package/dist/test/expect-closeto.d.ts +2 -0
  105. package/dist/test/expect-closeto.d.ts.map +1 -0
  106. package/dist/test/test-files.d.ts +2 -0
  107. package/dist/test/test-files.d.ts.map +1 -0
  108. package/dist/test/vcr.d.ts +2 -0
  109. package/dist/test/vcr.d.ts.map +1 -0
  110. package/package.json +9 -7
  111. package/src/HfInference.ts +7 -6
  112. package/src/lib/makeRequestOptions.ts +23 -18
  113. package/src/tasks/custom/request.ts +5 -0
  114. package/src/tasks/custom/streamingRequest.ts +8 -0
  115. package/src/tasks/cv/imageToImage.ts +1 -1
  116. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  117. package/src/tasks/index.ts +2 -0
  118. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  119. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  120. package/src/tasks/nlp/chatCompletion.ts +32 -0
  121. package/src/tasks/nlp/chatCompletionStream.ts +17 -0
  122. package/src/tasks/nlp/textGeneration.ts +3 -1
  123. package/src/tasks/nlp/textGenerationStream.ts +2 -2
  124. package/src/types.ts +13 -2
  125. package/src/utils/base64FromBytes.ts +11 -0
  126. package/src/utils/{distributive-omit.d.ts → distributive-omit.ts} +0 -2
  127. package/src/utils/isBackend.ts +6 -0
  128. package/src/utils/isFrontend.ts +3 -0
  129. package/dist/index.d.ts +0 -1341
@@ -0,0 +1,69 @@
1
+ /**
2
+ This file is a part of fetch-event-source package (as of v2.0.1)
3
+ https://github.com/Azure/fetch-event-source/blob/v2.0.1/src/parse.ts
4
+
5
+ Full package can be used after it is made compatible with nodejs:
6
+ https://github.com/Azure/fetch-event-source/issues/20
7
+
8
+ Below is the fetch-event-source package license:
9
+
10
+ MIT License
11
+
12
+ Copyright (c) Microsoft Corporation.
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE
31
+
32
+ */
33
+ /**
34
+ * Represents a message sent in an event stream
35
+ * https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format
36
+ */
37
+ export interface EventSourceMessage {
38
+ /** The event ID to set the EventSource object's last event ID value. */
39
+ id: string;
40
+ /** A string identifying the type of event described. */
41
+ event: string;
42
+ /** The event data */
43
+ data: string;
44
+ /** The reconnection interval (in milliseconds) to wait before retrying the connection */
45
+ retry?: number;
46
+ }
47
+ /**
48
+ * Converts a ReadableStream into a callback pattern.
49
+ * @param stream The input ReadableStream.
50
+ * @param onChunk A function that will be called on each new byte chunk in the stream.
51
+ * @returns {Promise<void>} A promise that will be resolved when the stream closes.
52
+ */
53
+ export declare function getBytes(stream: ReadableStream<Uint8Array>, onChunk: (arr: Uint8Array) => void): Promise<void>;
54
+ /**
55
+ * Parses arbitary byte chunks into EventSource line buffers.
56
+ * Each line should be of the format "field: value" and ends with \r, \n, or \r\n.
57
+ * @param onLine A function that will be called on each new EventSource line.
58
+ * @returns A function that should be called for each incoming byte chunk.
59
+ */
60
+ export declare function getLines(onLine: (line: Uint8Array, fieldLength: number) => void): (arr: Uint8Array) => void;
61
+ /**
62
+ * Parses line buffers into EventSourceMessages.
63
+ * @param onId A function that will be called on each `id` field.
64
+ * @param onRetry A function that will be called on each `retry` field.
65
+ * @param onMessage A function that will be called on each message.
66
+ * @returns A function that should be called for each incoming line buffer.
67
+ */
68
+ export declare function getMessages(onId: (id: string) => void, onRetry: (retry: number) => void, onMessage?: (msg: EventSourceMessage) => void): (line: Uint8Array, fieldLength: number) => void;
69
+ //# sourceMappingURL=parse.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"parse.d.ts","sourceRoot":"","sources":["../../../../src/vendor/fetch-event-source/parse.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA+BG;AAEH;;;GAGG;AACH,MAAM,WAAW,kBAAkB;IAC/B,wEAAwE;IACxE,EAAE,EAAE,MAAM,CAAC;IACX,wDAAwD;IACxD,KAAK,EAAE,MAAM,CAAC;IACd,qBAAqB;IACrB,IAAI,EAAE,MAAM,CAAC;IACb,yFAAyF;IACzF,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAED;;;;;GAKG;AACH,wBAAsB,QAAQ,CAAC,MAAM,EAAE,cAAc,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,CAAC,GAAG,EAAE,UAAU,KAAK,IAAI,iBAMpG;AASD;;;;;GAKG;AACH,wBAAgB,QAAQ,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,KAAK,IAAI,SAO/C,UAAU,UA4D1C;AAED;;;;;;GAMG;AACH,wBAAgB,WAAW,CACvB,IAAI,EAAE,CAAC,EAAE,EAAE,MAAM,KAAK,IAAI,EAC1B,OAAO,EAAE,CAAC,KAAK,EAAE,MAAM,KAAK,IAAI,EAChC,SAAS,CAAC,EAAE,CAAC,GAAG,EAAE,kBAAkB,KAAK,IAAI,UAMhB,UAAU,eAAe,MAAM,UAmC/D"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=parse.spec.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"parse.spec.d.ts","sourceRoot":"","sources":["../../../../src/vendor/fetch-event-source/parse.spec.ts"],"names":[],"mappings":""}
@@ -0,0 +1,2 @@
1
+ import "./vcr";
2
+ //# sourceMappingURL=HfInference.spec.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"HfInference.spec.d.ts","sourceRoot":"","sources":["../../test/HfInference.spec.ts"],"names":[],"mappings":"AAKA,OAAO,OAAO,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=expect-closeto.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"expect-closeto.d.ts","sourceRoot":"","sources":["../../test/expect-closeto.ts"],"names":[],"mappings":""}
@@ -0,0 +1,2 @@
1
+ export declare const readTestFile: (filename: string) => Uint8Array;
2
+ //# sourceMappingURL=test-files.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"test-files.d.ts","sourceRoot":"","sources":["../../test/test-files.ts"],"names":[],"mappings":"AAGA,eAAO,MAAM,YAAY,aAAc,MAAM,KAAG,UAK/C,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=vcr.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"vcr.d.ts","sourceRoot":"","sources":["../../test/vcr.ts"],"names":[],"mappings":""}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "2.6.6",
3
+ "version": "2.7.0",
4
4
  "packageManager": "pnpm@8.10.5",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -29,23 +29,25 @@
29
29
  "src"
30
30
  ],
31
31
  "source": "src/index.ts",
32
- "types": "./dist/index.d.ts",
32
+ "types": "./dist/src/index.d.ts",
33
33
  "main": "./dist/index.cjs",
34
34
  "module": "./dist/index.js",
35
35
  "exports": {
36
- "types": "./dist/index.d.ts",
36
+ "types": "./dist/src/index.d.ts",
37
37
  "require": "./dist/index.cjs",
38
38
  "import": "./dist/index.js"
39
39
  },
40
40
  "type": "module",
41
+ "dependencies": {
42
+ "@huggingface/tasks": "^0.10.0"
43
+ },
41
44
  "devDependencies": {
42
- "@types/node": "18.13.0",
43
- "@huggingface/tasks": "^0.6.0"
45
+ "@types/node": "18.13.0"
44
46
  },
45
47
  "resolutions": {},
46
48
  "scripts": {
47
- "build": "tsup src/index.ts --format cjs,esm --clean && pnpm run dts",
48
- "dts": "tsx scripts/generate-dts.ts",
49
+ "build": "tsup src/index.ts --format cjs,esm --clean && tsc --emitDeclarationOnly --declaration",
50
+ "dts": "tsx scripts/generate-dts.ts && tsc --noEmit dist/index.d.ts",
49
51
  "lint": "eslint --quiet --fix --ext .cjs,.ts .",
50
52
  "lint:check": "eslint --ext .cjs,.ts .",
51
53
  "format": "prettier --write .",
@@ -2,6 +2,9 @@ import * as tasks from "./tasks";
2
2
  import type { Options, RequestArgs } from "./types";
3
3
  import type { DistributiveOmit } from "./utils/distributive-omit";
4
4
 
5
+ /* eslint-disable @typescript-eslint/no-empty-interface */
6
+ /* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
7
+
5
8
  type Task = typeof tasks;
6
9
 
7
10
  type TaskWithNoAccessToken = {
@@ -11,9 +14,9 @@ type TaskWithNoAccessToken = {
11
14
  ) => ReturnType<Task[key]>;
12
15
  };
13
16
 
14
- type TaskWithNoAccessTokenNoModel = {
17
+ type TaskWithNoAccessTokenNoEndpointUrl = {
15
18
  [key in keyof Task]: (
16
- args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
19
+ args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
17
20
  options?: Parameters<Task[key]>[1]
18
21
  ) => ReturnType<Task[key]>;
19
22
  };
@@ -54,14 +57,12 @@ export class HfInferenceEndpoint {
54
57
  enumerable: false,
55
58
  value: (params: RequestArgs, options: Options) =>
56
59
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
57
- fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
60
+ fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
58
61
  });
59
62
  }
60
63
  }
61
64
  }
62
65
 
63
- // eslint-disable-next-line @typescript-eslint/no-empty-interface
64
66
  export interface HfInference extends TaskWithNoAccessToken {}
65
67
 
66
- // eslint-disable-next-line @typescript-eslint/no-empty-interface
67
- export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
68
+ export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
@@ -1,4 +1,5 @@
1
1
  import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { omit } from "../utils/omit";
2
3
  import { HF_HUB_URL } from "./getDefaultTask";
3
4
  import { isUrl } from "./isUrl";
4
5
 
@@ -22,10 +23,10 @@ export async function makeRequestOptions(
22
23
  forceTask?: string | InferenceTask;
23
24
  /** To load default model if needed */
24
25
  taskHint?: InferenceTask;
26
+ chatCompletion?: boolean;
25
27
  }
26
28
  ): Promise<{ url: string; info: RequestInit }> {
27
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
28
- const { accessToken, model: _model, ...otherArgs } = args;
29
+ const { accessToken, endpointUrl, ...otherArgs } = args;
29
30
  let { model } = args;
30
31
  const {
31
32
  forceTask: task,
@@ -34,7 +35,7 @@ export async function makeRequestOptions(
34
35
  wait_for_model,
35
36
  use_cache,
36
37
  dont_load_model,
37
- ...otherOptions
38
+ chatCompletion,
38
39
  } = options ?? {};
39
40
 
40
41
  const headers: Record<string, string> = {};
@@ -77,11 +78,17 @@ export async function makeRequestOptions(
77
78
  headers["X-Load-Model"] = "0";
78
79
  }
79
80
 
80
- const url = (() => {
81
+ let url = (() => {
82
+ if (endpointUrl && isUrl(model)) {
83
+ throw new TypeError("Both model and endpointUrl cannot be URLs");
84
+ }
81
85
  if (isUrl(model)) {
86
+ console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
82
87
  return model;
83
88
  }
84
-
89
+ if (endpointUrl) {
90
+ return endpointUrl;
91
+ }
85
92
  if (task) {
86
93
  return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
87
94
  }
@@ -89,19 +96,18 @@ export async function makeRequestOptions(
89
96
  return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
90
97
  })();
91
98
 
92
- // Let users configure credentials, or disable them all together (or keep default behavior).
93
- // ---
94
- // This used to be an internal property only and never exposed to users. This means that most usages will never define this value
95
- // So in order to make this backwards compatible, if it's undefined we go to "same-origin" (default behaviour before).
96
- // If it's a boolean and set to true then set to "include". If false, don't define credentials at all (useful for edge runtimes)
97
- // Then finally, if it's a string, use it as-is.
99
+ if (chatCompletion && !url.endsWith("/chat/completions")) {
100
+ url += "/v1/chat/completions";
101
+ }
102
+
103
+ /**
104
+ * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
105
+ */
98
106
  let credentials: RequestCredentials | undefined;
99
107
  if (typeof includeCredentials === "string") {
100
108
  credentials = includeCredentials as RequestCredentials;
101
- } else if (typeof includeCredentials === "boolean") {
102
- credentials = includeCredentials ? "include" : undefined;
103
- } else if (includeCredentials === undefined) {
104
- credentials = "same-origin";
109
+ } else if (includeCredentials === true) {
110
+ credentials = "include";
105
111
  }
106
112
 
107
113
  const info: RequestInit = {
@@ -110,10 +116,9 @@ export async function makeRequestOptions(
110
116
  body: binary
111
117
  ? args.data
112
118
  : JSON.stringify({
113
- ...otherArgs,
114
- options: options && otherOptions,
119
+ ...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
115
120
  }),
116
- credentials,
121
+ ...(credentials && { credentials }),
117
122
  signal: options?.signal,
118
123
  };
119
124
 
@@ -11,6 +11,8 @@ export async function request<T>(
11
11
  task?: string | InferenceTask;
12
12
  /** To load default model if needed */
13
13
  taskHint?: InferenceTask;
14
+ /** Is chat completion compatible */
15
+ chatCompletion?: boolean;
14
16
  }
15
17
  ): Promise<T> {
16
18
  const { url, info } = await makeRequestOptions(args, options);
@@ -26,6 +28,9 @@ export async function request<T>(
26
28
  if (!response.ok) {
27
29
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
28
30
  const output = await response.json();
31
+ if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
32
+ throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
33
+ }
29
34
  if (output.error) {
30
35
  throw new Error(output.error);
31
36
  }
@@ -13,6 +13,8 @@ export async function* streamingRequest<T>(
13
13
  task?: string | InferenceTask;
14
14
  /** To load default model if needed */
15
15
  taskHint?: InferenceTask;
16
+ /** Is chat completion compatible */
17
+ chatCompletion?: boolean;
16
18
  }
17
19
  ): AsyncGenerator<T> {
18
20
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
@@ -27,6 +29,9 @@ export async function* streamingRequest<T>(
27
29
  if (!response.ok) {
28
30
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
29
31
  const output = await response.json();
32
+ if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
33
+ throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
34
+ }
30
35
  if (output.error) {
31
36
  throw new Error(output.error);
32
37
  }
@@ -67,6 +72,9 @@ export async function* streamingRequest<T>(
67
72
  onChunk(value);
68
73
  for (const event of events) {
69
74
  if (event.data.length > 0) {
75
+ if (event.data === "[DONE]") {
76
+ return;
77
+ }
70
78
  const data = JSON.parse(event.data);
71
79
  if (typeof data === "object" && data !== null && "error" in data) {
72
80
  throw new Error(data.error);
@@ -1,7 +1,7 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
+ import { base64FromBytes } from "../../utils/base64FromBytes";
3
4
  import { request } from "../custom/request";
4
- import { base64FromBytes } from "../../../../shared";
5
5
 
6
6
  export type ImageToImageArgs = BaseArgs & {
7
7
  /**
@@ -2,7 +2,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
- import { base64FromBytes } from "../../../../shared";
5
+ import { base64FromBytes } from "../../utils/base64FromBytes";
6
6
 
7
7
  export type ZeroShotImageClassificationArgs = BaseArgs & {
8
8
  inputs: {
@@ -30,6 +30,8 @@ export * from "./nlp/textGenerationStream";
30
30
  export * from "./nlp/tokenClassification";
31
31
  export * from "./nlp/translation";
32
32
  export * from "./nlp/zeroShotClassification";
33
+ export * from "./nlp/chatCompletion";
34
+ export * from "./nlp/chatCompletionStream";
33
35
 
34
36
  // Multimodal tasks
35
37
  export * from "./multimodal/documentQuestionAnswering";
@@ -2,8 +2,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
- import { base64FromBytes } from "../../../../shared";
6
5
  import { toArray } from "../../utils/toArray";
6
+ import { base64FromBytes } from "../../utils/base64FromBytes";
7
7
 
8
8
  export type DocumentQuestionAnsweringArgs = BaseArgs & {
9
9
  inputs: {
@@ -1,7 +1,7 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
+ import { base64FromBytes } from "../../utils/base64FromBytes";
3
4
  import { request } from "../custom/request";
4
- import { base64FromBytes } from "../../../../shared";
5
5
 
6
6
  export type VisualQuestionAnsweringArgs = BaseArgs & {
7
7
  inputs: {
@@ -0,0 +1,32 @@
1
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import type { BaseArgs, Options } from "../../types";
3
+ import { request } from "../custom/request";
4
+ import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
5
+
6
+ /**
7
+ * Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
8
+ */
9
+
10
+ export async function chatCompletion(
11
+ args: BaseArgs & ChatCompletionInput,
12
+ options?: Options
13
+ ): Promise<ChatCompletionOutput> {
14
+ const res = await request<ChatCompletionOutput>(args, {
15
+ ...options,
16
+ taskHint: "text-generation",
17
+ chatCompletion: true,
18
+ });
19
+ const isValidOutput =
20
+ typeof res === "object" &&
21
+ Array.isArray(res?.choices) &&
22
+ typeof res?.created === "number" &&
23
+ typeof res?.id === "string" &&
24
+ typeof res?.model === "string" &&
25
+ typeof res?.system_fingerprint === "string" &&
26
+ typeof res?.usage === "object";
27
+
28
+ if (!isValidOutput) {
29
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
30
+ }
31
+ return res;
32
+ }
@@ -0,0 +1,17 @@
1
+ import type { BaseArgs, Options } from "../../types";
2
+ import { streamingRequest } from "../custom/streamingRequest";
3
+ import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
4
+
5
+ /**
6
+ * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
7
+ */
8
+ export async function* chatCompletionStream(
9
+ args: BaseArgs & ChatCompletionInput,
10
+ options?: Options
11
+ ): AsyncGenerator<ChatCompletionStreamOutput> {
12
+ yield* streamingRequest<ChatCompletionStreamOutput>(args, {
13
+ ...options,
14
+ taskHint: "text-generation",
15
+ chatCompletion: true,
16
+ });
17
+ }
@@ -1,8 +1,10 @@
1
- import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference";
1
+ import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
2
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { request } from "../custom/request";
5
5
 
6
+ export type { TextGenerationInput, TextGenerationOutput };
7
+
6
8
  /**
7
9
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
8
10
  */
@@ -1,8 +1,7 @@
1
+ import type { TextGenerationInput } from "@huggingface/tasks";
1
2
  import type { BaseArgs, Options } from "../../types";
2
3
  import { streamingRequest } from "../custom/streamingRequest";
3
4
 
4
- import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";
5
-
6
5
  export interface TextGenerationStreamToken {
7
6
  /** Token ID from the model tokenizer */
8
7
  id: number;
@@ -68,6 +67,7 @@ export interface TextGenerationStreamDetails {
68
67
  }
69
68
 
70
69
  export interface TextGenerationStreamOutput {
70
+ index?: number;
71
71
  /** Generated token, one at a time */
72
72
  token: TextGenerationStreamToken;
73
73
  /**
package/src/types.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import type { PipelineType } from "@huggingface/tasks";
2
+ import type { ChatCompletionInput } from "@huggingface/tasks";
2
3
 
3
4
  export interface Options {
4
5
  /**
@@ -47,15 +48,25 @@ export interface BaseArgs {
47
48
  */
48
49
  accessToken?: string;
49
50
  /**
50
- * The model to use. Can be a full URL for a dedicated inference endpoint.
51
+ * The model to use.
51
52
  *
52
53
  * If not specified, will call huggingface.co/api/tasks to get the default model for the task.
54
+ *
55
+ * /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
56
+ * Use the `endpointUrl` parameter instead.
53
57
  */
54
58
  model?: string;
59
+
60
+ /**
61
+ * The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
62
+ *
63
+ * If specified, will use this URL instead of the default one.
64
+ */
65
+ endpointUrl?: string;
55
66
  }
56
67
 
57
68
  export type RequestArgs = BaseArgs &
58
- ({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
69
+ ({ data: Blob | ArrayBuffer } | { inputs: unknown } | ChatCompletionInput) & {
59
70
  parameters?: Record<string, unknown>;
60
71
  accessToken?: string;
61
72
  };
@@ -0,0 +1,11 @@
1
+ export function base64FromBytes(arr: Uint8Array): string {
2
+ if (globalThis.Buffer) {
3
+ return globalThis.Buffer.from(arr).toString("base64");
4
+ } else {
5
+ const bin: string[] = [];
6
+ arr.forEach((byte) => {
7
+ bin.push(String.fromCharCode(byte));
8
+ });
9
+ return globalThis.btoa(bin.join(""));
10
+ }
11
+ }
@@ -4,8 +4,6 @@
4
4
  * This allows omitting keys from objects inside unions, without merging the individual components of the union.
5
5
  */
6
6
 
7
- type Keys<T> = keyof T;
8
- type DistributiveKeys<T> = T extends unknown ? Keys<T> : never;
9
7
  type Omit_<T, K> = Omit<T, Extract<keyof T, K>>;
10
8
 
11
9
  export type DistributiveOmit<T, K> = T extends unknown
@@ -0,0 +1,6 @@
1
+ const isBrowser = typeof window !== "undefined" && typeof window.document !== "undefined";
2
+
3
+ const isWebWorker =
4
+ typeof self === "object" && self.constructor && self.constructor.name === "DedicatedWorkerGlobalScope";
5
+
6
+ export const isBackend = !isBrowser && !isWebWorker;
@@ -0,0 +1,3 @@
1
+ import { isBackend } from "./isBackend";
2
+
3
+ export const isFrontend = !isBackend;