@huggingface/inference 2.6.7 → 2.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. package/README.md +126 -27
  2. package/dist/index.cjs +78 -12
  3. package/dist/index.js +76 -12
  4. package/dist/src/HfInference.d.ts +28 -0
  5. package/dist/src/HfInference.d.ts.map +1 -0
  6. package/dist/src/index.d.ts +5 -0
  7. package/dist/src/index.d.ts.map +1 -0
  8. package/dist/src/lib/InferenceOutputError.d.ts +4 -0
  9. package/dist/src/lib/InferenceOutputError.d.ts.map +1 -0
  10. package/dist/src/lib/getDefaultTask.d.ts +12 -0
  11. package/dist/src/lib/getDefaultTask.d.ts.map +1 -0
  12. package/dist/src/lib/isUrl.d.ts +2 -0
  13. package/dist/src/lib/isUrl.d.ts.map +1 -0
  14. package/dist/src/lib/makeRequestOptions.d.ts +18 -0
  15. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -0
  16. package/dist/src/tasks/audio/audioClassification.d.ts +24 -0
  17. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -0
  18. package/dist/src/tasks/audio/audioToAudio.d.ts +28 -0
  19. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -0
  20. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +19 -0
  21. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/textToSpeech.d.ts +14 -0
  23. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -0
  24. package/dist/src/tasks/custom/request.d.ts +13 -0
  25. package/dist/src/tasks/custom/request.d.ts.map +1 -0
  26. package/dist/src/tasks/custom/streamingRequest.d.ts +13 -0
  27. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -0
  28. package/dist/src/tasks/cv/imageClassification.d.ts +24 -0
  29. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -0
  30. package/dist/src/tasks/cv/imageSegmentation.d.ts +28 -0
  31. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -0
  32. package/dist/src/tasks/cv/imageToImage.d.ts +55 -0
  33. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -0
  34. package/dist/src/tasks/cv/imageToText.d.ts +18 -0
  35. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -0
  36. package/dist/src/tasks/cv/objectDetection.d.ts +33 -0
  37. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -0
  38. package/dist/src/tasks/cv/textToImage.d.ts +36 -0
  39. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -0
  40. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +26 -0
  41. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -0
  42. package/dist/src/tasks/index.d.ts +32 -0
  43. package/dist/src/tasks/index.d.ts.map +1 -0
  44. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +35 -0
  45. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -0
  46. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +27 -0
  47. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -0
  48. package/dist/src/tasks/nlp/chatCompletion.d.ts +7 -0
  49. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -0
  50. package/dist/src/tasks/nlp/chatCompletionStream.d.ts +7 -0
  51. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -0
  52. package/dist/src/tasks/nlp/featureExtraction.d.ts +19 -0
  53. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -0
  54. package/dist/src/tasks/nlp/fillMask.d.ts +27 -0
  55. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -0
  56. package/dist/src/tasks/nlp/questionAnswering.d.ts +30 -0
  57. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -0
  58. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +19 -0
  59. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -0
  60. package/dist/src/tasks/nlp/summarization.d.ts +48 -0
  61. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -0
  62. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +36 -0
  63. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -0
  64. package/dist/src/tasks/nlp/textClassification.d.ts +22 -0
  65. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -0
  66. package/dist/src/tasks/nlp/textGeneration.d.ts +8 -0
  67. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -0
  68. package/dist/src/tasks/nlp/textGenerationStream.d.ts +81 -0
  69. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -0
  70. package/dist/src/tasks/nlp/tokenClassification.d.ts +51 -0
  71. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -0
  72. package/dist/src/tasks/nlp/translation.d.ts +19 -0
  73. package/dist/src/tasks/nlp/translation.d.ts.map +1 -0
  74. package/dist/src/tasks/nlp/zeroShotClassification.d.ts +28 -0
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -0
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts +20 -0
  77. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -0
  78. package/dist/src/tasks/tabular/tabularRegression.d.ts +20 -0
  79. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -0
  80. package/dist/src/types.d.ts +69 -0
  81. package/dist/src/types.d.ts.map +1 -0
  82. package/dist/src/utils/base64FromBytes.d.ts +2 -0
  83. package/dist/src/utils/base64FromBytes.d.ts.map +1 -0
  84. package/dist/src/utils/distributive-omit.d.ts +9 -0
  85. package/dist/src/utils/distributive-omit.d.ts.map +1 -0
  86. package/dist/src/utils/isBackend.d.ts +2 -0
  87. package/dist/src/utils/isBackend.d.ts.map +1 -0
  88. package/dist/src/utils/isFrontend.d.ts +2 -0
  89. package/dist/src/utils/isFrontend.d.ts.map +1 -0
  90. package/dist/src/utils/omit.d.ts +5 -0
  91. package/dist/src/utils/omit.d.ts.map +1 -0
  92. package/dist/src/utils/pick.d.ts +5 -0
  93. package/dist/src/utils/pick.d.ts.map +1 -0
  94. package/dist/src/utils/toArray.d.ts +2 -0
  95. package/dist/src/utils/toArray.d.ts.map +1 -0
  96. package/dist/src/utils/typedInclude.d.ts +2 -0
  97. package/dist/src/utils/typedInclude.d.ts.map +1 -0
  98. package/dist/src/vendor/fetch-event-source/parse.d.ts +69 -0
  99. package/dist/src/vendor/fetch-event-source/parse.d.ts.map +1 -0
  100. package/dist/src/vendor/fetch-event-source/parse.spec.d.ts +2 -0
  101. package/dist/src/vendor/fetch-event-source/parse.spec.d.ts.map +1 -0
  102. package/dist/test/HfInference.spec.d.ts +2 -0
  103. package/dist/test/HfInference.spec.d.ts.map +1 -0
  104. package/dist/test/expect-closeto.d.ts +2 -0
  105. package/dist/test/expect-closeto.d.ts.map +1 -0
  106. package/dist/test/test-files.d.ts +2 -0
  107. package/dist/test/test-files.d.ts.map +1 -0
  108. package/dist/test/vcr.d.ts +2 -0
  109. package/dist/test/vcr.d.ts.map +1 -0
  110. package/package.json +9 -7
  111. package/src/HfInference.ts +4 -4
  112. package/src/lib/makeRequestOptions.ts +17 -7
  113. package/src/tasks/custom/request.ts +5 -0
  114. package/src/tasks/custom/streamingRequest.ts +8 -0
  115. package/src/tasks/cv/imageToImage.ts +1 -1
  116. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  117. package/src/tasks/index.ts +2 -0
  118. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  119. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  120. package/src/tasks/nlp/chatCompletion.ts +32 -0
  121. package/src/tasks/nlp/chatCompletionStream.ts +17 -0
  122. package/src/tasks/nlp/textGeneration.ts +2 -202
  123. package/src/tasks/nlp/textGenerationStream.ts +2 -1
  124. package/src/types.ts +14 -3
  125. package/src/utils/base64FromBytes.ts +11 -0
  126. package/src/utils/{distributive-omit.d.ts → distributive-omit.ts} +0 -2
  127. package/src/utils/isBackend.ts +6 -0
  128. package/src/utils/isFrontend.ts +3 -0
  129. package/dist/index.d.ts +0 -1536
