isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,129 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
+ from isa_model.inference.services.base_service import BaseLLMService
4
+ from isa_model.inference.providers.base_provider import BaseProvider
5
+ from isa_model.inference.backends.openai_client import OpenAIBackendClient
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class OpenAILLMService(BaseLLMService):
10
+ """OpenAI LLM service implementation"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo", backend: Optional[OpenAIBackendClient] = None):
13
+ super().__init__(provider, model_name)
14
+
15
+ # Use provided backend or create new one
16
+ if backend:
17
+ self.backend = backend
18
+ else:
19
+ api_key = self.config.get("api_key", "")
20
+ api_base = self.config.get("api_base", "https://api.openai.com/v1")
21
+ self.backend = OpenAIBackendClient(api_key, api_base)
22
+
23
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
24
+ logger.info(f"Initialized OpenAILLMService with model {model_name}")
25
+
26
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
27
+ """Universal invocation method"""
28
+ if isinstance(prompt, str):
29
+ return await self.acompletion(prompt)
30
+ elif isinstance(prompt, list):
31
+ return await self.achat(prompt)
32
+ else:
33
+ raise ValueError("Prompt must be string or list of messages")
34
+
35
+ async def achat(self, messages: List[Dict[str, str]]):
36
+ """Chat completion method"""
37
+ try:
38
+ temperature = self.config.get("temperature", 0.7)
39
+ max_tokens = self.config.get("max_tokens", 1024)
40
+
41
+ payload = {
42
+ "model": self.model_name,
43
+ "messages": messages,
44
+ "temperature": temperature,
45
+ "max_tokens": max_tokens
46
+ }
47
+ response = await self.backend.post("/chat/completions", payload)
48
+
49
+ # Update token usage
50
+ self.last_token_usage = response.get("usage", {
51
+ "prompt_tokens": 0,
52
+ "completion_tokens": 0,
53
+ "total_tokens": 0
54
+ })
55
+
56
+ return response["choices"][0]["message"]["content"]
57
+
58
+ except Exception as e:
59
+ logger.error(f"Error in chat completion: {e}")
60
+ raise
61
+
62
+ async def acompletion(self, prompt: str):
63
+ """Text completion method (using chat API since completions is deprecated)"""
64
+ try:
65
+ messages = [{"role": "user", "content": prompt}]
66
+ return await self.achat(messages)
67
+ except Exception as e:
68
+ logger.error(f"Error in text completion: {e}")
69
+ raise
70
+
71
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
72
+ """Generate multiple completions"""
73
+ try:
74
+ temperature = self.config.get("temperature", 0.7)
75
+ max_tokens = self.config.get("max_tokens", 1024)
76
+
77
+ payload = {
78
+ "model": self.model_name,
79
+ "messages": messages,
80
+ "temperature": temperature,
81
+ "max_tokens": max_tokens,
82
+ "n": n
83
+ }
84
+ response = await self.backend.post("/chat/completions", payload)
85
+
86
+ # Update token usage
87
+ self.last_token_usage = response.get("usage", {
88
+ "prompt_tokens": 0,
89
+ "completion_tokens": 0,
90
+ "total_tokens": 0
91
+ })
92
+
93
+ return [choice["message"]["content"] for choice in response["choices"]]
94
+ except Exception as e:
95
+ logger.error(f"Error in generate: {e}")
96
+ raise
97
+
98
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
99
+ """Stream chat responses"""
100
+ try:
101
+ temperature = self.config.get("temperature", 0.7)
102
+ max_tokens = self.config.get("max_tokens", 1024)
103
+
104
+ payload = {
105
+ "model": self.model_name,
106
+ "messages": messages,
107
+ "temperature": temperature,
108
+ "max_tokens": max_tokens,
109
+ "stream": True
110
+ }
111
+
112
+ async for chunk in self.backend.stream_chat(payload):
113
+ yield chunk
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error in stream chat: {e}")
117
+ raise
118
+
119
+ def get_token_usage(self):
120
+ """Get total token usage statistics"""
121
+ return self.last_token_usage
122
+
123
+ def get_last_token_usage(self) -> Dict[str, int]:
124
+ """Get token usage from last request"""
125
+ return self.last_token_usage
126
+
127
+ async def close(self):
128
+ """Close the backend client"""
129
+ await self.backend.close()
@@ -0,0 +1,179 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
+ from isa_model.inference.services.base_service import BaseLLMService
4
+ from isa_model.inference.providers.base_provider import BaseProvider
5
+ from isa_model.inference.backends.replicate_client import ReplicateBackendClient
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class ReplicateLLMService(BaseLLMService):
10
+ """Replicate LLM service implementation"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "meta/llama-3-8b-instruct", backend: Optional[ReplicateBackendClient] = None):
13
+ super().__init__(provider, model_name)
14
+
15
+ # Use provided backend or create new one
16
+ if backend:
17
+ self.backend = backend
18
+ else:
19
+ api_token = self.config.get("api_token", "")
20
+ self.backend = ReplicateBackendClient(api_token)
21
+
22
+ # Parse model name for Replicate format (owner/model)
23
+ if "/" not in model_name:
24
+ logger.warning(f"Model name {model_name} is not in Replicate format (owner/model). Using as-is.")
25
+
26
+ # Store version separately if provided
27
+ self.model_version = None
28
+ if ":" in model_name:
29
+ self.model_name, self.model_version = model_name.split(":", 1)
30
+
31
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
32
+ logger.info(f"Initialized ReplicateLLMService with model {model_name}")
33
+
34
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
35
+ """Universal invocation method"""
36
+ if isinstance(prompt, str):
37
+ return await self.acompletion(prompt)
38
+ elif isinstance(prompt, list):
39
+ return await self.achat(prompt)
40
+ else:
41
+ raise ValueError("Prompt must be string or list of messages")
42
+
43
+ async def achat(self, messages: List[Dict[str, str]]):
44
+ """Chat completion method"""
45
+ try:
46
+ temperature = self.config.get("temperature", 0.7)
47
+ max_tokens = self.config.get("max_tokens", 1024)
48
+
49
+ # Convert to Replicate format
50
+ prompt = self._convert_messages_to_prompt(messages)
51
+
52
+ # Prepare input data
53
+ input_data = {
54
+ "prompt": prompt,
55
+ "temperature": temperature,
56
+ "max_new_tokens": max_tokens,
57
+ "system_prompt": self._extract_system_prompt(messages)
58
+ }
59
+
60
+ # Call Replicate API
61
+ prediction = await self.backend.create_prediction(
62
+ self.model_name,
63
+ self.model_version,
64
+ input_data
65
+ )
66
+
67
+ # Get output - could be a list of strings or a single string
68
+ output = prediction.get("output", "")
69
+ if isinstance(output, list):
70
+ output = "".join(output)
71
+
72
+ # Approximate token usage - Replicate doesn't provide token counts
73
+ approx_prompt_tokens = len(prompt) // 4 # Very rough approximation
74
+ approx_completion_tokens = len(output) // 4
75
+ self.last_token_usage = {
76
+ "prompt_tokens": approx_prompt_tokens,
77
+ "completion_tokens": approx_completion_tokens,
78
+ "total_tokens": approx_prompt_tokens + approx_completion_tokens
79
+ }
80
+
81
+ return output
82
+
83
+ except Exception as e:
84
+ logger.error(f"Error in chat completion: {e}")
85
+ raise
86
+
87
+ def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
88
+ """Convert chat messages to Replicate prompt format"""
89
+ prompt = ""
90
+
91
+ for msg in messages:
92
+ role = msg.get("role", "user")
93
+ content = msg.get("content", "")
94
+
95
+ # Skip system prompts - handled separately
96
+ if role == "system":
97
+ continue
98
+
99
+ if role == "user":
100
+ prompt += f"Human: {content}\n\n"
101
+ elif role == "assistant":
102
+ prompt += f"Assistant: {content}\n\n"
103
+ else:
104
+ # Default to user for unknown roles
105
+ prompt += f"Human: {content}\n\n"
106
+
107
+ # Add final assistant prefix for the model to continue
108
+ prompt += "Assistant: "
109
+
110
+ return prompt
111
+
112
+ def _extract_system_prompt(self, messages: List[Dict[str, str]]) -> str:
113
+ """Extract system prompt from messages"""
114
+ for msg in messages:
115
+ if msg.get("role") == "system":
116
+ return msg.get("content", "")
117
+ return ""
118
+
119
+ async def acompletion(self, prompt: str):
120
+ """Text completion method"""
121
+ try:
122
+ # For simple completion, use chat format with a single user message
123
+ messages = [{"role": "user", "content": prompt}]
124
+ return await self.achat(messages)
125
+ except Exception as e:
126
+ logger.error(f"Error in text completion: {e}")
127
+ raise
128
+
129
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
130
+ """Generate multiple completions"""
131
+ # Replicate doesn't support multiple outputs in one call,
132
+ # so we make multiple calls
133
+ results = []
134
+ for _ in range(n):
135
+ result = await self.achat(messages)
136
+ results.append(result)
137
+ return results
138
+
139
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
140
+ """Stream chat responses"""
141
+ try:
142
+ temperature = self.config.get("temperature", 0.7)
143
+ max_tokens = self.config.get("max_tokens", 1024)
144
+
145
+ # Convert to Replicate format
146
+ prompt = self._convert_messages_to_prompt(messages)
147
+
148
+ # Prepare input data
149
+ input_data = {
150
+ "prompt": prompt,
151
+ "temperature": temperature,
152
+ "max_new_tokens": max_tokens,
153
+ "system_prompt": self._extract_system_prompt(messages),
154
+ "stream": True
155
+ }
156
+
157
+ # Call Replicate API with streaming
158
+ async for chunk in self.backend.stream_prediction(
159
+ self.model_name,
160
+ self.model_version,
161
+ input_data
162
+ ):
163
+ yield chunk
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error in stream chat: {e}")
167
+ raise
168
+
169
+ def get_token_usage(self):
170
+ """Get total token usage statistics"""
171
+ return self.last_token_usage
172
+
173
+ def get_last_token_usage(self) -> Dict[str, int]:
174
+ """Get token usage from last request"""
175
+ return self.last_token_usage
176
+
177
+ async def close(self):
178
+ """Close the backend client"""
179
+ await self.backend.close()
@@ -0,0 +1,230 @@
1
+ import json
2
+ import logging
3
+ import asyncio
4
+ from typing import Dict, List, Any, AsyncGenerator, Optional, Union
5
+
6
+ import numpy as np
7
+
8
+ from isa_model.inference.services.base_service import BaseLLMService
9
+ from isa_model.inference.providers.triton_provider import TritonProvider
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class TritonLLMService(BaseLLMService):
15
+ """
16
+ LLM service that uses Triton Inference Server to run inference.
17
+ """
18
+
19
+ def __init__(self, provider: TritonProvider, model_name: str):
20
+ """
21
+ Initialize the Triton LLM service.
22
+
23
+ Args:
24
+ provider: The Triton provider
25
+ model_name: Name of the model in Triton (e.g., "Llama3-8B")
26
+ """
27
+ super().__init__(provider, model_name)
28
+ self.client = None
29
+ self.tokenizer = None
30
+ self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
31
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
32
+
33
+ async def _initialize_client(self):
34
+ """Initialize the Triton client"""
35
+ if self.client is None:
36
+ self.client = self.provider.create_client()
37
+
38
+ # Check if model is ready
39
+ if not self.provider.is_model_ready(self.model_name):
40
+ logger.error(f"Model {self.model_name} is not ready on Triton server")
41
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
42
+
43
+ logger.info(f"Initialized Triton client for model: {self.model_name}")
44
+
45
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
46
+ """
47
+ Universal invocation method.
48
+
49
+ Args:
50
+ prompt: Text prompt or chat messages
51
+
52
+ Returns:
53
+ Generated text
54
+ """
55
+ if isinstance(prompt, str):
56
+ return await self.acompletion(prompt)
57
+ elif isinstance(prompt, list) and all(isinstance(m, dict) for m in prompt):
58
+ return await self.achat(prompt)
59
+ else:
60
+ raise ValueError("Prompt must be either a string or a list of message dictionaries")
61
+
62
+ async def achat(self, messages: List[Dict[str, str]]) -> str:
63
+ """
64
+ Chat completion method.
65
+
66
+ Args:
67
+ messages: List of message dictionaries with 'role' and 'content'
68
+
69
+ Returns:
70
+ Generated chat response
71
+ """
72
+ # Format chat messages into a single prompt
73
+ formatted_prompt = self._format_chat_messages(messages)
74
+ return await self.acompletion(formatted_prompt)
75
+
76
+ async def acompletion(self, prompt: str) -> str:
77
+ """
78
+ Text completion method.
79
+
80
+ Args:
81
+ prompt: Text prompt
82
+
83
+ Returns:
84
+ Generated text completion
85
+ """
86
+ await self._initialize_client()
87
+
88
+ try:
89
+ import tritonclient.http as httpclient
90
+
91
+ # Create input tensors
92
+ input_text = np.array([prompt], dtype=np.object_)
93
+ inputs = [httpclient.InferInput("TEXT", input_text.shape, "BYTES")]
94
+ inputs[0].set_data_from_numpy(input_text)
95
+
96
+ # Default parameters
97
+ generation_params = {
98
+ "max_new_tokens": self.config.get("max_new_tokens", 512),
99
+ "temperature": self.config.get("temperature", 0.7),
100
+ "top_p": self.config.get("top_p", 0.9),
101
+ "do_sample": self.config.get("temperature", 0.7) > 0
102
+ }
103
+
104
+ # Add parameters as input tensor
105
+ param_json = json.dumps(generation_params)
106
+ param_data = np.array([param_json], dtype=np.object_)
107
+ param_input = httpclient.InferInput("PARAMETERS", param_data.shape, "BYTES")
108
+ param_input.set_data_from_numpy(param_data)
109
+ inputs.append(param_input)
110
+
111
+ # Create output tensor
112
+ outputs = [httpclient.InferRequestedOutput("TEXT")]
113
+
114
+ # Send the request
115
+ response = await asyncio.to_thread(
116
+ self.client.infer,
117
+ self.model_name,
118
+ inputs,
119
+ outputs=outputs
120
+ )
121
+
122
+ # Process the response
123
+ output = response.as_numpy("TEXT")
124
+ response_text = output[0].decode('utf-8')
125
+
126
+ # Update token usage (estimated since we don't have actual token counts)
127
+ prompt_tokens = len(prompt) // 4 # Rough estimate
128
+ completion_tokens = len(response_text) // 4 # Rough estimate
129
+ total_tokens = prompt_tokens + completion_tokens
130
+
131
+ self.last_token_usage = {
132
+ "prompt_tokens": prompt_tokens,
133
+ "completion_tokens": completion_tokens,
134
+ "total_tokens": total_tokens
135
+ }
136
+
137
+ # Update total token usage
138
+ self.token_usage["prompt_tokens"] += prompt_tokens
139
+ self.token_usage["completion_tokens"] += completion_tokens
140
+ self.token_usage["total_tokens"] += total_tokens
141
+
142
+ return response_text
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error during Triton inference: {str(e)}")
146
+ raise
147
+
148
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
149
+ """
150
+ Generate multiple completions.
151
+
152
+ Args:
153
+ messages: List of message dictionaries
154
+ n: Number of completions to generate
155
+
156
+ Returns:
157
+ List of generated completions
158
+ """
159
+ results = []
160
+ for _ in range(n):
161
+ result = await self.achat(messages)
162
+ results.append(result)
163
+ return results
164
+
165
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
166
+ """
167
+ Stream chat responses.
168
+
169
+ Args:
170
+ messages: List of message dictionaries
171
+
172
+ Yields:
173
+ Generated text chunks
174
+ """
175
+ # For Triton, we don't have true streaming, so we generate the full response
176
+ # and then simulate streaming
177
+ full_response = await self.achat(messages)
178
+
179
+ # Simulate streaming by yielding words
180
+ words = full_response.split()
181
+ for i in range(len(words)):
182
+ chunk = ' '.join(words[:i+1])
183
+ yield chunk
184
+ await asyncio.sleep(0.05) # Small delay to simulate streaming
185
+
186
+ def get_token_usage(self) -> Dict[str, int]:
187
+ """
188
+ Get total token usage statistics.
189
+
190
+ Returns:
191
+ Dictionary with token usage statistics
192
+ """
193
+ return self.token_usage
194
+
195
+ def get_last_token_usage(self) -> Dict[str, int]:
196
+ """
197
+ Get token usage from last request.
198
+
199
+ Returns:
200
+ Dictionary with token usage statistics from last request
201
+ """
202
+ return self.last_token_usage
203
+
204
+ def _format_chat_messages(self, messages: List[Dict[str, str]]) -> str:
205
+ """
206
+ Format chat messages into a single prompt for models that don't support chat natively.
207
+
208
+ Args:
209
+ messages: List of message dictionaries
210
+
211
+ Returns:
212
+ Formatted prompt
213
+ """
214
+ formatted_prompt = ""
215
+
216
+ for message in messages:
217
+ role = message.get("role", "user").lower()
218
+ content = message.get("content", "")
219
+
220
+ if role == "system":
221
+ formatted_prompt += f"System: {content}\n\n"
222
+ elif role == "user":
223
+ formatted_prompt += f"User: {content}\n\n"
224
+ elif role == "assistant":
225
+ formatted_prompt += f"Assistant: {content}\n\n"
226
+ else:
227
+ formatted_prompt += f"{role.capitalize()}: {content}\n\n"
228
+
229
+ formatted_prompt += "Assistant: "
230
+ return formatted_prompt
@@ -0,0 +1,61 @@
1
+ from typing import Dict, Any, List, Union, Optional
2
+ from ...base_service import BaseService
3
+ from ...base_provider import BaseProvider
4
+ from transformers import TableTransformerForObjectDetection, DetrImageProcessor
5
+ import torch
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ class TableTransformerService(BaseService):
10
+ """Table detection service using Microsoft's Table Transformer"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "microsoft/table-transformer-detection"):
13
+ super().__init__(provider, model_name)
14
+ self.processor = DetrImageProcessor.from_pretrained(model_name)
15
+ self.model = TableTransformerForObjectDetection.from_pretrained(model_name)
16
+ if torch.cuda.is_available():
17
+ self.model = self.model.cuda()
18
+
19
+ async def detect_tables(self, image_path: str) -> Dict[str, Any]:
20
+ """Detect tables in image"""
21
+ try:
22
+ # Load and process image
23
+ image = Image.open(image_path)
24
+ inputs = self.processor(images=image, return_tensors="pt")
25
+
26
+ if torch.cuda.is_available():
27
+ inputs = {k: v.cuda() for k, v in inputs.items()}
28
+
29
+ # Run inference
30
+ outputs = self.model(**inputs)
31
+
32
+ # Convert outputs to image size
33
+ target_sizes = torch.tensor([image.size[::-1]])
34
+ results = self.processor.post_process_object_detection(
35
+ outputs, target_sizes=target_sizes, threshold=0.7
36
+ )[0]
37
+
38
+ tables = []
39
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
40
+ if label == 1: # Table class
41
+ tables.append({
42
+ "confidence": score.item(),
43
+ "bbox": box.tolist(),
44
+ "type": "table"
45
+ })
46
+
47
+ return {
48
+ "tables": tables,
49
+ "image_size": image.size
50
+ }
51
+
52
+ except Exception as e:
53
+ raise RuntimeError(f"Table detection failed: {e}")
54
+
55
+ async def close(self):
56
+ """Cleanup resources"""
57
+ if hasattr(self, 'model'):
58
+ del self.model
59
+ if hasattr(self, 'processor'):
60
+ del self.processor
61
+ torch.cuda.empty_cache()
@@ -0,0 +1,12 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Vision服务包
6
+ 包含所有视觉相关服务模块
7
+ """
8
+
9
+ # 导出ReplicateVisionService
10
+ from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
11
+
12
+ __all__ = ["ReplicateVisionService"]
@@ -0,0 +1,58 @@
1
+ from io import BytesIO
2
+ from PIL import Image
3
+ from typing import Union
4
+ import base64
5
+ from app.config.config_manager import config_manager
6
+
7
+ logger = config_manager.get_logger(__name__)
8
+
9
+ def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
10
+ """压缩图片以减小大小
11
+
12
+ Args:
13
+ image_data: 图片数据,可以是 bytes 或 BytesIO
14
+ max_size: 最大尺寸(像素)
15
+
16
+ Returns:
17
+ bytes: 压缩后的图片数据
18
+ """
19
+ try:
20
+ # 如果输入是 bytes,转换为 BytesIO
21
+ if isinstance(image_data, bytes):
22
+ image_data = BytesIO(image_data)
23
+
24
+ img = Image.open(image_data)
25
+
26
+ # 转换为 RGB 模式(如果需要)
27
+ if img.mode in ('RGBA', 'P'):
28
+ img = img.convert('RGB')
29
+
30
+ # 计算新尺寸,保持宽高比
31
+ ratio = max_size / max(img.size)
32
+ if ratio < 1:
33
+ new_size = tuple(int(dim * ratio) for dim in img.size)
34
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
35
+
36
+ # 保存压缩后的图片
37
+ output = BytesIO()
38
+ img.save(output, format='JPEG', quality=85, optimize=True)
39
+ return output.getvalue()
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error compressing image: {e}")
43
+ raise
44
+
45
+ def encode_image_to_base64(image_data: bytes) -> str:
46
+ """将图片数据编码为 base64 字符串
47
+
48
+ Args:
49
+ image_data: 图片二进制数据
50
+
51
+ Returns:
52
+ str: base64 编码的字符串
53
+ """
54
+ try:
55
+ return base64.b64encode(image_data).decode('utf-8')
56
+ except Exception as e:
57
+ logger.error(f"Error encoding image to base64: {e}")
58
+ raise