@gradio/client 0.1.3 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/client.ts CHANGED
@@ -58,72 +58,18 @@ type SubmitReturn = {
58
58
  const QUEUE_FULL_MSG = "This application is too busy. Keep trying!";
59
59
  const BROKEN_CONNECTION_MSG = "Connection errored out.";
60
60
 
61
- export async function post_data(
62
- url: string,
63
- body: unknown,
64
- token?: `hf_${string}`
65
- ): Promise<[PostResponse, number]> {
66
- const headers: {
67
- Authorization?: string;
68
- "Content-Type": "application/json";
69
- } = { "Content-Type": "application/json" };
70
- if (token) {
71
- headers.Authorization = `Bearer ${token}`;
72
- }
73
- try {
74
- var response = await fetch(url, {
75
- method: "POST",
76
- body: JSON.stringify(body),
77
- headers
78
- });
79
- } catch (e) {
80
- return [{ error: BROKEN_CONNECTION_MSG }, 500];
81
- }
82
- const output: PostResponse = await response.json();
83
- return [output, response.status];
84
- }
85
-
86
61
  export let NodeBlob;
87
62
 
88
- export async function upload_files(
89
- root: string,
90
- files: Array<File>,
91
- token?: `hf_${string}`
92
- ): Promise<UploadResponse> {
93
- const headers: {
94
- Authorization?: string;
95
- } = {};
96
- if (token) {
97
- headers.Authorization = `Bearer ${token}`;
98
- }
99
-
100
- const formData = new FormData();
101
- files.forEach((file) => {
102
- formData.append("files", file);
103
- });
104
- try {
105
- var response = await fetch(`${root}/upload`, {
106
- method: "POST",
107
- body: formData,
108
- headers
109
- });
110
- } catch (e) {
111
- return { error: BROKEN_CONNECTION_MSG };
112
- }
113
- const output: UploadResponse["files"] = await response.json();
114
- return { files: output };
115
- }
116
-
117
63
  export async function duplicate(
118
64
  app_reference: string,
119
65
  options: {
120
66
  hf_token: `hf_${string}`;
121
67
  private?: boolean;
122
68
  status_callback: SpaceStatusCallback;
123
- hardware?: typeof hardware_types[number];
69
+ hardware?: (typeof hardware_types)[number];
124
70
  timeout?: number;
125
71
  }
126
- ) {
72
+ ): Promise<client_return> {
127
73
  const { hf_token, private: _private, hardware, timeout } = options;
128
74
 
129
75
  if (hardware && !hardware_types.includes(hardware)) {
@@ -169,517 +115,688 @@ export async function duplicate(
169
115
 
170
116
  if (response.status === 409) {
171
117
  return client(`${user}/${space_name}`, options);
172
- } else {
173
- const duplicated_space = await response.json();
118
+ }
119
+ const duplicated_space = await response.json();
174
120
 
175
- let original_hardware;
121
+ let original_hardware;
176
122
 
177
- if (!hardware) {
178
- original_hardware = await get_space_hardware(app_reference, hf_token);
179
- }
123
+ if (!hardware) {
124
+ original_hardware = await get_space_hardware(app_reference, hf_token);
125
+ }
180
126
 
181
- const requested_hardware = hardware || original_hardware || "cpu-basic";
182
- await set_space_hardware(
183
- `${user}/${space_name}`,
184
- requested_hardware,
185
- hf_token
186
- );
127
+ const requested_hardware = hardware || original_hardware || "cpu-basic";
128
+ await set_space_hardware(
129
+ `${user}/${space_name}`,
130
+ requested_hardware,
131
+ hf_token
132
+ );
187
133
 
188
- await set_space_timeout(
189
- `${user}/${space_name}`,
190
- timeout || 300,
191
- hf_token
192
- );
193
- return client(duplicated_space.url, options);
194
- }
134
+ await set_space_timeout(`${user}/${space_name}`, timeout || 300, hf_token);
135
+ return client(duplicated_space.url, options);
195
136
  } catch (e: any) {
196
137
  throw new Error(e);
197
138
  }
198
139
  }
199
140
 
200
- export async function client(
201
- app_reference: string,
202
- options: {
203
- hf_token?: `hf_${string}`;
204
- status_callback?: SpaceStatusCallback;
205
- normalise_files?: boolean;
206
- } = { normalise_files: true }
207
- ): Promise<client_return> {
208
- return new Promise(async (res) => {
209
- const { status_callback, hf_token, normalise_files } = options;
210
- const return_obj = {
211
- predict,
212
- submit,
213
- view_api
214
- // duplicate
215
- };
216
-
217
- const transform_files = normalise_files ?? true;
218
- if (typeof window === "undefined" || !("WebSocket" in window)) {
219
- const ws = await import("ws");
220
- NodeBlob = (await import("node:buffer")).Blob;
221
- //@ts-ignore
222
- global.WebSocket = ws.WebSocket;
141
+ interface Client {
142
+ post_data: (
143
+ url: string,
144
+ body: unknown,
145
+ token?: `hf_${string}`
146
+ ) => Promise<[PostResponse, number]>;
147
+ upload_files: (
148
+ root: string,
149
+ files: File[],
150
+ token?: `hf_${string}`
151
+ ) => Promise<UploadResponse>;
152
+ client: (
153
+ app_reference: string,
154
+ options: {
155
+ hf_token?: `hf_${string}`;
156
+ status_callback?: SpaceStatusCallback;
157
+ normalise_files?: boolean;
223
158
  }
159
+ ) => Promise<client_return>;
160
+ handle_blob: (
161
+ endpoint: string,
162
+ data: unknown[],
163
+ api_info: ApiInfo<JsApiData>,
164
+ token?: `hf_${string}`
165
+ ) => Promise<unknown[]>;
166
+ }
224
167
 
225
- const { ws_protocol, http_protocol, host, space_id } =
226
- await process_endpoint(app_reference, hf_token);
227
-
228
- const session_hash = Math.random().toString(36).substring(2);
229
- const last_status: Record<string, Status["stage"]> = {};
230
- let config: Config;
231
- let api_map: Record<string, number> = {};
232
-
233
- let jwt: false | string = false;
234
-
235
- if (hf_token && space_id) {
236
- jwt = await get_jwt(space_id, hf_token);
168
+ export function api_factory(fetch_implementation: typeof fetch): Client {
169
+ return { post_data, upload_files, client, handle_blob };
170
+
171
+ async function post_data(
172
+ url: string,
173
+ body: unknown,
174
+ token?: `hf_${string}`
175
+ ): Promise<[PostResponse, number]> {
176
+ const headers: {
177
+ Authorization?: string;
178
+ "Content-Type": "application/json";
179
+ } = { "Content-Type": "application/json" };
180
+ if (token) {
181
+ headers.Authorization = `Bearer ${token}`;
182
+ }
183
+ try {
184
+ var response = await fetch_implementation(url, {
185
+ method: "POST",
186
+ body: JSON.stringify(body),
187
+ headers
188
+ });
189
+ } catch (e) {
190
+ return [{ error: BROKEN_CONNECTION_MSG }, 500];
237
191
  }
192
+ const output: PostResponse = await response.json();
193
+ return [output, response.status];
194
+ }
238
195
 
239
- async function config_success(_config: Config) {
240
- config = _config;
241
- api_map = map_names_to_ids(_config?.dependencies || []);
196
+ async function upload_files(
197
+ root: string,
198
+ files: (Blob | File)[],
199
+ token?: `hf_${string}`
200
+ ): Promise<UploadResponse> {
201
+ const headers: {
202
+ Authorization?: string;
203
+ } = {};
204
+ if (token) {
205
+ headers.Authorization = `Bearer ${token}`;
206
+ }
207
+ const chunkSize = 1000;
208
+ const uploadResponses = [];
209
+ for (let i = 0; i < files.length; i += chunkSize) {
210
+ const chunk = files.slice(i, i + chunkSize);
211
+ const formData = new FormData();
212
+ chunk.forEach((file) => {
213
+ formData.append("files", file);
214
+ });
242
215
  try {
243
- api = await view_api(config);
216
+ var response = await fetch_implementation(`${root}/upload`, {
217
+ method: "POST",
218
+ body: formData,
219
+ headers
220
+ });
244
221
  } catch (e) {
245
- console.error(`Could not get api details: ${e.message}`);
222
+ return { error: BROKEN_CONNECTION_MSG };
246
223
  }
224
+ const output: UploadResponse["files"] = await response.json();
225
+ uploadResponses.push(...output);
226
+ }
227
+ return { files: uploadResponses };
228
+ }
247
229
 
248
- return {
249
- config,
250
- ...return_obj
230
+ async function client(
231
+ app_reference: string,
232
+ options: {
233
+ hf_token?: `hf_${string}`;
234
+ status_callback?: SpaceStatusCallback;
235
+ normalise_files?: boolean;
236
+ } = { normalise_files: true }
237
+ ): Promise<client_return> {
238
+ return new Promise(async (res) => {
239
+ const { status_callback, hf_token, normalise_files } = options;
240
+ const return_obj = {
241
+ predict,
242
+ submit,
243
+ view_api
244
+ // duplicate
251
245
  };
252
- }
253
- let api: ApiInfo<JsApiData>;
254
- async function handle_space_sucess(status: SpaceStatus) {
255
- if (status_callback) status_callback(status);
256
- if (status.status === "running")
257
- try {
258
- config = await resolve_config(`${http_protocol}//${host}`, hf_token);
259
246
 
260
- const _config = await config_success(config);
261
- res(_config);
247
+ const transform_files = normalise_files ?? true;
248
+ if (typeof window === "undefined" || !("WebSocket" in window)) {
249
+ const ws = await import("ws");
250
+ NodeBlob = (await import("node:buffer")).Blob;
251
+ //@ts-ignore
252
+ global.WebSocket = ws.WebSocket;
253
+ }
254
+
255
+ const { ws_protocol, http_protocol, host, space_id } =
256
+ await process_endpoint(app_reference, hf_token);
257
+
258
+ const session_hash = Math.random().toString(36).substring(2);
259
+ const last_status: Record<string, Status["stage"]> = {};
260
+ let config: Config;
261
+ let api_map: Record<string, number> = {};
262
+
263
+ let jwt: false | string = false;
264
+
265
+ if (hf_token && space_id) {
266
+ jwt = await get_jwt(space_id, hf_token);
267
+ }
268
+
269
+ async function config_success(_config: Config): Promise<client_return> {
270
+ config = _config;
271
+ api_map = map_names_to_ids(_config?.dependencies || []);
272
+ try {
273
+ api = await view_api(config);
262
274
  } catch (e) {
263
- if (status_callback) {
275
+ console.error(`Could not get api details: ${e.message}`);
276
+ }
277
+
278
+ return {
279
+ config,
280
+ ...return_obj
281
+ };
282
+ }
283
+ let api: ApiInfo<JsApiData>;
284
+ async function handle_space_sucess(status: SpaceStatus): Promise<void> {
285
+ if (status_callback) status_callback(status);
286
+ if (status.status === "running")
287
+ try {
288
+ config = await resolve_config(
289
+ fetch_implementation,
290
+ `${http_protocol}//${host}`,
291
+ hf_token
292
+ );
293
+
294
+ const _config = await config_success(config);
295
+ res(_config);
296
+ } catch (e) {
297
+ console.error(e);
298
+ if (status_callback) {
299
+ status_callback({
300
+ status: "error",
301
+ message: "Could not load this space.",
302
+ load_status: "error",
303
+ detail: "NOT_FOUND"
304
+ });
305
+ }
306
+ }
307
+ }
308
+
309
+ try {
310
+ config = await resolve_config(
311
+ fetch_implementation,
312
+ `${http_protocol}//${host}`,
313
+ hf_token
314
+ );
315
+
316
+ const _config = await config_success(config);
317
+ res(_config);
318
+ } catch (e) {
319
+ console.error(e);
320
+ if (space_id) {
321
+ check_space_status(
322
+ space_id,
323
+ RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
324
+ handle_space_sucess
325
+ );
326
+ } else {
327
+ if (status_callback)
264
328
  status_callback({
265
329
  status: "error",
266
330
  message: "Could not load this space.",
267
331
  load_status: "error",
268
332
  detail: "NOT_FOUND"
269
333
  });
270
- }
271
334
  }
272
- }
273
-
274
- try {
275
- config = await resolve_config(`${http_protocol}//${host}`, hf_token);
276
-
277
- const _config = await config_success(config);
278
- res(_config);
279
- } catch (e) {
280
- if (space_id) {
281
- check_space_status(
282
- space_id,
283
- RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
284
- handle_space_sucess
285
- );
286
- } else {
287
- if (status_callback)
288
- status_callback({
289
- status: "error",
290
- message: "Could not load this space.",
291
- load_status: "error",
292
- detail: "NOT_FOUND"
293
- });
294
335
  }
295
- }
296
336
 
297
- /**
298
- * Run a prediction.
299
- * @param endpoint - The prediction endpoint to use.
300
- * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
301
- * @return Returns the data for the prediction or an error message.
302
- */
303
- function predict(endpoint: string, data: unknown[], event_data?: unknown) {
304
- let data_returned = false;
305
- let status_complete = false;
306
- return new Promise((res, rej) => {
307
- const app = submit(endpoint, data, event_data);
308
-
309
- app
310
- .on("data", (d) => {
311
- data_returned = true;
312
- if (status_complete) {
313
- app.destroy();
314
- }
315
- res(d);
316
- })
317
- .on("status", (status) => {
318
- if (status.stage === "error") rej(status);
319
- if (status.stage === "complete" && data_returned) {
320
- app.destroy();
321
- }
322
- if (status.stage === "complete") {
323
- status_complete = true;
324
- }
325
- });
326
- });
327
- }
337
+ function predict(
338
+ endpoint: string,
339
+ data: unknown[],
340
+ event_data?: unknown
341
+ ): Promise<unknown> {
342
+ let data_returned = false;
343
+ let status_complete = false;
344
+ let dependency;
345
+ if (typeof endpoint === "number") {
346
+ dependency = config.dependencies[endpoint];
347
+ } else {
348
+ const trimmed_endpoint = endpoint.replace(/^\//, "");
349
+ dependency = config.dependencies[api_map[trimmed_endpoint]];
350
+ }
328
351
 
329
- function submit(
330
- endpoint: string | number,
331
- data: unknown[],
332
- event_data?: unknown
333
- ): SubmitReturn {
334
- let fn_index: number;
335
- let api_info;
336
-
337
- if (typeof endpoint === "number") {
338
- fn_index = endpoint;
339
- api_info = api.unnamed_endpoints[fn_index];
340
- } else {
341
- const trimmed_endpoint = endpoint.replace(/^\//, "");
352
+ if (dependency.types.continuous) {
353
+ throw new Error(
354
+ "Cannot call predict on this function as it may run forever. Use submit instead"
355
+ );
356
+ }
342
357
 
343
- fn_index = api_map[trimmed_endpoint];
344
- api_info = api.named_endpoints[endpoint.trim()];
345
- }
358
+ return new Promise((res, rej) => {
359
+ const app = submit(endpoint, data, event_data);
360
+ let result;
346
361
 
347
- if (typeof fn_index !== "number") {
348
- throw new Error(
349
- "There is no endpoint matching that name of fn_index matching that number."
350
- );
362
+ app
363
+ .on("data", (d) => {
364
+ // if complete message comes before data, resolve here
365
+ if (status_complete) {
366
+ app.destroy();
367
+ res(d);
368
+ }
369
+ data_returned = true;
370
+ result = d;
371
+ })
372
+ .on("status", (status) => {
373
+ if (status.stage === "error") rej(status);
374
+ if (status.stage === "complete") {
375
+ status_complete = true;
376
+ app.destroy();
377
+ // if complete message comes after data, resolve here
378
+ if (data_returned) {
379
+ res(result);
380
+ }
381
+ }
382
+ });
383
+ });
351
384
  }
352
385
 
353
- let websocket: WebSocket;
354
-
355
- const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
356
- let payload: Payload;
357
- let complete: false | Record<string, any> = false;
358
- const listener_map: ListenerMap<EventType> = {};
359
-
360
- handle_blob(
361
- `${http_protocol}//${host + config.path}`,
362
- data,
363
- api_info,
364
- hf_token
365
- ).then((_payload) => {
366
- payload = { data: _payload || [], event_data, fn_index };
367
- if (skip_queue(fn_index, config)) {
368
- fire_event({
369
- type: "status",
370
- endpoint: _endpoint,
371
- stage: "pending",
372
- queue: false,
373
- fn_index,
374
- time: new Date()
375
- });
386
+ function submit(
387
+ endpoint: string | number,
388
+ data: unknown[],
389
+ event_data?: unknown
390
+ ): SubmitReturn {
391
+ let fn_index: number;
392
+ let api_info;
393
+
394
+ if (typeof endpoint === "number") {
395
+ fn_index = endpoint;
396
+ api_info = api.unnamed_endpoints[fn_index];
397
+ } else {
398
+ const trimmed_endpoint = endpoint.replace(/^\//, "");
376
399
 
377
- post_data(
378
- `${http_protocol}//${host + config.path}/run${
379
- _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
380
- }`,
381
- {
382
- ...payload,
383
- session_hash
384
- },
385
- hf_token
386
- )
387
- .then(([output, status_code]) => {
388
- const data = transform_files
389
- ? transform_output(
390
- output.data,
391
- api_info,
392
- config.root,
393
- config.root_url
394
- )
395
- : output.data;
396
- if (status_code == 200) {
397
- fire_event({
398
- type: "data",
399
- endpoint: _endpoint,
400
- fn_index,
401
- data: data,
402
- time: new Date()
403
- });
400
+ fn_index = api_map[trimmed_endpoint];
401
+ api_info = api.named_endpoints[endpoint.trim()];
402
+ }
404
403
 
405
- fire_event({
406
- type: "status",
407
- endpoint: _endpoint,
408
- fn_index,
409
- stage: "complete",
410
- eta: output.average_duration,
411
- queue: false,
412
- time: new Date()
413
- });
414
- } else {
404
+ if (typeof fn_index !== "number") {
405
+ throw new Error(
406
+ "There is no endpoint matching that name of fn_index matching that number."
407
+ );
408
+ }
409
+
410
+ let websocket: WebSocket;
411
+
412
+ const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
413
+ let payload: Payload;
414
+ let complete: false | Record<string, any> = false;
415
+ const listener_map: ListenerMap<EventType> = {};
416
+
417
+ handle_blob(
418
+ `${http_protocol}//${host + config.path}`,
419
+ data,
420
+ api_info,
421
+ hf_token
422
+ ).then((_payload) => {
423
+ payload = { data: _payload || [], event_data, fn_index };
424
+ if (skip_queue(fn_index, config)) {
425
+ fire_event({
426
+ type: "status",
427
+ endpoint: _endpoint,
428
+ stage: "pending",
429
+ queue: false,
430
+ fn_index,
431
+ time: new Date()
432
+ });
433
+
434
+ post_data(
435
+ `${http_protocol}//${host + config.path}/run${
436
+ _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
437
+ }`,
438
+ {
439
+ ...payload,
440
+ session_hash
441
+ },
442
+ hf_token
443
+ )
444
+ .then(([output, status_code]) => {
445
+ const data = transform_files
446
+ ? transform_output(
447
+ output.data,
448
+ api_info,
449
+ config.root,
450
+ config.root_url
451
+ )
452
+ : output.data;
453
+ if (status_code == 200) {
454
+ fire_event({
455
+ type: "data",
456
+ endpoint: _endpoint,
457
+ fn_index,
458
+ data: data,
459
+ time: new Date()
460
+ });
461
+
462
+ fire_event({
463
+ type: "status",
464
+ endpoint: _endpoint,
465
+ fn_index,
466
+ stage: "complete",
467
+ eta: output.average_duration,
468
+ queue: false,
469
+ time: new Date()
470
+ });
471
+ } else {
472
+ fire_event({
473
+ type: "status",
474
+ stage: "error",
475
+ endpoint: _endpoint,
476
+ fn_index,
477
+ message: output.error,
478
+ queue: false,
479
+ time: new Date()
480
+ });
481
+ }
482
+ })
483
+ .catch((e) => {
415
484
  fire_event({
416
485
  type: "status",
417
486
  stage: "error",
487
+ message: e.message,
418
488
  endpoint: _endpoint,
419
489
  fn_index,
420
- message: output.error,
421
490
  queue: false,
422
491
  time: new Date()
423
492
  });
424
- }
425
- })
426
- .catch((e) => {
427
- fire_event({
428
- type: "status",
429
- stage: "error",
430
- message: e.message,
431
- endpoint: _endpoint,
432
- fn_index,
433
- queue: false,
434
- time: new Date()
435
493
  });
494
+ } else {
495
+ fire_event({
496
+ type: "status",
497
+ stage: "pending",
498
+ queue: true,
499
+ endpoint: _endpoint,
500
+ fn_index,
501
+ time: new Date()
436
502
  });
437
- } else {
438
- fire_event({
439
- type: "status",
440
- stage: "pending",
441
- queue: true,
442
- endpoint: _endpoint,
443
- fn_index,
444
- time: new Date()
445
- });
446
503
 
447
- let url = new URL(`${ws_protocol}://${host}${config.path}
448
- /queue/join`);
504
+ let url = new URL(`${ws_protocol}://${host}${config.path}
505
+ /queue/join`);
449
506
 
450
- if (jwt) {
451
- url.searchParams.set("__sign", jwt);
452
- }
453
-
454
- websocket = new WebSocket(url);
455
-
456
- websocket.onclose = (evt) => {
457
- if (!evt.wasClean) {
458
- fire_event({
459
- type: "status",
460
- stage: "error",
461
- message: BROKEN_CONNECTION_MSG,
462
- queue: true,
463
- endpoint: _endpoint,
464
- fn_index,
465
- time: new Date()
466
- });
507
+ if (jwt) {
508
+ url.searchParams.set("__sign", jwt);
467
509
  }
468
- };
469
510
 
470
- websocket.onmessage = function (event) {
471
- const _data = JSON.parse(event.data);
472
- const { type, status, data } = handle_message(
473
- _data,
474
- last_status[fn_index]
475
- );
511
+ websocket = new WebSocket(url);
476
512
 
477
- if (type === "update" && status && !complete) {
478
- // call 'status' listeners
479
- fire_event({
480
- type: "status",
481
- endpoint: _endpoint,
482
- fn_index,
483
- time: new Date(),
484
- ...status
485
- });
486
- if (status.stage === "error") {
487
- websocket.close();
513
+ websocket.onclose = (evt) => {
514
+ if (!evt.wasClean) {
515
+ fire_event({
516
+ type: "status",
517
+ stage: "error",
518
+ broken: true,
519
+ message: BROKEN_CONNECTION_MSG,
520
+ queue: true,
521
+ endpoint: _endpoint,
522
+ fn_index,
523
+ time: new Date()
524
+ });
488
525
  }
489
- } else if (type === "hash") {
490
- websocket.send(JSON.stringify({ fn_index, session_hash }));
491
- return;
492
- } else if (type === "data") {
493
- websocket.send(JSON.stringify({ ...payload, session_hash }));
494
- } else if (type === "complete") {
495
- complete = status;
496
- } else if (type === "generating") {
497
- fire_event({
498
- type: "status",
499
- time: new Date(),
500
- ...status,
501
- stage: status?.stage!,
502
- queue: true,
503
- endpoint: _endpoint,
504
- fn_index
505
- });
506
- }
507
- if (data) {
508
- fire_event({
509
- type: "data",
510
- time: new Date(),
511
- data: transform_files
512
- ? transform_output(
513
- data.data,
514
- api_info,
515
- config.root,
516
- config.root_url
517
- )
518
- : data.data,
519
- endpoint: _endpoint,
520
- fn_index
521
- });
526
+ };
527
+
528
+ websocket.onmessage = function (event) {
529
+ const _data = JSON.parse(event.data);
530
+ const { type, status, data } = handle_message(
531
+ _data,
532
+ last_status[fn_index]
533
+ );
522
534
 
523
- if (complete) {
535
+ if (type === "update" && status && !complete) {
536
+ // call 'status' listeners
524
537
  fire_event({
525
538
  type: "status",
539
+ endpoint: _endpoint,
540
+ fn_index,
526
541
  time: new Date(),
527
- ...complete,
542
+ ...status
543
+ });
544
+ if (status.stage === "error") {
545
+ websocket.close();
546
+ }
547
+ } else if (type === "hash") {
548
+ websocket.send(JSON.stringify({ fn_index, session_hash }));
549
+ return;
550
+ } else if (type === "data") {
551
+ websocket.send(JSON.stringify({ ...payload, session_hash }));
552
+ } else if (type === "complete") {
553
+ complete = status;
554
+ } else if (type === "log") {
555
+ fire_event({
556
+ type: "log",
557
+ log: data.log,
558
+ level: data.level,
559
+ endpoint: _endpoint,
560
+ fn_index
561
+ });
562
+ } else if (type === "generating") {
563
+ fire_event({
564
+ type: "status",
565
+ time: new Date(),
566
+ ...status,
528
567
  stage: status?.stage!,
529
568
  queue: true,
530
569
  endpoint: _endpoint,
531
570
  fn_index
532
571
  });
533
- websocket.close();
534
572
  }
535
- }
536
- };
573
+ if (data) {
574
+ fire_event({
575
+ type: "data",
576
+ time: new Date(),
577
+ data: transform_files
578
+ ? transform_output(
579
+ data.data,
580
+ api_info,
581
+ config.root,
582
+ config.root_url
583
+ )
584
+ : data.data,
585
+ endpoint: _endpoint,
586
+ fn_index
587
+ });
537
588
 
538
- // different ws contract for gradio versions older than 3.6.0
539
- //@ts-ignore
540
- if (semiver(config.version || "2.0.0", "3.6") < 0) {
541
- addEventListener("open", () =>
542
- websocket.send(JSON.stringify({ hash: session_hash }))
543
- );
589
+ if (complete) {
590
+ fire_event({
591
+ type: "status",
592
+ time: new Date(),
593
+ ...complete,
594
+ stage: status?.stage!,
595
+ queue: true,
596
+ endpoint: _endpoint,
597
+ fn_index
598
+ });
599
+ websocket.close();
600
+ }
601
+ }
602
+ };
603
+
604
+ // different ws contract for gradio versions older than 3.6.0
605
+ //@ts-ignore
606
+ if (semiver(config.version || "2.0.0", "3.6") < 0) {
607
+ addEventListener("open", () =>
608
+ websocket.send(JSON.stringify({ hash: session_hash }))
609
+ );
610
+ }
544
611
  }
545
- }
546
- });
612
+ });
547
613
 
548
- function fire_event<K extends EventType>(event: Event<K>) {
549
- const narrowed_listener_map: ListenerMap<K> = listener_map;
550
- const listeners = narrowed_listener_map[event.type] || [];
551
- listeners?.forEach((l) => l(event));
552
- }
614
+ function fire_event<K extends EventType>(event: Event<K>): void {
615
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
616
+ const listeners = narrowed_listener_map[event.type] || [];
617
+ listeners?.forEach((l) => l(event));
618
+ }
553
619
 
554
- function on<K extends EventType>(
555
- eventType: K,
556
- listener: EventListener<K>
557
- ) {
558
- const narrowed_listener_map: ListenerMap<K> = listener_map;
559
- const listeners = narrowed_listener_map[eventType] || [];
560
- narrowed_listener_map[eventType] = listeners;
561
- listeners?.push(listener);
620
+ function on<K extends EventType>(
621
+ eventType: K,
622
+ listener: EventListener<K>
623
+ ): SubmitReturn {
624
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
625
+ const listeners = narrowed_listener_map[eventType] || [];
626
+ narrowed_listener_map[eventType] = listeners;
627
+ listeners?.push(listener);
562
628
 
563
- return { on, off, cancel, destroy };
564
- }
629
+ return { on, off, cancel, destroy };
630
+ }
565
631
 
566
- function off<K extends EventType>(
567
- eventType: K,
568
- listener: EventListener<K>
569
- ) {
570
- const narrowed_listener_map: ListenerMap<K> = listener_map;
571
- let listeners = narrowed_listener_map[eventType] || [];
572
- listeners = listeners?.filter((l) => l !== listener);
573
- narrowed_listener_map[eventType] = listeners;
632
+ function off<K extends EventType>(
633
+ eventType: K,
634
+ listener: EventListener<K>
635
+ ): SubmitReturn {
636
+ const narrowed_listener_map: ListenerMap<K> = listener_map;
637
+ let listeners = narrowed_listener_map[eventType] || [];
638
+ listeners = listeners?.filter((l) => l !== listener);
639
+ narrowed_listener_map[eventType] = listeners;
574
640
 
575
- return { on, off, cancel, destroy };
576
- }
641
+ return { on, off, cancel, destroy };
642
+ }
577
643
 
578
- async function cancel() {
579
- const _status: Status = {
580
- stage: "complete",
581
- queue: false,
582
- time: new Date()
583
- };
584
- complete = _status;
585
- fire_event({
586
- ..._status,
587
- type: "status",
588
- endpoint: _endpoint,
589
- fn_index: fn_index
590
- });
644
+ async function cancel(): Promise<void> {
645
+ const _status: Status = {
646
+ stage: "complete",
647
+ queue: false,
648
+ time: new Date()
649
+ };
650
+ complete = _status;
651
+ fire_event({
652
+ ..._status,
653
+ type: "status",
654
+ endpoint: _endpoint,
655
+ fn_index: fn_index
656
+ });
591
657
 
592
- if (websocket && websocket.readyState === 0) {
593
- websocket.addEventListener("open", () => {
658
+ if (websocket && websocket.readyState === 0) {
659
+ websocket.addEventListener("open", () => {
660
+ websocket.close();
661
+ });
662
+ } else {
594
663
  websocket.close();
595
- });
596
- } else {
597
- websocket.close();
664
+ }
665
+
666
+ try {
667
+ await fetch_implementation(
668
+ `${http_protocol}//${host + config.path}/reset`,
669
+ {
670
+ headers: { "Content-Type": "application/json" },
671
+ method: "POST",
672
+ body: JSON.stringify({ fn_index, session_hash })
673
+ }
674
+ );
675
+ } catch (e) {
676
+ console.warn(
677
+ "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
678
+ );
679
+ }
598
680
  }
599
681
 
600
- try {
601
- await fetch(`${http_protocol}//${host + config.path}/reset`, {
602
- headers: { "Content-Type": "application/json" },
603
- method: "POST",
604
- body: JSON.stringify({ fn_index, session_hash })
605
- });
606
- } catch (e) {
607
- console.warn(
608
- "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
609
- );
682
+ function destroy(): void {
683
+ for (const event_type in listener_map) {
684
+ listener_map[event_type as "data" | "status"].forEach((fn) => {
685
+ off(event_type as "data" | "status", fn);
686
+ });
687
+ }
610
688
  }
689
+
690
+ return {
691
+ on,
692
+ off,
693
+ cancel,
694
+ destroy
695
+ };
611
696
  }
612
697
 
613
- function destroy() {
614
- for (const event_type in listener_map) {
615
- listener_map[event_type as "data" | "status"].forEach((fn) => {
616
- off(event_type as "data" | "status", fn);
698
+ async function view_api(config?: Config): Promise<ApiInfo<JsApiData>> {
699
+ if (api) return api;
700
+
701
+ const headers: {
702
+ Authorization?: string;
703
+ "Content-Type": "application/json";
704
+ } = { "Content-Type": "application/json" };
705
+ if (hf_token) {
706
+ headers.Authorization = `Bearer ${hf_token}`;
707
+ }
708
+ let response: Response;
709
+ // @ts-ignore
710
+ if (semiver(config.version || "2.0.0", "3.30") < 0) {
711
+ response = await fetch_implementation(
712
+ "https://gradio-space-api-fetcher-v2.hf.space/api",
713
+ {
714
+ method: "POST",
715
+ body: JSON.stringify({
716
+ serialize: false,
717
+ config: JSON.stringify(config)
718
+ }),
719
+ headers
720
+ }
721
+ );
722
+ } else {
723
+ response = await fetch_implementation(`${config.root}/info`, {
724
+ headers
617
725
  });
618
726
  }
619
- }
620
727
 
621
- return {
622
- on,
623
- off,
624
- cancel,
625
- destroy
626
- };
627
- }
728
+ if (!response.ok) {
729
+ throw new Error(BROKEN_CONNECTION_MSG);
730
+ }
628
731
 
629
- async function view_api(config?: Config): Promise<ApiInfo<JsApiData>> {
630
- if (api) return api;
732
+ let api_info = (await response.json()) as
733
+ | ApiInfo<ApiData>
734
+ | { api: ApiInfo<ApiData> };
735
+ if ("api" in api_info) {
736
+ api_info = api_info.api;
737
+ }
631
738
 
632
- const headers: {
633
- Authorization?: string;
634
- "Content-Type": "application/json";
635
- } = { "Content-Type": "application/json" };
636
- if (hf_token) {
637
- headers.Authorization = `Bearer ${hf_token}`;
638
- }
639
- let response: Response;
640
- // @ts-ignore
641
- if (semiver(config.version || "2.0.0", "3.30") < 0) {
642
- response = await fetch(
643
- "https://gradio-space-api-fetcher-v2.hf.space/api",
644
- {
645
- method: "POST",
646
- body: JSON.stringify({
647
- serialize: false,
648
- config: JSON.stringify(config)
649
- }),
650
- headers
651
- }
652
- );
653
- } else {
654
- response = await fetch(`${config.root}/info`, {
655
- headers
656
- });
657
- }
739
+ if (
740
+ api_info.named_endpoints["/predict"] &&
741
+ !api_info.unnamed_endpoints["0"]
742
+ ) {
743
+ api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"];
744
+ }
658
745
 
659
- if (!response.ok) {
660
- throw new Error(BROKEN_CONNECTION_MSG);
746
+ const x = transform_api_info(api_info, config, api_map);
747
+ return x;
661
748
  }
749
+ });
750
+ }
662
751
 
663
- let api_info = (await response.json()) as
664
- | ApiInfo<ApiData>
665
- | { api: ApiInfo<ApiData> };
666
- if ("api" in api_info) {
667
- api_info = api_info.api;
668
- }
752
+ async function handle_blob(
753
+ endpoint: string,
754
+ data: unknown[],
755
+ api_info: ApiInfo<JsApiData>,
756
+ token?: `hf_${string}`
757
+ ): Promise<unknown[]> {
758
+ const blob_refs = await walk_and_store_blobs(
759
+ data,
760
+ undefined,
761
+ [],
762
+ true,
763
+ api_info
764
+ );
669
765
 
670
- if (
671
- api_info.named_endpoints["/predict"] &&
672
- !api_info.unnamed_endpoints["0"]
673
- ) {
674
- api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"];
675
- }
766
+ return Promise.all(
767
+ blob_refs.map(async ({ path, blob, data, type }) => {
768
+ if (blob) {
769
+ const file_url = (await upload_files(endpoint, [blob], token))
770
+ .files[0];
771
+ return { path, file_url, type };
772
+ }
773
+ return { path, base64: data, type };
774
+ })
775
+ ).then((r) => {
776
+ r.forEach(({ path, file_url, base64, type }) => {
777
+ if (base64) {
778
+ update_object(data, base64, path);
779
+ } else if (type === "Gallery") {
780
+ update_object(data, file_url, path);
781
+ } else if (file_url) {
782
+ const o = {
783
+ is_file: true,
784
+ name: `${file_url}`,
785
+ data: null
786
+ // orig_name: "file.csv"
787
+ };
788
+ update_object(data, o, path);
789
+ }
790
+ });
676
791
 
677
- const x = transform_api_info(api_info, config, api_map);
678
- return x;
679
- }
680
- });
792
+ return data;
793
+ });
794
+ }
681
795
  }
682
796
 
797
+ export const { post_data, upload_files, client, handle_blob } =
798
+ api_factory(fetch);
799
+
683
800
  function transform_output(
684
801
  data: any[],
685
802
  api_info: any,
@@ -687,27 +804,26 @@ function transform_output(
687
804
  remote_url?: string
688
805
  ): unknown[] {
689
806
  return data.map((d, i) => {
690
- if (api_info.returns?.[i]?.component === "File") {
807
+ if (api_info?.returns?.[i]?.component === "File") {
691
808
  return normalise_file(d, root_url, remote_url);
692
- } else if (api_info.returns?.[i]?.component === "Gallery") {
809
+ } else if (api_info?.returns?.[i]?.component === "Gallery") {
693
810
  return d.map((img) => {
694
811
  return Array.isArray(img)
695
812
  ? [normalise_file(img[0], root_url, remote_url), img[1]]
696
813
  : [normalise_file(img, root_url, remote_url), null];
697
814
  });
698
- } else if (typeof d === "object" && d.is_file) {
815
+ } else if (typeof d === "object" && d?.is_file) {
699
816
  return normalise_file(d, root_url, remote_url);
700
- } else {
701
- return d;
702
817
  }
818
+ return d;
703
819
  });
704
820
  }
705
821
 
706
822
  function normalise_file(
707
- file: Array<FileData>,
823
+ file: FileData[],
708
824
  root: string,
709
825
  root_url: string | null
710
- ): Array<FileData>;
826
+ ): FileData[];
711
827
  function normalise_file(
712
828
  file: FileData | string,
713
829
  root: string,
@@ -718,11 +834,7 @@ function normalise_file(
718
834
  root: string,
719
835
  root_url: string | null
720
836
  ): null;
721
- function normalise_file(
722
- file,
723
- root,
724
- root_url
725
- ): Array<FileData> | FileData | null {
837
+ function normalise_file(file, root, root_url): FileData[] | FileData | null {
726
838
  if (file == null) return null;
727
839
  if (typeof file === "string") {
728
840
  return {
@@ -730,7 +842,7 @@ function normalise_file(
730
842
  data: file
731
843
  };
732
844
  } else if (Array.isArray(file)) {
733
- const normalized_file: Array<FileData | null> = [];
845
+ const normalized_file: (FileData | null)[] = [];
734
846
 
735
847
  for (const x of file) {
736
848
  if (x === null) {
@@ -740,7 +852,7 @@ function normalise_file(
740
852
  }
741
853
  }
742
854
 
743
- return normalized_file as Array<FileData>;
855
+ return normalized_file as FileData[];
744
856
  } else if (file.is_file) {
745
857
  if (!root_url) {
746
858
  file.data = root + "/file=" + file.name;
@@ -786,7 +898,7 @@ function get_type(
786
898
  component: string,
787
899
  serializer: string,
788
900
  signature_type: "return" | "parameter"
789
- ) {
901
+ ): string {
790
902
  switch (type.type) {
791
903
  case "string":
792
904
  return "string";
@@ -810,11 +922,10 @@ function get_type(
810
922
  return signature_type === "parameter"
811
923
  ? "(Blob | File | Buffer)[]"
812
924
  : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`;
813
- } else {
814
- return signature_type === "parameter"
815
- ? "Blob | File | Buffer"
816
- : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`;
817
925
  }
926
+ return signature_type === "parameter"
927
+ ? "Blob | File | Buffer"
928
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`;
818
929
  } else if (serializer === "GallerySerializable") {
819
930
  return signature_type === "parameter"
820
931
  ? "[(Blob | File | Buffer), (string | null)][]"
@@ -825,16 +936,15 @@ function get_type(
825
936
  function get_description(
826
937
  type: { type: any; description: string },
827
938
  serializer: string
828
- ) {
939
+ ): string {
829
940
  if (serializer === "GallerySerializable") {
830
941
  return "array of [file, label] tuples";
831
942
  } else if (serializer === "ListStringSerializable") {
832
943
  return "array of strings";
833
944
  } else if (serializer === "FileSerializable") {
834
945
  return "array of files or single file";
835
- } else {
836
- return type.description;
837
946
  }
947
+ return type.description;
838
948
  }
839
949
 
840
950
  function transform_api_info(
@@ -902,51 +1012,7 @@ async function get_jwt(
902
1012
  }
903
1013
  }
904
1014
 
905
- export async function handle_blob(
906
- endpoint: string,
907
- data: unknown[],
908
- api_info,
909
- token?: `hf_${string}`
910
- ): Promise<unknown[]> {
911
- const blob_refs = await walk_and_store_blobs(
912
- data,
913
- undefined,
914
- [],
915
- true,
916
- api_info
917
- );
918
-
919
- return Promise.all(
920
- blob_refs.map(async ({ path, blob, data, type }) => {
921
- if (blob) {
922
- const file_url = (await upload_files(endpoint, [blob], token)).files[0];
923
- return { path, file_url, type };
924
- } else {
925
- return { path, base64: data, type };
926
- }
927
- })
928
- ).then((r) => {
929
- r.forEach(({ path, file_url, base64, type }) => {
930
- if (base64) {
931
- update_object(data, base64, path);
932
- } else if (type === "Gallery") {
933
- update_object(data, file_url, path);
934
- } else if (file_url) {
935
- const o = {
936
- is_file: true,
937
- name: `${file_url}`,
938
- data: null
939
- // orig_name: "file.csv"
940
- };
941
- update_object(data, o, path);
942
- }
943
- });
944
-
945
- return data;
946
- });
947
- }
948
-
949
- function update_object(object, newValue, stack) {
1015
+ function update_object(object, newValue, stack): void {
950
1016
  while (stack.length > 1) {
951
1017
  object = object[stack.shift()];
952
1018
  }
@@ -960,7 +1026,14 @@ export async function walk_and_store_blobs(
960
1026
  path = [],
961
1027
  root = false,
962
1028
  api_info = undefined
963
- ) {
1029
+ ): Promise<
1030
+ {
1031
+ path: string[];
1032
+ data: string | false;
1033
+ type: string;
1034
+ blob: Blob | false;
1035
+ }[]
1036
+ > {
964
1037
  if (Array.isArray(param)) {
965
1038
  let blob_refs = [];
966
1039
 
@@ -1007,10 +1080,9 @@ export async function walk_and_store_blobs(
1007
1080
  data = Buffer.from(buffer).toString("base64");
1008
1081
  }
1009
1082
 
1010
- return [{ path, data, type }];
1011
- } else {
1012
- return [{ path: path, blob: param, type }];
1083
+ return [{ path, data, type, blob: false }];
1013
1084
  }
1085
+ return [{ path: path, blob: param, type, data: false }];
1014
1086
  } else if (typeof param === "object") {
1015
1087
  let blob_refs = [];
1016
1088
  for (let key in param) {
@@ -1029,12 +1101,11 @@ export async function walk_and_store_blobs(
1029
1101
  }
1030
1102
  }
1031
1103
  return blob_refs;
1032
- } else {
1033
- return [];
1034
1104
  }
1105
+ return [];
1035
1106
  }
1036
1107
 
1037
- function image_to_data_uri(blob: Blob) {
1108
+ function image_to_data_uri(blob: Blob): Promise<string | ArrayBuffer> {
1038
1109
  return new Promise((resolve, _) => {
1039
1110
  const reader = new FileReader();
1040
1111
  reader.onloadend = () => resolve(reader.result);
@@ -1042,7 +1113,7 @@ function image_to_data_uri(blob: Blob) {
1042
1113
  });
1043
1114
  }
1044
1115
 
1045
- function skip_queue(id: number, config: Config) {
1116
+ function skip_queue(id: number, config: Config): boolean {
1046
1117
  return (
1047
1118
  !(config?.dependencies?.[id]?.queue === null
1048
1119
  ? config.enable_queue
@@ -1051,6 +1122,7 @@ function skip_queue(id: number, config: Config) {
1051
1122
  }
1052
1123
 
1053
1124
  async function resolve_config(
1125
+ fetch_implementation: typeof fetch,
1054
1126
  endpoint?: string,
1055
1127
  token?: `hf_${string}`
1056
1128
  ): Promise<Config> {
@@ -1068,16 +1140,17 @@ async function resolve_config(
1068
1140
  config.root = endpoint + config.root;
1069
1141
  return { ...config, path: path };
1070
1142
  } else if (endpoint) {
1071
- let response = await fetch(`${endpoint}/config`, { headers });
1143
+ let response = await fetch_implementation(`${endpoint}/config`, {
1144
+ headers
1145
+ });
1072
1146
 
1073
1147
  if (response.status === 200) {
1074
1148
  const config = await response.json();
1075
1149
  config.path = config.path ?? "";
1076
1150
  config.root = endpoint;
1077
1151
  return config;
1078
- } else {
1079
- throw new Error("Could not get config.");
1080
1152
  }
1153
+ throw new Error("Could not get config.");
1081
1154
  }
1082
1155
 
1083
1156
  throw new Error("No config or app endpoint found");
@@ -1087,7 +1160,7 @@ async function check_space_status(
1087
1160
  id: string,
1088
1161
  type: "subdomain" | "space_name",
1089
1162
  status_callback: SpaceStatusCallback
1090
- ) {
1163
+ ): Promise<void> {
1091
1164
  let endpoint =
1092
1165
  type === "subdomain"
1093
1166
  ? `https://huggingface.co/api/spaces/by-subdomain/${id}`
@@ -1180,7 +1253,7 @@ function handle_message(
1180
1253
  data: any,
1181
1254
  last_status: Status["stage"]
1182
1255
  ): {
1183
- type: "hash" | "data" | "update" | "complete" | "generating" | "none";
1256
+ type: "hash" | "data" | "update" | "complete" | "generating" | "log" | "none";
1184
1257
  data?: any;
1185
1258
  status?: Status;
1186
1259
  } {
@@ -1225,6 +1298,8 @@ function handle_message(
1225
1298
  success: data.success
1226
1299
  }
1227
1300
  };
1301
+ case "log":
1302
+ return { type: "log", data: data };
1228
1303
  case "process_generating":
1229
1304
  return {
1230
1305
  type: "generating",
@@ -1250,20 +1325,19 @@ function handle_message(
1250
1325
  success: data.success
1251
1326
  }
1252
1327
  };
1253
- } else {
1254
- return {
1255
- type: "complete",
1256
- status: {
1257
- queue,
1258
- message: !data.success ? data.output.error : undefined,
1259
- stage: data.success ? "complete" : "error",
1260
- code: data.code,
1261
- progress_data: data.progress_data,
1262
- eta: data.output.average_duration
1263
- },
1264
- data: data.success ? data.output : null
1265
- };
1266
1328
  }
1329
+ return {
1330
+ type: "complete",
1331
+ status: {
1332
+ queue,
1333
+ message: !data.success ? data.output.error : undefined,
1334
+ stage: data.success ? "complete" : "error",
1335
+ code: data.code,
1336
+ progress_data: data.progress_data,
1337
+ eta: data.output.average_duration
1338
+ },
1339
+ data: data.success ? data.output : null
1340
+ };
1267
1341
 
1268
1342
  case "process_starts":
1269
1343
  return {