@gradio/client 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/client.ts CHANGED
@@ -1,36 +1,57 @@
1
+ import semiver from "semiver";
2
+
1
3
  import {
2
4
  process_endpoint,
3
5
  RE_SPACE_NAME,
4
6
  map_names_to_ids,
5
- discussions_enabled
6
- } from "./utils";
7
+ discussions_enabled,
8
+ get_space_hardware,
9
+ set_space_hardware,
10
+ set_space_timeout,
11
+ hardware_types
12
+ } from "./utils.js";
7
13
 
8
14
  import type {
9
15
  EventType,
10
16
  EventListener,
11
17
  ListenerMap,
12
18
  Event,
13
- Config,
14
19
  Payload,
15
20
  PostResponse,
16
21
  UploadResponse,
17
22
  Status,
18
23
  SpaceStatus,
19
24
  SpaceStatusCallback
20
- } from "./types";
25
+ } from "./types.js";
26
+
27
+ import type { Config } from "./types.js";
21
28
 
22
29
  type event = <K extends EventType>(
23
30
  eventType: K,
24
31
  listener: EventListener<K>
25
- ) => ReturnType<predict>;
26
- type predict = (endpoint: string, payload: Payload) => {};
32
+ ) => SubmitReturn;
33
+ type predict = (
34
+ endpoint: string | number,
35
+ data?: unknown[],
36
+ event_data?: unknown
37
+ ) => Promise<unknown>;
27
38
 
28
39
  type client_return = {
29
40
  predict: predict;
30
41
  config: Config;
42
+ submit: (
43
+ endpoint: string | number,
44
+ data?: unknown[],
45
+ event_data?: unknown
46
+ ) => SubmitReturn;
47
+ view_api: (c?: Config) => Promise<Record<string, any>>;
48
+ };
49
+
50
+ type SubmitReturn = {
31
51
  on: event;
32
52
  off: event;
33
- cancel: (endpoint: string, fn_index?: number) => void;
53
+ cancel: () => void;
54
+ destroy: () => void;
34
55
  };
35
56
 
36
57
  const QUEUE_FULL_MSG = "This application is too busy. Keep trying!";
@@ -38,13 +59,21 @@ const BROKEN_CONNECTION_MSG = "Connection errored out.";
38
59
 
