@sjcrh/proteinpaint-rust 2.149.0 → 2.152.1-0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/ollama.rs ADDED
@@ -0,0 +1,1108 @@
1
+ // Compile: cd .. && cargo build --release
2
+ // Test: cd .. && export RUST_BACKTRACE=full && time cargo test -- --nocapture (runs all test except those marked as "ignored")
3
+ // Ignored tests: cd .. && export RUST_BACKTRACE=full && time cargo test -- --ignored --nocapture
4
+ use async_stream::stream;
5
+ use futures::StreamExt;
6
+ use rig::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
7
+ use rig::completion::{GetTokenUsage, Usage};
8
+ use rig::message::ConvertMessage;
9
+ use rig::streaming::RawStreamingChoice;
10
+ use rig::{
11
+ Embed, OneOrMany,
12
+ completion::{self, CompletionError, CompletionRequest},
13
+ embeddings::{self, EmbeddingError, EmbeddingsBuilder},
14
+ impl_conversion_traits, message,
15
+ message::{ImageDetail, Text},
16
+ streaming,
17
+ };
18
+ use serde::{Deserialize, Serialize};
19
+ use serde_json::{Value, json};
20
+ use std::convert::TryInto;
21
+ //use std::time::Duration;
22
+ use std::{convert::TryFrom, str::FromStr};
23
+ use url::Url;
24
+
25
+ // ---------- Main Client ----------
26
+ pub struct ClientBuilder<'a> {
27
+ base_url: &'a str,
28
+ http_client: Option<reqwest::Client>,
29
+ }
30
+
31
+ impl<'a> ClientBuilder<'a> {
32
+ #[allow(clippy::new_without_default)]
33
+ pub fn new() -> Self {
34
+ Self {
35
+ base_url: "",
36
+ http_client: None,
37
+ }
38
+ }
39
+
40
+ pub fn base_url(mut self, base_url: &'a str) -> Self {
41
+ //println!("base_url:{}", base_url);
42
+ self.base_url = base_url;
43
+ self
44
+ }
45
+
46
+ pub fn build(self) -> Result<Client, ClientBuilderError> {
47
+ let http_client = if let Some(http_client) = self.http_client {
48
+ http_client
49
+ } else {
50
+ reqwest::Client::builder().build()?
51
+ };
52
+
53
+ Ok(Client {
54
+ base_url: Url::parse(self.base_url).map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
55
+ http_client,
56
+ })
57
+ }
58
+ }
59
+
60
+ #[derive(Clone, Debug)]
61
+ pub struct Client {
62
+ base_url: Url,
63
+ http_client: reqwest::Client,
64
+ }
65
+
66
+ impl Default for Client {
67
+ fn default() -> Self {
68
+ Self::new()
69
+ }
70
+ }
71
+
72
+ impl Client {
73
+ pub fn builder() -> ClientBuilder<'static> {
74
+ ClientBuilder::new()
75
+ }
76
+
77
+ pub fn completion_model(&self, model: &str) -> CompletionModel {
78
+ CompletionModel::new(self.clone(), model)
79
+ }
80
+
81
+ pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
82
+ EmbeddingModel::new(self.clone(), model, 0)
83
+ }
84
+
85
+ pub fn new() -> Self {
86
+ Self::builder().build().expect("Myprovider client should build")
87
+ }
88
+
89
+ pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
90
+ let url = self.base_url.join(path)?;
91
+ Ok(self.http_client.post(url))
92
+ }
93
+
94
+ pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
95
+ let url = self.base_url.join(path)?;
96
+ Ok(self.http_client.get(url))
97
+ }
98
+ }
99
+
100
+ impl ProviderClient for Client {
101
+ fn from_env() -> Self
102
+ where
103
+ Self: Sized,
104
+ {
105
+ let api_base = std::env::var("MYPROVIDER_API_BASE_URL").expect("MYPROVIDER_API_BASE_URL not set");
106
+ Self::builder().base_url(&api_base).build().unwrap()
107
+ }
108
+
109
+ fn from_val(input: rig::client::ProviderValue) -> Self {
110
+ let rig::client::ProviderValue::Simple(_) = input else {
111
+ panic!("Incorrect provider value type")
112
+ };
113
+
114
+ Self::new()
115
+ }
116
+ }
117
+
118
+ impl CompletionClient for Client {
119
+ type CompletionModel = CompletionModel;
120
+
121
+ fn completion_model(&self, model: &str) -> CompletionModel {
122
+ CompletionModel::new(self.clone(), model)
123
+ }
124
+ }
125
+
126
+ impl EmbeddingsClient for Client {
127
+ type EmbeddingModel = EmbeddingModel;
128
+ fn embedding_model(&self, model: &str) -> EmbeddingModel {
129
+ EmbeddingModel::new(self.clone(), model, 0)
130
+ }
131
+ fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
132
+ EmbeddingModel::new(self.clone(), model, ndims)
133
+ }
134
+ fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
135
+ EmbeddingsBuilder::new(self.embedding_model(model))
136
+ }
137
+ }
138
+
139
+ impl VerifyClient for Client {
140
+ async fn verify(&self) -> Result<(), VerifyError> {
141
+ let response = self
142
+ .get("api/tags")
143
+ .expect("Failed to build request")
144
+ .send()
145
+ .await
146
+ .unwrap();
147
+ match response.status() {
148
+ reqwest::StatusCode::OK => Ok(()),
149
+ _ => {
150
+ response.error_for_status().unwrap();
151
+ Ok(())
152
+ }
153
+ }
154
+ }
155
+ }
156
+
157
+ impl_conversion_traits!(
158
+ AsTranscription,
159
+ AsImageGeneration,
160
+ AsAudioGeneration for Client
161
+ );
162
+
163
+ // ---------- API Error and Response Structures ----------
164
+
165
+ #[derive(Debug, Deserialize)]
166
+ struct ApiErrorResponse {
167
+ message: String,
168
+ }
169
+
170
+ #[derive(Debug, Deserialize)]
171
+ #[serde(untagged)]
172
+ enum ApiResponse<T> {
173
+ Ok(T),
174
+ Err(ApiErrorResponse),
175
+ }
176
+
177
+ // ---------- Embedding API ----------
178
+
179
+ #[derive(Debug, Serialize, Deserialize)]
180
+ pub struct EmbeddingResponse {
181
+ pub model: String,
182
+ pub embeddings: Vec<Vec<f64>>,
183
+ #[serde(default)]
184
+ pub total_duration: Option<u32>,
185
+ #[serde(default)]
186
+ pub load_duration: Option<u32>,
187
+ #[serde(default)]
188
+ pub prompt_eval_count: Option<u32>,
189
+ }
190
+
191
+ impl From<ApiErrorResponse> for EmbeddingError {
192
+ fn from(err: ApiErrorResponse) -> Self {
193
+ EmbeddingError::ProviderError(err.message)
194
+ }
195
+ }
196
+
197
+ impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
198
+ fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
199
+ match value {
200
+ ApiResponse::Ok(response) => Ok(response),
201
+ ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
202
+ }
203
+ }
204
+ }
205
+
206
+ // ---------- Embedding Model ----------
207
+
208
+ #[derive(Clone)]
209
+ pub struct EmbeddingModel {
210
+ client: Client,
211
+ pub model: String,
212
+ ndims: usize,
213
+ }
214
+
215
+ impl EmbeddingModel {
216
+ pub fn new(client: Client, model: &str, ndims: usize) -> Self {
217
+ Self {
218
+ client,
219
+ model: model.to_owned(),
220
+ ndims,
221
+ }
222
+ }
223
+ }
224
+
225
+ impl embeddings::EmbeddingModel for EmbeddingModel {
226
+ const MAX_DOCUMENTS: usize = 1024;
227
+ fn ndims(&self) -> usize {
228
+ self.ndims
229
+ }
230
+
231
+ async fn embed_texts(
232
+ &self,
233
+ documents: impl IntoIterator<Item = String>,
234
+ ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
235
+ let docs: Vec<String> = documents.into_iter().collect();
236
+ let payload = json!({
237
+ "model": self.model,
238
+ "input": docs,
239
+ });
240
+ let response = self
241
+ .client
242
+ .post("api/embed")?
243
+ .json(&payload)
244
+ .send()
245
+ .await
246
+ .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
247
+ if response.status().is_success() {
248
+ let api_resp: EmbeddingResponse = response
249
+ .json()
250
+ .await
251
+ .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
252
+ if api_resp.embeddings.len() != docs.len() {
253
+ return Err(EmbeddingError::ResponseError(
254
+ "Number of returned embeddings does not match input".into(),
255
+ ));
256
+ }
257
+ Ok(api_resp
258
+ .embeddings
259
+ .into_iter()
260
+ .zip(docs.into_iter())
261
+ .map(|(vec, document)| embeddings::Embedding { document, vec })
262
+ .collect())
263
+ } else {
264
+ Err(EmbeddingError::ProviderError(response.text().await.unwrap()))
265
+ }
266
+ }
267
+ //Ok(Vec::<Embedding>::new())
268
+ }
269
+
270
+ // ---------- Completion API ----------
271
+
272
+ #[derive(Debug, Serialize, Deserialize)]
273
+ pub struct CompletionResponse {
274
+ pub model: String,
275
+ pub message: Message,
276
+ pub timestamp: String,
277
+ }
278
+ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
279
+ type Error = CompletionError;
280
+ fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
281
+ match resp.message {
282
+ // Process only if an assistant message is present.
283
+ Message::Assistant { content, id } => {
284
+ //let mut assistant_contents = Vec::new();
285
+ // Add the assistant's text content if any.
286
+ //if !content.is_empty() {
287
+ // assistant_contents.push(completion::AssistantContent2::text(&content));
288
+ //}
289
+
290
+ let choice = rig::one_or_many::OneOrMany::one(completion::AssistantContent::text(&content));
291
+ //let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
292
+ //let completion_tokens = resp.eval_count.unwrap_or(0);
293
+
294
+ let raw_response = CompletionResponse {
295
+ model: resp.model,
296
+ message: Message::Assistant { content, id },
297
+ timestamp: resp.timestamp,
298
+ };
299
+
300
+ Ok(completion::CompletionResponse {
301
+ choice,
302
+ usage: Usage {
303
+ input_tokens: 0, // Not provided by custom provider
304
+ output_tokens: 0, // Not provided by custom provider
305
+ total_tokens: 0, // Not provided by custom provider
306
+ },
307
+ raw_response,
308
+ })
309
+ }
310
+ _ => Err(CompletionError::ResponseError(
311
+ "Chat response does not include an assistant message".into(),
312
+ )),
313
+ }
314
+ }
315
+ }
316
+
317
+ // ---------- Completion Model ----------
318
+
319
+ #[derive(Clone)]
320
+ pub struct CompletionModel {
321
+ client: Client,
322
+ pub model: String,
323
+ }
324
+
325
+ impl CompletionModel {
326
+ pub fn new(client: Client, model: &str) -> Self {
327
+ Self {
328
+ client,
329
+ model: model.to_owned(),
330
+ }
331
+ }
332
+
333
+ fn create_completion_request(&self, completion_request: CompletionRequest) -> Result<Value, CompletionError> {
334
+ let mut partial_history = vec![];
335
+ if let Some(docs) = completion_request.normalized_documents() {
336
+ partial_history.push(docs);
337
+ }
338
+ partial_history.extend(completion_request.chat_history);
339
+
340
+ // Initialize full history with preamble (or empty if non-existent)
341
+ let mut full_history: Vec<Message> = completion_request
342
+ .preamble
343
+ .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
344
+
345
+ // Convert and extend the rest of the history
346
+ full_history.extend(
347
+ partial_history
348
+ .into_iter()
349
+ .map(|msg| Message::convert_from_message(msg))
350
+ .collect::<Result<Vec<Vec<Message>>, _>>()?
351
+ .into_iter()
352
+ .flatten()
353
+ .collect::<Vec<Message>>(),
354
+ );
355
+
356
+ // Convert internal prompt into a provider Message
357
+ //let max_new_tokens: u64;
358
+ let top_p: f64;
359
+ let mut schema_json_string: Option<String> = None;
360
+ match completion_request.additional_params {
361
+ Some(extra) => {
362
+ top_p = extra["top_p"].as_f64().unwrap();
363
+ if let Value::Object(obj) = extra {
364
+ if obj.contains_key("schema_json") {
365
+ schema_json_string = Some(String::from(obj["schema_json"].as_str().unwrap()));
366
+ //println!("schema_json_string:{:?}", schema_json_string);
367
+ }
368
+ }
369
+ }
370
+ None => {
371
+ panic!("top_p not found!");
372
+ }
373
+ }
374
+
375
+ let mut user_query = "";
376
+ let mut system_prompt = "";
377
+ for message in &full_history {
378
+ match message {
379
+ self::Message::User {
380
+ content: text,
381
+ images: _,
382
+ name: _,
383
+ } => {
384
+ //println!("User:{:?}", text);
385
+ user_query = text;
386
+ }
387
+ self::Message::System {
388
+ content: text,
389
+ images: _,
390
+ name: _,
391
+ } => {
392
+ system_prompt = text;
393
+ //println!("System:{:?}", text);
394
+ }
395
+ self::Message::Assistant { content: _, id: _ } => {}
396
+ self::Message::ToolResult { content: _, name: _ } => {}
397
+ }
398
+ }
399
+ let final_text = system_prompt.replace(&"{question}", &user_query);
400
+
401
+ // Convert and extend the rest of the history
402
+ //full_history.extend(
403
+ // partial_history
404
+ // .into_iter()
405
+ // .map(Message::convert_from_message)
406
+ // .collect::<Result<Vec<Vec<Message>>, _>>()?
407
+ // .into_iter()
408
+ // .flatten()
409
+ // .collect::<Vec<Message>>(),
410
+ //);
411
+
412
+ let mut request_payload;
413
+ match schema_json_string {
414
+ // JSON schema is only added if its provided
415
+ Some(_schema) => {
416
+ request_payload = json!({
417
+ "model": self.model,
418
+ "messages": [{"role": "user", "content": final_text}],
419
+ "raw": false,
420
+ "stream": false,
421
+ "keep_alive": 15, // Keep the LLM loaded for 15mins
422
+ //"format":schema,
423
+ "options": {
424
+ "top_p": top_p,
425
+ "temperature": completion_request.temperature,
426
+ "num_ctx": 10000
427
+ }
428
+ });
429
+ }
430
+ None => {
431
+ request_payload = json!({
432
+ "model": self.model,
433
+ "messages": [{"role": "user", "content": final_text}],
434
+ "raw": false,
435
+ "stream": false,
436
+ "keep_alive": 15, // Keep the LLM loaded for 15mins
437
+ "options": {
438
+ "top_p": top_p,
439
+ "temperature": completion_request.temperature,
440
+ "num_ctx": 10000
441
+ }
442
+ });
443
+ }
444
+ }
445
+ //let mut request_payload = json!({
446
+ //"model": "llama3.3:70b",
447
+ //"messages": [{"role": "user", "content": "Tell me about Canada."}],
448
+ //"stream": false,
449
+ //"format": {
450
+ // "type": "object",
451
+ // "properties": {
452
+ // "name": {
453
+ // "type": "string"
454
+ // },
455
+ // "capital": {
456
+ // "type": "string"
457
+ // },
458
+ // "languages": {
459
+ // "type": "array",
460
+ // "items": {
461
+ // "type": "string"
462
+ // }
463
+ // }
464
+ // },
465
+ // "required": [
466
+ // "name",
467
+ // "capital",
468
+ // "languages"
469
+ // ]
470
+ //}
471
+ //});
472
+
473
+ //println!("request_payload:{}", request_payload);
474
+ if !completion_request.tools.is_empty() {
475
+ request_payload["tools"] = json!(
476
+ completion_request
477
+ .tools
478
+ .into_iter()
479
+ .map(|tool| tool.into())
480
+ .collect::<Vec<ToolDefinition>>()
481
+ );
482
+ }
483
+
484
+ //tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
485
+
486
+ Ok(request_payload)
487
+ }
488
+ }
489
+
490
+ // ---------- CompletionModel Implementation ----------
491
+
492
+ #[derive(Clone, Serialize, Deserialize, Debug)]
493
+ pub struct StreamingCompletionResponse {
494
+ pub done_reason: Option<String>,
495
+ pub total_duration: Option<u64>,
496
+ pub load_duration: Option<u64>,
497
+ pub prompt_eval_count: Option<u64>,
498
+ pub prompt_eval_duration: Option<u64>,
499
+ pub eval_count: Option<u64>,
500
+ pub eval_duration: Option<u64>,
501
+ }
502
+
503
+ impl GetTokenUsage for StreamingCompletionResponse {
504
+ fn token_usage(&self) -> Option<rig::completion::Usage> {
505
+ let mut usage = rig::completion::Usage::new();
506
+ let input_tokens = self.prompt_eval_count.unwrap_or_default();
507
+ let output_tokens = self.eval_count.unwrap_or_default();
508
+ usage.input_tokens = input_tokens;
509
+ usage.output_tokens = output_tokens;
510
+ usage.total_tokens = input_tokens + output_tokens;
511
+
512
+ Some(usage)
513
+ }
514
+ }
515
+
516
+ impl completion::CompletionModel for CompletionModel {
517
+ type Response = CompletionResponse;
518
+ type StreamingResponse = StreamingCompletionResponse;
519
+
520
+ async fn completion(
521
+ &self,
522
+ completion_request: CompletionRequest,
523
+ ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
524
+ let request_payload = self.create_completion_request(completion_request)?;
525
+
526
+ let response = self
527
+ .client
528
+ .post(&"api/chat")?
529
+ .json(&request_payload)
530
+ .send()
531
+ .await
532
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
533
+ //println!("response:{:?}", response);
534
+ if response.status().is_success() {
535
+ let text = response
536
+ .text()
537
+ .await
538
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
539
+ let text_json: Value = serde_json::from_str(&text)?;
540
+ //tracing::debug!(target: "rig", "Myprovider chat response: {}", text);
541
+ let chat_resp: CompletionResponse = CompletionResponse {
542
+ model: text_json["model_name"].to_string(),
543
+ message: Message::Assistant {
544
+ id: text_json["id"].to_string(),
545
+ content: text_json["message"].to_string(),
546
+ },
547
+ timestamp: text_json["created_at"].to_string(),
548
+ };
549
+ //println!("chat_resp:{:?}", chat_resp);
550
+ let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
551
+ Ok(conv)
552
+ } else {
553
+ let err_text = response
554
+ .text()
555
+ .await
556
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
557
+ Err(CompletionError::ProviderError(err_text))
558
+ }
559
+ }
560
+
561
+ async fn stream(
562
+ &self,
563
+ request: CompletionRequest,
564
+ ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
565
+ let mut request_payload = self.create_completion_request(request)?;
566
+ merge_inplace(&mut request_payload, json!({"stream": true}));
567
+
568
+ let response = self
569
+ .client
570
+ .post("api/chat")?
571
+ .json(&request_payload)
572
+ .send()
573
+ .await
574
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
575
+
576
+ if !response.status().is_success() {
577
+ let err_text = response
578
+ .text()
579
+ .await
580
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
581
+ return Err(CompletionError::ProviderError(err_text));
582
+ }
583
+
584
+ let stream = Box::pin(stream! {
585
+ let mut stream = response.bytes_stream();
586
+ while let Some(chunk_result) = stream.next().await {
587
+ let chunk = match chunk_result {
588
+ Ok(c) => c,
589
+ Err(e) => {
590
+ yield Err(CompletionError::RequestError(e.into()));
591
+ break;
592
+ }
593
+ };
594
+
595
+ let text = match String::from_utf8(chunk.to_vec()) {
596
+ Ok(t) => t,
597
+ Err(e) => {
598
+ yield Err(CompletionError::ResponseError(e.to_string()));
599
+ break;
600
+ }
601
+ };
602
+
603
+
604
+ for line in text.lines() {
605
+ let line = line.to_string();
606
+
607
+ let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
608
+ continue;
609
+ };
610
+
611
+ match response.message {
612
+ Message::Assistant{ content, .. } => {
613
+ if !content.is_empty() {
614
+ yield Ok(RawStreamingChoice::Message(content))
615
+ }
616
+ }
617
+ _ => {
618
+ continue;
619
+ }
620
+ }
621
+
622
+ //if response.message {
623
+ // yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
624
+ // total_duration: response.total_duration,
625
+ // load_duration: response.load_duration,
626
+ // prompt_eval_count: response.prompt_eval_count,
627
+ // prompt_eval_duration: response.prompt_eval_duration,
628
+ // eval_count: response.eval_count,
629
+ // eval_duration: response.eval_duration,
630
+ // done_reason: response.done_reason,
631
+ // }));
632
+ //}
633
+ }
634
+ }
635
+ });
636
+
637
+ Ok(streaming::StreamingCompletionResponse::stream(stream))
638
+ }
639
+ }
640
+
641
+ // ---------- Tool Definition Conversion ----------
642
+
643
+ /// Myprovider-required tool definition format.
644
+ #[derive(Clone, Debug, Deserialize, Serialize)]
645
+ pub struct ToolDefinition {
646
+ #[serde(rename = "type")]
647
+ pub type_field: String, // Fixed as "function"
648
+ pub function: completion::ToolDefinition,
649
+ }
650
+
651
+ /// Convert internal ToolDefinition (from the completion module) into Myprovider's tool definition.
652
+ impl From<rig::completion::ToolDefinition> for ToolDefinition {
653
+ fn from(tool: rig::completion::ToolDefinition) -> Self {
654
+ ToolDefinition {
655
+ type_field: "function".to_owned(),
656
+ function: completion::ToolDefinition {
657
+ name: tool.name,
658
+ description: tool.description,
659
+ parameters: tool.parameters,
660
+ },
661
+ }
662
+ }
663
+ }
664
+
665
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
666
+ pub struct ToolCall {
667
+ // pub id: String,
668
+ #[serde(default, rename = "type")]
669
+ pub r#type: ToolType,
670
+ pub function: Function,
671
+ }
672
+ #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
673
+ #[serde(rename_all = "lowercase")]
674
+ pub enum ToolType {
675
+ #[default]
676
+ Function,
677
+ }
678
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
679
+ pub struct Function {
680
+ pub name: String,
681
+ pub arguments: Value,
682
+ }
683
+
684
+ // ---------- Provider Message Definition ----------
685
+
686
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
687
+ #[serde(tag = "role", rename_all = "lowercase")]
688
+ pub enum Message {
689
+ User {
690
+ content: String,
691
+ #[serde(skip_serializing_if = "Option::is_none")]
692
+ images: Option<Vec<String>>,
693
+ #[serde(skip_serializing_if = "Option::is_none")]
694
+ name: Option<String>,
695
+ },
696
+ Assistant {
697
+ #[serde(default)]
698
+ content: String,
699
+ #[serde(default)]
700
+ id: String,
701
+ //#[serde(skip_serializing_if = "Option::is_none")]
702
+ //thinking: Option<String>,
703
+ //#[serde(skip_serializing_if = "Option::is_none")]
704
+ //images: Option<Vec<String>>,
705
+ //#[serde(skip_serializing_if = "Option::is_none")]
706
+ //name: Option<String>,
707
+ //#[serde(default, deserialize_with = "json_utils::null_or_vec")]
708
+ //tool_calls: Vec<ToolCall>,
709
+ },
710
+ System {
711
+ content: String,
712
+ #[serde(skip_serializing_if = "Option::is_none")]
713
+ images: Option<Vec<String>>,
714
+ #[serde(skip_serializing_if = "Option::is_none")]
715
+ name: Option<String>,
716
+ },
717
+ #[serde(rename = "tool")]
718
+ ToolResult {
719
+ #[serde(rename = "tool_name")]
720
+ name: String,
721
+ content: String,
722
+ },
723
+ }
724
+
725
+ /// -----------------------------
726
+ /// Provider Message Conversions
727
+ /// -----------------------------
728
+ /// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
729
+ /// (Only User and Assistant variants are supported.)
730
+ impl ConvertMessage for Message {
731
+ type Error = rig::message::MessageError;
732
+ fn convert_from_message(internal_msg: message::Message) -> Result<Vec<Self>, Self::Error> {
733
+ use rig::message::Message as InternalMessage;
734
+ match internal_msg {
735
+ InternalMessage::User { content, .. } => {
736
+ let (tool_results, other_content): (Vec<_>, Vec<_>) = content
737
+ .into_iter()
738
+ .partition(|content| matches!(content, rig::message::UserContent::ToolResult(_)));
739
+
740
+ if !tool_results.is_empty() {
741
+ tool_results
742
+ .into_iter()
743
+ .map(|content| match content {
744
+ rig::message::UserContent::ToolResult(rig::message::ToolResult { id, content, .. }) => {
745
+ // Ollama expects a single string for tool results, so we concatenate
746
+ let content_string = content
747
+ .into_iter()
748
+ .map(|content| match content {
749
+ rig::message::ToolResultContent::Text(text) => text.text,
750
+ _ => "[Non-text content]".to_string(),
751
+ })
752
+ .collect::<Vec<_>>()
753
+ .join("\n");
754
+
755
+ Ok::<_, rig::message::MessageError>(Message::ToolResult {
756
+ name: id,
757
+ content: content_string,
758
+ })
759
+ }
760
+ _ => unreachable!(),
761
+ })
762
+ .collect::<Result<Vec<_>, _>>()
763
+ } else {
764
+ // Ollama requires separate text content and images array
765
+ let (texts, _images) =
766
+ other_content
767
+ .into_iter()
768
+ .fold((Vec::new(), Vec::new()), |(mut texts, mut images), content| {
769
+ match content {
770
+ rig::message::UserContent::Text(rig::message::Text { text }) => texts.push(text),
771
+ rig::message::UserContent::Image(rig::message::Image { data, .. }) => {
772
+ images.push(data)
773
+ }
774
+ rig::message::UserContent::Document(rig::message::Document { data, .. }) => {
775
+ texts.push(data.to_string())
776
+ }
777
+ _ => {} // Audio not supported by Ollama
778
+ }
779
+ (texts, images)
780
+ });
781
+
782
+ Ok(vec![Message::User {
783
+ content: texts.join(" "),
784
+ images: None,
785
+ name: None,
786
+ }])
787
+ }
788
+ }
789
+ InternalMessage::Assistant { content, .. } => {
790
+ let mut thinking: Option<String> = None;
791
+ let (text_content, _tool_calls) =
792
+ content
793
+ .into_iter()
794
+ .fold((Vec::new(), Vec::new()), |(mut texts, mut tools), content| {
795
+ match content {
796
+ rig::message::AssistantContent::Text(text) => texts.push(text.text),
797
+ rig::message::AssistantContent::ToolCall(_tool_call) => tools.push(_tool_call),
798
+ rig::message::AssistantContent::Reasoning(rig::message::Reasoning {
799
+ reasoning,
800
+ ..
801
+ }) => {
802
+ thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
803
+ }
804
+ }
805
+ (texts, tools)
806
+ });
807
+
808
+ // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
809
+ // so either `content` or `tool_calls` will have some content.
810
+ #[allow(unreachable_code)]
811
+ Ok(vec![Message::Assistant {
812
+ content: text_content.join(" "),
813
+ id: todo!(),
814
+ }])
815
+ }
816
+ }
817
+ }
818
+ }
819
+
820
+ /// Conversion from provider Message to a completion message.
821
+ /// This is needed so that responses can be converted back into chat history.
822
+ impl From<Message> for rig::completion::Message {
823
+ fn from(msg: Message) -> Self {
824
+ match msg {
825
+ Message::User { content, .. } => rig::completion::Message::User {
826
+ content: OneOrMany::one(rig::completion::message::UserContent::Text(Text { text: content })),
827
+ },
828
+ Message::Assistant { content, .. } => {
829
+ let assistant_contents = vec![rig::completion::message::AssistantContent::Text(Text { text: content })];
830
+ rig::completion::Message::Assistant {
831
+ id: None,
832
+ content: OneOrMany::many(assistant_contents).unwrap(),
833
+ }
834
+ }
835
+ // System and ToolResult are converted to User message as needed.
836
+ Message::System { content, .. } => rig::completion::Message::User {
837
+ content: OneOrMany::one(rig::completion::message::UserContent::Text(Text { text: content })),
838
+ },
839
+ Message::ToolResult { name, content } => rig::completion::Message::User {
840
+ content: OneOrMany::one(message::UserContent::tool_result(
841
+ name,
842
+ OneOrMany::one(message::ToolResultContent::text(content)),
843
+ )),
844
+ },
845
+ }
846
+ }
847
+ }
848
+
849
+ impl Message {
850
+ /// Constructs a system message.
851
+ pub fn system(content: &str) -> Self {
852
+ Message::System {
853
+ content: content.to_owned(),
854
+ images: None,
855
+ name: None,
856
+ }
857
+ }
858
+ }
859
+
860
+ // ---------- Additional Message Types ----------
861
+
862
+ impl From<rig::message::ToolCall> for ToolCall {
863
+ fn from(tool_call: rig::message::ToolCall) -> Self {
864
+ Self {
865
+ r#type: ToolType::Function,
866
+ function: Function {
867
+ name: tool_call.function.name,
868
+ arguments: tool_call.function.arguments,
869
+ },
870
+ }
871
+ }
872
+ }
873
+
874
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
875
+ pub struct SystemContent {
876
+ #[serde(default)]
877
+ r#type: SystemContentType,
878
+ text: String,
879
+ }
880
+
881
+ #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
882
+ #[serde(rename_all = "lowercase")]
883
+ pub enum SystemContentType {
884
+ #[default]
885
+ Text,
886
+ }
887
+
888
+ impl From<String> for SystemContent {
889
+ fn from(s: String) -> Self {
890
+ SystemContent {
891
+ r#type: SystemContentType::default(),
892
+ text: s,
893
+ }
894
+ }
895
+ }
896
+
897
+ impl FromStr for SystemContent {
898
+ type Err = std::convert::Infallible;
899
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
900
+ Ok(SystemContent {
901
+ r#type: SystemContentType::default(),
902
+ text: s.to_string(),
903
+ })
904
+ }
905
+ }
906
+
907
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
908
+ pub struct AssistantContent {
909
+ pub text: String,
910
+ }
911
+
912
+ impl FromStr for AssistantContent {
913
+ type Err = std::convert::Infallible;
914
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
915
+ Ok(AssistantContent { text: s.to_owned() })
916
+ }
917
+ }
918
+
919
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
920
+ #[serde(tag = "type", rename_all = "lowercase")]
921
+ pub enum UserContent {
922
+ Text { text: String },
923
+ Image { image_url: ImageUrl },
924
+ // Audio variant removed as Ollama API does not support audio input.
925
+ }
926
+
927
+ impl FromStr for UserContent {
928
+ type Err = std::convert::Infallible;
929
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
930
+ Ok(UserContent::Text { text: s.to_owned() })
931
+ }
932
+ }
933
+
934
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
935
+ pub struct ImageUrl {
936
+ pub url: String,
937
+ #[serde(default)]
938
+ pub detail: ImageDetail,
939
+ }
940
+
941
+ // ---------JSON utils functions -----------------------------
942
+
943
+ pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) {
944
+ if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) {
945
+ b_map.into_iter().for_each(|(key, value)| {
946
+ a_map.insert(key, value);
947
+ });
948
+ }
949
+ }
950
+
951
+ // =================================================================
952
+ // Tests
953
+ // =================================================================
954
+
955
+ #[cfg(test)]
956
+ mod tests {
957
+ use super::*;
958
+ use rig::agent::AgentBuilder;
959
+ use rig::completion::request::Prompt;
960
+ //use rig::providers::myprovider;
961
+ use rig::vector_store::in_memory_store::InMemoryVectorStore;
962
+ //use serde_json::json;
963
+ use serde_json;
964
+ use std::fs::{self};
965
+ use std::path::Path;
966
+
967
+ // Test deserialization and conversion for the /api/chat endpoint.
968
+ #[tokio::test]
969
+ #[ignore]
970
+
971
+ async fn test_ollama_implementation() {
972
+ let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
973
+ let serverconfig_file_path = Path::new("../../serverconfig.json");
974
+ let absolute_path = serverconfig_file_path.canonicalize().unwrap();
975
+
976
+ // Read the file
977
+ let data = fs::read_to_string(absolute_path).unwrap();
978
+
979
+ // Parse the JSON data
980
+ let json: serde_json::Value = serde_json::from_str(&data).unwrap();
981
+
982
+ // Initialize Myprovider client
983
+ let myprovider_host = json["ollama_apilink"].as_str().unwrap();
984
+ let myprovider_embedding_model = json["ollama_embedding_model_name"].as_str().unwrap();
985
+ let myprovider_comp_model = json["ollama_comp_model_name"].as_str().unwrap();
986
+ let myprovider_client = Client::builder()
987
+ .base_url(myprovider_host)
988
+ .build()
989
+ .expect("myprovider server not found");
990
+ //let myprovider_client = myprovider::Client::new();
991
+ let embedding_model = myprovider_client.embedding_model(myprovider_embedding_model);
992
+ let comp_model = myprovider_client.completion_model(myprovider_comp_model); // "granite3-dense:latest" "PetrosStav/gemma3-tools:12b" "llama3-groq-tool-use:latest" "PetrosStav/gemma3-tools:12b"
993
+
994
+ let contents = String::from("SNV/SNP or point mutations nucleotide mutations are very common forms of mutations which can often give rise to genetic diseases such as cancer, Alzheimer's disease etc. They can be duw to substitution of nucleotide, or insertion or deletion of a nucleotide. Indels are multi-nucleotide insertion/deletion/substitutions. Complex indels are indels where insertion and deletion have happened in the same genomic locus. Every genomic sample from each patient has its own set of mutations therefore requiring personalized treatment.
995
+
996
+ If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
997
+
998
+ ---
999
+
1000
+ Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
1001
+
1002
+ If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
1003
+
1004
+ ---
1005
+
1006
+ Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
1007
+
1008
+ If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
1009
+ ---
1010
+
1011
+ Hierarchial clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
1012
+
1013
+ If a ProteinPaint dataset contains hierarchial data then return JSON with single key, 'hierarchial'.
1014
+
1015
+ ---
1016
+
1017
+ Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
1018
+
1019
+ If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
1020
+
1021
+ ---
1022
+
1023
+ Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens.
1024
+
1025
+ There are two main methods of survival analysis:
1026
+
1027
+ 1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
1028
+ 2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
1029
+
1030
+ The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
1031
+ HR = 1: No effect
1032
+ HR < 1: Reduction in the hazard
1033
+ HR > 1: Increase in Hazard
1034
+
1035
+ If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
1036
+
1037
+ ---
1038
+
1039
+ Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
1040
+
1041
+ If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
1042
+
1043
+ ---
1044
+
1045
+ Summary plot in ProteinPaint shows the various facets of the datasets. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for visualizing all the different facets of the dataset. You can display a categorical variable and overlay another variable on top it and stratify (or divide) using a third variable simultaneously. You can also custom filters to the dataset so that you can only study part of the dataset. If a user query asks about variant calling or mapping reads then JSON with single key, 'summary'.
1046
+
1047
+ Sample Query1: \"Show all fusions for patients with age less than 30\"
1048
+ Sample Answer1: { \"answer\": \"summary\" }
1049
+
1050
+ Sample Query1: \"List all molecular subtypes of leukemia\"
1051
+ Sample Answer1: { \"answer\": \"summary\" }
1052
+
1053
+ ---
1054
+
1055
+ If a query does not match any of the fields described above, then return JSON with single key, 'none'
1056
+ ");
1057
+
1058
+ // Split the contents by the delimiter "---"
1059
+ let parts: Vec<&str> = contents.split("---").collect();
1060
+
1061
+ //let schema_json: Value = serde_json::to_value(schemars::schema_for!(OutputJson)).unwrap(); // error handling here
1062
+
1063
+ //let additional = json!({
1064
+ // "format": schema_json
1065
+ //});
1066
+
1067
+ // Print the separated parts
1068
+ let mut rag_docs = Vec::<String>::new();
1069
+ for (_i, part) in parts.iter().enumerate() {
1070
+ //println!("Part {}: {}", i + 1, part.trim());
1071
+ rag_docs.push(part.trim().to_string())
1072
+ }
1073
+
1074
+ let top_k: usize = 3;
1075
+ // Create embeddings and add to vector store
1076
+ let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
1077
+ .documents(rag_docs)
1078
+ .expect("Reason1")
1079
+ .build()
1080
+ .await
1081
+ .unwrap();
1082
+
1083
+ // Create vector store
1084
+ let mut vector_store = InMemoryVectorStore::<String>::default();
1085
+ InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
1086
+
1087
+ let max_new_tokens: usize = 512;
1088
+ let top_p: f32 = 0.95;
1089
+ let temperature: f64 = 0.01;
1090
+ let additional = json!({
1091
+ "max_new_tokens": max_new_tokens,
1092
+ "top_p": top_p
1093
+ });
1094
+
1095
+ // Create RAG agent
1096
+ let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case. \nQuestion= {question} \nanswer").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
1097
+
1098
+ let response = agent.prompt(user_input).await.expect("Failed to prompt myprovider");
1099
+
1100
+ //println!("Myprovider: {}", response);
1101
+ let result = response.replace("json", "").replace("```", "");
1102
+ //println!("result:{}", result);
1103
+ let json_value: Value = serde_json::from_str(&result).expect("REASON2");
1104
+ let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON3");
1105
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON4");
1106
+ assert_eq!(json_value3["answer"].to_string().replace("\"", ""), "dge");
1107
+ }
1108
+ }