@@ -0,0 +1,69 @@
1
+ /**
2
+ This file is a part of fetch-event-source package (as of v2.0.1)
3
+ https://github.com/Azure/fetch-event-source/blob/v2.0.1/src/parse.ts
4
+
5
+ Full package can be used after it is made compatible with nodejs:
6
+ https://github.com/Azure/fetch-event-source/issues/20
7
+
8
+ Below is the fetch-event-source package license:
9
+
10
+ MIT License
11
+
12
+ Copyright (c) Microsoft Corporation.
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE
31
+
32
+ */
33
+ /**
34
+ * Represents a message sent in an event stream
35
+ * https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format
36
+ */
37
+ export interface EventSourceMessage {
38
+ /** The event ID to set the EventSource object's last event ID value. */
39
+ id: string;
40
+ /** A string identifying the type of event described. */
41
+ event: string;
42
+ /** The event data */
43
+ data: string;
44
+ /** The reconnection interval (in milliseconds) to wait before retrying the connection */
45
+ retry?: number;
46
+ }
47
+ /**
48
+ * Converts a ReadableStream into a callback pattern.
49
+ * @param stream The input ReadableStream.
50
+ * @param onChunk A function that will be called on each new byte chunk in the stream.
51
+ * @returns {Promise<void>} A promise that will be resolved when the stream closes.
52
+ */
53
+ export declare function getBytes(stream: ReadableStream<Uint8Array>, onChunk: (arr: Uint8Array) => void): Promise<void>;
54
+ /**
55
+ * Parses arbitary byte chunks into EventSource line buffers.
56
+ * Each line should be of the format "field: value" and ends with \r, \n, or \r\n.
57
+ * @param onLine A function that will be called on each new EventSource line.
58
+ * @returns A function that should be called for each incoming byte chunk.
59
+ */
60
+ export declare function getLines(onLine: (line: Uint8Array, fieldLength: number) => void): (arr: Uint8Array) => void;
61
+ /**
62
+ * Parses line buffers into EventSourceMessages.
63
+ * @param onId A function that will be called on each `id` field.
64
+ * @param onRetry A function that will be called on each `retry` field.
65
+ * @param onMessage A function that will be called on each message.
66
+ * @returns A function that should be called for each incoming line buffer.
67
+ */
68
+ export declare function getMessages(onId: (id: string) => void, onRetry: (retry: number) => void, onMessage?: (msg: EventSourceMessage) => void): (line: Uint8Array, fieldLength: number) => void;
69
+ //# sourceMappingURL=parse.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"parse.d.ts","sourceRoot":"","sources":["../../../../src/vendor/fetch-event-source/parse.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA+BG;AAEH;;;GAGG;AACH,MAAM,WAAW,kBAAkB;IAC/B,wEAAwE;IACxE,EAAE,EAAE,MAAM,CAAC;IACX,wDAAwD;IACxD,KAAK,EAAE,MAAM,CAAC;IACd,qBAAqB;IACrB,IAAI,EAAE,MAAM,CAAC;IACb,yFAAyF;IACzF,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAED;;;;;GAKG;AACH,wBAAsB,QAAQ,CAAC,MAAM,EAAE,cAAc,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,CAAC,GAAG,EAAE,UAAU,KAAK,IAAI,iBAMpG;AASD;;;;;GAKG;AACH,wBAAgB,QAAQ,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,KAAK,IAAI,SAO/C,UAAU,UA4D1C;AAED;;;;;;GAMG;AACH,wBAAgB,WAAW,CACvB,IAAI,EAAE,CAAC,EAAE,EAAE,MAAM,KAAK,IAAI,EAC1B,OAAO,EAAE,CAAC,KAAK,EAAE,MAAM,KAAK,IAAI,EAChC,SAAS,CAAC,EAAE,CAAC,GAAG,EAAE,kBAAkB,KAAK,IAAI,UAMhB,UAAU,eAAe,MAAM,UAmC/D"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=parse.spec.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"parse.spec.d.ts","sourceRoot":"","sources":["../../../../src/vendor/fetch-event-source/parse.spec.ts"],"names":[],"mappings":""}
@@ -0,0 +1,2 @@
1
+ import "./vcr";
2
+ //# sourceMappingURL=HfInference.spec.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"HfInference.spec.d.ts","sourceRoot":"","sources":["../../test/HfInference.spec.ts"],"names":[],"mappings":"AAKA,OAAO,OAAO,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=expect-closeto.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"expect-closeto.d.ts","sourceRoot":"","sources":["../../test/expect-closeto.ts"],"names":[],"mappings":""}
@@ -0,0 +1,2 @@
1
+ export declare const readTestFile: (filename: string) => Uint8Array;
2
+ //# sourceMappingURL=test-files.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"test-files.d.ts","sourceRoot":"","sources":["../../test/test-files.ts"],"names":[],"mappings":"AAGA,eAAO,MAAM,YAAY,aAAc,MAAM,KAAG,UAK/C,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=vcr.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"vcr.d.ts","sourceRoot":"","sources":["../../test/vcr.ts"],"names":[],"mappings":""}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "2.6.7",
3
+ "version": "2.7.0",
4
4
  "packageManager": "pnpm@8.10.5",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -29,23 +29,25 @@
