@sjcrh/proteinpaint-server 2.174.1 → 2.175.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,81 @@ import { mayLog } from "#src/helpers.ts";
7
7
  import Database from "better-sqlite3";
8
8
  import { formatElapsedTime } from "#shared";
9
9
  const num_filter_cutoff = 3;
10
+ const FILTER_TERM_DEFINITIONS = {
11
+ FilterTerm: {
12
+ anyOf: [{ $ref: "#/definitions/CategoricalFilterTerm" }, { $ref: "#/definitions/NumericFilterTerm" }]
13
+ },
14
+ CategoricalFilterTerm: {
15
+ type: "object",
16
+ properties: {
17
+ term: { type: "string", description: "Name of categorical term" },
18
+ category: { type: "string", description: "The category of the term" },
19
+ join: {
20
+ type: "string",
21
+ enum: ["and", "or"],
22
+ description: "join term to be used only when there is more than one filter term and should be placed from the 2nd filter term onwards describing how it connects to the previous term"
23
+ }
24
+ },
25
+ required: ["term", "category"],
26
+ additionalProperties: false
27
+ },
28
+ NumericFilterTerm: {
29
+ type: "object",
30
+ properties: {
31
+ term: { type: "string", description: "Name of numeric term" },
32
+ start: { type: "number", description: "start position (or lower limit) of numeric term" },
33
+ stop: { type: "number", description: "stop position (or upper limit) of numeric term" },
34
+ join: {
35
+ type: "string",
36
+ enum: ["and", "or"],
37
+ description: "join term to be used only when there is more than one filter term and should be placed from the 2nd filter term onwards describing how it connects to the previous term"
38
+ }
39
+ },
40
+ required: ["term"],
41
+ additionalProperties: false
42
+ }
43
+ };
44
+ function formatTrainingExamples(trainingData) {
45
+ return trainingData.map(
46
+ (td, i) => "Example question" + (i + 1).toString() + ": " + td.question + " Example answer" + (i + 1).toString() + ":" + JSON.stringify(td.answer)
47
+ ).join(" ");
48
+ }
49
+ const FILTER_DESCRIPTION = 'There are two kinds of filter variables: "Categorical" and "Numeric". "Categorical" variables are those variables which can have a fixed set of values e.g. gender, race. They are defined by the "CategoricalFilterTerm" which consists of "term" (a field from the sqlite3 db) and "category" (a value of the field from the sqlite db). "Numeric" variables are those which can have any numeric value. They are defined by "NumericFilterTerm" and contain the subfields "term" (a field from the sqlite3 db), "start" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and "stop" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. ';
50
+ function extractGenesFromPrompt(prompt, genes_list) {
51
+ const words = prompt.replace(/[^a-zA-Z0-9\s]/g, "").split(/\s+/).map((str) => str.toLowerCase());
52
+ return words.filter((item) => genes_list.includes(item));
53
+ }
54
+ const CHILD_TYPE_DEFAULTS = {
55
+ "categorical:undefined": "barchart",
56
+ "numeric:undefined": "violin",
57
+ "categorical:categorical": "barchart",
58
+ "numeric:categorical": "violin",
59
+ "categorical:numeric": "violin",
60
+ "numeric:numeric": "sampleScatter"
61
+ };
62
+ const CHILD_TYPE_INVALID = {
63
+ "categorical:undefined": /* @__PURE__ */ new Set(["violin", "boxplot", "sampleScatter"]),
64
+ "categorical:categorical": /* @__PURE__ */ new Set(["violin", "boxplot", "sampleScatter"])
65
+ };
66
+ function resolveChildType(cat1, cat2, llmChildType) {
67
+ const norm1 = cat1 == "float" || cat1 == "integer" ? "numeric" : cat1 || "undefined";
68
+ const norm2 = cat2 == "float" || cat2 == "integer" ? "numeric" : cat2 || "undefined";
69
+ const key = norm1 + ":" + norm2;
70
+ const defaultType = CHILD_TYPE_DEFAULTS[key];
71
+ if (!defaultType) {
72
+ return { childType: "barchart" };
73
+ }
74
+ const invalid = CHILD_TYPE_INVALID[key];
75
+ if (llmChildType && invalid && invalid.has(llmChildType)) {
76
+ return {
77
+ error: "Invalid plot type supplied by the user: " + llmChildType + ". For " + key.replace(":", " and ") + " variables the plot type should always be " + defaultType
78
+ };
79
+ }
80
+ return {
81
+ childType: llmChildType || defaultType,
82
+ bothNumeric: norm1 == "numeric" && norm2 == "numeric"
83
+ };
84
+ }
10
85
  const api = {
11
86
  endpoint: "termdb/chat",
12
87
  methods: {
@@ -32,72 +107,25 @@ function init({ genomes }) {
32
107
  if (!serverconfig_ds_entries.aifiles) {
33
108
  throw "aifiles are missing for chatbot to work";
34
109
  }
35
- let apilink;
36
- let comp_model_name;
37
- if (serverconfig.llm_backend == "SJ") {
38
- apilink = serverconfig.sj_apilink;
39
- comp_model_name = serverconfig.sj_comp_model_name;
40
- } else if (serverconfig.llm_backend == "ollama") {
41
- apilink = serverconfig.ollama_apilink;
42
- comp_model_name = serverconfig.ollama_comp_model_name;
43
- } else {
44
- throw "llm_backend either needs to be 'SJ' or 'ollama'";
110
+ const llm = serverconfig.llm;
111
+ if (!llm) throw "serverconfig.llm is not configured";
112
+ if (llm.provider !== "SJ" && llm.provider !== "ollama") {
113
+ throw "llm.provider must be 'SJ' or 'ollama'";
45
114
  }
46
115
  const dataset_db = serverconfig.tpmasterdir + "/" + ds.cohort.db.file;
47
116
  const genedb = serverconfig.tpmasterdir + "/" + g.genedb.dbfile;
48
117
  const dataset_json = await readJSONFile(serverconfig_ds_entries.aifiles);
49
- const time1 = (/* @__PURE__ */ new Date()).valueOf();
50
- const class_response = await classify_query_by_dataset_type(
118
+ const testing = false;
119
+ const ai_output_json = await run_chat_pipeline(
51
120
  q.prompt,
52
- comp_model_name,
53
- serverconfig.llm_backend,
54
- apilink,
121
+ llm,
55
122
  serverconfig.aiRoute,
56
- dataset_json
123
+ dataset_json,
124
+ testing,
125
+ dataset_db,
126
+ genedb,
127
+ ds
57
128
  );
58
- let ai_output_json;
59
- mayLog("Time taken for classification:", formatElapsedTime(Date.now() - time1));
60
- if (class_response.type == "html") {
61
- ai_output_json = class_response;
62
- } else if (class_response.type == "plot") {
63
- const classResult = class_response.plot;
64
- mayLog("classResult:", classResult);
65
- if (classResult == "summary") {
66
- const time12 = (/* @__PURE__ */ new Date()).valueOf();
67
- ai_output_json = await extract_summary_terms(
68
- q.prompt,
69
- serverconfig.llm_backend,
70
- comp_model_name,
71
- apilink,
72
- dataset_db,
73
- dataset_json,
74
- genedb,
75
- ds
76
- );
77
- mayLog("Time taken for summary agent:", formatElapsedTime(Date.now() - time12));
78
- } else if (classResult == "dge") {
79
- const time12 = (/* @__PURE__ */ new Date()).valueOf();
80
- ai_output_json = await extract_DE_search_terms_from_query(
81
- q.prompt,
82
- serverconfig.llm_backend,
83
- comp_model_name,
84
- apilink,
85
- dataset_db,
86
- dataset_json,
87
- ds
88
- );
89
- mayLog("Time taken for DE agent:", formatElapsedTime(Date.now() - time12));
90
- } else if (classResult == "survival") {
91
- ai_output_json = { type: "html", html: "survival agent has not been implemented yet" };
92
- } else {
93
- ai_output_json = { type: "html", html: "Unknown classification value" };
94
- }
95
- } else {
96
- ai_output_json = {
97
- type: "html",
98
- html: "Unknown classification type"
99
- };
100
- }
101
129
  res.send(ai_output_json);
102
130
  } catch (e) {
103
131
  if (e.stack) mayLog(e.stack);
@@ -105,6 +133,72 @@ function init({ genomes }) {
105
133
  }
106
134
  };
107
135
  }
136
+ async function run_chat_pipeline(user_prompt, llm, aiRoute, dataset_json, testing, dataset_db, genedb, ds) {
137
+ const time1 = (/* @__PURE__ */ new Date()).valueOf();
138
+ const class_response = await classify_query_by_dataset_type(
139
+ user_prompt,
140
+ llm,
141
+ aiRoute,
142
+ dataset_json,
143
+ testing
144
+ );
145
+ let ai_output_json;
146
+ mayLog("Time taken for classification:", formatElapsedTime(Date.now() - time1));
147
+ if (class_response.type == "html") {
148
+ ai_output_json = class_response;
149
+ } else if (class_response.type == "plot") {
150
+ const classResult = class_response.plot;
151
+ mayLog("classResult:", classResult);
152
+ const dataset_db_output = await parse_dataset_db(dataset_db);
153
+ const genes_list = dataset_json.hasGeneExpression ? await parse_geneset_db(genedb) : [];
154
+ if (classResult == "summary") {
155
+ const time12 = (/* @__PURE__ */ new Date()).valueOf();
156
+ ai_output_json = await extract_summary_terms(
157
+ user_prompt,
158
+ llm,
159
+ dataset_db_output,
160
+ dataset_json,
161
+ genes_list,
162
+ ds,
163
+ testing
164
+ );
165
+ mayLog("Time taken for summary agent:", formatElapsedTime(Date.now() - time12));
166
+ } else if (classResult == "dge") {
167
+ const time12 = (/* @__PURE__ */ new Date()).valueOf();
168
+ ai_output_json = await extract_DE_search_terms_from_query(
169
+ user_prompt,
170
+ llm,
171
+ dataset_db_output,
172
+ dataset_json,
173
+ ds,
174
+ testing
175
+ );
176
+ mayLog("Time taken for DE agent:", formatElapsedTime(Date.now() - time12));
177
+ } else if (classResult == "survival") {
178
+ ai_output_json = { type: "html", html: "survival agent has not been implemented yet" };
179
+ } else if (classResult == "matrix") {
180
+ const time12 = (/* @__PURE__ */ new Date()).valueOf();
181
+ ai_output_json = await extract_matrix_search_terms_from_query(
182
+ user_prompt,
183
+ llm,
184
+ dataset_db_output,
185
+ dataset_json,
186
+ genes_list,
187
+ ds,
188
+ testing
189
+ );
190
+ mayLog("Time taken for matrix agent:", formatElapsedTime(Date.now() - time12));
191
+ } else {
192
+ ai_output_json = { type: "html", html: "Unknown classification value" };
193
+ }
194
+ } else {
195
+ ai_output_json = {
196
+ type: "html",
197
+ html: "Unknown classification type"
198
+ };
199
+ }
200
+ return ai_output_json;
201
+ }
108
202
  async function call_ollama(prompt, model_name, apilink) {
109
203
  const temperature = 0.01;
110
204
  const top_p = 0.95;
@@ -177,6 +271,17 @@ async function call_sj_llm(prompt, model_name, apilink) {
177
271
  throw "SJ API request failed:" + error;
178
272
  }
179
273
  }
274
+ async function route_to_appropriate_llm_provider(template, llm) {
275
+ let response;
276
+ if (llm.provider == "SJ") {
277
+ response = await call_sj_llm(template, llm.modelName, llm.api);
278
+ } else if (llm.provider == "ollama") {
279
+ response = await call_ollama(template, llm.modelName, llm.api);
280
+ } else {
281
+ throw "Unknown LLM provider";
282
+ }
283
+ return response;
284
+ }
180
285
  function checkField(sentence) {
181
286
  if (!sentence) return "";
182
287
  else return sentence;
@@ -185,7 +290,7 @@ async function readJSONFile(file) {
185
290
  const json_file = await fs.promises.readFile(file);
186
291
  return JSON.parse(json_file.toString());
187
292
  }
188
- async function classify_query_by_dataset_type(user_prompt, comp_model_name, llm_backend_type, apilink, aiRoute, dataset_json) {
293
+ async function classify_query_by_dataset_type(user_prompt, llm, aiRoute, dataset_json, testing) {
189
294
  const data = await readJSONFile(aiRoute);
190
295
  let contents = data["general"];
191
296
  for (const key of Object.keys(data)) {
@@ -193,32 +298,24 @@ async function classify_query_by_dataset_type(user_prompt, comp_model_name, llm_
193
298
  contents += data[key];
194
299
  }
195
300
  }
196
- const classification_ds = dataset_json.charts.filter((chart) => chart.type == "Classification");
197
- if (classification_ds.length == 0) throw "Classification information is not present in the dataset file.";
198
- if (classification_ds[0].TrainingData.length == 0) throw "No training data is provided for the classification agent.";
199
- let train_iter = 0;
301
+ const classification_ds = dataset_json.charts.find((chart) => chart.type == "Classification");
302
+ if (!classification_ds) throw "Classification information is not present in the dataset file.";
303
+ if (classification_ds.TrainingData.length == 0) throw "No training data is provided for the classification agent.";
200
304
  let training_data = "";
201
- if (classification_ds.length > 0 && classification_ds[0].TrainingData.length > 0) {
202
- contents += checkField(dataset_json.DatasetPrompt) + checkField(classification_ds[0].SystemPrompt);
203
- for (const train_data of classification_ds[0].TrainingData) {
204
- train_iter += 1;
205
- training_data += "Example question" + train_iter.toString() + ": " + train_data.question + " Example answer" + train_iter.toString() + ":" + JSON.stringify(train_data.answer) + " ";
206
- }
305
+ if (classification_ds && classification_ds.TrainingData.length > 0) {
306
+ contents += checkField(dataset_json.DatasetPrompt) + checkField(classification_ds.SystemPrompt);
307
+ training_data = formatTrainingExamples(classification_ds.TrainingData);
207
308
  }
208
309
  const template = contents + " training data is as follows:" + training_data + " Question: {" + user_prompt + "} Answer: {answer}";
209
- let response;
210
- if (llm_backend_type == "SJ") {
211
- response = await call_sj_llm(template, comp_model_name, apilink);
212
- } else if (llm_backend_type == "ollama") {
213
- response = await call_ollama(template, comp_model_name, apilink);
310
+ const response = await route_to_appropriate_llm_provider(template, llm);
311
+ if (testing) {
312
+ return { action: "html", response: JSON.parse(response) };
214
313
  } else {
215
- throw "Unknown LLM backend";
314
+ return JSON.parse(response);
216
315
  }
217
- return JSON.parse(response);
218
316
  }
219
- async function extract_DE_search_terms_from_query(prompt, llm_backend_type, comp_model_name, apilink, dataset_db, dataset_json, ds) {
317
+ async function extract_DE_search_terms_from_query(prompt, llm, dataset_db_output, dataset_json, ds, testing) {
220
318
  if (dataset_json.hasDE) {
221
- const dataset_db_output = await parse_dataset_db(dataset_db);
222
319
  const Schema = {
223
320
  $schema: "http://json-schema.org/draft-07/schema#",
224
321
  $ref: "#/definitions/DEType",
@@ -245,60 +342,20 @@ async function extract_DE_search_terms_from_query(prompt, llm_backend_type, comp
245
342
  required: ["group1", "group2"],
246
343
  additionalProperties: false
247
344
  },
248
- FilterTerm: {
249
- anyOf: [{ $ref: "#/definitions/CategoricalFilterTerm" }, { $ref: "#/definitions/NumericFilterTerm" }]
250
- },
251
- CategoricalFilterTerm: {
252
- type: "object",
253
- properties: {
254
- term: { type: "string", description: "Name of categorical term" },
255
- category: { type: "string", description: "The category of the term" },
256
- join: {
257
- type: "string",
258
- enum: ["and", "or"],
259
- description: "join term to be used only when there is more than one filter term and should be placed from the 2nd filter term onwards describing how it connects to the previous term"
260
- }
261
- },
262
- required: ["term", "category"],
263
- additionalProperties: false
264
- },
265
- NumericFilterTerm: {
266
- type: "object",
267
- properties: {
268
- term: { type: "string", description: "Name of numeric term" },
269
- start: { type: "number", description: "start position (or lower limit) of numeric term" },
270
- stop: { type: "number", description: "stop position (or upper limit) of numeric term" },
271
- join: {
272
- type: "string",
273
- enum: ["and", "or"],
274
- description: "join term to be used only when there is more than one filter term and should be placed from the 2nd filter term onwards describing how it connects to the previous term"
275
- }
276
- },
277
- required: ["term"],
278
- additionalProperties: false
279
- }
345
+ ...FILTER_TERM_DEFINITIONS
280
346
  }
281
347
  };
282
- const DE_ds = dataset_json.charts.filter((chart) => chart.type == "DE");
283
- if (DE_ds.length == 0) throw "DE information is not present in the dataset file.";
284
- if (DE_ds[0].TrainingData.length == 0) throw "No training data is provided for the DE agent.";
285
- let train_iter = 0;
286
- let training_data = "";
287
- for (const train_data of DE_ds[0].TrainingData) {
288
- train_iter += 1;
289
- training_data += "Example question" + train_iter.toString() + ": " + train_data.question + " Example answer" + train_iter.toString() + ":" + JSON.stringify(train_data.answer) + " ";
290
- }
291
- const system_prompt = "I am an assistant that extracts the groups from the user prompt to carry out differential gene expression. The final output must be in the following JSON with NO extra comments. The schema is as follows: " + JSON.stringify(Schema) + ' . "group1" and "group2" fields are compulsory. Both "group1" and "group2" consist of an array of filter variables. There are two kinds of filter variables: "Categorical" and "Numeric". "Categorical" variables are those variables which can have a fixed set of values e.g. gender, race. They are defined by the "CategoricalFilterTerm" which consists of "term" (a field from the sqlite3 db) and "category" (a value of the field from the sqlite db). "Numeric" variables are those which can have any numeric value. They are defined by "NumericFilterTerm" and contain the subfields "term" (a field from the sqlite3 db), "start" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and "stop" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. ' + // May consider deprecating this natural language description after units tests are implemented
292
- checkField(dataset_json.DatasetPrompt) + checkField(DE_ds[0].SystemPrompt) + "The sqlite db in plain language is as follows:\n" + dataset_db_output.rag_docs.join(",") + " training data is as follows:" + training_data + " Question: {" + prompt + "} answer:";
293
- let response;
294
- if (llm_backend_type == "SJ") {
295
- response = await call_sj_llm(system_prompt, comp_model_name, apilink);
296
- } else if (llm_backend_type == "ollama") {
297
- response = await call_ollama(system_prompt, comp_model_name, apilink);
348
+ const DE_ds = dataset_json.charts.find((chart) => chart.type == "DE");
349
+ if (!DE_ds) throw "DE information is not present in the dataset file.";
350
+ if (DE_ds.TrainingData.length == 0) throw "No training data is provided for the DE agent.";
351
+ const training_data = formatTrainingExamples(DE_ds.TrainingData);
352
+ const system_prompt = "I am an assistant that extracts the groups from the user prompt to carry out differential gene expression. The final output must be in the following JSON with NO extra comments. The schema is as follows: " + JSON.stringify(Schema) + ' . "group1" and "group2" fields are compulsory. Both "group1" and "group2" consist of an array of filter variables. ' + FILTER_DESCRIPTION + checkField(dataset_json.DatasetPrompt) + checkField(DE_ds.SystemPrompt) + "The sqlite db in plain language is as follows:\n" + dataset_db_output.rag_docs.join(",") + " training data is as follows:" + training_data + " Question: {" + prompt + "} answer:";
353
+ const response = await route_to_appropriate_llm_provider(system_prompt, llm);
354
+ if (testing) {
355
+ return { action: "dge", response: JSON.parse(response) };
298
356
  } else {
299
- throw "Unknown LLM backend";
357
+ return await validate_DE_response(response, ds, dataset_db_output.db_rows);
300
358
  }
301
- return await validate_DE_response(response, ds, dataset_db_output.db_rows);
302
359
  } else {
303
360
  return { type: "html", html: "Differential gene expression not supported for this dataset" };
304
361
  }
@@ -447,9 +504,7 @@ function find_label(filter, db_rows) {
447
504
  }
448
505
  return label;
449
506
  }
450
- async function extract_summary_terms(prompt, llm_backend_type, comp_model_name, apilink, dataset_db, dataset_json, genedb, ds) {
451
- const dataset_db_output = await parse_dataset_db(dataset_db);
452
- const genes_list = await parse_geneset_db(genedb);
507
+ async function extract_summary_terms(prompt, llm, dataset_db_output, dataset_json, genes_list, ds, testing) {
453
508
  const Schema = {
454
509
  $schema: "http://json-schema.org/draft-07/schema#",
455
510
  $ref: "#/definitions/SummaryType",
@@ -463,73 +518,37 @@ async function extract_summary_terms(prompt, llm_backend_type, comp_model_name,
463
518
  type: "array",
464
519
  items: { $ref: "#/definitions/FilterTerm" },
465
520
  description: "Optional simple filter terms"
466
- }
467
- },
468
- required: ["term", "simpleFilter"],
469
- additionalProperties: false
470
- },
471
- FilterTerm: {
472
- anyOf: [{ $ref: "#/definitions/CategoricalFilterTerm" }, { $ref: "#/definitions/NumericFilterTerm" }]
473
- },
474
- CategoricalFilterTerm: {
475
- type: "object",
476
- properties: {
477
- term: { type: "string", description: "Name of categorical term" },
478
- category: { type: "string", description: "The category of the term" },
479
- join: {
521
+ },
522
+ childType: {
480
523
  type: "string",
481
- enum: ["and", "or"],
482
- description: "join term to be used only when there there is more than one filter term and should be placed in the 2nd filter term describing how it connects to the 1st term"
524
+ enum: ["violin", "boxplot", "sampleScatter", "barchart"],
525
+ description: "Optional explicit child type requested by the user. If omitted, the logic of the data types picks the child type."
483
526
  }
484
527
  },
485
- required: ["term", "category"],
528
+ required: ["term", "simpleFilter"],
486
529
  additionalProperties: false
487
530
  },
488
- NumericFilterTerm: {
489
- type: "object",
490
- properties: {
491
- term: { type: "string", description: "Name of numeric term" },
492
- start: { type: "number", description: "start position (or lower limit) of numeric term" },
493
- stop: { type: "number", description: "stop position (or upper limit) of numeric term" },
494
- join: {
495
- type: "string",
496
- enum: ["and", "or"],
497
- description: "join term to be used only when there there is more than one filter term and should be placed in the 2nd filter term describing how it connects to the 1st term"
498
- }
499
- },
500
- required: ["term"],
501
- additionalProperties: false
502
- }
531
+ ...FILTER_TERM_DEFINITIONS
503
532
  }
504
533
  };
505
- const words = prompt.replace(/[^a-zA-Z0-9\s]/g, "").split(/\s+/).map((str) => str.toLowerCase());
506
- const common_genes = words.filter((item) => genes_list.includes(item));
507
- const summary_ds = dataset_json.charts.filter((chart) => chart.type == "Summary");
508
- if (summary_ds.length == 0) throw "Summary information is not present in the dataset file.";
509
- if (summary_ds[0].TrainingData.length == 0) throw "No training data is provided for the summary agent.";
510
- let train_iter = 0;
511
- let training_data = "";
512
- for (const train_data of summary_ds[0].TrainingData) {
513
- train_iter += 1;
514
- training_data += "Example question" + train_iter.toString() + ": " + train_data.question + " Example answer" + train_iter.toString() + ":" + JSON.stringify(train_data.answer) + " ";
515
- }
516
- let system_prompt = "I am an assistant that extracts the summary terms from user query. The final output must be in the following JSON format with NO extra comments. The JSON schema is as follows: " + JSON.stringify(Schema) + ' term and term2 (if present) should ONLY contain names of the fields from the sqlite db. The "simpleFilter" field is optional and should contain an array of JSON terms with which the dataset will be filtered. A variable simultaneously CANNOT be part of both "term"/"term2" and "simpleFilter". There are two kinds of filter variables: "Categorical" and "Numeric". "Categorical" variables are those variables which can have a fixed set of values e.g. gender, race. They are defined by the "CategoricalFilterTerm" which consists of "term" (a field from the sqlite3 db) and "category" (a value of the field from the sqlite db). "Numeric" variables are those which can have any numeric value. They are defined by "NumericFilterTerm" and contain the subfields "term" (a field from the sqlite3 db), "start" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and "stop" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. ' + // May consider deprecating this natural language description after unit tests are implemented
517
- checkField(dataset_json.DatasetPrompt) + checkField(summary_ds[0].SystemPrompt) + "\n The DB content is as follows: " + dataset_db_output.rag_docs.join(",") + " training data is as follows:" + training_data;
534
+ const common_genes = extractGenesFromPrompt(prompt, genes_list);
535
+ const summary_ds = dataset_json.charts.find((chart) => chart.type == "Summary");
536
+ if (!summary_ds) throw "Summary information is not present in the dataset file.";
537
+ if (summary_ds.TrainingData.length == 0) throw "No training data is provided for the summary agent.";
538
+ const training_data = formatTrainingExamples(summary_ds.TrainingData);
539
+ let system_prompt = "I am an assistant that extracts the summary terms from user query. The final output must be in the following JSON format with NO extra comments. The JSON schema is as follows: " + JSON.stringify(Schema) + ' term and term2 (if present) should ONLY contain names of the fields from the sqlite db. The "simpleFilter" field is optional and should contain an array of JSON terms with which the dataset will be filtered. ' + FILTER_DESCRIPTION + checkField(dataset_json.DatasetPrompt) + checkField(summary_ds.SystemPrompt) + "\n The DB content is as follows: " + dataset_db_output.rag_docs.join(",") + " training data is as follows:" + training_data;
518
540
  if (dataset_json.hasGeneExpression) {
519
541
  if (common_genes.length > 0) {
520
542
  system_prompt += "\n List of relevant genes are as follows (separated by comma(,)):" + common_genes.join(",");
521
543
  }
522
544
  }
523
545
  system_prompt += " Question: {" + prompt + "} answer:";
524
- let response;
525
- if (llm_backend_type == "SJ") {
526
- response = await call_sj_llm(system_prompt, comp_model_name, apilink);
527
- } else if (llm_backend_type == "ollama") {
528
- response = await call_ollama(system_prompt, comp_model_name, apilink);
546
+ const response = await route_to_appropriate_llm_provider(system_prompt, llm);
547
+ if (testing) {
548
+ return { action: "summary", response: JSON.parse(response) };
529
549
  } else {
530
- throw "Unknown LLM backend";
550
+ return validate_summary_response(response, common_genes, dataset_json, ds);
531
551
  }
532
- return validate_summary_response(response, common_genes, dataset_json, ds);
533
552
  }
534
553
  function validate_summary_response(response, common_genes, dataset_json, ds) {
535
554
  const response_type = JSON.parse(response);
@@ -542,6 +561,10 @@ function validate_summary_response(response, common_genes, dataset_json, ds) {
542
561
  html += term1_validation.html;
543
562
  } else {
544
563
  pp_plot_json.term = term1_validation.term_type;
564
+ if (term1_validation.category == "float" || term1_validation.category == "integer") {
565
+ pp_plot_json.term.q = { mode: "continuous" };
566
+ }
567
+ pp_plot_json.category = term1_validation.category;
545
568
  }
546
569
  if (response_type.term2) {
547
570
  const term2_validation = validate_term(response_type.term2, common_genes, dataset_json, ds);
@@ -549,6 +572,118 @@ function validate_summary_response(response, common_genes, dataset_json, ds) {
549
572
  html += term2_validation.html;
550
573
  } else {
551
574
  pp_plot_json.term2 = term2_validation.term_type;
575
+ if (term2_validation.category == "float" || term2_validation.category == "integer") {
576
+ pp_plot_json.term2.q = { mode: "continuous" };
577
+ }
578
+ pp_plot_json.category2 = term2_validation.category;
579
+ }
580
+ }
581
+ const llmChildType = response_type.childType && ["violin", "boxplot", "sampleScatter", "barchart"].includes(response_type.childType) ? response_type.childType : void 0;
582
+ const resolved = resolveChildType(pp_plot_json.category, pp_plot_json.category2, llmChildType);
583
+ if (resolved.error) {
584
+ html += resolved.error;
585
+ } else {
586
+ pp_plot_json.childType = resolved.childType;
587
+ if (resolved.bothNumeric && (resolved.childType == "violin" || resolved.childType == "boxplot")) {
588
+ pp_plot_json.term2.q = { mode: "discrete" };
589
+ }
590
+ }
591
+ delete pp_plot_json.category;
592
+ if (pp_plot_json.category2) delete pp_plot_json.category2;
593
+ if (response_type.simpleFilter && response_type.simpleFilter.length > 0) {
594
+ const validated_filters = validate_filter(response_type.simpleFilter, ds, "");
595
+ if (validated_filters.html.length > 0) {
596
+ html += validated_filters.html;
597
+ } else {
598
+ pp_plot_json.filter = validated_filters.simplefilter;
599
+ }
600
+ }
601
+ if (html.length > 0) {
602
+ return { type: "html", html };
603
+ } else {
604
+ return { type: "plot", plot: pp_plot_json };
605
+ }
606
+ }
607
+ async function extract_matrix_search_terms_from_query(prompt, llm, dataset_db_output, dataset_json, genes_list, ds, testing) {
608
+ const Schema = {
609
+ $schema: "http://json-schema.org/draft-07/schema#",
610
+ $ref: "#/definitions/MatrixType",
611
+ definitions: {
612
+ MatrixType: {
613
+ type: "object",
614
+ properties: {
615
+ terms: {
616
+ type: "array",
617
+ items: { type: "string" },
618
+ description: "Names of dictionary/clinical terms to include as rows in the matrix"
619
+ },
620
+ geneNames: {
621
+ type: "array",
622
+ items: { type: "string" },
623
+ description: "Names of genes to include as gene variant rows in the matrix"
624
+ },
625
+ simpleFilter: {
626
+ type: "array",
627
+ items: { $ref: "#/definitions/FilterTerm" },
628
+ description: "Optional simple filter terms to restrict the sample set"
629
+ }
630
+ },
631
+ additionalProperties: false
632
+ },
633
+ ...FILTER_TERM_DEFINITIONS
634
+ }
635
+ };
636
+ const common_genes = extractGenesFromPrompt(prompt, genes_list);
637
+ const matrix_ds = dataset_json.charts.filter((chart) => chart.type == "Matrix");
638
+ console.log("matrix_ds", matrix_ds);
639
+ console.log("dataset_json.charts", dataset_json.charts);
640
+ if (matrix_ds.length == 0) throw "Matrix information is not present in the dataset file.";
641
+ if (matrix_ds[0].TrainingData.length == 0) throw "No training data is provided for the matrix agent.";
642
+ const training_data = formatTrainingExamples(matrix_ds[0].TrainingData);
643
+ let system_prompt = "I am an assistant that extracts terms and gene names from the user query to create a matrix plot. A matrix plot displays multiple genes and/or clinical variables across samples in a grid layout. The final output must be in the following JSON format with NO extra comments. The JSON schema is as follows: " + JSON.stringify(Schema) + ' The "terms" field should ONLY contain names of clinical/dictionary fields from the sqlite db. The "geneNames" field should ONLY contain gene names. At least one of "terms" or "geneNames" must be provided. The "simpleFilter" field is optional and should contain an array of JSON terms with which the dataset will be filtered. ' + FILTER_DESCRIPTION + checkField(dataset_json.DatasetPrompt) + checkField(matrix_ds[0].SystemPrompt) + "\n The DB content is as follows: " + dataset_db_output.rag_docs.join(",") + " training data is as follows:" + training_data;
644
+ if (dataset_json.hasGeneExpression && common_genes.length > 0) {
645
+ system_prompt += "\n List of relevant genes are as follows (separated by comma(,)):" + common_genes.join(",");
646
+ }
647
+ system_prompt += " Question: {" + prompt + "} answer:";
648
+ const response = await route_to_appropriate_llm_provider(system_prompt, llm);
649
+ if (testing) {
650
+ return { action: "matrix", response: JSON.parse(response) };
651
+ } else {
652
+ return validate_matrix_response(response, common_genes, dataset_json, ds);
653
+ }
654
+ }
655
+ function validate_matrix_response(response, common_genes, dataset_json, ds) {
656
+ const response_type = JSON.parse(response);
657
+ const pp_plot_json = { chartType: "matrix" };
658
+ let html = "";
659
+ if (response_type.html) html = response_type.html;
660
+ if ((!response_type.terms || response_type.terms.length == 0) && (!response_type.geneNames || response_type.geneNames.length == 0)) {
661
+ html += "At least one clinical term or gene name is required for a matrix plot";
662
+ }
663
+ const twLst = [];
664
+ if (response_type.terms && Array.isArray(response_type.terms)) {
665
+ for (const t of response_type.terms) {
666
+ const term = ds.cohort.termdb.q.termjsonByOneid(t);
667
+ if (!term) {
668
+ html += "invalid term id:" + t + " ";
669
+ } else {
670
+ twLst.push({ id: term.id });
671
+ }
672
+ }
673
+ }
674
+ if (response_type.geneNames && Array.isArray(response_type.geneNames)) {
675
+ for (const g of response_type.geneNames) {
676
+ const gene_hits = common_genes.filter((gene) => gene == g.toLowerCase());
677
+ if (gene_hits.length == 0) {
678
+ html += "invalid gene name:" + g + " ";
679
+ } else {
680
+ const geneName = g.toUpperCase();
681
+ if (dataset_json.hasGeneExpression) {
682
+ twLst.push({ term: { gene: geneName, type: "geneExpression" } });
683
+ } else {
684
+ twLst.push({ term: { gene: geneName, name: geneName, type: "geneVariant" } });
685
+ }
686
+ }
552
687
  }
553
688
  }
554
689
  if (response_type.simpleFilter && response_type.simpleFilter.length > 0) {
@@ -562,12 +697,14 @@ function validate_summary_response(response, common_genes, dataset_json, ds) {
562
697
  if (html.length > 0) {
563
698
  return { type: "html", html };
564
699
  } else {
700
+ pp_plot_json.termgroups = [{ name: "", lst: twLst }];
565
701
  return { type: "plot", plot: pp_plot_json };
566
702
  }
567
703
  }
568
704
  function validate_term(response_term, common_genes, dataset_json, ds) {
569
705
  let html = "";
570
706
  let term_type;
707
+ let category = "";
571
708
  const term = ds.cohort.termdb.q.termjsonByOneid(response_term);
572
709
  if (!term) {
573
710
  const gene_hits = common_genes.filter((gene) => gene == response_term.toLowerCase());
@@ -576,14 +713,16 @@ function validate_term(response_term, common_genes, dataset_json, ds) {
576
713
  } else {
577
714
  if (dataset_json.hasGeneExpression) {
578
715
  term_type = { term: { gene: response_term.toUpperCase(), type: "geneExpression" } };
716
+ category = "float";
579
717
  } else {
580
718
  html += "Dataset does not support gene expression";
581
719
  }
582
720
  }
583
721
  } else {
584
722
  term_type = { id: term.id };
723
+ category = term.type;
585
724
  }
586
- return { term_type, html };
725
+ return { term_type, html, category };
587
726
  }
588
727
  function countOccurrences(str, word) {
589
728
  if (word === "") return 0;
@@ -768,5 +907,7 @@ function parse_db_rows(db_row) {
768
907
  return output_string;
769
908
  }
770
909
  export {
771
- api
910
+ api,
911
+ readJSONFile,
912
+ run_chat_pipeline
772
913
  };