@gradio/client 0.19.4 → 0.20.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/CHANGELOG.md +14 -0
  2. package/README.md +8 -1
  3. package/dist/client.d.ts +4 -0
  4. package/dist/client.d.ts.map +1 -1
  5. package/dist/constants.d.ts +4 -0
  6. package/dist/constants.d.ts.map +1 -1
  7. package/dist/helpers/api_info.d.ts +1 -0
  8. package/dist/helpers/api_info.d.ts.map +1 -1
  9. package/dist/helpers/data.d.ts.map +1 -1
  10. package/dist/helpers/init_helpers.d.ts +4 -1
  11. package/dist/helpers/init_helpers.d.ts.map +1 -1
  12. package/dist/index.js +229 -61
  13. package/dist/test/handlers.d.ts.map +1 -1
  14. package/dist/test/test_data.d.ts.map +1 -1
  15. package/dist/types.d.ts +6 -0
  16. package/dist/types.d.ts.map +1 -1
  17. package/dist/utils/duplicate.d.ts.map +1 -1
  18. package/dist/utils/post_data.d.ts.map +1 -1
  19. package/dist/utils/predict.d.ts.map +1 -1
  20. package/dist/utils/submit.d.ts.map +1 -1
  21. package/dist/utils/upload_files.d.ts.map +1 -1
  22. package/dist/utils/view_api.d.ts.map +1 -1
  23. package/package.json +1 -1
  24. package/src/client.ts +70 -28
  25. package/src/constants.ts +5 -0
  26. package/src/helpers/api_info.ts +44 -17
  27. package/src/helpers/data.ts +3 -2
  28. package/src/helpers/init_helpers.ts +98 -9
  29. package/src/test/api_info.test.ts +69 -4
  30. package/src/test/data.test.ts +4 -4
  31. package/src/test/handlers.ts +249 -2
  32. package/src/test/init.test.ts +2 -2
  33. package/src/test/init_helpers.test.ts +53 -1
  34. package/src/test/test_data.ts +3 -0
  35. package/src/types.ts +6 -0
  36. package/src/utils/duplicate.ts +27 -2
  37. package/src/utils/post_data.ts +2 -1
  38. package/src/utils/predict.ts +4 -2
  39. package/src/utils/submit.ts +37 -8
  40. package/src/utils/upload_files.ts +2 -1
  41. package/src/utils/view_api.ts +7 -4
