@sjcrh/proteinpaint-rust 2.145.1 → 2.146.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1134 @@
1
+ // Compile: cd .. && cargo build --release
2
+ // Test: cd .. && export RUST_BACKTRACE=full && time cargo test -- --nocapture (runs all test except those marked as "ignored")
3
+ // Ignored tests: cd .. && export RUST_BACKTRACE=full && time cargo test -- --ignored --nocapture
4
+ use async_stream::stream;
5
+ use futures::StreamExt;
6
+ use rig::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
7
+ use rig::completion::{GetTokenUsage, Usage};
8
+ use rig::message::ConvertMessage;
9
+ use rig::streaming::RawStreamingChoice;
10
+ use rig::{
11
+ Embed, OneOrMany,
12
+ completion::{self, CompletionError, CompletionRequest},
13
+ embeddings::{self, EmbeddingError, EmbeddingsBuilder},
14
+ impl_conversion_traits, message,
15
+ message::{ImageDetail, Text},
16
+ streaming,
17
+ };
18
+ use serde::{Deserialize, Serialize};
19
+ use serde_json::{Value, json};
20
+ use std::convert::TryInto;
21
+ use std::time::Duration;
22
+ use std::{convert::TryFrom, str::FromStr};
23
+ use url::Url;
24
+
25
+ // ---------- Main Client ----------
26
+ pub struct ClientBuilder<'a> {
27
+ base_url: &'a str,
28
+ http_client: Option<reqwest::Client>,
29
+ }
30
+
31
+ impl<'a> ClientBuilder<'a> {
32
+ #[allow(clippy::new_without_default)]
33
+ pub fn new() -> Self {
34
+ Self {
35
+ base_url: "",
36
+ http_client: None,
37
+ }
38
+ }
39
+
40
+ pub fn base_url(mut self, base_url: &'a str) -> Self {
41
+ //println!("base_url:{}", base_url);
42
+ self.base_url = base_url;
43
+ self
44
+ }
45
+
46
+ pub fn build(self) -> Result<Client, ClientBuilderError> {
47
+ let http_client = if let Some(http_client) = self.http_client {
48
+ http_client
49
+ } else {
50
+ reqwest::Client::builder().build()?
51
+ };
52
+
53
+ Ok(Client {
54
+ base_url: Url::parse(self.base_url).map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
55
+ http_client,
56
+ })
57
+ }
58
+ }
59
+
60
+ #[derive(Clone, Debug)]
61
+ pub struct Client {
62
+ base_url: Url,
63
+ http_client: reqwest::Client,
64
+ }
65
+
66
+ impl Default for Client {
67
+ fn default() -> Self {
68
+ Self::new()
69
+ }
70
+ }
71
+
72
+ impl Client {
73
+ pub fn builder() -> ClientBuilder<'static> {
74
+ ClientBuilder::new()
75
+ }
76
+ pub fn new() -> Self {
77
+ Self::builder().build().expect("Myprovider client should build")
78
+ }
79
+
80
+ pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
81
+ let url = self.base_url.join(path)?;
82
+ Ok(self.http_client.post(url))
83
+ }
84
+
85
+ pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
86
+ let url = self.base_url.join(path)?;
87
+ Ok(self.http_client.get(url))
88
+ }
89
+ }
90
+
91
+ impl ProviderClient for Client {
92
+ fn from_env() -> Self
93
+ where
94
+ Self: Sized,
95
+ {
96
+ let api_base = std::env::var("MYPROVIDER_API_BASE_URL").expect("MYPROVIDER_API_BASE_URL not set");
97
+ Self::builder().base_url(&api_base).build().unwrap()
98
+ }
99
+
100
+ fn from_val(input: rig::client::ProviderValue) -> Self {
101
+ let rig::client::ProviderValue::Simple(_) = input else {
102
+ panic!("Incorrect provider value type")
103
+ };
104
+
105
+ Self::new()
106
+ }
107
+ }
108
+
109
+ impl CompletionClient for Client {
110
+ type CompletionModel = CompletionModel;
111
+
112
+ fn completion_model(&self, model: &str) -> CompletionModel {
113
+ CompletionModel::new(self.clone(), model)
114
+ }
115
+ }
116
+
117
+ impl EmbeddingsClient for Client {
118
+ type EmbeddingModel = EmbeddingModel;
119
+ fn embedding_model(&self, model: &str) -> EmbeddingModel {
120
+ EmbeddingModel::new(self.clone(), model, 0, self.base_url.to_string())
121
+ }
122
+ fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
123
+ EmbeddingModel::new(self.clone(), model, ndims, self.base_url.to_string())
124
+ }
125
+ fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
126
+ EmbeddingsBuilder::new(self.embedding_model(model))
127
+ }
128
+ }
129
+
130
+ impl VerifyClient for Client {
131
+ async fn verify(&self) -> Result<(), VerifyError> {
132
+ let response = self.get("api/tags").expect("Failed to build request").send().await?;
133
+ match response.status() {
134
+ reqwest::StatusCode::OK => Ok(()),
135
+ _ => {
136
+ response.error_for_status()?;
137
+ Ok(())
138
+ }
139
+ }
140
+ }
141
+ }
142
+
143
+ impl_conversion_traits!(
144
+ AsTranscription,
145
+ AsImageGeneration,
146
+ AsAudioGeneration for Client
147
+ );
148
+
149
+ // ---------- API Error and Response Structures ----------
150
+
151
+ #[derive(Debug, Deserialize)]
152
+ struct ApiErrorResponse {
153
+ message: String,
154
+ }
155
+
156
+ #[derive(Debug, Deserialize)]
157
+ #[serde(untagged)]
158
+ enum ApiResponse<T> {
159
+ Ok(T),
160
+ Err(ApiErrorResponse),
161
+ }
162
+
163
+ // ---------- Embedding API ----------
164
+
165
+ #[derive(Debug, Serialize, Deserialize)]
166
+ pub struct EmbeddingResponse {
167
+ pub model: String,
168
+ pub embeddings: Vec<Vec<f64>>,
169
+ #[serde(default)]
170
+ pub total_duration: Option<u32>,
171
+ #[serde(default)]
172
+ pub load_duration: Option<u32>,
173
+ #[serde(default)]
174
+ pub prompt_eval_count: Option<u32>,
175
+ }
176
+
177
+ impl From<ApiErrorResponse> for EmbeddingError {
178
+ fn from(err: ApiErrorResponse) -> Self {
179
+ EmbeddingError::ProviderError(err.message)
180
+ }
181
+ }
182
+
183
+ impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
184
+ fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
185
+ match value {
186
+ ApiResponse::Ok(response) => Ok(response),
187
+ ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
188
+ }
189
+ }
190
+ }
191
+
192
+ // ---------- Embedding Model ----------
193
+
194
+ #[derive(Clone)]
195
+ pub struct EmbeddingModel {
196
+ base_url: String,
197
+ client: Client,
198
+ pub model: String,
199
+ ndims: usize,
200
+ }
201
+
202
+ impl EmbeddingModel {
203
+ pub fn new(client: Client, model: &str, ndims: usize, base_url: String) -> Self {
204
+ Self {
205
+ client,
206
+ model: model.to_owned(),
207
+ ndims,
208
+ base_url,
209
+ }
210
+ }
211
+ }
212
+
213
+ impl embeddings::EmbeddingModel for EmbeddingModel {
214
+ const MAX_DOCUMENTS: usize = 1024;
215
+ fn ndims(&self) -> usize {
216
+ self.ndims
217
+ }
218
+
219
+ async fn embed_texts(
220
+ &self,
221
+ documents: impl IntoIterator<Item = String>,
222
+ ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223
+ let docs: Vec<String> = documents.into_iter().collect();
224
+
225
+ let mut embed_vec: Vec<Vec<f64>> = Vec::new();
226
+ for doc in &docs {
227
+ let payload = json!({
228
+ "inputs": [
229
+ {
230
+ "model_name": self.model,
231
+ "inputs": {"text": doc}
232
+ }
233
+ ]
234
+ }
235
+ );
236
+ //println!("embedding_payload:{}", payload);
237
+ //println!("self.base_url:{}", self.base_url);
238
+ let response = self
239
+ .client
240
+ .post(&self.base_url)?
241
+ .json(&payload)
242
+ .timeout(Duration::from_secs(2000))
243
+ .send()
244
+ .await
245
+ .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
246
+
247
+ // Get the headers
248
+ //let headers: HeaderMap = response.headers().clone();
249
+
250
+ //// Get the body as text
251
+ //let body = response.text().await?;
252
+
253
+ //// Print the headers
254
+ //println!("Headers:");
255
+ //for (key, value) in headers.iter() {
256
+ // println!("{}: {:?}", key, value);
257
+ //}
258
+
259
+ //// Print the body
260
+ //println!("\nBody:");
261
+ //println!("{}", body);
262
+
263
+ if response.status().is_success() {
264
+ //println!("response.json:{:?}", response.text().await?);
265
+ let json_data: Value = serde_json::from_str(&response.text().await?)?;
266
+ let emb = json_data["outputs"].as_array().unwrap();
267
+ //.unwrap_or(&vec![serde_json::Value::String(
268
+ // "No embeddings found in json output".to_string(),
269
+ //)]);
270
+ //println!("emb:{:?}", emb[0]["embeddings"].as_array().unwrap());
271
+
272
+ for item in emb[0]["embeddings"].as_array().unwrap() {
273
+ let item2 = item.as_array().unwrap();
274
+ let mut item3 = Vec::<f64>::new();
275
+ for item4 in item2 {
276
+ item3.push(item4.as_f64().unwrap())
277
+ }
278
+ embed_vec.push(item3)
279
+ }
280
+ } else {
281
+ let _ = Err::<String, EmbeddingError>(EmbeddingError::ProviderError(response.text().await.unwrap()));
282
+ }
283
+ }
284
+
285
+ if embed_vec.len() > 0 {
286
+ let api_resp = EmbeddingResponse {
287
+ model: self.model.clone(),
288
+ embeddings: embed_vec,
289
+ total_duration: None,
290
+ load_duration: None,
291
+ prompt_eval_count: None,
292
+ };
293
+
294
+ //let api_resp: EmbeddingResponse = response
295
+ // .json()
296
+ // .await
297
+ // .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
298
+ if api_resp.embeddings.len() != docs.len() {
299
+ println!("Number of embeddings:{}", api_resp.embeddings.len());
300
+ println!("Number of docs:{}", docs.len());
301
+ return Err(EmbeddingError::ResponseError(
302
+ "Number of returned embeddings does not match input".into(),
303
+ ));
304
+ }
305
+ Ok(api_resp
306
+ .embeddings
307
+ .into_iter()
308
+ .zip(docs.into_iter())
309
+ .map(|(vec, document)| embeddings::Embedding { document, vec })
310
+ .collect())
311
+ } else {
312
+ panic!("No embeddings found") // If no embeddings are found, it should crash earlier. Still adding this panic statement for safety
313
+ }
314
+ }
315
+ //Ok(Vec::<Embedding>::new())
316
+ }
317
+
318
+ // ---------- Completion API ----------
319
+
320
+ #[derive(Debug, Serialize, Deserialize)]
321
+ pub struct CompletionResponse {
322
+ pub model: String,
323
+ pub message: Message,
324
+ pub timestamp: String,
325
+ }
326
+ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
327
+ type Error = CompletionError;
328
+ fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
329
+ match resp.message {
330
+ // Process only if an assistant message is present.
331
+ Message::Assistant { content, id } => {
332
+ //let mut assistant_contents = Vec::new();
333
+ // Add the assistant's text content if any.
334
+ //if !content.is_empty() {
335
+ // assistant_contents.push(completion::AssistantContent2::text(&content));
336
+ //}
337
+
338
+ let choice = rig::one_or_many::OneOrMany::one(completion::AssistantContent::text(&content));
339
+ //let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
340
+ //let completion_tokens = resp.eval_count.unwrap_or(0);
341
+
342
+ let raw_response = CompletionResponse {
343
+ model: resp.model,
344
+ message: Message::Assistant { content, id },
345
+ timestamp: resp.timestamp,
346
+ };
347
+
348
+ Ok(completion::CompletionResponse {
349
+ choice,
350
+ usage: Usage {
351
+ input_tokens: 0, // Not provided by custom provider
352
+ output_tokens: 0, // Not provided by custom provider
353
+ total_tokens: 0, // Not provided by custom provider
354
+ },
355
+ raw_response,
356
+ })
357
+ }
358
+ _ => Err(CompletionError::ResponseError(
359
+ "Chat response does not include an assistant message".into(),
360
+ )),
361
+ }
362
+ }
363
+ }
364
+
365
+ // ---------- Completion Model ----------
366
+
367
+ #[derive(Clone)]
368
+ pub struct CompletionModel {
369
+ client: Client,
370
+ pub model: String,
371
+ }
372
+
373
+ impl CompletionModel {
374
+ pub fn new(client: Client, model: &str) -> Self {
375
+ Self {
376
+ client,
377
+ model: model.to_owned(),
378
+ }
379
+ }
380
+
381
+ fn create_completion_request(&self, completion_request: CompletionRequest) -> Result<Value, CompletionError> {
382
+ // Build up the order of messages (context, chat_history)
383
+
384
+ // Build up the order of messages (context, chat_history)
385
+ let mut partial_history = vec![];
386
+ if let Some(docs) = completion_request.normalized_documents() {
387
+ partial_history.push(docs);
388
+ }
389
+ partial_history.extend(completion_request.chat_history);
390
+
391
+ // Initialize full history with preamble (or empty if non-existent)
392
+ let mut full_history: Vec<Message> = completion_request
393
+ .preamble
394
+ .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
395
+
396
+ // Convert and extend the rest of the history
397
+ full_history.extend(
398
+ partial_history
399
+ .into_iter()
400
+ .map(|msg| Message::convert_from_message(msg))
401
+ .collect::<Result<Vec<Vec<Message>>, _>>()?
402
+ .into_iter()
403
+ .flatten()
404
+ .collect::<Vec<Message>>(),
405
+ );
406
+
407
+ //let mut context: String = "".to_string();
408
+ //if let Some(docs) = completion_request.normalized_documents() {
409
+ // println!("docs:{:?}", docs);
410
+
411
+ // match docs {
412
+ // completion::message::Message::User { content: cont } => {
413
+ // println!("cont_first:{:?}", cont.first());
414
+ // match cont.first() {
415
+ // rig::completion::message::UserContent::Document(data) => {
416
+ // //println!("data:{:?}", std::any::type_name_of_val(&data.data));
417
+ // context += &data.data; // Need to get 2nd line, will do that later
418
+ // }
419
+ // rig::completion::message::UserContent::ToolResult { .. } => todo!(),
420
+ // rig::completion::message::UserContent::Text { .. } => todo!(),
421
+ // rig::completion::message::UserContent::Image { .. } => todo!(),
422
+ // rig::completion::message::UserContent::Audio { .. } => todo!(),
423
+ // rig::completion::message::UserContent::Video { .. } => todo!(),
424
+ // }
425
+
426
+ // for item in cont.rest() {
427
+ // match item {
428
+ // rig::completion::message::UserContent::Document(data) => {
429
+ // //println!("data:{:?}", std::any::type_name_of_val(&data.data));
430
+ // context += &data.data; // Need to get 2nd line, will do that later
431
+ // }
432
+ // rig::completion::message::UserContent::ToolResult { .. } => todo!(),
433
+ // rig::completion::message::UserContent::Text { .. } => todo!(),
434
+ // rig::completion::message::UserContent::Image { .. } => todo!(),
435
+ // rig::completion::message::UserContent::Audio { .. } => todo!(),
436
+ // rig::completion::message::UserContent::Video { .. } => todo!(),
437
+ // }
438
+ // }
439
+ // }
440
+ // completion::message::Message::Assistant { .. } => todo!(),
441
+ // }
442
+ // //partial_history.push(docs);
443
+ // //context = docs;
444
+ //}
445
+
446
+ //println!("context:{}", context);
447
+ //println!(
448
+ // "completion_request.chat_history:{:?}",
449
+ // std::any::type_name_of_val(&completion_request.chat_history)
450
+ //);
451
+ //let question: String = String::from("");
452
+ //partial_history.extend(completion_request.chat_history);
453
+
454
+ //// Initialize full history with preamble (or empty if non-existent)
455
+ //let mut full_history: Vec<Message> = completion_request
456
+ // .preamble
457
+ // .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
458
+
459
+ //// Convert and extend the rest of the history
460
+ //full_history.extend(
461
+ // partial_history
462
+ // .into_iter()
463
+ // .map(|msg| msg.try_into())
464
+ // .collect::<Result<Vec<Vec<Message>>, _>>()?
465
+ // .into_iter()
466
+ // .flatten()
467
+ // .collect::<Vec<Message>>(),
468
+ //);
469
+
470
+ //let mut full_history: String = completion_request.preamble.unwrap();
471
+ //full_history = full_history + &"\n\nContext:{" + &context + &"}\n\n";
472
+ //println!("full_history:{}", full_history);
473
+
474
+ // Convert internal prompt into a provider Message
475
+ let max_new_tokens: u64;
476
+ let top_p: f64;
477
+ if let Some(extra) = completion_request.additional_params {
478
+ max_new_tokens = extra["max_new_tokens"].as_u64().unwrap();
479
+ top_p = extra["top_p"].as_f64().unwrap();
480
+ } else {
481
+ panic!("max_new_tokens and top_p not found!");
482
+ };
483
+
484
+ let mut request_payload = json!({
485
+ "inputs":[
486
+ {
487
+ "model_name": self.model,
488
+ "inputs": {
489
+ "text": full_history,
490
+ "max_new_tokens": max_new_tokens,
491
+ "temperature": completion_request.temperature,
492
+ "top_p": top_p
493
+ }
494
+ }]
495
+ });
496
+ //println!("comp_request_payload:{}", request_payload);
497
+
498
+ if !completion_request.tools.is_empty() {
499
+ println!("completion_request.tools:{:?}", completion_request.tools);
500
+ request_payload["tools"] = json!(
501
+ completion_request
502
+ .tools
503
+ .into_iter()
504
+ .map(|tool| tool.into())
505
+ .collect::<Vec<ToolDefinition>>()
506
+ );
507
+ }
508
+
509
+ //tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
510
+ Ok(request_payload)
511
+ }
512
+ }
513
+
514
+ // ---------- CompletionModel Implementation ----------
515
+
516
+ #[derive(Clone, Serialize, Deserialize, Debug)]
517
+ pub struct StreamingCompletionResponse {
518
+ pub done_reason: Option<String>,
519
+ pub total_duration: Option<u64>,
520
+ pub load_duration: Option<u64>,
521
+ pub prompt_eval_count: Option<u64>,
522
+ pub prompt_eval_duration: Option<u64>,
523
+ pub eval_count: Option<u64>,
524
+ pub eval_duration: Option<u64>,
525
+ }
526
+
527
+ impl GetTokenUsage for StreamingCompletionResponse {
528
+ fn token_usage(&self) -> Option<rig::completion::Usage> {
529
+ let mut usage = rig::completion::Usage::new();
530
+ let input_tokens = self.prompt_eval_count.unwrap_or_default();
531
+ let output_tokens = self.eval_count.unwrap_or_default();
532
+ usage.input_tokens = input_tokens;
533
+ usage.output_tokens = output_tokens;
534
+ usage.total_tokens = input_tokens + output_tokens;
535
+
536
+ Some(usage)
537
+ }
538
+ }
539
+
540
+ impl completion::CompletionModel for CompletionModel {
541
+ type Response = CompletionResponse;
542
+ type StreamingResponse = StreamingCompletionResponse;
543
+
544
+ async fn completion(
545
+ &self,
546
+ completion_request: CompletionRequest,
547
+ ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
548
+ let request_payload = self.create_completion_request(completion_request)?;
549
+
550
+ let response = self
551
+ .client
552
+ .post(&self.client.base_url.to_string())?
553
+ .json(&request_payload)
554
+ .send()
555
+ .await
556
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
557
+ //println!("response:{:?}", response);
558
+ if response.status().is_success() {
559
+ let text = response
560
+ .text()
561
+ .await
562
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
563
+ let text_json: Value = serde_json::from_str(&text)?;
564
+ //println!("text:{:?}", text_json);
565
+ //tracing::debug!(target: "rig", "Myprovider chat response: {}", text);
566
+ let chat_resp: CompletionResponse = CompletionResponse {
567
+ model: text_json["model_name"].to_string(),
568
+ message: Message::Assistant {
569
+ id: text_json["id"].to_string(),
570
+ content: text_json["outputs"].to_string(),
571
+ },
572
+ timestamp: text_json["timestamp"].to_string(),
573
+ };
574
+ //println!("chat_resp:{:?}", chat_resp);
575
+ let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
576
+ Ok(conv)
577
+ } else {
578
+ let err_text = response
579
+ .text()
580
+ .await
581
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
582
+ Err(CompletionError::ProviderError(err_text))
583
+ }
584
+ }
585
+
586
+ async fn stream(
587
+ &self,
588
+ request: CompletionRequest,
589
+ ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
590
+ let mut request_payload = self.create_completion_request(request)?;
591
+ merge_inplace(&mut request_payload, json!({"stream": true}));
592
+
593
+ let response = self
594
+ .client
595
+ .post("api/chat")?
596
+ .json(&request_payload)
597
+ .send()
598
+ .await
599
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
600
+
601
+ if !response.status().is_success() {
602
+ let err_text = response
603
+ .text()
604
+ .await
605
+ .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
606
+ return Err(CompletionError::ProviderError(err_text));
607
+ }
608
+
609
+ let stream = Box::pin(stream! {
610
+ let mut stream = response.bytes_stream();
611
+ while let Some(chunk_result) = stream.next().await {
612
+ let chunk = match chunk_result {
613
+ Ok(c) => c,
614
+ Err(e) => {
615
+ yield Err(CompletionError::from(e));
616
+ break;
617
+ }
618
+ };
619
+
620
+ let text = match String::from_utf8(chunk.to_vec()) {
621
+ Ok(t) => t,
622
+ Err(e) => {
623
+ yield Err(CompletionError::ResponseError(e.to_string()));
624
+ break;
625
+ }
626
+ };
627
+
628
+
629
+ for line in text.lines() {
630
+ let line = line.to_string();
631
+
632
+ let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
633
+ continue;
634
+ };
635
+
636
+ match response.message {
637
+ Message::Assistant{ content, .. } => {
638
+ if !content.is_empty() {
639
+ yield Ok(RawStreamingChoice::Message(content))
640
+ }
641
+ }
642
+ _ => {
643
+ continue;
644
+ }
645
+ }
646
+
647
+ //if response.message {
648
+ // yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
649
+ // total_duration: response.total_duration,
650
+ // load_duration: response.load_duration,
651
+ // prompt_eval_count: response.prompt_eval_count,
652
+ // prompt_eval_duration: response.prompt_eval_duration,
653
+ // eval_count: response.eval_count,
654
+ // eval_duration: response.eval_duration,
655
+ // done_reason: response.done_reason,
656
+ // }));
657
+ //}
658
+ }
659
+ }
660
+ });
661
+
662
+ Ok(streaming::StreamingCompletionResponse::stream(stream))
663
+ }
664
+ }
665
+
666
+ // ---------- Tool Definition Conversion ----------
667
+
668
+ /// Myprovider-required tool definition format.
669
+ #[derive(Clone, Debug, Deserialize, Serialize)]
670
+ pub struct ToolDefinition {
671
+ #[serde(rename = "type")]
672
+ pub type_field: String, // Fixed as "function"
673
+ pub function: completion::ToolDefinition,
674
+ }
675
+
676
+ /// Convert internal ToolDefinition (from the completion module) into Myprovider's tool definition.
677
+ impl From<rig::completion::ToolDefinition> for ToolDefinition {
678
+ fn from(tool: rig::completion::ToolDefinition) -> Self {
679
+ ToolDefinition {
680
+ type_field: "function".to_owned(),
681
+ function: completion::ToolDefinition {
682
+ name: tool.name,
683
+ description: tool.description,
684
+ parameters: tool.parameters,
685
+ },
686
+ }
687
+ }
688
+ }
689
+
690
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
691
+ pub struct ToolCall {
692
+ // pub id: String,
693
+ #[serde(default, rename = "type")]
694
+ pub r#type: ToolType,
695
+ pub function: Function,
696
+ }
697
+ #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
698
+ #[serde(rename_all = "lowercase")]
699
+ pub enum ToolType {
700
+ #[default]
701
+ Function,
702
+ }
703
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
704
+ pub struct Function {
705
+ pub name: String,
706
+ pub arguments: Value,
707
+ }
708
+
709
+ // ---------- Provider Message Definition ----------
710
+
711
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
712
+ #[serde(tag = "role", rename_all = "lowercase")]
713
+ pub enum Message {
714
+ User {
715
+ content: String,
716
+ #[serde(skip_serializing_if = "Option::is_none")]
717
+ images: Option<Vec<String>>,
718
+ #[serde(skip_serializing_if = "Option::is_none")]
719
+ name: Option<String>,
720
+ },
721
+ Assistant {
722
+ #[serde(default)]
723
+ content: String,
724
+ #[serde(default)]
725
+ id: String,
726
+ //#[serde(skip_serializing_if = "Option::is_none")]
727
+ //thinking: Option<String>,
728
+ //#[serde(skip_serializing_if = "Option::is_none")]
729
+ //images: Option<Vec<String>>,
730
+ //#[serde(skip_serializing_if = "Option::is_none")]
731
+ //name: Option<String>,
732
+ //#[serde(default, deserialize_with = "json_utils::null_or_vec")]
733
+ //tool_calls: Vec<ToolCall>,
734
+ },
735
+ System {
736
+ content: String,
737
+ #[serde(skip_serializing_if = "Option::is_none")]
738
+ images: Option<Vec<String>>,
739
+ #[serde(skip_serializing_if = "Option::is_none")]
740
+ name: Option<String>,
741
+ },
742
+ #[serde(rename = "tool")]
743
+ ToolResult {
744
+ #[serde(rename = "tool_name")]
745
+ name: String,
746
+ content: String,
747
+ },
748
+ }
749
+
750
+ /// -----------------------------
751
+ /// Provider Message Conversions
752
+ /// -----------------------------
753
+ /// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
754
+ /// (Only User and Assistant variants are supported.)
755
+ impl ConvertMessage for Message {
756
+ type Error = rig::message::MessageError;
757
+ fn convert_from_message(internal_msg: message::Message) -> Result<Vec<Self>, Self::Error> {
758
+ use rig::message::Message as InternalMessage;
759
+ match internal_msg {
760
+ InternalMessage::User { content, .. } => {
761
+ let (tool_results, other_content): (Vec<_>, Vec<_>) = content
762
+ .into_iter()
763
+ .partition(|content| matches!(content, rig::message::UserContent::ToolResult(_)));
764
+
765
+ if !tool_results.is_empty() {
766
+ tool_results
767
+ .into_iter()
768
+ .map(|content| match content {
769
+ rig::message::UserContent::ToolResult(rig::message::ToolResult { id, content, .. }) => {
770
+ // Ollama expects a single string for tool results, so we concatenate
771
+ let content_string = content
772
+ .into_iter()
773
+ .map(|content| match content {
774
+ rig::message::ToolResultContent::Text(text) => text.text,
775
+ _ => "[Non-text content]".to_string(),
776
+ })
777
+ .collect::<Vec<_>>()
778
+ .join("\n");
779
+
780
+ Ok::<_, rig::message::MessageError>(Message::ToolResult {
781
+ name: id,
782
+ content: content_string,
783
+ })
784
+ }
785
+ _ => unreachable!(),
786
+ })
787
+ .collect::<Result<Vec<_>, _>>()
788
+ } else {
789
+ // Ollama requires separate text content and images array
790
+ let (texts, _images) =
791
+ other_content
792
+ .into_iter()
793
+ .fold((Vec::new(), Vec::new()), |(mut texts, mut images), content| {
794
+ match content {
795
+ rig::message::UserContent::Text(rig::message::Text { text }) => texts.push(text),
796
+ rig::message::UserContent::Image(rig::message::Image { data, .. }) => {
797
+ images.push(data)
798
+ }
799
+ rig::message::UserContent::Document(rig::message::Document { data, .. }) => {
800
+ texts.push(data)
801
+ }
802
+ _ => {} // Audio not supported by Ollama
803
+ }
804
+ (texts, images)
805
+ });
806
+
807
+ Ok(vec![Message::User {
808
+ content: texts.join(" "),
809
+ images: None,
810
+ name: None,
811
+ }])
812
+ }
813
+ }
814
+ InternalMessage::Assistant { content, .. } => {
815
+ let mut thinking: Option<String> = None;
816
+ let (text_content, _tool_calls) =
817
+ content
818
+ .into_iter()
819
+ .fold((Vec::new(), Vec::new()), |(mut texts, mut tools), content| {
820
+ match content {
821
+ rig::message::AssistantContent::Text(text) => texts.push(text.text),
822
+ rig::message::AssistantContent::ToolCall(_tool_call) => tools.push(_tool_call),
823
+ rig::message::AssistantContent::Reasoning(rig::message::Reasoning {
824
+ reasoning,
825
+ ..
826
+ }) => {
827
+ thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
828
+ }
829
+ }
830
+ (texts, tools)
831
+ });
832
+
833
+ // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
834
+ // so either `content` or `tool_calls` will have some content.
835
+ #[allow(unreachable_code)]
836
+ Ok(vec![Message::Assistant {
837
+ content: text_content.join(" "),
838
+ id: todo!(),
839
+ }])
840
+ }
841
+ }
842
+ }
843
+ }
844
+
845
+ /// Conversion from provider Message to a completion message.
846
+ /// This is needed so that responses can be converted back into chat history.
847
+ impl From<Message> for rig::completion::Message {
848
+ fn from(msg: Message) -> Self {
849
+ match msg {
850
+ Message::User { content, .. } => rig::completion::Message::User {
851
+ content: OneOrMany::one(rig::completion::message::UserContent::Text(Text { text: content })),
852
+ },
853
+ Message::Assistant { content, .. } => {
854
+ let assistant_contents = vec![rig::completion::message::AssistantContent::Text(Text { text: content })];
855
+ rig::completion::Message::Assistant {
856
+ id: None,
857
+ content: OneOrMany::many(assistant_contents).unwrap(),
858
+ }
859
+ }
860
+ // System and ToolResult are converted to User message as needed.
861
+ Message::System { content, .. } => rig::completion::Message::User {
862
+ content: OneOrMany::one(rig::completion::message::UserContent::Text(Text { text: content })),
863
+ },
864
+ Message::ToolResult { name, content } => rig::completion::Message::User {
865
+ content: OneOrMany::one(message::UserContent::tool_result(
866
+ name,
867
+ OneOrMany::one(message::ToolResultContent::text(content)),
868
+ )),
869
+ },
870
+ }
871
+ }
872
+ }
873
+
874
+ impl Message {
875
+ /// Constructs a system message.
876
+ pub fn system(content: &str) -> Self {
877
+ Message::System {
878
+ content: content.to_owned(),
879
+ images: None,
880
+ name: None,
881
+ }
882
+ }
883
+ }
884
+
885
+ // ---------- Additional Message Types ----------
886
+
887
+ impl From<rig::message::ToolCall> for ToolCall {
888
+ fn from(tool_call: rig::message::ToolCall) -> Self {
889
+ Self {
890
+ r#type: ToolType::Function,
891
+ function: Function {
892
+ name: tool_call.function.name,
893
+ arguments: tool_call.function.arguments,
894
+ },
895
+ }
896
+ }
897
+ }
898
+
899
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
900
+ pub struct SystemContent {
901
+ #[serde(default)]
902
+ r#type: SystemContentType,
903
+ text: String,
904
+ }
905
+
906
+ #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
907
+ #[serde(rename_all = "lowercase")]
908
+ pub enum SystemContentType {
909
+ #[default]
910
+ Text,
911
+ }
912
+
913
+ impl From<String> for SystemContent {
914
+ fn from(s: String) -> Self {
915
+ SystemContent {
916
+ r#type: SystemContentType::default(),
917
+ text: s,
918
+ }
919
+ }
920
+ }
921
+
922
+ impl FromStr for SystemContent {
923
+ type Err = std::convert::Infallible;
924
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
925
+ Ok(SystemContent {
926
+ r#type: SystemContentType::default(),
927
+ text: s.to_string(),
928
+ })
929
+ }
930
+ }
931
+
932
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
933
+ pub struct AssistantContent {
934
+ pub text: String,
935
+ }
936
+
937
+ impl FromStr for AssistantContent {
938
+ type Err = std::convert::Infallible;
939
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
940
+ Ok(AssistantContent { text: s.to_owned() })
941
+ }
942
+ }
943
+
944
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
945
+ #[serde(tag = "type", rename_all = "lowercase")]
946
+ pub enum UserContent {
947
+ Text { text: String },
948
+ Image { image_url: ImageUrl },
949
+ // Audio variant removed as Ollama API does not support audio input.
950
+ }
951
+
952
+ impl FromStr for UserContent {
953
+ type Err = std::convert::Infallible;
954
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
955
+ Ok(UserContent::Text { text: s.to_owned() })
956
+ }
957
+ }
958
+
959
+ #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
960
+ pub struct ImageUrl {
961
+ pub url: String,
962
+ #[serde(default)]
963
+ pub detail: ImageDetail,
964
+ }
965
+
966
+ // ---------JSON utils functions -----------------------------
967
+
968
+ pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) {
969
+ if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) {
970
+ b_map.into_iter().for_each(|(key, value)| {
971
+ a_map.insert(key, value);
972
+ });
973
+ }
974
+ }
975
+
976
+ // =================================================================
977
+ // Tests
978
+ // =================================================================
979
+
980
+ #[cfg(test)]
981
+ mod tests {
982
+ use super::*;
983
+ use rig::agent::AgentBuilder;
984
+ use rig::completion::request::Prompt;
985
+ //use rig::providers::myprovider;
986
+ use rig::vector_store::in_memory_store::InMemoryVectorStore;
987
+ //use serde_json::json;
988
+ use serde_json;
989
+ use std::fs::{self};
990
+ use std::path::Path;
991
+
992
+ // Test deserialization and conversion for the /api/chat endpoint.
993
+ #[tokio::test]
994
+ #[ignore]
995
+
996
+ async fn test_myprovider_implementation() {
997
+ let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
998
+ let serverconfig_file_path = Path::new("../../serverconfig.json");
999
+ let absolute_path = serverconfig_file_path.canonicalize().unwrap();
1000
+
1001
+ // Read the file
1002
+ let data = fs::read_to_string(absolute_path).unwrap();
1003
+
1004
+ // Parse the JSON data
1005
+ let json: serde_json::Value = serde_json::from_str(&data).unwrap();
1006
+
1007
+ // Initialize Myprovider client
1008
+ let myprovider_host = json["sj_apilink"].as_str().unwrap();
1009
+ let myprovider_embedding_model = json["sj_embedding_model_name"].as_str().unwrap();
1010
+ let myprovider_comp_model = json["sj_comp_model_name"].as_str().unwrap();
1011
+ let myprovider_client = Client::builder()
1012
+ .base_url(myprovider_host)
1013
+ .build()
1014
+ .expect("myprovider server not found");
1015
+ //let myprovider_client = myprovider::Client::new();
1016
+ let embedding_model = myprovider_client.embedding_model(myprovider_embedding_model);
1017
+ let comp_model = myprovider_client.completion_model(myprovider_comp_model); // "granite3-dense:latest" "PetrosStav/gemma3-tools:12b" "llama3-groq-tool-use:latest" "PetrosStav/gemma3-tools:12b"
1018
+
1019
+ let contents = String::from("SNV/SNP or point mutations nucleotide mutations are very common forms of mutations which can often give rise to genetic diseases such as cancer, Alzheimer's disease etc. They can be duw to substitution of nucleotide, or insertion or deletion of a nucleotide. Indels are multi-nucleotide insertion/deletion/substitutions. Complex indels are indels where insertion and deletion have happened in the same genomic locus. Every genomic sample from each patient has its own set of mutations therefore requiring personalized treatment.
1020
+
1021
+ If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
1022
+
1023
+ ---
1024
+
1025
+ Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
1026
+
1027
+ If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
1028
+
1029
+ ---
1030
+
1031
+ Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
1032
+
1033
+ If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
1034
+ ---
1035
+
1036
+ Hierarchial clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
1037
+
1038
+ If a ProteinPaint dataset contains hierarchial data then return JSON with single key, 'hierarchial'.
1039
+
1040
+ ---
1041
+
1042
+ Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
1043
+
1044
+ If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
1045
+
1046
+ ---
1047
+
1048
+ Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens.
1049
+
1050
+ There are two main methods of survival analysis:
1051
+
1052
+ 1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
1053
+ 2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
1054
+
1055
+ The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
1056
+ HR = 1: No effect
1057
+ HR < 1: Reduction in the hazard
1058
+ HR > 1: Increase in Hazard
1059
+
1060
+ If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
1061
+
1062
+ ---
1063
+
1064
+ Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
1065
+
1066
+ If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
1067
+
1068
+ ---
1069
+
1070
+ Summary plot in ProteinPaint shows the various facets of the datasets. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for visualizing all the different facets of the dataset. You can display a categorical variable and overlay another variable on top it and stratify (or divide) using a third variable simultaneously. You can also custom filters to the dataset so that you can only study part of the dataset. If a user query asks about variant calling or mapping reads then JSON with single key, 'summary'.
1071
+
1072
+ Sample Query1: \"Show all fusions for patients with age less than 30\"
1073
+ Sample Answer1: { \"answer\": \"summary\" }
1074
+
1075
+ Sample Query1: \"List all molecular subtypes of leukemia\"
1076
+ Sample Answer1: { \"answer\": \"summary\" }
1077
+
1078
+ ---
1079
+
1080
+ If a query does not match any of the fields described above, then return JSON with single key, 'none'
1081
+ ");
1082
+
1083
+ // Split the contents by the delimiter "---"
1084
+ let parts: Vec<&str> = contents.split("---").collect();
1085
+
1086
+ //let schema_json: Value = serde_json::to_value(schemars::schema_for!(OutputJson)).unwrap(); // error handling here
1087
+
1088
+ //let additional = json!({
1089
+ // "format": schema_json
1090
+ //});
1091
+
1092
+ // Print the separated parts
1093
+ let mut rag_docs = Vec::<String>::new();
1094
+ for (_i, part) in parts.iter().enumerate() {
1095
+ //println!("Part {}: {}", i + 1, part.trim());
1096
+ rag_docs.push(part.trim().to_string())
1097
+ }
1098
+
1099
+ let top_k: usize = 3;
1100
+ // Create embeddings and add to vector store
1101
+ let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
1102
+ .documents(rag_docs)
1103
+ .expect("Reason1")
1104
+ .build()
1105
+ .await
1106
+ .unwrap();
1107
+
1108
+ // Create vector store
1109
+ let mut vector_store = InMemoryVectorStore::<String>::default();
1110
+ InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
1111
+
1112
+ let max_new_tokens: usize = 512;
1113
+ let top_p: f32 = 0.95;
1114
+ let temperature: f64 = 0.01;
1115
+ let additional = json!({
1116
+ "max_new_tokens": max_new_tokens,
1117
+ "top_p": top_p
1118
+ });
1119
+
1120
+ // Create RAG agent
1121
+ let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
1122
+
1123
+ let response = agent.prompt(user_input).await.expect("Failed to prompt myprovider");
1124
+
1125
+ //println!("Myprovider: {}", response);
1126
+ let result = response.replace("json", "").replace("```", "");
1127
+ //println!("result:{}", result);
1128
+ let json_value: Value = serde_json::from_str(&result).expect("REASON");
1129
+ let json_value2: Value = serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
1130
+ //println!("json_value2:{}", json_value2.as_str().unwrap());
1131
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
1132
+ assert_eq!(json_value3["answer"].to_string().replace("\"", ""), "dge");
1133
+ }
1134
+ }