@sjcrh/proteinpaint-server 2.172.0 → 2.173.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/package.json +2 -2
  2. package/routes/termdb.chat.js +361 -108
  3. package/src/app.js +377 -124
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@sjcrh/proteinpaint-server",
3
- "version": "2.172.0",
3
+ "version": "2.173.0",
4
4
  "type": "module",
5
5
  "description": "a genomics visualization tool for exploring a cohort's genotype and phenotype data",
6
6
  "main": "src/app.js",
@@ -66,7 +66,7 @@
66
66
  "@sjcrh/proteinpaint-r": "2.152.1-0",
67
67
  "@sjcrh/proteinpaint-rust": "2.171.0",
68
68
  "@sjcrh/proteinpaint-shared": "2.171.0-0",
69
- "@sjcrh/proteinpaint-types": "2.172.0",
69
+ "@sjcrh/proteinpaint-types": "2.173.0",
70
70
  "@types/express": "^5.0.0",
71
71
  "@types/express-session": "^1.18.1",
72
72
  "better-sqlite3": "^12.4.1",
@@ -1,8 +1,10 @@
1
+ import fs from "fs";
2
+ import { ezFetch } from "#shared";
1
3
  import { ChatPayload } from "#types/checkers";
2
- import { run_rust } from "@sjcrh/proteinpaint-rust";
3
4
  import serverconfig from "../src/serverconfig.js";
4
5
  import { mayLog } from "#src/helpers.ts";
