@gradio/client 0.0.1 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/client.ts CHANGED
@@ -1,36 +1,58 @@
1
+ import semiver from "semiver";
2
+
1
3
  import {
2
4
  process_endpoint,
3
5
  RE_SPACE_NAME,
4
6
  map_names_to_ids,
5
- discussions_enabled
6
- } from "./utils";
7
+ discussions_enabled,
8
+ get_space_hardware,
9
+ set_space_hardware,
10
+ set_space_timeout,
11
+ hardware_types
12
+ } from "./utils.js";
7
13
 
8
14
  import type {
9
15
  EventType,
10
16
  EventListener,
11
17
  ListenerMap,
12
18
  Event,
13
- Config,
14
19
  Payload,
15
20
  PostResponse,
16
21
  UploadResponse,
17
22
  Status,
18
23
  SpaceStatus,
19
- SpaceStatusCallback
20
- } from "./types";
24
+ SpaceStatusCallback,
25
+ FileData
26
+ } from "./types.js";
27
+
28
+ import type { Config } from "./types.js";
21
29
 
22
30
  type event = <K extends EventType>(
23
31
  eventType: K,
24
32
  listener: EventListener<K>
25
- ) => ReturnType<predict>;
26
- type predict = (endpoint: string, payload: Payload) => {};
33
+ ) => SubmitReturn;
34
+ type predict = (
35
+ endpoint: string | number,
36
+ data?: unknown[],
37
+ event_data?: unknown
38
+ ) => Promise<unknown>;
27
39
 
28
40
  type client_return = {
29
41
  predict: predict;
30
42
  config: Config;
43
+ submit: (
44
+ endpoint: string | number,
45
+ data?: unknown[],
46
+ event_data?: unknown
47
+ ) => SubmitReturn;
48
+ view_api: (c?: Config) => Promise<Record<string, any>>;
49
+ };
50
+
51
+ type SubmitReturn = {
31
52
  on: event;
32
53
  off: event;
33
- cancel: (endpoint: string, fn_index?: number) => void;
54
+ cancel: () => Promise<void>;
55
+ destroy: () => void;
34
56
  };
35
57
 
36
58
  const QUEUE_FULL_MSG = "This application is too busy. Keep trying!";
@@ -38,13 +60,21 @@ const BROKEN_CONNECTION_MSG = "Connection errored out.";
38
60
 