39
60
  export async function post_data(
40
61
  url: string,
41
- body: unknown
62
+ body: unknown,
63
+ token?: `hf_${string}`
42
64
  ): Promise<[PostResponse, number]> {
65
+ const headers: {
66
+ Authorization?: string;
67
+ "Content-Type": "application/json";
68
+ } = { "Content-Type": "application/json" };
69
+ if (token) {
70
+ headers.Authorization = `Bearer ${token}`;
71
+ }
43
72
  try {
44
73
  var response = await fetch(url, {
45
74
  method: "POST",
46
75
  body: JSON.stringify(body),
47
- headers: { "Content-Type": "application/json" }
76
+ headers
48
77
  });
49
78
  } catch (e) {
50
79
  return [{ error: BROKEN_CONNECTION_MSG }, 500];
@@ -53,10 +82,20 @@ export async function post_data(
53
82
  return [output, response.status];
54
83
  }
55
84
 
85
+ export let NodeBlob;
86
+
56
87
  export async function upload_files(
57
88
  root: string,
58
- files: Array<File>
89
+ files: Array<File>,
90
+ token?: `hf_${string}`
59
91
  ): Promise<UploadResponse> {
92
+ const headers: {
93
+ Authorization?: string;
94
+ } = {};
95
+ if (token) {
96
+ headers.Authorization = `Bearer ${token}`;
97
+ }
98
+
60
99
  const formData = new FormData();
61
100
  files.forEach((file) => {
62
101
  formData.append("files", file);
@@ -64,7 +103,8 @@ export async function upload_files(
64
103
  try {
65
104
  var response = await fetch(`${root}/upload`, {
66
105
  method: "POST",
67
- body: formData
106
+ body: formData,
107
+ headers
68
108
  });
69
109
  } catch (e) {
70
110
  return { error: BROKEN_CONNECTION_MSG };
@@ -73,88 +113,154 @@ export async function upload_files(
73
113
  return { files: output };
74
114
  }
75
115
 
116
+ export async function duplicate(
117
+ app_reference: string,
118
+ options: {
119
+ hf_token: `hf_${string}`;
120
+ private?: boolean;
121
+ status_callback: SpaceStatusCallback;
122
+ hardware?: typeof hardware_types[number];
123
+ timeout?: number;
124
+ }
125
+ ) {
126
+ const { hf_token, private: _private, hardware, timeout } = options;
127
+
128
+ if (hardware && !hardware_types.includes(hardware)) {
129
+ throw new Error(
130
+ `Invalid hardware type provided. Valid types are: ${hardware_types
131
+ .map((v) => `"${v}"`)
132
+ .join(",")}.`
133
+ );
134
+ }
135
+ const headers = {
136
+ Authorization: `Bearer ${hf_token}`
137
+ };
138
+
139
+ const user = (
140
+ await (
141
+ await fetch(`https://huggingface.co/api/whoami-v2`, {
142
+ headers
143
+ })
144
+ ).json()
145
+ ).name;
146
+
147
+ const space_name = app_reference.split("/")[1];
148
+ const body: {
149
+ repository: string;
150
+ private?: boolean;
151
+ } = {
152
+ repository: `${user}/${space_name}`
153
+ };
154
+
155
+ if (_private) {
156
+ body.private = true;
157
+ }
158
+
159
+ try {
160
+ const response = await fetch(
161
+ `https://huggingface.co/api/spaces/${app_reference}/duplicate`,
162
+ {
163
+ method: "POST",
164
+ headers: { "Content-Type": "application/json", ...headers },
165
+ body: JSON.stringify(body)
166
+ }
167
+ );
168
+
169
+ if (response.status === 409) {
170
+ return client(`${user}/${space_name}`, options);
171
+ } else {
172
+ const duplicated_space = await response.json();
173
+
174
+ let original_hardware;
175
+
176
+ if (!hardware) {
177
+ original_hardware = await get_space_hardware(app_reference, hf_token);
178
+ }
179
+
180
+ const requested_hardware = hardware || original_hardware || "cpu-basic";
181
+ await set_space_hardware(
182
+ `${user}/${space_name}`,
183
+ requested_hardware,
184
+ hf_token
185
+ );
186
+
187
+ await set_space_timeout(
188
+ `${user}/${space_name}`,
189
+ timeout || 300,
190
+ hf_token
191
+ );
192
+ return client(duplicated_space.url, options);
193
+ }
194
+ } catch (e: any) {
195
+ throw new Error(e);
196
+ }
197
+ }
198
+
76
199
  export async function client(
77
200
  app_reference: string,
78
- space_status_callback?: SpaceStatusCallback
201
+ options: {
202
+ hf_token?: `hf_${string}`;
203
+ status_callback?: SpaceStatusCallback;
204
+ } = {}
79
205
  ): Promise<client_return> {
80
- return new Promise(async (res, rej) => {
206
+ return new Promise(async (res) => {
207
+ const { status_callback, hf_token } = options;
81
208
  const return_obj = {
82
209
  predict,
83
- on,
84
- off,
85
- cancel
210
+ submit,
211
+ view_api
212
+ // duplicate
86
213
  };
87
214
 
88
- const listener_map: ListenerMap<EventType> = {};
215
+ if (typeof window === "undefined" || !("WebSocket" in window)) {
216
+ const ws = await import("ws");
217
+ NodeBlob = (await import("node:buffer")).Blob;
218
+ //@ts-ignore
219
+ global.WebSocket = ws.WebSocket;
220
+ }
221
+
89
222
  const { ws_protocol, http_protocol, host, space_id } =
90
- await process_endpoint(app_reference);
223
+ await process_endpoint(app_reference, hf_token);
224
+
91
225
  const session_hash = Math.random().toString(36).substring(2);
92
- const ws_map = new Map<number, WebSocket>();
93
- const last_status: Record<string, Status["status"]> = {};
226
+ const last_status: Record<string, Status["stage"]> = {};
94
227
  let config: Config;
95
228
  let api_map: Record<string, number> = {};
96
229
 
97
- function config_success(_config: Config) {
230
+ const listener_map: ListenerMap<EventType> = {};
231
+
232
+ let jwt: false | string = false;
233
+
234
+ if (hf_token && space_id) {
235
+ jwt = await get_jwt(space_id, hf_token);
236
+ }
237
+
238
+ async function config_success(_config: Config) {
98
239
  config = _config;
99
240
  api_map = map_names_to_ids(_config?.dependencies || []);
241
+ try {
242
+ api = await view_api(config);
243
+ } catch (e) {
244
+ console.error(`Could not get api details: ${e.message}`);
245
+ }
246
+
100
247
  return {
101
248
  config,
102
249
  ...return_obj
103
250
  };
104
251
  }
105
-
106
- function on<K extends EventType>(eventType: K, listener: EventListener<K>) {
107
- const narrowed_listener_map: ListenerMap<K> = listener_map;
108
- let listeners = narrowed_listener_map[eventType] || [];
109
- narrowed_listener_map[eventType] = listeners;
110
- listeners?.push(listener);
111
-
112
- return { ...return_obj, config };
113
- }
114
-
115
- function off<K extends EventType>(
116
- eventType: K,
117
- listener: EventListener<K>
118
- ) {
119
- const narrowed_listener_map: ListenerMap<K> = listener_map;
120
- let listeners = narrowed_listener_map[eventType] || [];
121
- listeners = listeners?.filter((l) => l !== listener);
122
- narrowed_listener_map[eventType] = listeners;
123
-
124
- return { ...return_obj, config };
125
- }
126
-
127
- function cancel(endpoint: string, fn_index?: number) {
128
- const _index =
129
- typeof fn_index === "number" ? fn_index : api_map[endpoint];
130
-
131
- fire_event({
132
- type: "status",
133
- endpoint,
134
- fn_index: _index,
135
- status: "complete",
136
- queue: false
137
- });
138
-
139
- ws_map.get(_index)?.close();
140
- }
141
-
142
- function fire_event<K extends EventType>(event: Event<K>) {
143
- const narrowed_listener_map: ListenerMap<K> = listener_map;
144
- let listeners = narrowed_listener_map[event.type] || [];
145
- listeners?.forEach((l) => l(event));
146
- }
147
-
252
+ let api;
148
253
  async function handle_space_sucess(status: SpaceStatus) {
149
- if (space_status_callback) space_status_callback(status);
254
+ if (status_callback) status_callback(status);
150
255
  if (status.status === "running")
151
256
  try {
152
- console.log(host);
153
- config = await resolve_config(`${http_protocol}//${host}`);
154
- res(config_success(config));
257
+ config = await resolve_config(`${http_protocol}//${host}`, hf_token);
258
+
259
+ const _config = await config_success(config);
260
+ res(_config);
155
261
  } catch (e) {
156
- if (space_status_callback) {
157
- space_status_callback({
262
+ if (status_callback) {
263
+ status_callback({
158
264
  status: "error",
159
265
  message: "Could not load this space.",
160
266
  load_status: "error",
@@ -165,9 +271,13 @@ export async function client(
165
271
  }
166
272
 
167
273
  try {
168
- config = await resolve_config(`${http_protocol}//${host}`);
169
- res(config_success(config));
274
+ console.log(`${http_protocol}//${host}`);
275
+ config = await resolve_config(`${http_protocol}//${host}`, hf_token);
276
+ console.log(config);
277
+ const _config = await config_success(config);
278
+ res(_config);
170
279
  } catch (e) {
280
+ console.log(space_id, e);
171
281
  if (space_id) {
172
282
  check_space_status(
173
283
  space_id,
@@ -175,8 +285,8 @@ export async function client(
175
285
  handle_space_sucess
176
286
  );
177
287
  } else {
178
- if (space_status_callback)
179
- space_status_callback({
288
+ if (status_callback)
289
+ status_callback({
180
290
  status: "error",
181
291
  message: "Could not load this space.",
182
292
  load_status: "error",
@@ -184,96 +294,166 @@ export async function client(
184
294
  });
185
295
  }
186
296
  }
187
- function make_predict(endpoint: string, payload: Payload) {
297
+
298
+ /**
299
+ * Run a prediction.
300
+ * @param endpoint - The prediction endpoint to use.
301
+ * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
302
+ * @return Returns the data for the prediction or an error message.
303
+ */
304
+ function predict(endpoint: string, data: unknown[], event_data?: unknown) {
305
+ let data_returned = false;
306
+ let status_complete = false;
188
307
  return new Promise((res, rej) => {
308
+ const app = submit(endpoint, data, event_data);
309
+
310
+ app
311
+ .on("data", (d) => {
312
+ data_returned = true;
313
+ if (status_complete) {
314
+ app.destroy();
315
+ }
316
+ res(d);
317
+ })
318
+ .on("status", (status) => {
319
+ if (status.stage === "error") rej(status);
320
+ if (status.stage === "complete" && data_returned) {
321
+ app.destroy();
322
+ }
323
+ if (status.stage === "complete") {
324
+ status_complete = true;
325
+ }
326
+ });
327
+ });
328
+ }
329
+
330
+ function submit(
331
+ endpoint: string | number,
332
+ data: unknown[],
333
+ event_data?: unknown
334
+ ): SubmitReturn {
335
+ let fn_index: number;
336
+ let api_info;
337
+ if (typeof endpoint === "number") {
338
+ fn_index = endpoint;
339
+ api_info = api.unnamed_endpoints[fn_index];
340
+ } else {
189
341
  const trimmed_endpoint = endpoint.replace(/^\//, "");
190
- let fn_index =
191
- typeof payload.fn_index === "number"
192
- ? payload.fn_index
193
- : api_map[trimmed_endpoint];
194
342
 
343
+ fn_index = api_map[trimmed_endpoint];
344
+ api_info = api.named_endpoints[endpoint.trim()];
345
+ }
346
+
347
+ if (typeof fn_index !== "number") {
348
+ throw new Error(
349
+ "There is no endpoint matching that name of fn_index matching that number."
350
+ );
351
+ }
352
+
353
+ let websocket: WebSocket;
354
+
355
+ const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
356
+ let payload: Payload;
357
+
358
+ //@ts-ignore
359
+ handle_blob(
360
+ `${http_protocol}//${host + config.path}`,
361
+ data,
362
+ api_info,
363
+ hf_token
364
+ ).then((_payload) => {
365
+ payload = { data: _payload || [], event_data, fn_index };
195
366
  if (skip_queue(fn_index, config)) {
196
367
  fire_event({
197
368
  type: "status",
198
- endpoint,
199
- status: "pending",
369
+ endpoint: _endpoint,
370
+ stage: "pending",
200
371
  queue: false,
201
- fn_index
372
+ fn_index,
373
+ time: new Date()
202
374
  });
203
375
 
204
376
  post_data(
205
377
  `${http_protocol}//${host + config.path}/run${
206
- endpoint.startsWith("/") ? endpoint : `/${endpoint}`
378
+ _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
207
379
  }`,
208
380
  {
209
381
  ...payload,
210
382
  session_hash
211
- }
383
+ },
384
+ hf_token
212
385
  )
213
386
  .then(([output, status_code]) => {
214
387
  if (status_code == 200) {
215
388
  fire_event({
216
389
  type: "status",
217
- endpoint,
390
+ endpoint: _endpoint,
218
391
  fn_index,
219
- status: "complete",
392
+ stage: "complete",
220
393
  eta: output.average_duration,
221
- queue: false
394
+ queue: false,
395
+ time: new Date()
222
396
  });
223
397
 
224
398
  fire_event({
225
399
  type: "data",
226
- endpoint,
400
+ endpoint: _endpoint,
227
401
  fn_index,
228
- data: output.data
402
+ data: output.data,
403
+ time: new Date()
229
404
  });
230
405
  } else {
231
406
  fire_event({
232
407
  type: "status",
233
- status: "error",
234
- endpoint,
408
+ stage: "error",
409
+ endpoint: _endpoint,
235
410
  fn_index,
236
411
  message: output.error,
237
- queue: false
412
+ queue: false,
413
+ time: new Date()
238
414
  });
239
415
  }
240
416
  })
241
417
  .catch((e) => {
242
418
  fire_event({
243
419
  type: "status",
244
- status: "error",
420
+ stage: "error",
245
421
  message: e.message,
246
- endpoint,
422
+ endpoint: _endpoint,
247
423
  fn_index,
248
- queue: false
424
+ queue: false,
425
+ time: new Date()
249
426
  });
250
- throw new Error(e.message);
251
427
  });
252
428
  } else {
253
429
  fire_event({
254
430
  type: "status",
255
- status: "pending",
431
+ stage: "pending",
256
432
  queue: true,
257
- endpoint,
258
- fn_index
433
+ endpoint: _endpoint,
434
+ fn_index,
435
+ time: new Date()
259
436
  });
260
437
 
261
- const ws_endpoint = `${ws_protocol}://${
262
- host + config.path
263
- }/queue/join`;
438
+ let url = new URL(`${ws_protocol}://${host}${config.path}
439
+ /queue/join`);
440
+
441
+ if (jwt) {
442
+ url.searchParams.set("__sign", jwt);
443
+ }
264
444
 
265
- const websocket = new WebSocket(ws_endpoint);
445
+ websocket = new WebSocket(url);
266
446
 
267
- ws_map.set(fn_index, websocket);
268
447
  websocket.onclose = (evt) => {
269
448
  if (!evt.wasClean) {
270
449
  fire_event({
271
450
  type: "status",
272
- status: "error",
451
+ stage: "error",
273
452
  message: BROKEN_CONNECTION_MSG,
274
453
  queue: true,
275
- endpoint,
276
- fn_index
454
+ endpoint: _endpoint,
455
+ fn_index,
456
+ time: new Date()
277
457
  });
278
458
  }
279
459
  };
@@ -287,10 +467,15 @@ export async function client(
287
467
 
288
468
  if (type === "update" && status) {
289
469
  // call 'status' listeners
290
- fire_event({ type: "status", endpoint, fn_index, ...status });
291
- if (status.status === "error") {
470
+ fire_event({
471
+ type: "status",
472
+ endpoint: _endpoint,
473
+ fn_index,
474
+ time: new Date(),
475
+ ...status
476
+ });
477
+ if (status.stage === "error") {
292
478
  websocket.close();
293
- rej(status);
294
479
  }
295
480
  } else if (type === "hash") {
296
481
  websocket.send(JSON.stringify({ fn_index, session_hash }));
@@ -300,66 +485,505 @@ export async function client(
300
485
  } else if (type === "complete") {
301
486
  fire_event({
302
487
  type: "status",
488
+ time: new Date(),
303
489
  ...status,
304
- status: status?.status!,
490
+ stage: status?.stage!,
305
491
  queue: true,
306
- endpoint,
492
+ endpoint: _endpoint,
307
493
  fn_index
308
494
  });
309
495
  websocket.close();
310
496
  } else if (type === "generating") {
311
497
  fire_event({
312
498
  type: "status",
499
+ time: new Date(),
313
500
  ...status,
314
- status: status?.status!,
501
+ stage: status?.stage!,
315
502
  queue: true,
316
- endpoint,
503
+ endpoint: _endpoint,
317
504
  fn_index
318
505
  });
319
506
  }
320
507
  if (data) {
321
508
  fire_event({
322
509
  type: "data",
510
+ time: new Date(),
323
511
  data: data.data,
324
- endpoint,
512
+ endpoint: _endpoint,
325
513
  fn_index
326
514
  });
327
- res({ data: data.data });
328
515
  }
329
516
  };
517
+
518
+ // different ws contract for gradio versions older than 3.6.0
519
+ //@ts-ignore
520
+ if (semiver(config.version || "2.0.0", "3.6") < 0) {
521
+ addEventListener("open", () =>
522
+ websocket.send(JSON.stringify({ hash: session_hash }))
523
+ );
524
+ }
330
525
  }
331
526
  });
527
+
528
+ function fire_event<K extends EventType>(event: Event<K>) {
529
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
530
+ let listeners = narrowed_listener_map[event.type] || [];
531
+ listeners?.forEach((l) => l(event));
532
+ }
533
+
534
+ function on<K extends EventType>(
535
+ eventType: K,
536
+ listener: EventListener<K>
537
+ ) {
538
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
539
+ let listeners = narrowed_listener_map[eventType] || [];
540
+ narrowed_listener_map[eventType] = listeners;
541
+ listeners?.push(listener);
542
+
543
+ return { on, off, cancel, destroy };
544
+ }
545
+
546
+ function off<K extends EventType>(
547
+ eventType: K,
548
+ listener: EventListener<K>
549
+ ) {
550
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
551
+ let listeners = narrowed_listener_map[eventType] || [];
552
+ listeners = listeners?.filter((l) => l !== listener);
553
+ narrowed_listener_map[eventType] = listeners;
554
+
555
+ return { on, off, cancel, destroy };
556
+ }
557
+
558
+ async function cancel() {
559
+ fire_event({
560
+ type: "status",
561
+ endpoint: _endpoint,
562
+ fn_index: fn_index,
563
+ stage: "complete",
564
+ queue: false,
565
+ time: new Date()
566
+ });
567
+
568
+ try {
569
+ await fetch(`${http_protocol}//${host + config.path}/reset`, {
570
+ method: "POST",
571
+ body: JSON.stringify(session_hash)
572
+ });
573
+ } catch (e) {
574
+ console.warn(
575
+ "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
576
+ );
577
+ }
578
+
579
+ if (websocket && websocket.readyState === 0) {
580
+ websocket.addEventListener("open", () => {
581
+ websocket.close();
582
+ });
583
+ } else {
584
+ websocket.close();
585
+ }
586
+
587
+ destroy();
588
+ }
589
+
590
+ function destroy() {
591
+ for (const event_type in listener_map) {
592
+ listener_map[event_type as "data" | "status"].forEach((fn) => {
593
+ off(event_type as "data" | "status", fn);
594
+ });
595
+ }
596
+ }
597
+
598
+ return {
599
+ on,
600
+ off,
601
+ cancel,
602
+ destroy
603
+ };
332
604
  }
333
605
 
334
- /**
335
- * Run a prediction.
336
- * @param endpoint - The prediction endpoint to use.
337
- * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
338
- * @return Returns the data for the prediction or an error message.
339
- */
340
- function predict(endpoint: string, payload: Payload) {
341
- return make_predict(endpoint, payload);
606
+ async function view_api(
607
+ config?: Config
608
+ ): Promise<ApiInfo<JsApiData> | [{ error: string }, 500]> {
609
+ if (api) return api;
610
+
611
+ const headers: {
612
+ Authorization?: string;
613
+ "Content-Type": "application/json";
614
+ } = { "Content-Type": "application/json" };
615
+ if (hf_token) {
616
+ headers.Authorization = `Bearer ${hf_token}`;
617
+ }
618
+ try {
619
+ let response: Response;
620
+ // @ts-ignore
621
+ if (semiver(config.version || "2.0.0", "3.30") < 0) {
622
+ response = await fetch(
623
+ "https://gradio-space-api-fetcher-v2.hf.space/api",
624
+ {
625
+ method: "POST",
626
+ body: JSON.stringify({
627
+ serialize: false,
628
+ config: JSON.stringify(config)
629
+ }),
630
+ headers
631
+ }
632
+ );
633
+ } else {
634
+ response = await fetch(`${http_protocol}//${host}/info`, {
635
+ headers
636
+ });
637
+ }
638
+
639
+ let api_info = (await response.json()) as
640
+ | ApiInfo<ApiData>
641
+ | { api: ApiInfo<ApiData> };
642
+ if ("api" in api_info) {
643
+ api_info = api_info.api;
644
+ }
645
+
646
+ if (
647
+ api_info.named_endpoints["/predict"] &&
648
+ !api_info.unnamed_endpoints["0"]
649
+ ) {
650
+ api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"];
651
+ }
652
+
653
+ const x = transform_api_info(api_info, config, api_map);
654
+ return x;
655
+ } catch (e) {
656
+ return [{ error: BROKEN_CONNECTION_MSG }, 500];
657
+ }
658
+ }
659
+ });
660
+ }
661
+
662
+ interface ApiData {
663
+ label: string;
664
+ type: {
665
+ type: any;
666
+ description: string;
667
+ };
668
+ component: string;
669
+ example_input?: any;
670
+ }
671
+
672
+ interface JsApiData {
673
+ label: string;
674
+ type: string;
675
+ component: string;
676
+ example_input: any;
677
+ }
678
+
679
+ interface EndpointInfo<T extends ApiData | JsApiData> {
680
+ parameters: T[];
681
+ returns: T[];
682
+ }
683
+ interface ApiInfo<T extends ApiData | JsApiData> {
684
+ named_endpoints: {
685
+ [key: string]: EndpointInfo<T>;
686
+ };
687
+ unnamed_endpoints: {
688
+ [key: string]: EndpointInfo<T>;
689
+ };
690
+ }
691
+
692
+ function get_type(
693
+ type: { [key: string]: any },
694
+ component: string,
695
+ serializer: string,
696
+ signature_type: "return" | "parameter"
697
+ ) {
698
+ switch (type.type) {
699
+ case "string":
700
+ return "string";
701
+ case "boolean":
702
+ return "boolean";
703
+ case "number":
704
+ return "number";
705
+ }
706
+
707
+ if (
708
+ serializer === "JSONSerializable" ||
709
+ serializer === "StringSerializable"
710
+ ) {
711
+ return "any";
712
+ } else if (serializer === "ListStringSerializable") {
713
+ return "string[]";
714
+ } else if (component === "Image") {
715
+ return signature_type === "parameter" ? "Blob | File | Buffer" : "string";
716
+ } else if (serializer === "FileSerializable") {
717
+ if (type?.type === "array") {
718
+ return signature_type === "parameter"
719
+ ? "(Blob | File | Buffer)[]"
720
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`;
721
+ } else {
722
+ return signature_type === "parameter"
723
+ ? "Blob | File | Buffer"
724
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`;
725
+ }
726
+ } else if (serializer === "GallerySerializable") {
727
+ return signature_type === "parameter"
728
+ ? "[(Blob | File | Buffer), (string | null)][]"
729
+ : `[{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}, (string | null))][]`;
730
+ }
731
+ }
732
+
733
+ function get_description(
734
+ type: { type: any; description: string },
735
+ serializer: string
736
+ ) {
737
+ if (serializer === "GallerySerializable") {
738
+ return "array of [file, label] tuples";
739
+ } else if (serializer === "ListStringSerializable") {
740
+ return "array of strings";
741
+ } else if (serializer === "FileSerializable") {
742
+ return "array of files or single file";
743
+ } else {
744
+ return type.description;
745
+ }
746
+ }
747
+
748
+ function transform_api_info(
749
+ api_info: ApiInfo<ApiData>,
750
+ config: Config,
751
+ api_map: Record<string, number>
752
+ ): ApiInfo<JsApiData> {
753
+ const new_data = {
754
+ named_endpoints: {},
755
+ unnamed_endpoints: {}
756
+ };
757
+ for (const key in api_info) {
758
+ const cat = api_info[key];
759
+
760
+ for (const endpoint in cat) {
761
+ const dep_index = config.dependencies[endpoint]
762
+ ? endpoint
763
+ : api_map[endpoint.replace("/", "")];
764
+
765
+ const info = cat[endpoint];
766
+ new_data[key][endpoint] = {};
767
+ new_data[key][endpoint].parameters = {};
768
+ new_data[key][endpoint].returns = {};
769
+ new_data[key][endpoint].type = config.dependencies[dep_index].types;
770
+ new_data[key][endpoint].parameters = info.parameters.map(
771
+ ({ label, component, type, serializer }) => ({
772
+ label,
773
+ component,
774
+ type: get_type(type, component, serializer, "parameter"),
775
+ description: get_description(type, serializer)
776
+ })
777
+ );
778
+
779
+ new_data[key][endpoint].returns = info.returns.map(
780
+ ({ label, component, type, serializer }) => ({
781
+ label,
782
+ component,
783
+ type: get_type(type, component, serializer, "return"),
784
+ description: get_description(type, serializer)
785
+ })
786
+ );
342
787
  }
788
+ }
789
+
790
+ return new_data;
791
+ }
792
+
793
+ async function get_jwt(
794
+ space: string,
795
+ token: `hf_${string}`
796
+ ): Promise<string | false> {
797
+ try {
798
+ const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
799
+ headers: {
800
+ Authorization: `Bearer ${token}`
801
+ }
802
+ });
803
+
804
+ const jwt = (await r.json()).token;
805
+
806
+ return jwt || false;
807
+ } catch (e) {
808
+ console.error(e);
809
+ return false;
810
+ }
811
+ }
812
+
813
+ export async function handle_blob(
814
+ endpoint: string,
815
+ data: unknown[],
816
+ api_info,
817
+ token?: `hf_${string}`
818
+ ): Promise<unknown[]> {
819
+ const blob_refs = await walk_and_store_blobs(
820
+ data,
821
+ undefined,
822
+ [],
823
+ true,
824
+ api_info
825
+ );
826
+
827
+ return new Promise((res) => {
828
+ Promise.all(
829
+ blob_refs.map(async ({ path, blob, data, type }) => {
830
+ if (blob) {
831
+ const file_url = (await upload_files(endpoint, [blob], token))
832
+ .files[0];
833
+ return { path, file_url, type };
834
+ } else {
835
+ return { path, base64: data, type };
836
+ }
837
+ })
838
+ )
839
+ .then((r) => {
840
+ r.forEach(({ path, file_url, base64, type }) => {
841
+ if (base64) {
842
+ update_object(data, base64, path);
843
+ } else if (type === "Gallery") {
844
+ update_object(data, file_url, path);
845
+ } else if (file_url) {
846
+ const o = {
847
+ is_file: true,
848
+ name: `${file_url}`,
849
+ data: null
850
+ // orig_name: "file.csv"
851
+ };
852
+ update_object(data, o, path);
853
+ }
854
+ });
855
+
856
+ res(data);
857
+ })
858
+ .catch(console.log);
859
+ });
860
+ }
861
+
862
+ function update_object(object, newValue, stack) {
863
+ while (stack.length > 1) {
864
+ object = object[stack.shift()];
865
+ }
866
+
867
+ object[stack.shift()] = newValue;
868
+ }
869
+
870
+ export async function walk_and_store_blobs(
871
+ param,
872
+ type = undefined,
873
+ path = [],
874
+ root = false,
875
+ api_info = undefined
876
+ ) {
877
+ if (Array.isArray(param)) {
878
+ let blob_refs = [];
879
+
880
+ await Promise.all(
881
+ param.map(async (v, i) => {
882
+ let new_path = path.slice();
883
+ new_path.push(i);
884
+
885
+ const array_refs = await walk_and_store_blobs(
886
+ param[i],
887
+ root ? api_info?.parameters[i]?.component || undefined : type,
888
+ new_path,
889
+ false,
890
+ api_info
891
+ );
892
+
893
+ blob_refs = blob_refs.concat(array_refs);
894
+ })
895
+ );
896
+
897
+ return blob_refs;
898
+ } else if (globalThis.Buffer && param instanceof globalThis.Buffer) {
899
+ const is_image = type === "Image";
900
+ return [
901
+ {
902
+ path: path,
903
+ blob: is_image ? false : new NodeBlob([param]),
904
+ data: is_image ? `${param.toString("base64")}` : false,
905
+ type
906
+ }
907
+ ];
908
+ } else if (
909
+ param instanceof Blob ||
910
+ (typeof window !== "undefined" && param instanceof File)
911
+ ) {
912
+ if (type === "Image") {
913
+ let data;
914
+
915
+ if (typeof window !== "undefined") {
916
+ // browser
917
+ data = await image_to_data_uri(param);
918
+ } else {
919
+ const buffer = await param.arrayBuffer();
920
+ data = Buffer.from(buffer).toString("base64");
921
+ }
922
+
923
+ return [{ path, data, type }];
924
+ } else {
925
+ return [{ path: path, blob: param, type }];
926
+ }
927
+ } else if (typeof param === "object") {
928
+ let blob_refs = [];
929
+ for (let key in param) {
930
+ if (param.hasOwnProperty(key)) {
931
+ let new_path = path.slice();
932
+ new_path.push(key);
933
+ blob_refs = blob_refs.concat(
934
+ await walk_and_store_blobs(
935
+ param[key],
936
+ undefined,
937
+ new_path,
938
+ false,
939
+ api_info
940
+ )
941
+ );
942
+ }
943
+ }
944
+ return blob_refs;
945
+ } else {
946
+ return [];
947
+ }
948
+ }
949
+
950
+ function image_to_data_uri(blob: Blob) {
951
+ return new Promise((resolve, _) => {
952
+ const reader = new FileReader();
953
+ reader.onloadend = () => resolve(reader.result);
954
+ reader.readAsDataURL(blob);
343
955
  });
344
956
  }
345
957
 
346
958
  function skip_queue(id: number, config: Config) {
347
959
  return (
348
- !(config?.dependencies?.[id].queue === null
960
+ !(config?.dependencies?.[id]?.queue === null
349
961
  ? config.enable_queue
350
- : config?.dependencies?.[id].queue) || false
962
+ : config?.dependencies?.[id]?.queue) || false
351
963
  );
352
964
  }
353
965
 
354
- async function resolve_config(endpoint?: string): Promise<Config> {
355
- if (window.gradio_config && location.origin !== "http://localhost:9876") {
966
+ async function resolve_config(
967
+ endpoint?: string,
968
+ token?: `hf_${string}`
969
+ ): Promise<Config> {
970
+ const headers: { Authorization?: string } = {};
971
+ if (token) {
972
+ headers.Authorization = `Bearer ${token}`;
973
+ }
974
+ if (
975
+ typeof window !== "undefined" &&
976
+ window.gradio_config &&
977
+ location.origin !== "http://localhost:9876"
978
+ ) {
356
979
  const path = window.gradio_config.root;
357
980
  const config = window.gradio_config;
358
981
  config.root = endpoint + config.root;
359
982
  return { ...config, path: path };
360
983
  } else if (endpoint) {
361
- let response = await fetch(`${endpoint}/config`);
362
-
984
+ console.log(`${endpoint}/config`, headers);
985
+ let response = await fetch(`${endpoint}/config`, { headers });
986
+ console.log(response);
363
987
  if (response.status === 200) {
364
988
  const config = await response.json();
365
989
  config.path = config.path ?? "";
@@ -376,7 +1000,7 @@ async function resolve_config(endpoint?: string): Promise<Config> {
376
1000
  async function check_space_status(
377
1001
  id: string,
378
1002
  type: "subdomain" | "space_name",
379
- space_status_callback: SpaceStatusCallback
1003
+ status_callback: SpaceStatusCallback
380
1004
  ) {
381
1005
  let endpoint =
382
1006
  type === "subdomain"
@@ -392,7 +1016,7 @@ async function check_space_status(
392
1016
  }
393
1017
  response = await response.json();
394
1018
  } catch (e) {
395
- space_status_callback({
1019
+ status_callback({
396
1020
  status: "error",
397
1021
  load_status: "error",
398
1022
  message: "Could not get space status",
@@ -410,7 +1034,7 @@ async function check_space_status(
410
1034
  switch (stage) {
411
1035
  case "STOPPED":
412
1036
  case "SLEEPING":
413
- space_status_callback({
1037
+ status_callback({
414
1038
  status: "sleeping",
415
1039
  load_status: "pending",
416
1040
  message: "Space is asleep. Waking it up...",
@@ -418,13 +1042,13 @@ async function check_space_status(
418
1042
  });
419
1043
 
420
1044
  setTimeout(() => {
421
- check_space_status(id, type, space_status_callback);
1045
+ check_space_status(id, type, status_callback);
422
1046
  }, 1000);
423
1047
  break;
424
1048
  // poll for status
425
1049
  case "RUNNING":
426
1050
  case "RUNNING_BUILDING":
427
- space_status_callback({
1051
+ status_callback({
428
1052
  status: "running",
429
1053
  load_status: "complete",
430
1054
  message: "",
@@ -434,7 +1058,7 @@ async function check_space_status(
434
1058
  // launch
435
1059
  break;
436
1060
  case "BUILDING":
437
- space_status_callback({
1061
+ status_callback({
438
1062
  status: "building",
439
1063
  load_status: "pending",
440
1064
  message: "Space is building...",
@@ -442,11 +1066,11 @@ async function check_space_status(
442
1066
  });
443
1067
 
444
1068
  setTimeout(() => {
445
- check_space_status(id, type, space_status_callback);
1069
+ check_space_status(id, type, status_callback);
446
1070
  }, 1000);
447
1071
  break;
448
1072
  default:
449
- space_status_callback({
1073
+ status_callback({
450
1074
  status: "space_error",
451
1075
  load_status: "error",
452
1076
  message: "This space is experiencing an issue.",
@@ -459,7 +1083,7 @@ async function check_space_status(
459
1083
 
460
1084
  function handle_message(
461
1085
  data: any,
462
- last_status: Status["status"]
1086
+ last_status: Status["stage"]
463
1087
  ): {
464
1088
  type: "hash" | "data" | "update" | "complete" | "generating" | "none";
465
1089
  data?: any;
@@ -477,7 +1101,9 @@ function handle_message(
477
1101
  status: {
478
1102
  queue,
479
1103
  message: QUEUE_FULL_MSG,
480
- status: "error"
1104
+ stage: "error",
1105
+ code: data.code,
1106
+ success: data.success
481
1107
  }
482
1108
  };
483
1109
  case "estimation":
@@ -485,10 +1111,12 @@ function handle_message(
485
1111
  type: "update",
486
1112
  status: {
487
1113
  queue,
488
- status: last_status || "pending",
1114
+ stage: last_status || "pending",
1115
+ code: data.code,
489
1116
  size: data.queue_size,
490
1117
  position: data.rank,
491
- eta: data.rank_eta
1118
+ eta: data.rank_eta,
1119
+ success: data.success
492
1120
  }
493
1121
  };
494
1122
  case "progress":
@@ -496,8 +1124,10 @@ function handle_message(
496
1124
  type: "update",
497
1125
  status: {
498
1126
  queue,
499
- status: "pending",
500
- progress: data.progress_data
1127
+ stage: "pending",
1128
+ code: data.code,
1129
+ progress_data: data.progress_data,
1130
+ success: data.success
501
1131
  }
502
1132
  };
503
1133
  case "process_generating":
@@ -506,8 +1136,9 @@ function handle_message(
506
1136
  status: {
507
1137
  queue,
508
1138
  message: !data.success ? data.output.error : null,
509
- status: data.success ? "generating" : "error",
510
- progress: data.progress_data,
1139
+ stage: data.success ? "generating" : "error",
1140
+ code: data.code,
1141
+ progress_data: data.progress_data,
511
1142
  eta: data.average_duration
512
1143
  },
513
1144
  data: data.success ? data.output : null
@@ -518,8 +1149,9 @@ function handle_message(
518
1149
  status: {
519
1150
  queue,
520
1151
  message: !data.success ? data.output.error : undefined,
521
- status: data.success ? "complete" : "error",
522
- progress: data.progress_data,
1152
+ stage: data.success ? "complete" : "error",
1153
+ code: data.code,
1154
+ progress_data: data.progress_data,
523
1155
  eta: data.output.average_duration
524
1156
  },
525
1157
  data: data.success ? data.output : null
@@ -529,12 +1161,14 @@ function handle_message(
529
1161
  type: "update",
530
1162
  status: {
531
1163
  queue,
532
- status: "pending",
1164
+ stage: "pending",
1165
+ code: data.code,
533
1166
  size: data.rank,
534
- position: 0
1167
+ position: 0,
1168
+ success: data.success
535
1169
  }
536
1170
  };
537
1171
  }
538
1172
 
539
- return { type: "none", status: { status: "error", queue } };
1173
+ return { type: "none", status: { stage: "error", queue } };
540
1174
  }