29
29
  "src"
30
30
  ],
31
31
  "source": "src/index.ts",
32
- "types": "./dist/index.d.ts",
32
+ "types": "./dist/src/index.d.ts",
33
33
  "main": "./dist/index.cjs",
34
34
  "module": "./dist/index.js",
35
35
  "exports": {
36
- "types": "./dist/index.d.ts",
36
+ "types": "./dist/src/index.d.ts",
37
37
  "require": "./dist/index.cjs",
38
38
  "import": "./dist/index.js"
39
39
  },
40
40
  "type": "module",
41
+ "dependencies": {
42
+ "@huggingface/tasks": "^0.10.0"
43
+ },
41
44
  "devDependencies": {
42
- "@types/node": "18.13.0",
43
- "@huggingface/tasks": "^0.8.0"
45
+ "@types/node": "18.13.0"
44
46
  },
45
47
  "resolutions": {},
46
48
  "scripts": {
47
- "build": "tsup src/index.ts --format cjs,esm --clean && pnpm run dts",
48
- "dts": "tsx scripts/generate-dts.ts",
49
+ "build": "tsup src/index.ts --format cjs,esm --clean && tsc --emitDeclarationOnly --declaration",
50
+ "dts": "tsx scripts/generate-dts.ts && tsc --noEmit dist/index.d.ts",
49
51
  "lint": "eslint --quiet --fix --ext .cjs,.ts .",
50
52
  "lint:check": "eslint --ext .cjs,.ts .",
51
53
  "format": "prettier --write .",
@@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
14
14
  ) => ReturnType<Task[key]>;