@@ -1 +1 @@
1
- {"version":3,"file":"view_api.d.ts","sourceRoot":"","sources":["../../src/utils/view_api.ts"],"names":[],"mappings":"AAGA,OAAO,EAAE,MAAM,EAAE,MAAM,WAAW,CAAC;AAInC,wBAAsB,QAAQ,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,GAAG,CAAC,CA2DzD"}
1
+ {"version":3,"file":"view_api.d.ts","sourceRoot":"","sources":["../../src/utils/view_api.ts"],"names":[],"mappings":"AAGA,OAAO,EAAE,MAAM,EAAE,MAAM,WAAW,CAAC;AAInC,wBAAsB,QAAQ,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,GAAG,CAAC,CA8DzD"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@gradio/client",
3
- "version": "0.19.4",
3
+ "version": "0.20.0",
4
4
  "description": "Gradio API client",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
package/src/client.ts CHANGED
@@ -23,8 +23,10 @@ import { submit } from "./utils/submit";
23
23
  import { RE_SPACE_NAME, process_endpoint } from "./helpers/api_info";
24
24
  import {
25
25
  map_names_to_ids,
26
+ resolve_cookies,
26
27
  resolve_config,
27
- get_jwt
28
+ get_jwt,
29
+ parse_and_set_cookies
28
30
  } from "./helpers/init_helpers";
29
31
  import { check_space_status } from "./helpers/spaces";
30
32
  import { open_stream } from "./utils/stream";
@@ -47,6 +49,8 @@ export class Client {
47
49
  jwt: string | false = false;
48
50
  last_status: Record<string, Status["stage"]> = {};
49
51
 
52
+ private cookies: string | null = null;
53
+
50
54
  // streaming
51
55
  stream_status = { open: false };
52
56
  pending_stream_messages: Record<string, any[][]> = {};
@@ -56,7 +60,12 @@ export class Client {
56
60
  heartbeat_event: EventSource | null = null;
57
61
 
58
62
  fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
59
- return fetch(input, init);
63
+ const headers = new Headers(init?.headers || {});
64
+ if (this && this.cookies) {
65
+ headers.append("Cookie", this.cookies);
66
+ }
67
+
68
+ return fetch(input, { ...init, headers });
60
69
  }
61
70
 
62
71
  async stream(url: URL): Promise<EventSource> {
@@ -108,6 +117,7 @@ export class Client {
108
117
  ) => Promise<SubmitReturn>;
109
118
  open_stream: () => Promise<void>;
110
119
  private resolve_config: (endpoint: string) => Promise<Config | undefined>;
120
+ private resolve_cookies: () => Promise<void>;
111
121
  constructor(app_reference: string, options: ClientOptions = {}) {
112
122
  this.app_reference = app_reference;
113
123
  this.options = options;
@@ -120,6 +130,7 @@ export class Client {
120
130
  this.predict = predict.bind(this);
121
131
  this.open_stream = open_stream.bind(this);
122
132
  this.resolve_config = resolve_config.bind(this);
133
+ this.resolve_cookies = resolve_cookies.bind(this);
123
134
  this.upload = upload.bind(this);
124
135
  }
125
136
 
@@ -135,33 +146,56 @@ export class Client {
135
146
  }
136
147
 
137
148
  try {
138
- await this._resolve_config().then(async ({ config }) => {
139
- this.config = config;
140
-
141
- if (config.space_id && this.options.hf_token) {
142
- this.jwt = await get_jwt(config.space_id, this.options.hf_token);
143
- }
149
+ if (this.options.auth) {
150
+ await this.resolve_cookies();
151
+ }
144
152
 
145
- if (this.config && this.config.connect_heartbeat) {
146
- // connect to the heartbeat endpoint via GET request
147
- const heartbeat_url = new URL(
148
- `${this.config.root}/heartbeat/${this.session_hash}`
149
- );
153
+ await this._resolve_config().then(({ config }) =>
154
+ this._resolve_hearbeat(config)
155
+ );
156
+ } catch (e: any) {
157
+ throw Error(e);
158
+ }
150
159
 
151
- // if the jwt is available, add it to the query params
152
- if (this.jwt) {
153
- heartbeat_url.searchParams.set("__sign", this.jwt);
154
- }
160
+ this.api_info = await this.view_api();
161
+ this.api_map = map_names_to_ids(this.config?.dependencies || []);
162
+ }
155
163
 
156
- this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
164
+ async _resolve_hearbeat(_config: Config): Promise<void> {
165
+ if (_config) {
166
+ this.config = _config;
167
+ if (this.config && this.config.connect_heartbeat) {
168
+ if (this.config.space_id && this.options.hf_token) {
169
+ this.jwt = await get_jwt(
170
+ this.config.space_id,
171
+ this.options.hf_token,
172
+ this.cookies
173
+ );
157
174
  }
158
- });
159
- } catch (e) {
160
- throw Error(CONFIG_ERROR_MSG + (e as Error).message);
175
+ }
161
176
  }
162
177
 
163
- this.api_info = await this.view_api();
164
- this.api_map = map_names_to_ids(this.config?.dependencies || []);
178
+ if (_config.space_id && this.options.hf_token) {
179
+ this.jwt = await get_jwt(_config.space_id, this.options.hf_token);
180
+ }
181
+
182
+ if (this.config && this.config.connect_heartbeat) {
183
+ // connect to the heartbeat endpoint via GET request
184
+ const heartbeat_url = new URL(
185
+ `${this.config.root}/heartbeat/${this.session_hash}`
186
+ );
187
+
188
+ // if the jwt is available, add it to the query params
189
+ if (this.jwt) {
190
+ heartbeat_url.searchParams.set("__sign", this.jwt);
191
+ }
192
+
193
+ // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
194
+ if (!this.heartbeat_event)
195
+ this.heartbeat_event = await this.stream(heartbeat_url);
196
+ } else {
197
+ this.heartbeat_event?.close();
198
+ }
165
199
  }
166
200
 
167
201
  static async connect(
@@ -201,9 +235,8 @@ export class Client {
201
235
  }
202
236
 
203
237
  return this.config_success(config);
204
- } catch (e) {
205
- console.error(e);
206
- if (space_id) {
238
+ } catch (e: any) {
239
+ if (space_id && status_callback) {
207
240
  check_space_status(
208
241
  space_id,
209
242
  RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
@@ -217,6 +250,7 @@ export class Client {
217
250
  load_status: "error",
218
251
  detail: "NOT_FOUND"
219
252
  });
253
+ throw Error(e);
220
254
  }
221
255
  }
222
256
  }
@@ -246,6 +280,9 @@ export class Client {
246
280
  }
247
281
 
248
282
  async handle_space_success(status: SpaceStatus): Promise<Config | void> {
283
+ if (!this) {
284
+ throw new Error(CONFIG_ERROR_MSG);
285
+ }
249
286
  const { status_callback } = this.options;
250
287
  if (status_callback) status_callback(status);
251
288
  if (status.status === "running") {
@@ -259,7 +296,6 @@ export class Client {
259
296
 
260
297
  return _config as Config;
261
298
  } catch (e) {
262
- console.error(e);
263
299
  if (status_callback) {
264
300
  status_callback({
265
301
  status: "error",
@@ -268,6 +304,7 @@ export class Client {
268
304
  detail: "NOT_FOUND"
269
305
  });
270
306
  }
307
+ throw e;
271
308
  }
272
309
  }
273
310
  }
@@ -333,7 +370,8 @@ export class Client {
333
370
  const response = await this.fetch(`${root_url}/component_server/`, {
334
371
  method: "POST",
335
372
  body: body,
336
- headers
373
+ headers,
374
+ credentials: "include"
337
375
  });
338
376
 
339
377
  if (!response.ok) {
@@ -349,6 +387,10 @@ export class Client {
349
387
  }
350
388
  }
351
389
 
390
+ public set_cookies(raw_cookies: string): void {
391
+ this.cookies = parse_and_set_cookies(raw_cookies).join("; ");
392
+ }
393
+
352
394
  private prepare_return_obj(): client_return {
353
395
  return {
354
396
  config: this.config,
package/src/constants.ts CHANGED
@@ -25,3 +25,8 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. ";
25
25
  export const SPACE_STATUS_ERROR_MSG = "Could not get space status. ";
26
26
  export const API_INFO_ERROR_MSG = "Could not get API info. ";
27
27
  export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. ";
28
+ export const INVALID_URL_MSG = "Invalid URL. A full URL path is required.";
29
+ export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
30
+ export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
31
+ export const MISSING_CREDENTIALS_MSG =
32
+ "Login credentials are required to access this space.";
@@ -1,9 +1,14 @@
1
1
  import type { Status } from "../types";
2
- import { QUEUE_FULL_MSG } from "../constants";
2
+ import {
3
+ HOST_URL,
4
+ INVALID_URL_MSG,
5
+ QUEUE_FULL_MSG,
6
+ SPACE_METADATA_ERROR_MSG
7
+ } from "../constants";
3
8
  import type { ApiData, ApiInfo, Config, JsApiData } from "../types";
4
9
  import { determine_protocol } from "./init_helpers";
5
10
 
6
- export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/;
11
+ export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/;
7
12
  export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/;
8
13
 
9
14
  export async function process_endpoint(
@@ -20,12 +25,13 @@ export async function process_endpoint(
20
25
  headers.Authorization = `Bearer ${hf_token}`;
21
26
  }
22
27
 
23
- const _app_reference = app_reference.trim();
28
+ const _app_reference = app_reference.trim().replace(/\/$/, "");
24
29
 
25
30
  if (RE_SPACE_NAME.test(_app_reference)) {
31
+ // app_reference is a HF space name
26
32
  try {
27
33
  const res = await fetch(
28
- `https://huggingface.co/api/spaces/${_app_reference}/host`,
34
+ `https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`,
29
35
  { headers }
30
36
  );
31
37
 
@@ -36,13 +42,12 @@ export async function process_endpoint(
36
42
  ...determine_protocol(_host)
37
43
  };
38
44
  } catch (e) {
39
- throw new Error(
40
- "Space metadata could not be loaded. " + (e as Error).message
41
- );
45
+ throw new Error(SPACE_METADATA_ERROR_MSG);
42
46
  }
43
47
  }
44
48
 
45
49
  if (RE_SPACE_DOMAIN.test(_app_reference)) {
50
+ // app_reference is a direct HF space domain
46
51
  const { ws_protocol, http_protocol, host } =
47
52
  determine_protocol(_app_reference);
48
53
 
@@ -60,6 +65,18 @@ export async function process_endpoint(
60
65
  };
61
66
  }
62
67
 
68
+ export const join_urls = (...urls: string[]): string => {
69
+ try {
70
+ return urls.reduce((base_url: string, part: string) => {
71
+ base_url = base_url.replace(/\/+$/, "");
72
+ part = part.replace(/^\/+/, "");
73
+ return new URL(part, base_url + "/").toString();
74
+ });
75
+ } catch (e) {
76
+ throw new Error(INVALID_URL_MSG);
77
+ }
78
+ };
79
+
63
80
  export function transform_api_info(
64
81
  api_info: ApiInfo<ApiData>,
65
82
  config: Config,
@@ -77,27 +94,30 @@ export function transform_api_info(
77
94
  Object.entries(api_info[category]).forEach(
78
95
  ([endpoint, { parameters, returns }]) => {
79
96
  const dependencyIndex =
80
- config.dependencies.findIndex(
97
+ config.dependencies.find(
81
98
  (dep) =>
82
99
  dep.api_name === endpoint ||
83
100
  dep.api_name === endpoint.replace("/", "")
84
- ) ||
101
+ )?.id ||
85
102
  api_map[endpoint.replace("/", "")] ||
86
103
  -1;
87
104
 
88
105
  const dependencyTypes =
89
106
  dependencyIndex !== -1
90
- ? config.dependencies[dependencyIndex].types
107
+ ? config.dependencies.find((dep) => dep.id == dependencyIndex)
108
+ ?.types
91
109
  : { continuous: false, generator: false };
92
110
 
93
111
  if (
94
112
  dependencyIndex !== -1 &&
95
- config.dependencies[dependencyIndex]?.inputs?.length !==
96
- parameters.length
113
+ config.dependencies.find((dep) => dep.id == dependencyIndex)?.inputs
114
+ ?.length !== parameters.length
97
115
  ) {
98
- const components = config.dependencies[dependencyIndex].inputs.map(
99
- (input) => config.components.find((c) => c.id === input)?.type
100
- );
116
+ const components = config.dependencies
117
+ .find((dep) => dep.id == dependencyIndex)!
118
+ .inputs.map(
119
+ (input) => config.components.find((c) => c.id === input)?.type
120
+ );
101
121
 
102
122
  try {
103
123
  components.forEach((comp, idx) => {
@@ -115,7 +135,9 @@ export function transform_api_info(
115
135
  parameters.splice(idx, 0, new_param);
116
136
  }
117
137
  });
118
- } catch (e) {}
138
+ } catch (e) {
139
+ console.error(e);
140
+ }
119
141
  }
120
142
 
121
143
  const transform_type = (
@@ -201,6 +223,7 @@ export function get_description(
201
223
  return type?.description;
202
224
  }
203
225
 
226
+ /* eslint-disable complexity */
204
227
  export function handle_message(
205
228
  data: any,
206
229
  last_status: Status["stage"]
@@ -308,7 +331,10 @@ export function handle_message(
308
331
  message: !data.success ? data.output.error : undefined,
309
332
  stage: data.success ? "complete" : "error",
310
333
  code: data.code,
311
- progress_data: data.progress_data
334
+ progress_data: data.progress_data,
335
+ changed_state_ids: data.success
336
+ ? data.output.changed_state_ids
337
+ : undefined
312
338
  },
313
339
  data: data.success ? data.output : null
314
340
  };
@@ -330,6 +356,7 @@ export function handle_message(
330
356
 
331
357
  return { type: "none", status: { stage: "error", queue } };
332
358
  }
359
+ /* eslint-enable complexity */
333
360
 
334
361
  /**
335
362
  * Maps the provided `data` to the parameters defined by the `/info` endpoint response.
@@ -96,8 +96,9 @@ export async function walk_and_store_blobs(
96
96
  }
97
97
 
98
98
  export function skip_queue(id: number, config: Config): boolean {
99
- if (config?.dependencies?.[id]?.queue !== null) {
100
- return !config.dependencies[id].queue;
99
+ let fn_queue = config?.dependencies?.find((dep) => dep.id == id)?.queue;
100
+ if (fn_queue != null) {
101
+ return !fn_queue;
101
102
  }
102
103
  return !config.enable_queue;
103
104
  }
@@ -1,6 +1,15 @@
1
1
  import type { Config } from "../types";
2
- import { CONFIG_ERROR_MSG, CONFIG_URL } from "../constants";
2
+ import {
3
+ CONFIG_ERROR_MSG,
4
+ CONFIG_URL,
5
+ INVALID_CREDENTIALS_MSG,
6
+ LOGIN_URL,
7
+ MISSING_CREDENTIALS_MSG,
8
+ SPACE_METADATA_ERROR_MSG,
9
+ UNAUTHORIZED_MSG
10
+ } from "../constants";
3
11
  import { Client } from "..";
12
+ import { join_urls, process_endpoint } from "./api_info";
4
13
 
5
14
  /**
6
15
  * This function is used to resolve the URL for making requests when the app has a root path.
@@ -25,12 +34,14 @@ export function resolve_root(
25
34
 
26
35
  export async function get_jwt(
27
36
  space: string,
28
- token: `hf_${string}`
37
+ token: `hf_${string}`,
38
+ cookies?: string | null
29
39
  ): Promise<string | false> {
30
40
  try {
31
41
  const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
32
42
  headers: {
33
- Authorization: `Bearer ${token}`
43
+ Authorization: `Bearer ${token}`,
44
+ ...(cookies ? { Cookie: cookies } : {})
34
45
  }
35
46
  });
36
47
 
@@ -47,8 +58,8 @@ export function map_names_to_ids(
47
58
  ): Record<string, number> {
48
59
  let apis: Record<string, number> = {};
49
60
 
50
- fns.forEach(({ api_name }, i: number) => {
51
- if (api_name) apis[api_name] = i;
61
+ fns.forEach(({ api_name, id }) => {
62
+ if (api_name) apis[api_name] = id;
52
63
  });
53
64
  return apis;
54
65
  }
@@ -75,15 +86,24 @@ export async function resolve_config(
75
86
  config.root = config_root;
76
87
  return { ...config, path } as Config;
77
88
  } else if (endpoint) {
78
- const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, {
79
- headers
89
+ const config_url = join_urls(endpoint, CONFIG_URL);
90
+ const response = await this.fetch(config_url, {
91
+ headers,
92
+ credentials: "include"
80
93
  });
81
94
 
95
+ if (response?.status === 401 && !this.options.auth) {
96
+ throw new Error(MISSING_CREDENTIALS_MSG);
97
+ } else if (response?.status === 401 && this.options.auth) {
98
+ throw new Error(INVALID_CREDENTIALS_MSG);
99
+ }
82
100
  if (response?.status === 200) {
83
101
  let config = await response.json();
84
102
  config.path = config.path ?? "";
85
103
  config.root = endpoint;
86
104
  return config;
105
+ } else if (response?.status === 401) {
106
+ throw new Error(UNAUTHORIZED_MSG);
87
107
  }
88
108
  throw new Error(CONFIG_ERROR_MSG);
89
109
  }
@@ -91,13 +111,70 @@ export async function resolve_config(
91
111
  throw new Error(CONFIG_ERROR_MSG);
92
112
  }
93
113
 
114
+ export async function resolve_cookies(this: Client): Promise<void> {
115
+ const { http_protocol, host } = await process_endpoint(
116
+ this.app_reference,
117
+ this.options.hf_token
118
+ );
119
+
120
+ try {
121
+ if (this.options.auth) {
122
+ const cookie_header = await get_cookie_header(
123
+ http_protocol,
124
+ host,
125
+ this.options.auth,
126
+ this.fetch,
127
+ this.options.hf_token
128
+ );
129
+
130
+ if (cookie_header) this.set_cookies(cookie_header);
131
+ }
132
+ } catch (e: unknown) {
133
+ throw Error((e as Error).message);
134
+ }
135
+ }
136
+
137
+ // separating this from client-bound resolve_cookies so that it can be used in duplicate
138
+ export async function get_cookie_header(
139
+ http_protocol: string,
140
+ host: string,
141
+ auth: [string, string],
142
+ _fetch: typeof fetch,
143
+ hf_token?: `hf_${string}`
144
+ ): Promise<string | null> {
145
+ const formData = new FormData();
146
+ formData.append("username", auth?.[0]);
147
+ formData.append("password", auth?.[1]);
148
+
149
+ let headers: { Authorization?: string } = {};
150
+
151
+ if (hf_token) {
152
+ headers.Authorization = `Bearer ${hf_token}`;
153
+ }
154
+
155
+ const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, {
156
+ headers,
157
+ method: "POST",
158
+ body: formData,
159
+ credentials: "include"
160
+ });
161
+
162
+ if (res.status === 200) {
163
+ return res.headers.get("set-cookie");
164
+ } else if (res.status === 401) {
165
+ throw new Error(INVALID_CREDENTIALS_MSG);
166
+ } else {
167
+ throw new Error(SPACE_METADATA_ERROR_MSG);
168
+ }
169
+ }
170
+
94
171
  export function determine_protocol(endpoint: string): {
95
172
  ws_protocol: "ws" | "wss";
96
173
  http_protocol: "http:" | "https:";
97
174
  host: string;
98
175
  } {
99
176
  if (endpoint.startsWith("http")) {
100
- const { protocol, host } = new URL(endpoint);
177
+ const { protocol, host, pathname } = new URL(endpoint);
101
178
 
102
179
  if (host.endsWith("hf.space")) {
103
180
  return {
@@ -109,7 +186,7 @@ export function determine_protocol(endpoint: string): {
109
186
  return {
110
187
  ws_protocol: protocol === "https:" ? "wss" : "ws",
111
188
  http_protocol: protocol as "http:" | "https:",
112
- host
189
+ host: host + (pathname !== "/" ? pathname : "")
113
190
  };
114
191
  } else if (endpoint.startsWith("file:")) {
115
192
  // This case is only expected to be used for the Wasm mode (Gradio-lite),
@@ -128,3 +205,15 @@ export function determine_protocol(endpoint: string): {
128
205
  host: endpoint
129
206
  };
130
207
  }
208
+
209
+ export const parse_and_set_cookies = (cookie_header: string): string[] => {
210
+ let cookies: string[] = [];
211
+ const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/);
212
+ parts.forEach((cookie) => {
213
+ const [cookie_name, cookie_value] = cookie.split(";")[0].split("=");
214
+ if (cookie_name && cookie_value) {
215
+ cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`);
216
+ }
217
+ });
218
+ return cookies;
219
+ };
@@ -1,16 +1,22 @@
1
- import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
1
+ import {
2
+ INVALID_URL_MSG,
3
+ QUEUE_FULL_MSG,
4
+ SPACE_METADATA_ERROR_MSG
5
+ } from "../constants";
2
6
  import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest";
3
7
  import {
4
8
  handle_message,
5
9
  get_description,
6
10
  get_type,
7
11
  process_endpoint,
12
+ join_urls,
8
13
  map_data_to_params
9
14
  } from "../helpers/api_info";
10
15
  import { initialise_server } from "./server";
11
16
  import { transformed_api_info } from "./test_data";
12
17
 
13
18
  const server = initialise_server();
19
+ const IS_NODE = process.env.TEST_MODE === "node";
14
20
 
15
21
  beforeAll(() => server.listen());
16
22
  afterEach(() => server.resetHandlers());
@@ -435,9 +441,7 @@ describe("process_endpoint", () => {
435
441
  try {
436
442
  await process_endpoint(app_reference, hf_token);
437
443
  } catch (error) {
438
- expect(error.message).toEqual(
439
- SPACE_METADATA_ERROR_MSG + "Unexpected end of JSON input"
440
- );
444
+ expect(error.message).toEqual(SPACE_METADATA_ERROR_MSG);
441
445
  }
442
446
  });
443
447
 
@@ -455,6 +459,67 @@ describe("process_endpoint", () => {
455
459
  const result = await process_endpoint("hmb/hello_world");
456
460
  expect(result).toEqual(expected);
457
461
  });
462
+
463
+ it("processes local server URLs correctly", async () => {
464
+ const local_url = "http://localhost:7860/gradio";
465
+ const response_local_url = await process_endpoint(local_url);
466
+ expect(response_local_url.space_id).toBe(false);
467
+ expect(response_local_url.host).toBe("localhost:7860/gradio");
468
+
469
+ const local_url_2 = "http://localhost:7860/gradio/";
470
+ const response_local_url_2 = await process_endpoint(local_url_2);
471
+ expect(response_local_url_2.space_id).toBe(false);
472
+ expect(response_local_url_2.host).toBe("localhost:7860/gradio");
473
+ });
474
+
475
+ it("handles hugging face space references", async () => {
476
+ const space_id = "hmb/hello_world";
477
+
478
+ const response = await process_endpoint(space_id);
479
+ expect(response.space_id).toBe(space_id);
480
+ expect(response.host).toContain("hf.space");
481
+ });
482
+
483
+ it("handles hugging face domain URLs", async () => {
484
+ const app_reference = "https://hmb-hello-world.hf.space/";
485
+ const response = await process_endpoint(app_reference);
486
+ expect(response.space_id).toBe("hmb-hello-world");
487
+ expect(response.host).toBe("hmb-hello-world.hf.space");
488
+ });
489
+ });
490
+
491
+ describe("join_urls", () => {
492
+ it("joins URLs correctly", () => {
493
+ expect(join_urls("http://localhost:7860", "/gradio")).toBe(
494
+ "http://localhost:7860/gradio"
495
+ );
496
+ expect(join_urls("http://localhost:7860/", "/gradio")).toBe(
497
+ "http://localhost:7860/gradio"
498
+ );
499
+ expect(join_urls("http://localhost:7860", "app/", "/gradio")).toBe(
500
+ "http://localhost:7860/app/gradio"
501
+ );
502
+ expect(join_urls("http://localhost:7860/", "/app/", "/gradio/")).toBe(
503
+ "http://localhost:7860/app/gradio/"
504
+ );
505
+
506
+ expect(join_urls("http://127.0.0.1:8000/app", "/config")).toBe(
507
+ "http://127.0.0.1:8000/app/config"
508
+ );
509
+
510
+ expect(join_urls("http://127.0.0.1:8000/app/gradio", "/config")).toBe(
511
+ "http://127.0.0.1:8000/app/gradio/config"
512
+ );
513
+ });
514
+ it("throws an error when the URLs are not valid", () => {
515
+ expect(() => join_urls("localhost:7860", "/gradio")).toThrowError(
516
+ INVALID_URL_MSG
517
+ );
518
+
519
+ expect(() => join_urls("localhost:7860", "/gradio", "app")).toThrowError(
520
+ INVALID_URL_MSG
521
+ );
522
+ });
458
523
  });
459
524
 
460
525
  describe("map_data_params", () => {
@@ -195,7 +195,7 @@ describe("skip_queue", () => {
195
195
 
196
196
  it("should not skip queue when global and dependency queue is enabled", () => {
197
197
  config.enable_queue = true;
198
- config.dependencies[id].queue = true;
198
+ config.dependencies.find((dep) => dep.id === id)!.queue = true;
199
199
 
200
200
  const result = skip_queue(id, config_response);
201
201
 
@@ -204,7 +204,7 @@ describe("skip_queue", () => {
204
204
 
205
205
  it("should not skip queue when global queue is disabled and dependency queue is enabled", () => {
206
206
  config.enable_queue = false;
207
- config.dependencies[id].queue = true;
207
+ config.dependencies.find((dep) => dep.id === id)!.queue = true;
208
208
 
209
209
  const result = skip_queue(id, config_response);
210
210
 
@@ -213,7 +213,7 @@ describe("skip_queue", () => {
213
213
 
214
214
  it("should should skip queue when global queue and dependency queue is disabled", () => {
215
215
  config.enable_queue = false;
216
- config.dependencies[id].queue = false;
216
+ config.dependencies.find((dep) => dep.id === id)!.queue = false;
217
217
 
218
218
  const result = skip_queue(id, config_response);
219
219
 
@@ -222,7 +222,7 @@ describe("skip_queue", () => {
222
222
 
223
223
  it("should should skip queue when global queue is enabled and dependency queue is disabled", () => {
224
224
  config.enable_queue = true;
225
- config.dependencies[id].queue = false;
225
+ config.dependencies.find((dep) => dep.id === id)!.queue = false;
226
226
 
227
227
  const result = skip_queue(id, config_response);
228
228