39
61
  export async function post_data(
40
62
  url: string,
41
- body: unknown
63
+ body: unknown,
64
+ token?: `hf_${string}`
42
65
  ): Promise<[PostResponse, number]> {
66
+ const headers: {
67
+ Authorization?: string;
68
+ "Content-Type": "application/json";
69
+ } = { "Content-Type": "application/json" };
70
+ if (token) {
71
+ headers.Authorization = `Bearer ${token}`;
72
+ }
43
73
  try {
44
74
  var response = await fetch(url, {
45
75
  method: "POST",
46
76
  body: JSON.stringify(body),
47
- headers: { "Content-Type": "application/json" }
77
+ headers
48
78
  });
49
79
  } catch (e) {
50
80
  return [{ error: BROKEN_CONNECTION_MSG }, 500];
@@ -53,10 +83,20 @@ export async function post_data(
53
83
  return [output, response.status];
54
84
  }
55
85
 
86
+ export let NodeBlob;
87
+
56
88
  export async function upload_files(
57
89
  root: string,
58
- files: Array<File>
90
+ files: Array<File>,
91
+ token?: `hf_${string}`
59
92
  ): Promise<UploadResponse> {
93
+ const headers: {
94
+ Authorization?: string;
95
+ } = {};
96
+ if (token) {
97
+ headers.Authorization = `Bearer ${token}`;
98
+ }
99
+
60
100
  const formData = new FormData();
61
101
  files.forEach((file) => {
62
102
  formData.append("files", file);
@@ -64,7 +104,8 @@ export async function upload_files(
64
104
  try {
65
105
  var response = await fetch(`${root}/upload`, {
66
106
  method: "POST",
67
- body: formData
107
+ body: formData,
108
+ headers
68
109
  });
69
110
  } catch (e) {
70
111
  return { error: BROKEN_CONNECTION_MSG };
@@ -73,88 +114,154 @@ export async function upload_files(
73
114
  return { files: output };
74
115
  }
75
116
 
117
+ export async function duplicate(
118
+ app_reference: string,
119
+ options: {
120
+ hf_token: `hf_${string}`;
121
+ private?: boolean;
122
+ status_callback: SpaceStatusCallback;
123
+ hardware?: typeof hardware_types[number];
124
+ timeout?: number;
125
+ }
126
+ ) {
127
+ const { hf_token, private: _private, hardware, timeout } = options;
128
+
129
+ if (hardware && !hardware_types.includes(hardware)) {
130
+ throw new Error(
131
+ `Invalid hardware type provided. Valid types are: ${hardware_types
132
+ .map((v) => `"${v}"`)
133
+ .join(",")}.`
134
+ );
135
+ }
136
+ const headers = {
137
+ Authorization: `Bearer ${hf_token}`
138
+ };
139
+
140
+ const user = (
141
+ await (
142
+ await fetch(`https://huggingface.co/api/whoami-v2`, {
143
+ headers
144
+ })
145
+ ).json()
146
+ ).name;
147
+
148
+ const space_name = app_reference.split("/")[1];
149
+ const body: {
150
+ repository: string;
151
+ private?: boolean;
152
+ } = {
153
+ repository: `${user}/${space_name}`
154
+ };
155
+
156
+ if (_private) {
157
+ body.private = true;
158
+ }
159
+
160
+ try {
161
+ const response = await fetch(
162
+ `https://huggingface.co/api/spaces/${app_reference}/duplicate`,
163
+ {
164
+ method: "POST",
165
+ headers: { "Content-Type": "application/json", ...headers },
166
+ body: JSON.stringify(body)
167
+ }
168
+ );
169
+
170
+ if (response.status === 409) {
171
+ return client(`${user}/${space_name}`, options);
172
+ } else {
173
+ const duplicated_space = await response.json();
174
+
175
+ let original_hardware;
176
+
177
+ if (!hardware) {
178
+ original_hardware = await get_space_hardware(app_reference, hf_token);
179
+ }
180
+
181
+ const requested_hardware = hardware || original_hardware || "cpu-basic";
182
+ await set_space_hardware(
183
+ `${user}/${space_name}`,
184
+ requested_hardware,
185
+ hf_token
186
+ );
187
+
188
+ await set_space_timeout(
189
+ `${user}/${space_name}`,
190
+ timeout || 300,
191
+ hf_token
192
+ );
193
+ return client(duplicated_space.url, options);
194
+ }
195
+ } catch (e: any) {
196
+ throw new Error(e);
197
+ }
198
+ }
199
+
76
200
  export async function client(
77
201
  app_reference: string,
78
- space_status_callback?: SpaceStatusCallback
202
+ options: {
203
+ hf_token?: `hf_${string}`;
204
+ status_callback?: SpaceStatusCallback;
205
+ normalise_files?: boolean;
206
+ } = { normalise_files: true }
79
207
  ): Promise<client_return> {
80
- return new Promise(async (res, rej) => {
208
+ return new Promise(async (res) => {
209
+ const { status_callback, hf_token, normalise_files } = options;
81
210
  const return_obj = {
82
211
  predict,
83
- on,
84
- off,
85
- cancel
212
+ submit,
213
+ view_api
214
+ // duplicate
86
215
  };
87
216
 
88
- const listener_map: ListenerMap<EventType> = {};
217
+ let transform_files = normalise_files ?? true;
218
+ if (typeof window === "undefined" || !("WebSocket" in window)) {
219
+ const ws = await import("ws");
220
+ NodeBlob = (await import("node:buffer")).Blob;
221
+ //@ts-ignore
222
+ global.WebSocket = ws.WebSocket;
223
+ }
224
+
89
225
  const { ws_protocol, http_protocol, host, space_id } =
90
- await process_endpoint(app_reference);
226
+ await process_endpoint(app_reference, hf_token);
227
+
91
228
  const session_hash = Math.random().toString(36).substring(2);
92
- const ws_map = new Map<number, WebSocket>();
93
- const last_status: Record<string, Status["status"]> = {};
229
+ const last_status: Record<string, Status["stage"]> = {};
94
230
  let config: Config;
95
231
  let api_map: Record<string, number> = {};
96
232
 
97
- function config_success(_config: Config) {
233
+ let jwt: false | string = false;
234
+
235
+ if (hf_token && space_id) {
236
+ jwt = await get_jwt(space_id, hf_token);
237
+ }
238
+
239
+ async function config_success(_config: Config) {
98
240
  config = _config;
99
241
  api_map = map_names_to_ids(_config?.dependencies || []);
242
+ try {
243
+ api = await view_api(config);
244
+ } catch (e) {
245
+ console.error(`Could not get api details: ${e.message}`);
246
+ }
247
+
100
248
  return {
101
249
  config,
102
250
  ...return_obj
103
251
  };
104
252
  }
105
-
106
- function on<K extends EventType>(eventType: K, listener: EventListener<K>) {
107
- const narrowed_listener_map: ListenerMap<K> = listener_map;
108
- let listeners = narrowed_listener_map[eventType] || [];
109
- narrowed_listener_map[eventType] = listeners;
110
- listeners?.push(listener);
111
-
112
- return { ...return_obj, config };
113
- }
114
-
115
- function off<K extends EventType>(
116
- eventType: K,
117
- listener: EventListener<K>
118
- ) {
119
- const narrowed_listener_map: ListenerMap<K> = listener_map;
120
- let listeners = narrowed_listener_map[eventType] || [];
121
- listeners = listeners?.filter((l) => l !== listener);
122
- narrowed_listener_map[eventType] = listeners;
123
-
124
- return { ...return_obj, config };
125
- }
126
-
127
- function cancel(endpoint: string, fn_index?: number) {
128
- const _index =
129
- typeof fn_index === "number" ? fn_index : api_map[endpoint];
130
-
131
- fire_event({
132
- type: "status",
133
- endpoint,
134
- fn_index: _index,
135
- status: "complete",
136
- queue: false
137
- });
138
-
139
- ws_map.get(_index)?.close();
140
- }
141
-
142
- function fire_event<K extends EventType>(event: Event<K>) {
143
- const narrowed_listener_map: ListenerMap<K> = listener_map;
144
- let listeners = narrowed_listener_map[event.type] || [];
145
- listeners?.forEach((l) => l(event));
146
- }
147
-
253
+ let api;
148
254
  async function handle_space_sucess(status: SpaceStatus) {
149
- if (space_status_callback) space_status_callback(status);
255
+ if (status_callback) status_callback(status);
150
256
  if (status.status === "running")
151
257
  try {
152
- console.log(host);
153
- config = await resolve_config(`${http_protocol}//${host}`);
154
- res(config_success(config));
258
+ config = await resolve_config(`${http_protocol}//${host}`, hf_token);
259
+
260
+ const _config = await config_success(config);
261
+ res(_config);
155
262
  } catch (e) {
156
- if (space_status_callback) {
157
- space_status_callback({
263
+ if (status_callback) {
264
+ status_callback({
158
265
  status: "error",
159
266
  message: "Could not load this space.",
160
267
  load_status: "error",
@@ -165,8 +272,10 @@ export async function client(
165
272
  }
166
273
 
167
274
  try {
168
- config = await resolve_config(`${http_protocol}//${host}`);
169
- res(config_success(config));
275
+ config = await resolve_config(`${http_protocol}//${host}`, hf_token);
276
+
277
+ const _config = await config_success(config);
278
+ res(_config);
170
279
  } catch (e) {
171
280
  if (space_id) {
172
281
  check_space_status(
@@ -175,8 +284,8 @@ export async function client(
175
284
  handle_space_sucess
176
285
  );
177
286
  } else {
178
- if (space_status_callback)
179
- space_status_callback({
287
+ if (status_callback)
288
+ status_callback({
180
289
  status: "error",
181
290
  message: "Could not load this space.",
182
291
  load_status: "error",
@@ -184,96 +293,177 @@ export async function client(
184
293
  });
185
294
  }
186
295
  }
187
- function make_predict(endpoint: string, payload: Payload) {
296
+
297
+ /**
298
+ * Run a prediction.
299
+ * @param endpoint - The prediction endpoint to use.
300
+ * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
301
+ * @return Returns the data for the prediction or an error message.
302
+ */
303
+ function predict(endpoint: string, data: unknown[], event_data?: unknown) {
304
+ let data_returned = false;
305
+ let status_complete = false;
188
306
  return new Promise((res, rej) => {
307
+ const app = submit(endpoint, data, event_data);
308
+
309
+ app
310
+ .on("data", (d) => {
311
+ data_returned = true;
312
+ if (status_complete) {
313
+ app.destroy();
314
+ }
315
+ res(d);
316
+ })
317
+ .on("status", (status) => {
318
+ if (status.stage === "error") rej(status);
319
+ if (status.stage === "complete" && data_returned) {
320
+ app.destroy();
321
+ }
322
+ if (status.stage === "complete") {
323
+ status_complete = true;
324
+ }
325
+ });
326
+ });
327
+ }
328
+
329
+ function submit(
330
+ endpoint: string | number,
331
+ data: unknown[],
332
+ event_data?: unknown
333
+ ): SubmitReturn {
334
+ let fn_index: number;
335
+ let api_info;
336
+
337
+ if (typeof endpoint === "number") {
338
+ fn_index = endpoint;
339
+ api_info = api.unnamed_endpoints[fn_index];
340
+ } else {
189
341
  const trimmed_endpoint = endpoint.replace(/^\//, "");
190
- let fn_index =
191
- typeof payload.fn_index === "number"
192
- ? payload.fn_index
193
- : api_map[trimmed_endpoint];
194
342
 
343
+ fn_index = api_map[trimmed_endpoint];
344
+ api_info = api.named_endpoints[endpoint.trim()];
345
+ }
346
+
347
+ if (typeof fn_index !== "number") {
348
+ throw new Error(
349
+ "There is no endpoint matching that name of fn_index matching that number."
350
+ );
351
+ }
352
+
353
+ let websocket: WebSocket;
354
+
355
+ const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
356
+ let payload: Payload;
357
+ let complete: false | Record<string, any> = false;
358
+ const listener_map: ListenerMap<EventType> = {};
359
+
360
+ //@ts-ignore
361
+ handle_blob(
362
+ `${http_protocol}//${host + config.path}`,
363
+ data,
364
+ api_info,
365
+ hf_token
366
+ ).then((_payload) => {
367
+ payload = { data: _payload || [], event_data, fn_index };
195
368
  if (skip_queue(fn_index, config)) {
196
369
  fire_event({
197
370
  type: "status",
198
- endpoint,
199
- status: "pending",
371
+ endpoint: _endpoint,
372
+ stage: "pending",
200
373
  queue: false,
201
- fn_index
374
+ fn_index,
375
+ time: new Date()
202
376
  });
203
377
 
204
378
  post_data(
205
379
  `${http_protocol}//${host + config.path}/run${
206
- endpoint.startsWith("/") ? endpoint : `/${endpoint}`
380
+ _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
207
381
  }`,
208
382
  {
209
383
  ...payload,
210
384
  session_hash
211
- }
385
+ },
386
+ hf_token
212
387
  )
213
388
  .then(([output, status_code]) => {
389
+ const data = transform_files
390
+ ? transform_output(
391
+ output.data,
392
+ api_info,
393
+ config.root,
394
+ config.root_url
395
+ )
396
+ : output.data;
214
397
  if (status_code == 200) {
215
398
  fire_event({
216
- type: "status",
217
- endpoint,
399
+ type: "data",
400
+ endpoint: _endpoint,
218
401
  fn_index,
219
- status: "complete",
220
- eta: output.average_duration,
221
- queue: false
402
+ data: output.data,
403
+ time: new Date()
222
404
  });
223
405
 
224
406
  fire_event({
225
- type: "data",
226
- endpoint,
407
+ type: "status",
408
+ endpoint: _endpoint,
227
409
  fn_index,
228
- data: output.data
410
+ stage: "complete",
411
+ eta: output.average_duration,
412
+ queue: false,
413
+ time: new Date()
229
414
  });
230
415
  } else {
231
416
  fire_event({
232
417
  type: "status",
233
- status: "error",
234
- endpoint,
418
+ stage: "error",
419
+ endpoint: _endpoint,
235
420
  fn_index,
236
421
  message: output.error,
237
- queue: false
422
+ queue: false,
423
+ time: new Date()
238
424
  });
239
425
  }
240
426
  })
241
427
  .catch((e) => {
242
428
  fire_event({
243
429
  type: "status",
244
- status: "error",
430
+ stage: "error",
245
431
  message: e.message,
246
- endpoint,
432
+ endpoint: _endpoint,
247
433
  fn_index,
248
- queue: false
434
+ queue: false,
435
+ time: new Date()
249
436
  });
250
- throw new Error(e.message);
251
437
  });
252
438
  } else {
253
439
  fire_event({
254
440
  type: "status",
255
- status: "pending",
441
+ stage: "pending",
256
442
  queue: true,
257
- endpoint,
258
- fn_index
443
+ endpoint: _endpoint,
444
+ fn_index,
445
+ time: new Date()
259
446
  });
260
447
 
261
- const ws_endpoint = `${ws_protocol}://${
262
- host + config.path
263
- }/queue/join`;
448
+ let url = new URL(`${ws_protocol}://${host}${config.path}
449
+ /queue/join`);
264
450
 
265
- const websocket = new WebSocket(ws_endpoint);
451
+ if (jwt) {
452
+ url.searchParams.set("__sign", jwt);
453
+ }
454
+
455
+ websocket = new WebSocket(url);
266
456
 
267
- ws_map.set(fn_index, websocket);
268
457
  websocket.onclose = (evt) => {
269
458
  if (!evt.wasClean) {
270
459
  fire_event({
271
460
  type: "status",
272
- status: "error",
461
+ stage: "error",
273
462
  message: BROKEN_CONNECTION_MSG,
274
463
  queue: true,
275
- endpoint,
276
- fn_index
464
+ endpoint: _endpoint,
465
+ fn_index,
466
+ time: new Date()
277
467
  });
278
468
  }
279
469
  };
@@ -285,12 +475,17 @@ export async function client(
285
475
  last_status[fn_index]
286
476
  );
287
477
 
288
- if (type === "update" && status) {
478
+ if (type === "update" && status && !complete) {
289
479
  // call 'status' listeners
290
- fire_event({ type: "status", endpoint, fn_index, ...status });
291
- if (status.status === "error") {
480
+ fire_event({
481
+ type: "status",
482
+ endpoint: _endpoint,
483
+ fn_index,
484
+ time: new Date(),
485
+ ...status
486
+ });
487
+ if (status.stage === "error") {
292
488
  websocket.close();
293
- rej(status);
294
489
  }
295
490
  } else if (type === "hash") {
296
491
  websocket.send(JSON.stringify({ fn_index, session_hash }));
@@ -298,67 +493,578 @@ export async function client(
298
493
  } else if (type === "data") {
299
494
  websocket.send(JSON.stringify({ ...payload, session_hash }));
300
495
  } else if (type === "complete") {
301
- fire_event({
302
- type: "status",
303
- ...status,
304
- status: status?.status!,
305
- queue: true,
306
- endpoint,
307
- fn_index
308
- });
309
- websocket.close();
496
+ complete = status;
310
497
  } else if (type === "generating") {
311
498
  fire_event({
312
499
  type: "status",
500
+ time: new Date(),
313
501
  ...status,
314
- status: status?.status!,
502
+ stage: status?.stage!,
315
503
  queue: true,
316
- endpoint,
504
+ endpoint: _endpoint,
317
505
  fn_index
318
506
  });
319
507
  }
320
508
  if (data) {
321
509
  fire_event({
322
510
  type: "data",
323
- data: data.data,
324
- endpoint,
511
+ time: new Date(),
512
+ data: transform_files
513
+ ? transform_output(
514
+ data.data,
515
+ api_info,
516
+ config.root,
517
+ config.root_url
518
+ )
519
+ : data.data,
520
+ endpoint: _endpoint,
325
521
  fn_index
326
522
  });
327
- res({ data: data.data });
523
+
524
+ if (complete) {
525
+ fire_event({
526
+ type: "status",
527
+ time: new Date(),
528
+ ...complete,
529
+ stage: status?.stage!,
530
+ queue: true,
531
+ endpoint: _endpoint,
532
+ fn_index
533
+ });
534
+ websocket.close();
535
+ }
328
536
  }
329
537
  };
538
+
539
+ // different ws contract for gradio versions older than 3.6.0
540
+ //@ts-ignore
541
+ if (semiver(config.version || "2.0.0", "3.6") < 0) {
542
+ addEventListener("open", () =>
543
+ websocket.send(JSON.stringify({ hash: session_hash }))
544
+ );
545
+ }
546
+ }
547
+ });
548
+
549
+ function fire_event<K extends EventType>(event: Event<K>) {
550
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
551
+ let listeners = narrowed_listener_map[event.type] || [];
552
+ listeners?.forEach((l) => l(event));
553
+ }
554
+
555
+ function on<K extends EventType>(
556
+ eventType: K,
557
+ listener: EventListener<K>
558
+ ) {
559
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
560
+ let listeners = narrowed_listener_map[eventType] || [];
561
+ narrowed_listener_map[eventType] = listeners;
562
+ listeners?.push(listener);
563
+
564
+ return { on, off, cancel, destroy };
565
+ }
566
+
567
+ function off<K extends EventType>(
568
+ eventType: K,
569
+ listener: EventListener<K>
570
+ ) {
571
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
572
+ let listeners = narrowed_listener_map[eventType] || [];
573
+ listeners = listeners?.filter((l) => l !== listener);
574
+ narrowed_listener_map[eventType] = listeners;
575
+
576
+ return { on, off, cancel, destroy };
577
+ }
578
+
579
+ async function cancel() {
580
+ const _status: Status = {
581
+ stage: "complete",
582
+ queue: false,
583
+ time: new Date()
584
+ };
585
+ complete = _status;
586
+ fire_event({
587
+ ..._status,
588
+ type: "status",
589
+ endpoint: _endpoint,
590
+ fn_index: fn_index
591
+ });
592
+
593
+ if (websocket && websocket.readyState === 0) {
594
+ websocket.addEventListener("open", () => {
595
+ websocket.close();
596
+ });
597
+ } else {
598
+ websocket.close();
330
599
  }
600
+
601
+ try {
602
+ await fetch(`${http_protocol}//${host + config.path}/reset`, {
603
+ headers: { "Content-Type": "application/json" },
604
+ method: "POST",
605
+ body: JSON.stringify({ fn_index, session_hash })
606
+ });
607
+ } catch (e) {
608
+ console.warn(
609
+ "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
610
+ );
611
+ }
612
+ }
613
+
614
+ function destroy() {
615
+ for (const event_type in listener_map) {
616
+ listener_map[event_type as "data" | "status"].forEach((fn) => {
617
+ off(event_type as "data" | "status", fn);
618
+ });
619
+ }
620
+ }
621
+
622
+ return {
623
+ on,
624
+ off,
625
+ cancel,
626
+ destroy
627
+ };
628
+ }
629
+
630
+ async function view_api(
631
+ config?: Config
632
+ ): Promise<ApiInfo<JsApiData> | [{ error: string }, 500]> {
633
+ if (api) return api;
634
+
635
+ const headers: {
636
+ Authorization?: string;
637
+ "Content-Type": "application/json";
638
+ } = { "Content-Type": "application/json" };
639
+ if (hf_token) {
640
+ headers.Authorization = `Bearer ${hf_token}`;
641
+ }
642
+ try {
643
+ let response: Response;
644
+ // @ts-ignore
645
+ if (semiver(config.version || "2.0.0", "3.30") < 0) {
646
+ response = await fetch(
647
+ "https://gradio-space-api-fetcher-v2.hf.space/api",
648
+ {
649
+ method: "POST",
650
+ body: JSON.stringify({
651
+ serialize: false,
652
+ config: JSON.stringify(config)
653
+ }),
654
+ headers
655
+ }
656
+ );
657
+ } else {
658
+ response = await fetch(`${http_protocol}//${host}/info`, {
659
+ headers
660
+ });
661
+ }
662
+
663
+ let api_info = (await response.json()) as
664
+ | ApiInfo<ApiData>
665
+ | { api: ApiInfo<ApiData> };
666
+ if ("api" in api_info) {
667
+ api_info = api_info.api;
668
+ }
669
+
670
+ if (
671
+ api_info.named_endpoints["/predict"] &&
672
+ !api_info.unnamed_endpoints["0"]
673
+ ) {
674
+ api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"];
675
+ }
676
+
677
+ const x = transform_api_info(api_info, config, api_map);
678
+ return x;
679
+ } catch (e) {
680
+ return [{ error: BROKEN_CONNECTION_MSG }, 500];
681
+ }
682
+ }
683
+ });
684
+ }
685
+
686
+ function transform_output(
687
+ data: any[],
688
+ api_info: any,
689
+ root_url: string,
690
+ remote_url?: string
691
+ ): unknown[] {
692
+ let transformed_data = data.map((d, i) => {
693
+ if (api_info.returns?.[i]?.component === "File") {
694
+ return normalise_file(d, root_url, remote_url);
695
+ } else if (api_info.returns?.[i]?.component === "Gallery") {
696
+ return d.map((img) => {
697
+ return Array.isArray(img)
698
+ ? [normalise_file(img[0], root_url, remote_url), img[1]]
699
+ : [normalise_file(img, root_url, remote_url), null];
331
700
  });
701
+ } else if (typeof d === "object" && d.is_file) {
702
+ return normalise_file(d, root_url, remote_url);
703
+ } else {
704
+ return d;
332
705
  }
706
+ });
333
707
 
334
- /**
335
- * Run a prediction.
336
- * @param endpoint - The prediction endpoint to use.
337
- * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
338
- * @return Returns the data for the prediction or an error message.
339
- */
340
- function predict(endpoint: string, payload: Payload) {
341
- return make_predict(endpoint, payload);
708
+ return transformed_data;
709
+ }
710
+
711
+ export function normalise_file(
712
+ file: Array<FileData> | FileData | string | null,
713
+ root: string,
714
+ root_url: string | null
715
+ ): Array<FileData> | FileData | null {
716
+ if (file == null) return null;
717
+ if (typeof file === "string") {
718
+ return {
719
+ name: "file_data",
720
+ data: file
721
+ };
722
+ } else if (Array.isArray(file)) {
723
+ const normalized_file: Array<FileData | null> = [];
724
+
725
+ for (const x of file) {
726
+ if (x === null) {
727
+ normalized_file.push(null);
728
+ } else {
729
+ //@ts-ignore
730
+ normalized_file.push(normalise_file(x, root, root_url));
731
+ }
342
732
  }
733
+
734
+ return normalized_file as Array<FileData>;
735
+ } else if (file.is_file) {
736
+ if (!root_url) {
737
+ file.data = root + "/file=" + file.name;
738
+ } else {
739
+ file.data = "/proxy=" + root_url + "/file=" + file.name;
740
+ }
741
+ }
742
+ return file;
743
+ }
744
+
745
+ interface ApiData {
746
+ label: string;
747
+ type: {
748
+ type: any;
749
+ description: string;
750
+ };
751
+ component: string;
752
+ example_input?: any;
753
+ }
754
+
755
+ interface JsApiData {
756
+ label: string;
757
+ type: string;
758
+ component: string;
759
+ example_input: any;
760
+ }
761
+
762
+ interface EndpointInfo<T extends ApiData | JsApiData> {
763
+ parameters: T[];
764
+ returns: T[];
765
+ }
766
+ interface ApiInfo<T extends ApiData | JsApiData> {
767
+ named_endpoints: {
768
+ [key: string]: EndpointInfo<T>;
769
+ };
770
+ unnamed_endpoints: {
771
+ [key: string]: EndpointInfo<T>;
772
+ };
773
+ }
774
+
775
+ function get_type(
776
+ type: { [key: string]: any },
777
+ component: string,
778
+ serializer: string,
779
+ signature_type: "return" | "parameter"
780
+ ) {
781
+ switch (type.type) {
782
+ case "string":
783
+ return "string";
784
+ case "boolean":
785
+ return "boolean";
786
+ case "number":
787
+ return "number";
788
+ }
789
+
790
+ if (
791
+ serializer === "JSONSerializable" ||
792
+ serializer === "StringSerializable"
793
+ ) {
794
+ return "any";
795
+ } else if (serializer === "ListStringSerializable") {
796
+ return "string[]";
797
+ } else if (component === "Image") {
798
+ return signature_type === "parameter" ? "Blob | File | Buffer" : "string";
799
+ } else if (serializer === "FileSerializable") {
800
+ if (type?.type === "array") {
801
+ return signature_type === "parameter"
802
+ ? "(Blob | File | Buffer)[]"
803
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`;
804
+ } else {
805
+ return signature_type === "parameter"
806
+ ? "Blob | File | Buffer"
807
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`;
808
+ }
809
+ } else if (serializer === "GallerySerializable") {
810
+ return signature_type === "parameter"
811
+ ? "[(Blob | File | Buffer), (string | null)][]"
812
+ : `[{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}, (string | null))][]`;
813
+ }
814
+ }
815
+
816
+ function get_description(
817
+ type: { type: any; description: string },
818
+ serializer: string
819
+ ) {
820
+ if (serializer === "GallerySerializable") {
821
+ return "array of [file, label] tuples";
822
+ } else if (serializer === "ListStringSerializable") {
823
+ return "array of strings";
824
+ } else if (serializer === "FileSerializable") {
825
+ return "array of files or single file";
826
+ } else {
827
+ return type.description;
828
+ }
829
+ }
830
+
831
+ function transform_api_info(
832
+ api_info: ApiInfo<ApiData>,
833
+ config: Config,
834
+ api_map: Record<string, number>
835
+ ): ApiInfo<JsApiData> {
836
+ const new_data = {
837
+ named_endpoints: {},
838
+ unnamed_endpoints: {}
839
+ };
840
+ for (const key in api_info) {
841
+ const cat = api_info[key];
842
+
843
+ for (const endpoint in cat) {
844
+ const dep_index = config.dependencies[endpoint]
845
+ ? endpoint
846
+ : api_map[endpoint.replace("/", "")];
847
+
848
+ const info = cat[endpoint];
849
+ new_data[key][endpoint] = {};
850
+ new_data[key][endpoint].parameters = {};
851
+ new_data[key][endpoint].returns = {};
852
+ new_data[key][endpoint].type = config.dependencies[dep_index].types;
853
+ new_data[key][endpoint].parameters = info.parameters.map(
854
+ ({ label, component, type, serializer }) => ({
855
+ label,
856
+ component,
857
+ type: get_type(type, component, serializer, "parameter"),
858
+ description: get_description(type, serializer)
859
+ })
860
+ );
861
+
862
+ new_data[key][endpoint].returns = info.returns.map(
863
+ ({ label, component, type, serializer }) => ({
864
+ label,
865
+ component,
866
+ type: get_type(type, component, serializer, "return"),
867
+ description: get_description(type, serializer)
868
+ })
869
+ );
870
+ }
871
+ }
872
+
873
+ return new_data;
874
+ }
875
+
876
+ async function get_jwt(
877
+ space: string,
878
+ token: `hf_${string}`
879
+ ): Promise<string | false> {
880
+ try {
881
+ const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
882
+ headers: {
883
+ Authorization: `Bearer ${token}`
884
+ }
885
+ });
886
+
887
+ const jwt = (await r.json()).token;
888
+
889
+ return jwt || false;
890
+ } catch (e) {
891
+ console.error(e);
892
+ return false;
893
+ }
894
+ }
895
+
896
+ export async function handle_blob(
897
+ endpoint: string,
898
+ data: unknown[],
899
+ api_info,
900
+ token?: `hf_${string}`
901
+ ): Promise<unknown[]> {
902
+ const blob_refs = await walk_and_store_blobs(
903
+ data,
904
+ undefined,
905
+ [],
906
+ true,
907
+ api_info
908
+ );
909
+
910
+ return new Promise((res) => {
911
+ Promise.all(
912
+ blob_refs.map(async ({ path, blob, data, type }) => {
913
+ if (blob) {
914
+ const file_url = (await upload_files(endpoint, [blob], token))
915
+ .files[0];
916
+ return { path, file_url, type };
917
+ } else {
918
+ return { path, base64: data, type };
919
+ }
920
+ })
921
+ )
922
+ .then((r) => {
923
+ r.forEach(({ path, file_url, base64, type }) => {
924
+ if (base64) {
925
+ update_object(data, base64, path);
926
+ } else if (type === "Gallery") {
927
+ update_object(data, file_url, path);
928
+ } else if (file_url) {
929
+ const o = {
930
+ is_file: true,
931
+ name: `${file_url}`,
932
+ data: null
933
+ // orig_name: "file.csv"
934
+ };
935
+ update_object(data, o, path);
936
+ }
937
+ });
938
+
939
+ res(data);
940
+ })
941
+ .catch(console.log);
942
+ });
943
+ }
944
+
945
+ function update_object(object, newValue, stack) {
946
+ while (stack.length > 1) {
947
+ object = object[stack.shift()];
948
+ }
949
+
950
+ object[stack.shift()] = newValue;
951
+ }
952
+
953
+ export async function walk_and_store_blobs(
954
+ param,
955
+ type = undefined,
956
+ path = [],
957
+ root = false,
958
+ api_info = undefined
959
+ ) {
960
+ if (Array.isArray(param)) {
961
+ let blob_refs = [];
962
+
963
+ await Promise.all(
964
+ param.map(async (v, i) => {
965
+ let new_path = path.slice();
966
+ new_path.push(i);
967
+
968
+ const array_refs = await walk_and_store_blobs(
969
+ param[i],
970
+ root ? api_info?.parameters[i]?.component || undefined : type,
971
+ new_path,
972
+ false,
973
+ api_info
974
+ );
975
+
976
+ blob_refs = blob_refs.concat(array_refs);
977
+ })
978
+ );
979
+
980
+ return blob_refs;
981
+ } else if (globalThis.Buffer && param instanceof globalThis.Buffer) {
982
+ const is_image = type === "Image";
983
+ return [
984
+ {
985
+ path: path,
986
+ blob: is_image ? false : new NodeBlob([param]),
987
+ data: is_image ? `${param.toString("base64")}` : false,
988
+ type
989
+ }
990
+ ];
991
+ } else if (
992
+ param instanceof Blob ||
993
+ (typeof window !== "undefined" && param instanceof File)
994
+ ) {
995
+ if (type === "Image") {
996
+ let data;
997
+
998
+ if (typeof window !== "undefined") {
999
+ // browser
1000
+ data = await image_to_data_uri(param);
1001
+ } else {
1002
+ const buffer = await param.arrayBuffer();
1003
+ data = Buffer.from(buffer).toString("base64");
1004
+ }
1005
+
1006
+ return [{ path, data, type }];
1007
+ } else {
1008
+ return [{ path: path, blob: param, type }];
1009
+ }
1010
+ } else if (typeof param === "object") {
1011
+ let blob_refs = [];
1012
+ for (let key in param) {
1013
+ if (param.hasOwnProperty(key)) {
1014
+ let new_path = path.slice();
1015
+ new_path.push(key);
1016
+ blob_refs = blob_refs.concat(
1017
+ await walk_and_store_blobs(
1018
+ param[key],
1019
+ undefined,
1020
+ new_path,
1021
+ false,
1022
+ api_info
1023
+ )
1024
+ );
1025
+ }
1026
+ }
1027
+ return blob_refs;
1028
+ } else {
1029
+ return [];
1030
+ }
1031
+ }
1032
+
1033
+ function image_to_data_uri(blob: Blob) {
1034
+ return new Promise((resolve, _) => {
1035
+ const reader = new FileReader();
1036
+ reader.onloadend = () => resolve(reader.result);
1037
+ reader.readAsDataURL(blob);
343
1038
  });
344
1039
  }
345
1040
 
346
1041
  function skip_queue(id: number, config: Config) {
347
1042
  return (
348
- !(config?.dependencies?.[id].queue === null
1043
+ !(config?.dependencies?.[id]?.queue === null
349
1044
  ? config.enable_queue
350
- : config?.dependencies?.[id].queue) || false
1045
+ : config?.dependencies?.[id]?.queue) || false
351
1046
  );
352
1047
  }
353
1048
 
354
- async function resolve_config(endpoint?: string): Promise<Config> {
355
- if (window.gradio_config && location.origin !== "http://localhost:9876") {
1049
+ async function resolve_config(
1050
+ endpoint?: string,
1051
+ token?: `hf_${string}`
1052
+ ): Promise<Config> {
1053
+ const headers: { Authorization?: string } = {};
1054
+ if (token) {
1055
+ headers.Authorization = `Bearer ${token}`;
1056
+ }
1057
+ if (
1058
+ typeof window !== "undefined" &&
1059
+ window.gradio_config &&
1060
+ location.origin !== "http://localhost:9876"
1061
+ ) {
356
1062
  const path = window.gradio_config.root;
357
1063
  const config = window.gradio_config;
358
1064
  config.root = endpoint + config.root;
359
1065
  return { ...config, path: path };
360
1066
  } else if (endpoint) {
361
- let response = await fetch(`${endpoint}/config`);
1067
+ let response = await fetch(`${endpoint}/config`, { headers });
362
1068
 
363
1069
  if (response.status === 200) {
364
1070
  const config = await response.json();
@@ -376,7 +1082,7 @@ async function resolve_config(endpoint?: string): Promise<Config> {
376
1082
  async function check_space_status(
377
1083
  id: string,
378
1084
  type: "subdomain" | "space_name",
379
- space_status_callback: SpaceStatusCallback
1085
+ status_callback: SpaceStatusCallback
380
1086
  ) {
381
1087
  let endpoint =
382
1088
  type === "subdomain"
@@ -392,7 +1098,7 @@ async function check_space_status(
392
1098
  }
393
1099
  response = await response.json();
394
1100
  } catch (e) {
395
- space_status_callback({
1101
+ status_callback({
396
1102
  status: "error",
397
1103
  load_status: "error",
398
1104
  message: "Could not get space status",
@@ -410,7 +1116,7 @@ async function check_space_status(
410
1116
  switch (stage) {
411
1117
  case "STOPPED":
412
1118
  case "SLEEPING":
413
- space_status_callback({
1119
+ status_callback({
414
1120
  status: "sleeping",
415
1121
  load_status: "pending",
416
1122
  message: "Space is asleep. Waking it up...",
@@ -418,13 +1124,13 @@ async function check_space_status(
418
1124
  });
419
1125
 
420
1126
  setTimeout(() => {
421
- check_space_status(id, type, space_status_callback);
1127
+ check_space_status(id, type, status_callback);
422
1128
  }, 1000);
423
1129
  break;
424
1130
  // poll for status
425
1131
  case "RUNNING":
426
1132
  case "RUNNING_BUILDING":
427
- space_status_callback({
1133
+ status_callback({
428
1134
  status: "running",
429
1135
  load_status: "complete",
430
1136
  message: "",
@@ -434,7 +1140,7 @@ async function check_space_status(
434
1140
  // launch
435
1141
  break;
436
1142
  case "BUILDING":
437
- space_status_callback({
1143
+ status_callback({
438
1144
  status: "building",
439
1145
  load_status: "pending",
440
1146
  message: "Space is building...",
@@ -442,11 +1148,11 @@ async function check_space_status(
442
1148
  });
443
1149
 
444
1150
  setTimeout(() => {
445
- check_space_status(id, type, space_status_callback);
1151
+ check_space_status(id, type, status_callback);
446
1152
  }, 1000);
447
1153
  break;
448
1154
  default:
449
- space_status_callback({
1155
+ status_callback({
450
1156
  status: "space_error",
451
1157
  load_status: "error",
452
1158
  message: "This space is experiencing an issue.",
@@ -459,7 +1165,7 @@ async function check_space_status(
459
1165
 
460
1166
  function handle_message(
461
1167
  data: any,
462
- last_status: Status["status"]
1168
+ last_status: Status["stage"]
463
1169
  ): {
464
1170
  type: "hash" | "data" | "update" | "complete" | "generating" | "none";
465
1171
  data?: any;
@@ -477,7 +1183,9 @@ function handle_message(
477
1183
  status: {
478
1184
  queue,
479
1185
  message: QUEUE_FULL_MSG,
480
- status: "error"
1186
+ stage: "error",
1187
+ code: data.code,
1188
+ success: data.success
481
1189
  }
482
1190
  };
483
1191
  case "estimation":
@@ -485,10 +1193,12 @@ function handle_message(
485
1193
  type: "update",
486
1194
  status: {
487
1195
  queue,
488
- status: last_status || "pending",
1196
+ stage: last_status || "pending",
1197
+ code: data.code,
489
1198
  size: data.queue_size,
490
1199
  position: data.rank,
491
- eta: data.rank_eta
1200
+ eta: data.rank_eta,
1201
+ success: data.success
492
1202
  }
493
1203
  };
494
1204
  case "progress":
@@ -496,8 +1206,10 @@ function handle_message(
496
1206
  type: "update",
497
1207
  status: {
498
1208
  queue,
499
- status: "pending",
500
- progress: data.progress_data
1209
+ stage: "pending",
1210
+ code: data.code,
1211
+ progress_data: data.progress_data,
1212
+ success: data.success
501
1213
  }
502
1214
  };
503
1215
  case "process_generating":
@@ -506,8 +1218,9 @@ function handle_message(
506
1218
  status: {
507
1219
  queue,
508
1220
  message: !data.success ? data.output.error : null,
509
- status: data.success ? "generating" : "error",
510
- progress: data.progress_data,
1221
+ stage: data.success ? "generating" : "error",
1222
+ code: data.code,
1223
+ progress_data: data.progress_data,
511
1224
  eta: data.average_duration
512
1225
  },
513
1226
  data: data.success ? data.output : null
@@ -518,8 +1231,9 @@ function handle_message(
518
1231
  status: {
519
1232
  queue,
520
1233
  message: !data.success ? data.output.error : undefined,
521
- status: data.success ? "complete" : "error",
522
- progress: data.progress_data,
1234
+ stage: data.success ? "complete" : "error",
1235
+ code: data.code,
1236
+ progress_data: data.progress_data,
523
1237
  eta: data.output.average_duration
524
1238
  },
525
1239
  data: data.success ? data.output : null
@@ -529,12 +1243,14 @@ function handle_message(
529
1243
  type: "update",
530
1244
  status: {
531
1245
  queue,
532
- status: "pending",
1246
+ stage: "pending",
1247
+ code: data.code,
533
1248
  size: data.rank,
534
- position: 0
1249
+ position: 0,
1250
+ success: data.success
535
1251
  }
536
1252
  };
537
1253
  }
538
1254
 
539
- return { type: "none", status: { status: "error", queue } };
1255
+ return { type: "none", status: { stage: "error", queue } };
540
1256
  }