15
15
  };
16
16
 
17
- type TaskWithNoAccessTokenNoModel = {
17
+ type TaskWithNoAccessTokenNoEndpointUrl = {
18
18
  [key in keyof Task]: (
19
- args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
19
+ args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
20
20
  options?: Parameters<Task[key]>[1]
21
21
  ) => ReturnType<Task[key]>;
22
22
  };
@@ -57,7 +57,7 @@ export class HfInferenceEndpoint {
57
57
  enumerable: false,
58
58
  value: (params: RequestArgs, options: Options) =>
59
59
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
60
- fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
60
+ fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
61
61
  });
62
62
  }
63
63
  }
@@ -65,4 +65,4 @@ export class HfInferenceEndpoint {
65
65
 
66
66
  export interface HfInference extends TaskWithNoAccessToken {}
67
67
 
68
- export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
68
+ export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
@@ -1,4 +1,5 @@
1
1
  import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { omit } from "../utils/omit";
2
3
  import { HF_HUB_URL } from "./getDefaultTask";
3
4
  import { isUrl } from "./isUrl";
4
5
 
@@ -22,10 +23,10 @@ export async function makeRequestOptions(
22
23
  forceTask?: string | InferenceTask;
23
24
  /** To load default model if needed */
24
25
  taskHint?: InferenceTask;
26
+ chatCompletion?: boolean;
25
27
  }
26
28
  ): Promise<{ url: string; info: RequestInit }> {
27
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
28
- const { accessToken, model: _model, ...otherArgs } = args;
29
+ const { accessToken, endpointUrl, ...otherArgs } = args;
29
30
  let { model } = args;
30
31
  const {
31
32
  forceTask: task,
@@ -34,7 +35,7 @@ export async function makeRequestOptions(
34
35
  wait_for_model,
35
36
  use_cache,
36
37
  dont_load_model,
37
- ...otherOptions
38
+ chatCompletion,
38
39
  } = options ?? {};
39
40
 
40
41
  const headers: Record<string, string> = {};
@@ -77,11 +78,17 @@ export async function makeRequestOptions(
77
78
  headers["X-Load-Model"] = "0";
78
79
  }
79
80
 
80
- const url = (() => {
81
+ let url = (() => {
82
+ if (endpointUrl && isUrl(model)) {
83
+ throw new TypeError("Both model and endpointUrl cannot be URLs");
84
+ }
81
85
  if (isUrl(model)) {
86
+ console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
82
87
  return model;
83
88
  }
84
-
89
+ if (endpointUrl) {
90
+ return endpointUrl;
91
+ }
85
92
  if (task) {
86
93
  return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
87
94
  }
@@ -89,6 +96,10 @@ export async function makeRequestOptions(
89
96
  return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
90
97
  })();
91
98
 
99
+ if (chatCompletion && !url.endsWith("/chat/completions")) {
100
+ url += "/v1/chat/completions";
101
+ }
102
+
92
103
  /**
93
104
  * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
94
105
  */
@@ -105,8 +116,7 @@ export async function makeRequestOptions(
105
116
  body: binary
106
117
  ? args.data
107
118
  : JSON.stringify({
108
- ...otherArgs,
109
- options: options && otherOptions,
119
+ ...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
110
120
  }),
111
121
  ...(credentials && { credentials }),
112
122
  signal: options?.signal,
@@ -11,6 +11,8 @@ export async function request<T>(
11
11
  task?: string | InferenceTask;
12
12
  /** To load default model if needed */
13
13
  taskHint?: InferenceTask;
14
+ /** Is chat completion compatible */
15
+ chatCompletion?: boolean;
14
16
  }
15
17
  ): Promise<T> {
16
18
  const { url, info } = await makeRequestOptions(args, options);
@@ -26,6 +28,9 @@ export async function request<T>(
26
28
  if (!response.ok) {
27
29
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
28
30
  const output = await response.json();
31
+ if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
32
+ throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
33
+ }
29
34
  if (output.error) {
30
35
  throw new Error(output.error);
31
36
  }
@@ -13,6 +13,8 @@ export async function* streamingRequest<T>(
13
13
  task?: string | InferenceTask;
14
14
  /** To load default model if needed */
15
15
  taskHint?: InferenceTask;
16
+ /** Is chat completion compatible */
17
+ chatCompletion?: boolean;
16
18
  }
17
19
  ): AsyncGenerator<T> {
18
20
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
@@ -27,6 +29,9 @@ export async function* streamingRequest<T>(
27
29
  if (!response.ok) {
28
30
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
29
31
  const output = await response.json();
32
+ if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
33
+ throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
34
+ }
30
35
  if (output.error) {
31
36
  throw new Error(output.error);
32
37
  }
@@ -67,6 +72,9 @@ export async function* streamingRequest<T>(
67
72
  onChunk(value);
68
73
  for (const event of events) {
69
74
  if (event.data.length > 0) {
75
+ if (event.data === "[DONE]") {
76
+ return;
77
+ }
70
78
  const data = JSON.parse(event.data);
71
79
  if (typeof data === "object" && data !== null && "error" in data) {
72
80
  throw new Error(data.error);
@@ -1,7 +1,7 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
+ import { base64FromBytes } from "../../utils/base64FromBytes";
3
4
  import { request } from "../custom/request";
4
- import { base64FromBytes } from "../../../../shared";
5
5
 
6
6
  export type ImageToImageArgs = BaseArgs & {
7
7
  /**
@@ -2,7 +2,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
- import { base64FromBytes } from "../../../../shared";
5
+ import { base64FromBytes } from "../../utils/base64FromBytes";
6
6
 
7
7
  export type ZeroShotImageClassificationArgs = BaseArgs & {
8
8
  inputs: {
@@ -30,6 +30,8 @@ export * from "./nlp/textGenerationStream";
30
30
  export * from "./nlp/tokenClassification";
31
31
  export * from "./nlp/translation";
32
32
  export * from "./nlp/zeroShotClassification";
33
+ export * from "./nlp/chatCompletion";
34
+ export * from "./nlp/chatCompletionStream";
33
35
 
34
36
  // Multimodal tasks
35
37
  export * from "./multimodal/documentQuestionAnswering";
@@ -2,8 +2,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
- import { base64FromBytes } from "../../../../shared";
6
5
  import { toArray } from "../../utils/toArray";
6
+ import { base64FromBytes } from "../../utils/base64FromBytes";
7
7
 
8
8
  export type DocumentQuestionAnsweringArgs = BaseArgs & {
9
9
  inputs: {
@@ -1,7 +1,7 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
2
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
+ import { base64FromBytes } from "../../utils/base64FromBytes";
3
4
  import { request } from "../custom/request";
4
- import { base64FromBytes } from "../../../../shared";
5
5
 
6
6
  export type VisualQuestionAnsweringArgs = BaseArgs & {
7
7
  inputs: {
@@ -0,0 +1,32 @@
1
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import type { BaseArgs, Options } from "../../types";
3
+ import { request } from "../custom/request";
4
+ import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
5
+
6
+ /**
7
+ * Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
8
+ */
9
+
10
+ export async function chatCompletion(
11
+ args: BaseArgs & ChatCompletionInput,
12
+ options?: Options
13
+ ): Promise<ChatCompletionOutput> {
14
+ const res = await request<ChatCompletionOutput>(args, {
15
+ ...options,
16
+ taskHint: "text-generation",
17
+ chatCompletion: true,
18
+ });
19
+ const isValidOutput =
20
+ typeof res === "object" &&
21
+ Array.isArray(res?.choices) &&
22
+ typeof res?.created === "number" &&
23
+ typeof res?.id === "string" &&
24
+ typeof res?.model === "string" &&
25
+ typeof res?.system_fingerprint === "string" &&
26
+ typeof res?.usage === "object";
27
+
28
+ if (!isValidOutput) {
29
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
30
+ }
31
+ return res;
32
+ }
@@ -0,0 +1,17 @@
1
+ import type { BaseArgs, Options } from "../../types";
2
+ import { streamingRequest } from "../custom/streamingRequest";
3
+ import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
4
+
5
+ /**
6
+ * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
7
+ */
8
+ export async function* chatCompletionStream(
9
+ args: BaseArgs & ChatCompletionInput,
10
+ options?: Options
11
+ ): AsyncGenerator<ChatCompletionStreamOutput> {
12
+ yield* streamingRequest<ChatCompletionStreamOutput>(args, {
13
+ ...options,
14
+ taskHint: "text-generation",
15
+ chatCompletion: true,
16
+ });
17
+ }
@@ -1,209 +1,9 @@
1
+ import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- /**
6
- * Inputs for Text Generation inference
7
- */
8
- export interface TextGenerationInput {
9
- /**
10
- * The text to initialize generation with
11
- */
12
- inputs: string;
13
- /**
14
- * Additional inference parameters
15
- */
16
- parameters?: TextGenerationParameters;
17
- /**
18
- * Whether to stream output tokens
19
- */
20
- stream?: boolean;
21
- [property: string]: unknown;
22
- }
23
-
24
- /**
25
- * Additional inference parameters
26
- *
27
- * Additional inference parameters for Text Generation
28
- */
29
- export interface TextGenerationParameters {
30
- /**
31
- * The number of sampling queries to run. Only the best one (in terms of total logprob) will
32
- * be returned.
33
- */
34
- best_of?: number;
35
- /**
36
- * Whether or not to output decoder input details
37
- */
38
- decoder_input_details?: boolean;
39
- /**
40
- * Whether or not to output details
41
- */
42
- details?: boolean;
43
- /**
44
- * Whether to use logits sampling instead of greedy decoding when generating new tokens.
45
- */
46
- do_sample?: boolean;
47
- /**
48
- * The maximum number of tokens to generate.
49
- */
50
- max_new_tokens?: number;
51
- /**
52
- * The parameter for repetition penalty. A value of 1.0 means no penalty. See [this
53
- * paper](https://hf.co/papers/1909.05858) for more details.
54
- */
55
- repetition_penalty?: number;
56
- /**
57
- * Whether to prepend the prompt to the generated text.
58
- */
59
- return_full_text?: boolean;
60
- /**
61
- * The random sampling seed.
62
- */
63
- seed?: number;
64
- /**
65
- * Stop generating tokens if a member of `stop_sequences` is generated.
66
- */
67
- stop_sequences?: string[];
68
- /**
69
- * The value used to modulate the logits distribution.
70
- */
71
- temperature?: number;
72
- /**
73
- * The number of highest probability vocabulary tokens to keep for top-k-filtering.
74
- */
75
- top_k?: number;
76
- /**
77
- * If set to < 1, only the smallest set of most probable tokens with probabilities that add
78
- * up to `top_p` or higher are kept for generation.
79
- */
80
- top_p?: number;
81
- /**
82
- * Truncate input tokens to the given size.
83
- */
84
- truncate?: number;
85
- /**
86
- * Typical Decoding mass. See [Typical Decoding for Natural Language
87
- * Generation](https://hf.co/papers/2202.00666) for more information
88
- */
89
- typical_p?: number;
90
- /**
91
- * Watermarking with [A Watermark for Large Language Models](https://hf.co/papers/2301.10226)
92
- */
93
- watermark?: boolean;
94
- [property: string]: unknown;
95
- }
96
-
97
- /**
98
- * Outputs for Text Generation inference
99
- */
100
- export interface TextGenerationOutput {
101
- /**
102
- * When enabled, details about the generation
103
- */
104
- details?: TextGenerationOutputDetails;
105
- /**
106
- * The generated text
107
- */
108
- generated_text: string;
109
- [property: string]: unknown;
110
- }
111
-
112
- /**
113
- * When enabled, details about the generation
114
- */
115
- export interface TextGenerationOutputDetails {
116
- /**
117
- * Details about additional sequences when best_of is provided
118
- */
119
- best_of_sequences?: TextGenerationOutputSequenceDetails[];
120
- /**
121
- * The reason why the generation was stopped.
122
- */
123
- finish_reason: TextGenerationFinishReason;
124
- /**
125
- * The number of generated tokens
126
- */
127
- generated_tokens: number;
128
- prefill: TextGenerationPrefillToken[];
129
- /**
130
- * The random seed used for generation
131
- */
132
- seed?: number;
133
- /**
134
- * The generated tokens and associated details
135
- */
136
- tokens: TextGenerationOutputToken[];
137
- /**
138
- * Most likely tokens
139
- */
140
- top_tokens?: Array<TextGenerationOutputToken[]>;
141
- [property: string]: unknown;
142
- }
143
-
144
- export interface TextGenerationOutputSequenceDetails {
145
- finish_reason: TextGenerationFinishReason;
146
- /**
147
- * The generated text
148
- */
149
- generated_text: string;
150
- /**
151
- * The number of generated tokens
152
- */
153
- generated_tokens: number;
154
- prefill: TextGenerationPrefillToken[];
155
- /**
156
- * The random seed used for generation
157
- */
158
- seed?: number;
159
- /**
160
- * The generated tokens and associated details
161
- */
162
- tokens: TextGenerationOutputToken[];
163
- /**
164
- * Most likely tokens
165
- */
166
- top_tokens?: Array<TextGenerationOutputToken[]>;
167
- [property: string]: unknown;
168
- }
169
-
170
- export interface TextGenerationPrefillToken {
171
- id: number;
172
- logprob: number;
173
- /**
174
- * The text associated with that token
175
- */
176
- text: string;
177
- [property: string]: unknown;
178
- }
179
-
180
- /**
181
- * Generated token.
182
- */
183
- export interface TextGenerationOutputToken {
184
- id: number;
185
- logprob?: number;
186
- /**
187
- * Whether or not that token is a special one
188
- */
189
- special: boolean;
190
- /**
191
- * The text associated with that token
192
- */
193
- text: string;
194
- [property: string]: unknown;
195
- }
196
-
197
- /**
198
- * The reason why the generation was stopped.
199
- *
200
- * length: The generated sequence reached the maximum allowed length
201
- *
202
- * eos_token: The model generated an end-of-sentence (EOS) token
203
- *
204
- * stop_sequence: One of the sequence in stop_sequences was generated
205
- */
206
- export type TextGenerationFinishReason = "length" | "eos_token" | "stop_sequence";
6
+ export type { TextGenerationInput, TextGenerationOutput };
207
7
 
208
8
  /**
209
9
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
@@ -1,6 +1,6 @@
1
+ import type { TextGenerationInput } from "@huggingface/tasks";
1
2
  import type { BaseArgs, Options } from "../../types";
2
3
  import { streamingRequest } from "../custom/streamingRequest";
3
- import type { TextGenerationInput } from "./textGeneration";
4
4
 
5
5
  export interface TextGenerationStreamToken {
6
6
  /** Token ID from the model tokenizer */
@@ -67,6 +67,7 @@ export interface TextGenerationStreamDetails {
67
67
  }
68
68
 
69
69
  export interface TextGenerationStreamOutput {
70
+ index?: number;
70
71
  /** Generated token, one at a time */
71
72
  token: TextGenerationStreamToken;
72
73
  /**
package/src/types.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import type { PipelineType } from "@huggingface/tasks";
2
+ import type { ChatCompletionInput } from "@huggingface/tasks";
2
3
 
3
4
  export interface Options {
4
5
  /**
@@ -32,7 +33,7 @@ export interface Options {
32
33
  signal?: AbortSignal;
33
34
 
34
35
  /**
35
- * Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all (which defaults to "same-origin" inside browsers).
36
+ * (Default: "same-origin"). String | Boolean. Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all.
36
37
  */
37
38
  includeCredentials?: string | boolean;
38
39
  }
@@ -47,15 +48,25 @@ export interface BaseArgs {
47
48
  */
48
49
  accessToken?: string;
49
50
  /**
50
- * The model to use. Can be a full URL for a dedicated inference endpoint.
51
+ * The model to use.
51
52
  *
52
53
  * If not specified, will call huggingface.co/api/tasks to get the default model for the task.
54
+ *
55
+ * /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
56
+ * Use the `endpointUrl` parameter instead.
53
57
  */
54
58
  model?: string;
59
+
60
+ /**
61
+ * The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
62
+ *
63
+ * If specified, will use this URL instead of the default one.
64
+ */
65
+ endpointUrl?: string;
55
66
  }
56
67
 
57
68
  export type RequestArgs = BaseArgs &
58
- ({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
69
+ ({ data: Blob | ArrayBuffer } | { inputs: unknown } | ChatCompletionInput) & {
59
70
  parameters?: Record<string, unknown>;
60
71
  accessToken?: string;
61
72
  };