5
- import { run_python } from "@sjcrh/proteinpaint-python";
6
+ import Database from "better-sqlite3";
7
+ import { formatElapsedTime } from "#shared";
6
8
  const api = {
7
9
  endpoint: "termdb/chat",
8
10
  methods: {
@@ -24,138 +26,389 @@ function init({ genomes }) {
24
26
  if (!g) throw "invalid genome";
25
27
  const ds = g.datasets?.[q.dslabel];
26
28
  if (!ds) throw "invalid dslabel";
27
- if (serverconfig.features.pythonChatBot) {
28
- const chatbot_input2 = {
29
- prompt: q.prompt,
30
- genome: q.genome,
31
- dslabel: q.dslabel
32
- //terms_tsv_path: df
33
- };
34
- try {
35
- const ai_output_data2 = await run_python("chatBot.py", JSON.stringify(chatbot_input2));
36
- res.send(ai_output_data2);
37
- } catch (error) {
38
- const errmsg = "Error running chatBot Python script:" + error;
39
- throw new Error(errmsg);
40
- }
41
- return;
42
- }
43
29
  const serverconfig_ds_entries = serverconfig.genomes.find((genome) => genome.name == q.genome).datasets.find((dslabel) => dslabel.name == ds.label);
44
30
  if (!serverconfig_ds_entries.aifiles) {
45
31
  throw "aifiles are missing for chatbot to work";
46
32
  }
47
33
  let apilink;
48
34
  let comp_model_name;
49
- let embedding_model_name;
50
35
  if (serverconfig.llm_backend == "SJ") {
51
36
  apilink = serverconfig.sj_apilink;
52
37
  comp_model_name = serverconfig.sj_comp_model_name;
53
- embedding_model_name = serverconfig.sj_embedding_model_name;
54
38
  } else if (serverconfig.llm_backend == "ollama") {
55
39
  apilink = serverconfig.ollama_apilink;
56
40
  comp_model_name = serverconfig.ollama_comp_model_name;
57
- embedding_model_name = serverconfig.ollama_embedding_model_name;
58
41
  } else {
59
42
  throw "llm_backend either needs to be 'SJ' or 'ollama'";
60
43
  }
61
- const chatbot_input = {
62
- user_input: q.prompt,
63
- apilink,
64
- tpmasterdir: serverconfig.tpmasterdir,
65
- comp_model_name,
66
- embedding_model_name,
67
- dataset_db: ds.cohort.db.file,
68
- genedb: g.genedb.dbfile,
69
- aiRoute: serverconfig.aiRoute,
70
- // Route file for classifying chat request into various routes
71
- llm_backend_name: serverconfig.llm_backend,
72
- // The type of backend (engine) used for running the embedding and completion model. Currently "SJ" and "Ollama" are supported
73
- aifiles: serverconfig_ds_entries.aifiles,
74
- // Dataset specific data containing data-specific routes, system prompts for agents and few-shot examples
75
- binpath: serverconfig.binpath
76
- };
44
+ const dataset_db = serverconfig.tpmasterdir + "/" + ds.cohort.db.file;
45
+ const genedb = serverconfig.tpmasterdir + "/" + g.genedb.dbfile;
46
+ const dataset_json = await readJSONFile(serverconfig_ds_entries.aifiles);
77
47
  const time1 = (/* @__PURE__ */ new Date()).valueOf();
78
- const classResult = JSON.parse(await run_rust("query_classification", JSON.stringify(chatbot_input)));
79
- const time2 = (/* @__PURE__ */ new Date()).valueOf();
80
- mayLog("Time taken for classification:", time2 - time1, "ms");
81
- let ai_output_data;
48
+ const class_response = await classify_query_by_dataset_type(
49
+ q.prompt,
50
+ comp_model_name,
51
+ serverconfig.llm_backend,
52
+ apilink,
53
+ serverconfig.aiRoute,
54
+ dataset_json
55
+ );
82
56
  let ai_output_json;
83
- if (classResult.route == "summary") {
84
- const time12 = (/* @__PURE__ */ new Date()).valueOf();
85
- ai_output_data = await run_rust("summary_agent", JSON.stringify(chatbot_input));
86
- const time22 = (/* @__PURE__ */ new Date()).valueOf();
87
- mayLog("Time taken for running summary agent:", time22 - time12, "ms");
88
- for (const line of ai_output_data.split("\n")) {
89
- if (line.startsWith("final_output:") == true) {
90
- ai_output_json = JSON.parse(JSON.parse(line.replace("final_output:", "")));
91
- } else {
92
- mayLog(line);
93
- }
57
+ mayLog("Time taken for classification:", formatElapsedTime(Date.now() - time1));
58
+ if (class_response.type == "html") {
59
+ ai_output_json = class_response;
60
+ } else if (class_response.type == "plot") {
61
+ const classResult = class_response.plot;
62
+ mayLog("classResult:", classResult);
63
+ if (classResult == "summary") {
64
+ const time12 = (/* @__PURE__ */ new Date()).valueOf();
65
+ ai_output_json = await extract_summary_terms(
66
+ q.prompt,
67
+ serverconfig.llm_backend,
68
+ comp_model_name,
69
+ apilink,
70
+ dataset_db,
71
+ dataset_json,
72
+ genedb,
73
+ ds
74
+ );
75
+ mayLog("Time taken for summary agent:", formatElapsedTime(Date.now() - time12));
76
+ } else if (classResult == "dge") {
77
+ ai_output_json = { type: "html", html: "DE agent not implemented yet" };
78
+ } else {
79
+ ai_output_json = { type: "html", html: "Unknown classification value" };
94
80
  }
95
- } else if (classResult.route == "dge") {
96
- ai_output_json = { type: "html", html: "DE agent not implemented yet" };
97
81
  } else {
98
- ai_output_json = { type: "html", html: "Unknown classification value" };
99
- }
100
- if (ai_output_json.type == "plot") {
101
- if (typeof ai_output_json.plot != "object") throw ".plot{} missing when .type=plot";
102
- if (ai_output_json.plot.simpleFilter) {
103
- if (!Array.isArray(ai_output_json.plot.simpleFilter)) throw "ai_output_json.plot.simpleFilter is not array";
104
- const localfilter = { type: "tvslst", in: true, join: "", lst: [] };
105
- if (ai_output_json.plot.simpleFilter.length > 1) localfilter.join = "and";
106
- for (const f of ai_output_json.plot.simpleFilter) {
107
- const term = ds.cohort.termdb.q.termjsonByOneid(f.term);
108
- if (!term) throw "invalid term id from simpleFilter[].term";
109
- if (term.type == "categorical") {
110
- let cat;
111
- for (const ck in term.values) {
112
- if (ck == f.category) cat = ck;
113
- else if (term.values[ck].label == f.category) cat = ck;
114
- }
115
- if (!cat) throw "invalid category from " + JSON.stringify(f);
116
- localfilter.lst.push({
117
- type: "tvs",
118
- tvs: {
119
- term,
120
- values: [{ key: cat }]
121
- }
122
- });
123
- } else if (term.type == "float" || term.type == "integer") {
124
- const numeric = {
125
- type: "tvs",
126
- tvs: {
127
- term,
128
- ranges: []
129
- }
130
- };
131
- const range = {};
132
- if (f.gt && !f.lt) {
133
- range.start = Number(f.gt);
134
- range.stopunbounded = true;
135
- } else if (f.lt && !f.gt) {
136
- range.stop = Number(f.lt);
137
- range.startunbounded = true;
138
- } else if (f.gt && f.lt) {
139
- range.start = Number(f.gt);
140
- range.stop = Number(f.lt);
141
- } else {
142
- throw "Neither greater or lesser defined";
143
- }
144
- numeric.tvs.ranges.push(range);
145
- localfilter.lst.push(numeric);
146
- }
147
- }
148
- delete ai_output_json.plot.simpleFilter;
149
- ai_output_json.plot.filter = localfilter;
150
- }
82
+ ai_output_json = {
83
+ type: "html",
84
+ html: "Unknown classification type"
85
+ };
151
86
  }
152
87
  res.send(ai_output_json);
153
88
  } catch (e) {
154
- if (e.stack) console.log(e.stack);
89
+ if (e.stack) mayLog(e.stack);
155
90
  res.send({ error: e?.message || e });
156
91
  }
157
92
  };
158
93
  }
94
+ async function call_ollama(prompt, model_name, apilink) {
95
+ const temperature = 0.01;
96
+ const top_p = 0.95;
97
+ const timeout = 2e5;
98
+ const payload = {
99
+ model: model_name,
100
+ messages: [{ role: "user", content: prompt }],
101
+ raw: false,
102
+ stream: false,
103
+ keep_alive: 15,
104
+ //Keep the LLM loaded for 15mins
105
+ options: {
106
+ top_p,
107
+ temperature,
108
+ num_ctx: 1e4
109
+ }
110
+ };
111
+ try {
112
+ const result = await ezFetch(apilink + "/api/chat", {
113
+ method: "POST",
114
+ body: payload,
115
+ // ezfetch automatically stringifies objects
116
+ headers: { "Content-Type": "application/json" },
117
+ timeout: { request: timeout }
118
+ // ezfetch accepts milliseconds directly
119
+ });
120
+ if (result && result.message && result.message.content && result.message.content.length > 0)
121
+ return result.message.content;
122
+ else {
123
+ throw "Error: Received an unexpected response format:" + result;
124
+ }
125
+ } catch (error) {
126
+ throw "Ollama API request failed:" + error;
127
+ }
128
+ }
129
+ async function call_sj_llm(prompt, model_name, apilink) {
130
+ const temperature = 0.01;
131
+ const top_p = 0.95;
132
+ const timeout = 2e5;
133
+ const max_new_tokens = 512;
134
+ const payload = {
135
+ inputs: [
136
+ {
137
+ model_name,
138
+ inputs: {
139
+ text: prompt,
140
+ max_new_tokens,
141
+ temperature,
142
+ top_p
143
+ }
144
+ }
145
+ ]
146
+ };
147
+ try {
148
+ const response = await ezFetch(apilink, {
149
+ method: "POST",
150
+ body: payload,
151
+ // ezfetch automatically stringifies objects
152
+ headers: { "Content-Type": "application/json" },
153
+ timeout: { request: timeout }
154
+ // ezfetch accepts milliseconds directly
155
+ });
156
+ if (response.outputs && response.outputs[0] && response.outputs[0].generated_text) {
157
+ const result = response.outputs[0].generated_text;
158
+ return result;
159
+ } else {
160
+ throw "Error: Received an unexpected response format:" + response;
161
+ }
162
+ } catch (error) {
163
+ throw "SJ API request failed:" + error;
164
+ }
165
+ }
166
+ async function readJSONFile(file) {
167
+ const json_file = await fs.promises.readFile(file);
168
+ return JSON.parse(json_file.toString());
169
+ }
170
+ async function classify_query_by_dataset_type(user_prompt, comp_model_name, llm_backend_type, apilink, aiRoute, dataset_json) {
171
+ const data = await readJSONFile(aiRoute);
172
+ let contents = data["general"];
173
+ for (const key of Object.keys(data)) {
174
+ if (key != "general") {
175
+ contents += data[key];
176
+ }
177
+ }
178
+ const classification_ds = dataset_json.charts.filter((chart) => chart.type == "Classification");
179
+ let train_iter = 0;
180
+ let training_data = "";
181
+ if (classification_ds.length > 0 && classification_ds[0].TrainingData.length > 0) {
182
+ contents += classification_ds.SystemPrompt;
183
+ for (const train_data of classification_ds[0].TrainingData) {
184
+ train_iter += 1;
185
+ training_data += "Example question" + train_iter.toString() + ": " + train_data.question + " Example answer" + train_iter.toString() + ":" + JSON.stringify(train_data.answer) + " ";
186
+ }
187
+ }
188
+ const template = contents + " training data is as follows:" + training_data + " Question: {" + user_prompt + "} Answer: {answer}";
189
+ let response;
190
+ if (llm_backend_type == "SJ") {
191
+ response = await call_sj_llm(template, comp_model_name, apilink);
192
+ } else if (llm_backend_type == "ollama") {
193
+ response = await call_ollama(template, comp_model_name, apilink);
194
+ } else {
195
+ throw "Unknown LLM backend";
196
+ }
197
+ mayLog("response:", response);
198
+ return JSON.parse(response);
199
+ }
200
+ async function extract_summary_terms(prompt, llm_backend_type, comp_model_name, apilink, dataset_db, dataset_json, genedb, ds) {
201
+ const rag_docs = await parse_dataset_db(dataset_db);
202
+ const genes_list = await parse_geneset_db(genedb);
203
+ const StringifiedSchema = '{"$schema":"http://json-schema.org/draft-07/schema#","$ref":"#/definitions/SummaryType","definitions":{"SummaryType":{"type":"object","properties":{"term":{"type":"string"},"term2":{"type":"string"},"simpleFilter":{"type":"array","items":{"$ref":"#/definitions/FilterTerm"}}},"required":["term","simpleFilter"],"additionalProperties":false},"FilterTerm":{"anyOf":[{"$ref":"#/definitions/CategoricalFilterTerm"},{"$ref":"#/definitions/NumericFilterTerm"}]},"CategoricalFilterTerm":{"type":"object","properties":{"term":{"type":"string"},"category":{"type":"string"}},"required":["term","category"],"additionalProperties":false},"NumericFilterTerm":{"type":"object","properties":{"term":{"type":"string"},"start":{"type":"number"},"stop":{"type":"number"}},"required":["term"],"additionalProperties":false}}}';
204
+ const words = prompt.replace(/[^a-zA-Z0-9\s]/g, "").split(/\s+/).map((str) => str.toLowerCase());
205
+ const common_genes = words.filter((item) => genes_list.includes(item));
206
+ const summary_ds = dataset_json.charts.filter((chart) => chart.type == "Summary");
207
+ if (summary_ds.length == 0) throw "summary information not present in dataset file";
208
+ if (summary_ds[0].TrainingData.length == 0) throw "no training data provided for summary agent";
209
+ let train_iter = 0;
210
+ let training_data = "";
211
+ for (const train_data of summary_ds[0].TrainingData) {
212
+ train_iter += 1;
213
+ training_data += "Example question" + train_iter.toString() + ": " + train_data.question + " Example answer" + train_iter.toString() + ":" + JSON.stringify(train_data.answer) + " ";
214
+ }
215
+ let system_prompt = "I am an assistant that extracts the summary terms from user query. The final output must be in the following JSON format with NO extra comments. The JSON schema is as follows: " + StringifiedSchema + ' term and term2 (if present) should ONLY contain names of the fields from the sqlite db. The "simpleFilter" field is optional and should contain an array of JSON terms with which the dataset will be filtered. A variable simultaneously CANNOT be part of both "term"/"term2" and "simpleFilter". There are two kinds of filter variables: "Categorical" and "Numeric". "Categorical" variables are those variables which can have a fixed set of values e.g. gender, race. They are defined by the "CategoricalFilterTerm" which consists of "term" (a field from the sqlite3 db) and "category" (a value of the field from the sqlite db). "Numeric" variables are those which can have any numeric value. They are defined by "NumericFilterTerm" and contain the subfields "term" (a field from the sqlite3 db), "start" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and "stop" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. ' + summary_ds.SystemPrompt + rag_docs.join(",") + " training data is as follows:" + training_data;
216
+ if (dataset_json.hasGeneExpression) {
217
+ if (common_genes.length > 0) {
218
+ system_prompt += "\n List of relevant genes are as follows (separated by comma(,)):" + common_genes.join(",");
219
+ }
220
+ }
221
+ system_prompt += " Question: {" + prompt + "} answer:";
222
+ let response;
223
+ if (llm_backend_type == "SJ") {
224
+ response = await call_sj_llm(system_prompt, comp_model_name, apilink);
225
+ } else if (llm_backend_type == "ollama") {
226
+ response = await call_ollama(system_prompt, comp_model_name, apilink);
227
+ } else {
228
+ throw "Unknown LLM backend";
229
+ }
230
+ return validate_summary_response(response, common_genes, dataset_json, ds);
231
+ }
232
+ function validate_summary_response(response, common_genes, dataset_json, ds) {
233
+ const response_type = JSON.parse(response);
234
+ const pp_plot_json = { chartType: "summary" };
235
+ let html = "";
236
+ if (response_type.html) html = response_type.html;
237
+ if (!response_type.term) html += "term type is not present in summary output";
238
+ const term1_validation = validate_term(response_type.term, common_genes, dataset_json, ds);
239
+ if (term1_validation.html.length > 0) {
240
+ html += term1_validation.html;
241
+ } else {
242
+ pp_plot_json.term = term1_validation.term_type;
243
+ }
244
+ if (response_type.term2) {
245
+ const term2_validation = validate_term(response_type.term2, common_genes, dataset_json, ds);
246
+ if (term2_validation.html.length > 0) {
247
+ html += term2_validation.html;
248
+ } else {
249
+ pp_plot_json.term2 = term2_validation.term_type;
250
+ }
251
+ }
252
+ if (response_type.simpleFilter && response_type.simpleFilter.length > 0) {
253
+ const validated_filters = validate_filter(response_type.simpleFilter, ds);
254
+ if (validated_filters.html.length > 0) {
255
+ html += validated_filters.html;
256
+ } else {
257
+ pp_plot_json.filter = validated_filters.simplefilter;
258
+ }
259
+ }
260
+ if (html.length > 0) {
261
+ return { type: "html", html };
262
+ } else {
263
+ return { type: "plot", plot: pp_plot_json };
264
+ }
265
+ }
266
+ function validate_term(response_term, common_genes, dataset_json, ds) {
267
+ let html = "";
268
+ let term_type;
269
+ const term = ds.cohort.termdb.q.termjsonByOneid(response_term);
270
+ if (!term) {
271
+ const gene_hits = common_genes.filter((gene) => gene == response_term.toLowerCase());
272
+ if (gene_hits.length == 0) {
273
+ html += "invalid term id:" + response_term;
274
+ } else {
275
+ if (dataset_json.hasGeneExpression) {
276
+ term_type = { term: { gene: response_term.toUpperCase(), type: "geneExpression" } };
277
+ } else {
278
+ html += "Dataset does not support gene expression";
279
+ }
280
+ }
281
+ } else {
282
+ term_type = { id: term.id };
283
+ }
284
+ return { term_type, html };
285
+ }
286
+ function validate_filter(filters, ds) {
287
+ if (!Array.isArray(filters)) throw "filter is not array";
288
+ let invalid_html = "";
289
+ const localfilter = { type: "tvslst", in: true, join: "", lst: [] };
290
+ if (filters.length > 1) localfilter.join = "and";
291
+ for (const f of filters) {
292
+ const term = ds.cohort.termdb.q.termjsonByOneid(f.term);
293
+ if (!term) {
294
+ invalid_html += "invalid filter id:" + f.term;
295
+ } else {
296
+ if (term.type == "categorical") {
297
+ let cat;
298
+ for (const ck in term.values) {
299
+ if (ck == f.category) cat = ck;
300
+ else if (term.values[ck].label == f.category) cat = ck;
301
+ }
302
+ if (!cat) invalid_html += "invalid category from " + JSON.stringify(f);
303
+ localfilter.lst.push({
304
+ type: "tvs",
305
+ tvs: {
306
+ term,
307
+ values: [{ key: cat }]
308
+ }
309
+ });
310
+ } else if (term.type == "float" || term.type == "integer") {
311
+ const numeric = {
312
+ type: "tvs",
313
+ tvs: {
314
+ term,
315
+ ranges: []
316
+ }
317
+ };
318
+ const range = {};
319
+ if (f.start && !f.stop) {
320
+ range.start = Number(f.start);
321
+ range.stopunbounded = true;
322
+ } else if (f.stop && !f.start) {
323
+ range.stop = Number(f.stop);
324
+ range.startunbounded = true;
325
+ } else if (f.start && f.stop) {
326
+ range.start = Number(f.start);
327
+ range.stop = Number(f.stop);
328
+ } else {
329
+ invalid_html += "Neither greater or lesser defined";
330
+ }
331
+ numeric.tvs.ranges.push(range);
332
+ localfilter.lst.push(numeric);
333
+ }
334
+ }
335
+ }
336
+ return { simplefilter: localfilter, html: invalid_html };
337
+ }
338
+ async function parse_geneset_db(genedb) {
339
+ let genes_list = [];
340
+ const db = new Database(genedb);
341
+ try {
342
+ const desc_rows = db.prepare("SELECT name from codingGenes").all();
343
+ desc_rows.forEach((row) => {
344
+ genes_list.push(row.name);
345
+ });
346
+ genes_list = genes_list.map((str) => str.toLowerCase());
347
+ } catch (error) {
348
+ throw "Could not parse geneDB" + error;
349
+ } finally {
350
+ db.close();
351
+ }
352
+ return genes_list;
353
+ }
354
+ async function parse_dataset_db(dataset_db) {
355
+ const db = new Database(dataset_db);
356
+ const rag_docs = [];
357
+ try {
358
+ const desc_rows = db.prepare("SELECT * from termhtmldef").all();
359
+ const description_map = [];
360
+ desc_rows.forEach((row) => {
361
+ const name = row.id;
362
+ const jsonhtml = JSON.parse(row.jsonhtml);
363
+ const description = jsonhtml.description[0].value;
364
+ description_map.push({ name, description });
365
+ });
366
+ const term_db_rows = db.prepare("SELECT * from terms").all();
367
+ const db_rows = [];
368
+ term_db_rows.forEach((row) => {
369
+ const found = description_map.find((item) => item.name === row.id);
370
+ if (found) {
371
+ const jsondata = JSON.parse(row.jsondata);
372
+ const description = description_map.filter((item) => item.name === row.id);
373
+ const term_type = row.type;
374
+ const values = [];
375
+ if (jsondata.values && Object.keys(jsondata.values).length > 0) {
376
+ for (const key of Object.keys(jsondata.values)) {
377
+ const value = jsondata.values[key];
378
+ const db_val = { key, value };
379
+ values.push(db_val);
380
+ }
381
+ }
382
+ const db_row = {
383
+ name: row.id,
384
+ description: description[0].description,
385
+ values,
386
+ term_type
387
+ };
388
+ const stringified_db = parse_db_rows(db_row);
389
+ rag_docs.push(stringified_db);
390
+ db_rows.push(db_row);
391
+ }
392
+ });
393
+ } catch (error) {
394
+ throw "Error in parsing dataset DB:" + error;
395
+ } finally {
396
+ db.close();
397
+ }
398
+ return rag_docs;
399
+ }
400
+ function parse_db_rows(db_row) {
401
+ let output_string = "Name of the field is:" + db_row.name + ". This field is of the type:" + db_row.term_type + ". Description: " + db_row.description;
402
+ if (db_row.values.length > 0) {
403
+ output_string += "This field contains the following possible values.";
404
+ for (const value of db_row.values) {
405
+ if (value.value && value.value.label) {
406
+ output_string += "The key is " + value.key + " and the label is " + value.value.label + ".";
407
+ }
408
+ }
409
+ }
410
+ return output_string;
411
+ }
159
412
  export {
160
413
  api
161
414
  };