agentscope-runtime 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentscope_runtime/engine/agents/agentscope_agent/agent.py +105 -50
- agentscope_runtime/engine/agents/agentscope_agent/hooks.py +16 -3
- agentscope_runtime/engine/helpers/helper.py +33 -0
- agentscope_runtime/engine/runner.py +33 -1
- agentscope_runtime/engine/schemas/agent_schemas.py +208 -13
- agentscope_runtime/engine/services/context_manager.py +34 -1
- agentscope_runtime/engine/services/rag_service.py +195 -0
- agentscope_runtime/engine/services/reme_personal_memory_service.py +106 -0
- agentscope_runtime/engine/services/reme_task_memory_service.py +11 -0
- agentscope_runtime/sandbox/box/browser/browser_sandbox.py +25 -0
- agentscope_runtime/sandbox/box/sandbox.py +60 -7
- agentscope_runtime/sandbox/box/shared/routers/mcp_utils.py +20 -2
- agentscope_runtime/sandbox/box/training_box/env_service.py +1 -1
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_dataprocess.py +216 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_env.py +380 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/env_handler.py +934 -0
- agentscope_runtime/sandbox/box/training_box/training_box.py +139 -9
- agentscope_runtime/sandbox/client/http_client.py +1 -1
- agentscope_runtime/sandbox/enums.py +2 -0
- agentscope_runtime/sandbox/manager/container_clients/docker_client.py +19 -9
- agentscope_runtime/sandbox/manager/container_clients/kubernetes_client.py +61 -6
- agentscope_runtime/sandbox/manager/sandbox_manager.py +95 -35
- agentscope_runtime/sandbox/manager/server/app.py +128 -17
- agentscope_runtime/sandbox/model/__init__.py +1 -5
- agentscope_runtime/sandbox/model/manager_config.py +2 -13
- agentscope_runtime/sandbox/tools/mcp_tool.py +1 -1
- agentscope_runtime/version.py +1 -1
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/METADATA +59 -3
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/RECORD +33 -27
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/WHEEL +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/entry_points.txt +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# type: ignore
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import List, Dict, Optional, Any
|
|
5
|
+
from typing import List, Dict, Optional, Any, Literal, TypeAlias, Annotated
|
|
6
6
|
from typing import Union
|
|
7
7
|
|
|
8
8
|
try:
|
|
@@ -12,6 +12,7 @@ except ImportError:
|
|
|
12
12
|
from uuid import uuid4
|
|
13
13
|
|
|
14
14
|
from pydantic import BaseModel, Field, field_validator
|
|
15
|
+
from openai.types.chat import ChatCompletionChunk
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class MessageType:
|
|
@@ -29,6 +30,15 @@ class MessageType:
|
|
|
29
30
|
HEARTBEAT = "heartbeat"
|
|
30
31
|
ERROR = "error"
|
|
31
32
|
|
|
33
|
+
@classmethod
|
|
34
|
+
def all_values(cls):
|
|
35
|
+
"""return all constants values in MessageType"""
|
|
36
|
+
return [
|
|
37
|
+
value
|
|
38
|
+
for name, value in vars(cls).items()
|
|
39
|
+
if not name.startswith("_") and isinstance(value, str)
|
|
40
|
+
]
|
|
41
|
+
|
|
32
42
|
|
|
33
43
|
class ContentType:
|
|
34
44
|
TEXT = "text"
|
|
@@ -41,6 +51,7 @@ class Role:
|
|
|
41
51
|
ASSISTANT = "assistant"
|
|
42
52
|
USER = "user"
|
|
43
53
|
SYSTEM = "system"
|
|
54
|
+
TOOL = "tool"
|
|
44
55
|
|
|
45
56
|
|
|
46
57
|
class RunStatus:
|
|
@@ -70,7 +81,7 @@ class FunctionParameters(BaseModel):
|
|
|
70
81
|
|
|
71
82
|
class FunctionTool(BaseModel):
|
|
72
83
|
"""
|
|
73
|
-
Model class for
|
|
84
|
+
Model class for message tool.
|
|
74
85
|
"""
|
|
75
86
|
|
|
76
87
|
name: str
|
|
@@ -89,10 +100,10 @@ class FunctionTool(BaseModel):
|
|
|
89
100
|
|
|
90
101
|
class Tool(BaseModel):
|
|
91
102
|
"""
|
|
92
|
-
Model class for assistant
|
|
103
|
+
Model class for assistant message tool call.
|
|
93
104
|
"""
|
|
94
105
|
|
|
95
|
-
type: Optional[str] =
|
|
106
|
+
type: Optional[str] = "function"
|
|
96
107
|
"""The type of the tool. Currently, only `function` is supported."""
|
|
97
108
|
|
|
98
109
|
function: Optional[FunctionTool] = None
|
|
@@ -141,7 +152,7 @@ class Error(BaseModel):
|
|
|
141
152
|
|
|
142
153
|
|
|
143
154
|
class Event(BaseModel):
|
|
144
|
-
sequence_number: Optional[
|
|
155
|
+
sequence_number: Optional[int] = None
|
|
145
156
|
"""sequence number of event"""
|
|
146
157
|
|
|
147
158
|
object: str
|
|
@@ -213,9 +224,42 @@ class Content(Event):
|
|
|
213
224
|
msg_id: Optional[str] = None
|
|
214
225
|
"""message unique id"""
|
|
215
226
|
|
|
227
|
+
@staticmethod
|
|
228
|
+
def from_chat_completion_chunk(
|
|
229
|
+
chunk: ChatCompletionChunk,
|
|
230
|
+
index: Optional[int] = None,
|
|
231
|
+
) -> Optional[Union["TextContent", "DataContent", "ImageContent"]]:
|
|
232
|
+
if not chunk.choices:
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
choice = chunk.choices[0]
|
|
236
|
+
if choice.delta.content is not None:
|
|
237
|
+
return TextContent(
|
|
238
|
+
delta=True,
|
|
239
|
+
text=choice.delta.content,
|
|
240
|
+
index=index,
|
|
241
|
+
)
|
|
242
|
+
elif choice.delta.tool_calls:
|
|
243
|
+
# TODO: support multiple tool calls output
|
|
244
|
+
tool_call = choice.delta.tool_calls[0]
|
|
245
|
+
if tool_call.function is not None:
|
|
246
|
+
return DataContent(
|
|
247
|
+
delta=True,
|
|
248
|
+
data={
|
|
249
|
+
"call_id": tool_call.id,
|
|
250
|
+
"name": tool_call.function.name,
|
|
251
|
+
"arguments": tool_call.function.arguments,
|
|
252
|
+
},
|
|
253
|
+
index=index,
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
return None
|
|
257
|
+
else:
|
|
258
|
+
return None
|
|
259
|
+
|
|
216
260
|
|
|
217
261
|
class ImageContent(Content):
|
|
218
|
-
type:
|
|
262
|
+
type: Literal[ContentType.IMAGE] = ContentType.IMAGE
|
|
219
263
|
"""The type of the content part."""
|
|
220
264
|
|
|
221
265
|
image_url: Optional[str] = None
|
|
@@ -223,7 +267,7 @@ class ImageContent(Content):
|
|
|
223
267
|
|
|
224
268
|
|
|
225
269
|
class TextContent(Content):
|
|
226
|
-
type:
|
|
270
|
+
type: Literal[ContentType.TEXT] = ContentType.TEXT
|
|
227
271
|
"""The type of the content part."""
|
|
228
272
|
|
|
229
273
|
text: Optional[str] = None
|
|
@@ -231,13 +275,27 @@ class TextContent(Content):
|
|
|
231
275
|
|
|
232
276
|
|
|
233
277
|
class DataContent(Content):
|
|
234
|
-
type:
|
|
278
|
+
type: Literal[ContentType.DATA] = ContentType.DATA
|
|
235
279
|
"""The type of the content part."""
|
|
236
280
|
|
|
237
281
|
data: Optional[Dict] = None
|
|
238
282
|
"""The data content."""
|
|
239
283
|
|
|
240
284
|
|
|
285
|
+
AgentRole: TypeAlias = Literal[
|
|
286
|
+
Role.ASSISTANT,
|
|
287
|
+
Role.SYSTEM,
|
|
288
|
+
Role.USER,
|
|
289
|
+
Role.TOOL,
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
AgentContent = Annotated[
|
|
294
|
+
Union[TextContent, ImageContent, DataContent],
|
|
295
|
+
Field(discriminator="type"),
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
|
|
241
299
|
class Message(Event):
|
|
242
300
|
id: str = Field(default_factory=lambda: "msg_" + str(uuid4()))
|
|
243
301
|
"""message unique id"""
|
|
@@ -251,13 +309,11 @@ class Message(Event):
|
|
|
251
309
|
status: str = RunStatus.Created
|
|
252
310
|
"""The status of the message. in_progress, completed, or incomplete"""
|
|
253
311
|
|
|
254
|
-
role: Optional[
|
|
312
|
+
role: Optional[AgentRole] = None
|
|
255
313
|
"""The role of the messages author, should be in `user`,`system`,
|
|
256
314
|
'assistant'."""
|
|
257
315
|
|
|
258
|
-
content: Optional[
|
|
259
|
-
List[Union[TextContent, ImageContent, DataContent]]
|
|
260
|
-
] = None
|
|
316
|
+
content: Optional[List[AgentContent]] = None
|
|
261
317
|
"""The contents of the message."""
|
|
262
318
|
|
|
263
319
|
code: Optional[str] = None
|
|
@@ -266,6 +322,145 @@ class Message(Event):
|
|
|
266
322
|
message: Optional[str] = None
|
|
267
323
|
"""The error message of the message."""
|
|
268
324
|
|
|
325
|
+
usage: Optional[Dict] = None
|
|
326
|
+
"""response usage for output"""
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def from_openai_message(message: Union[BaseModel, dict]) -> "Message":
|
|
330
|
+
"""Create a message object from an openai message."""
|
|
331
|
+
|
|
332
|
+
# in case message is a Message object
|
|
333
|
+
if isinstance(message, Message):
|
|
334
|
+
return message
|
|
335
|
+
|
|
336
|
+
# make sure operation on dict object
|
|
337
|
+
if isinstance(message, BaseModel):
|
|
338
|
+
message = message.model_dump()
|
|
339
|
+
|
|
340
|
+
# in case message is a Message format dict
|
|
341
|
+
if "type" in message and message["type"] in MessageType.all_values():
|
|
342
|
+
return Message(**message)
|
|
343
|
+
|
|
344
|
+
# handle message in openai message format
|
|
345
|
+
if message["role"] == Role.ASSISTANT and "tool_calls" in message:
|
|
346
|
+
_content_list = []
|
|
347
|
+
for tool_call in message["tool_calls"]:
|
|
348
|
+
_content = DataContent(
|
|
349
|
+
data=FunctionCall(
|
|
350
|
+
call_id=tool_call["id"],
|
|
351
|
+
name=tool_call["function"]["name"],
|
|
352
|
+
arguments=tool_call["function"]["arguments"],
|
|
353
|
+
).model_dump(),
|
|
354
|
+
)
|
|
355
|
+
_content_list.append(_content)
|
|
356
|
+
_message = Message(
|
|
357
|
+
type=MessageType.FUNCTION_CALL,
|
|
358
|
+
content=_content_list,
|
|
359
|
+
)
|
|
360
|
+
elif message["role"] == Role.TOOL:
|
|
361
|
+
_content = DataContent(
|
|
362
|
+
data=FunctionCallOutput(
|
|
363
|
+
call_id=message["tool_call_id"],
|
|
364
|
+
output=message["content"],
|
|
365
|
+
).model_dump(),
|
|
366
|
+
)
|
|
367
|
+
_message = Message(
|
|
368
|
+
type=MessageType.FUNCTION_CALL_OUTPUT,
|
|
369
|
+
content=[_content],
|
|
370
|
+
)
|
|
371
|
+
# mainly focus on matching content
|
|
372
|
+
elif isinstance(message["content"], str):
|
|
373
|
+
_content = TextContent(text=message["content"])
|
|
374
|
+
_message = Message(
|
|
375
|
+
type=MessageType.MESSAGE,
|
|
376
|
+
role=message["role"],
|
|
377
|
+
content=[_content],
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
_content_list = []
|
|
381
|
+
for content in message["content"]:
|
|
382
|
+
if content["type"] == "image_url":
|
|
383
|
+
_content = ImageContent(
|
|
384
|
+
image_url=content["image_url"]["url"],
|
|
385
|
+
)
|
|
386
|
+
elif content["type"] == "text":
|
|
387
|
+
_content = TextContent(text=content["text"])
|
|
388
|
+
else:
|
|
389
|
+
_content = DataContent(data=content["text"])
|
|
390
|
+
_content_list.append(_content)
|
|
391
|
+
_message = Message(
|
|
392
|
+
type=MessageType.MESSAGE,
|
|
393
|
+
role=message["role"],
|
|
394
|
+
content=_content_list,
|
|
395
|
+
)
|
|
396
|
+
return _message
|
|
397
|
+
|
|
398
|
+
def get_text_content(self) -> Optional[str]:
|
|
399
|
+
"""
|
|
400
|
+
Extract the first text content from the message.
|
|
401
|
+
|
|
402
|
+
:return:
|
|
403
|
+
First text string found in the content, or None if no text content
|
|
404
|
+
"""
|
|
405
|
+
if self.content is None:
|
|
406
|
+
return None
|
|
407
|
+
|
|
408
|
+
for item in self.content:
|
|
409
|
+
if isinstance(item, TextContent):
|
|
410
|
+
return item.text
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
def get_image_content(self) -> List[str]:
|
|
414
|
+
"""
|
|
415
|
+
Extract all image content (URLs or base64 data) from the message.
|
|
416
|
+
|
|
417
|
+
:return:
|
|
418
|
+
List of image URLs or base64 encoded strings found in the content
|
|
419
|
+
"""
|
|
420
|
+
images = []
|
|
421
|
+
|
|
422
|
+
if self.content is None:
|
|
423
|
+
return images
|
|
424
|
+
|
|
425
|
+
for item in self.content:
|
|
426
|
+
if isinstance(item, ImageContent):
|
|
427
|
+
images.append(item.image_url)
|
|
428
|
+
return images
|
|
429
|
+
|
|
430
|
+
def get_audio_content(self) -> List[str]:
|
|
431
|
+
"""
|
|
432
|
+
Extract all audio content (URLs or base64 data) from the message.
|
|
433
|
+
|
|
434
|
+
:return:
|
|
435
|
+
List of audio URLs or base64 encoded strings found in the content
|
|
436
|
+
"""
|
|
437
|
+
audios = []
|
|
438
|
+
|
|
439
|
+
if self.content is None:
|
|
440
|
+
return audios
|
|
441
|
+
|
|
442
|
+
for item in self.content:
|
|
443
|
+
if hasattr(item, "type"):
|
|
444
|
+
if item.type == "input_audio" and hasattr(
|
|
445
|
+
item,
|
|
446
|
+
"input_audio",
|
|
447
|
+
):
|
|
448
|
+
if hasattr(item.input_audio, "data"):
|
|
449
|
+
audios.append(item.input_audio.data)
|
|
450
|
+
elif hasattr(item.input_audio, "base64_data"):
|
|
451
|
+
# Construct data URL for audio
|
|
452
|
+
format_type = getattr(
|
|
453
|
+
item.input_audio,
|
|
454
|
+
"format",
|
|
455
|
+
"mp3",
|
|
456
|
+
)
|
|
457
|
+
audios.append(
|
|
458
|
+
f"data:{format_type};base64,"
|
|
459
|
+
f"{item.input_audio.base64_data}",
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return audios
|
|
463
|
+
|
|
269
464
|
def add_delta_content(
|
|
270
465
|
self,
|
|
271
466
|
new_content: Union[TextContent, ImageContent, DataContent],
|
|
@@ -500,7 +695,7 @@ def convert_to_openai_tool_call(function: FunctionCall):
|
|
|
500
695
|
}
|
|
501
696
|
|
|
502
697
|
|
|
503
|
-
def convert_to_openai_messages(messages: List[Message]) ->
|
|
698
|
+
def convert_to_openai_messages(messages: List[Message]) -> List[Dict]:
|
|
504
699
|
"""
|
|
505
700
|
Convert a generic message protocol to a model-specific protocol.
|
|
506
701
|
Args:
|
|
@@ -4,12 +4,19 @@ from typing import List
|
|
|
4
4
|
|
|
5
5
|
from .manager import ServiceManager
|
|
6
6
|
from .memory_service import MemoryService, InMemoryMemoryService
|
|
7
|
+
from .rag_service import RAGService
|
|
7
8
|
from .session_history_service import (
|
|
8
9
|
SessionHistoryService,
|
|
9
10
|
Session,
|
|
10
11
|
InMemorySessionHistoryService,
|
|
11
12
|
)
|
|
12
|
-
from ..schemas.agent_schemas import
|
|
13
|
+
from ..schemas.agent_schemas import (
|
|
14
|
+
Message,
|
|
15
|
+
MessageType,
|
|
16
|
+
Role,
|
|
17
|
+
TextContent,
|
|
18
|
+
ContentType,
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
class ContextComposer:
|
|
@@ -19,6 +26,7 @@ class ContextComposer:
|
|
|
19
26
|
session: Session, # session
|
|
20
27
|
memory_service: MemoryService = None,
|
|
21
28
|
session_history_service: SessionHistoryService = None,
|
|
29
|
+
rag_service: RAGService = None,
|
|
22
30
|
):
|
|
23
31
|
# session
|
|
24
32
|
if session_history_service:
|
|
@@ -42,6 +50,24 @@ class ContextComposer:
|
|
|
42
50
|
)
|
|
43
51
|
session.messages = memories + session.messages
|
|
44
52
|
|
|
53
|
+
# rag
|
|
54
|
+
if rag_service:
|
|
55
|
+
query = await rag_service.get_query_text(request_input[-1])
|
|
56
|
+
docs = await rag_service.retrieve(query=query, k=5)
|
|
57
|
+
cooked_doc = "\n".join(docs)
|
|
58
|
+
message = Message(
|
|
59
|
+
type=MessageType.MESSAGE,
|
|
60
|
+
role=Role.SYSTEM,
|
|
61
|
+
content=[TextContent(type=ContentType.TEXT, text=cooked_doc)],
|
|
62
|
+
)
|
|
63
|
+
if len(session.messages) >= 1:
|
|
64
|
+
last_message = session.messages[-1]
|
|
65
|
+
session.messages.remove(last_message)
|
|
66
|
+
session.messages.append(message)
|
|
67
|
+
session.messages.append(last_message)
|
|
68
|
+
else:
|
|
69
|
+
session.messages.append(message)
|
|
70
|
+
|
|
45
71
|
|
|
46
72
|
class ContextManager(ServiceManager):
|
|
47
73
|
"""
|
|
@@ -53,10 +79,12 @@ class ContextManager(ServiceManager):
|
|
|
53
79
|
context_composer_cls=ContextComposer,
|
|
54
80
|
session_history_service: SessionHistoryService = None,
|
|
55
81
|
memory_service: MemoryService = None,
|
|
82
|
+
rag_service: RAGService = None,
|
|
56
83
|
):
|
|
57
84
|
self._context_composer_cls = context_composer_cls
|
|
58
85
|
self._session_history_service = session_history_service
|
|
59
86
|
self._memory_service = memory_service
|
|
87
|
+
self._rag_service = rag_service
|
|
60
88
|
super().__init__()
|
|
61
89
|
|
|
62
90
|
def _register_default_services(self):
|
|
@@ -68,6 +96,8 @@ class ContextManager(ServiceManager):
|
|
|
68
96
|
|
|
69
97
|
self.register_service("session", self._session_history_service)
|
|
70
98
|
self.register_service("memory", self._memory_service)
|
|
99
|
+
if self._rag_service:
|
|
100
|
+
self.register_service("rag", self._rag_service)
|
|
71
101
|
|
|
72
102
|
async def compose_context(
|
|
73
103
|
self,
|
|
@@ -77,6 +107,7 @@ class ContextManager(ServiceManager):
|
|
|
77
107
|
await self._context_composer_cls.compose(
|
|
78
108
|
memory_service=self._memory_service,
|
|
79
109
|
session_history_service=self._session_history_service,
|
|
110
|
+
rag_service=self._rag_service,
|
|
80
111
|
session=session,
|
|
81
112
|
request_input=request_input,
|
|
82
113
|
)
|
|
@@ -119,10 +150,12 @@ class ContextManager(ServiceManager):
|
|
|
119
150
|
async def create_context_manager(
|
|
120
151
|
memory_service: MemoryService = None,
|
|
121
152
|
session_history_service: SessionHistoryService = None,
|
|
153
|
+
rag_service: RAGService = None,
|
|
122
154
|
):
|
|
123
155
|
manager = ContextManager(
|
|
124
156
|
memory_service=memory_service,
|
|
125
157
|
session_history_service=session_history_service,
|
|
158
|
+
rag_service=rag_service,
|
|
126
159
|
)
|
|
127
160
|
|
|
128
161
|
async with manager:
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from .base import ServiceWithLifecycleManager
|
|
3
|
+
from ..schemas.agent_schemas import Message, MessageType
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RAGService(ServiceWithLifecycleManager):
|
|
7
|
+
"""
|
|
8
|
+
RAG Service
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
async def get_query_text(self, message: Message) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Gets the query text from the messages.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
message: A list of messages.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
The query text.
|
|
20
|
+
"""
|
|
21
|
+
if message:
|
|
22
|
+
if message.type == MessageType.MESSAGE:
|
|
23
|
+
for content in message.content:
|
|
24
|
+
if content.type == "text":
|
|
25
|
+
return content.text
|
|
26
|
+
return ""
|
|
27
|
+
|
|
28
|
+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
|
|
29
|
+
"""
|
|
30
|
+
Retrieves similar documents based on the given query.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
query (str): The query string to search for similar documents.
|
|
34
|
+
k (int, optional): The number of similar documents to retrieve.
|
|
35
|
+
Defaults to 1.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
list[str]: A list of document contents that are similar to
|
|
39
|
+
the query.
|
|
40
|
+
"""
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
async def start(self) -> None:
|
|
44
|
+
"""Starts the service."""
|
|
45
|
+
|
|
46
|
+
async def stop(self) -> None:
|
|
47
|
+
"""Stops the service."""
|
|
48
|
+
|
|
49
|
+
async def health(self) -> bool:
|
|
50
|
+
"""Checks the health of the service."""
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
DEFAULT_URI = "milvus_demo.db"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class LangChainRAGService(RAGService):
|
|
58
|
+
"""
|
|
59
|
+
RAG Service using LangChain
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
vectorstore=None,
|
|
65
|
+
embedding=None,
|
|
66
|
+
):
|
|
67
|
+
# set default embedding alg.
|
|
68
|
+
if embedding is None:
|
|
69
|
+
from langchain_community.embeddings import DashScopeEmbeddings
|
|
70
|
+
|
|
71
|
+
self.embeddings = DashScopeEmbeddings()
|
|
72
|
+
else:
|
|
73
|
+
self.embeddings = embedding
|
|
74
|
+
|
|
75
|
+
# set default vectorstore class.
|
|
76
|
+
if vectorstore is None:
|
|
77
|
+
from langchain_milvus import Milvus
|
|
78
|
+
|
|
79
|
+
self.vectorstore = Milvus.from_documents(
|
|
80
|
+
[],
|
|
81
|
+
embedding=self.embeddings,
|
|
82
|
+
connection_args={
|
|
83
|
+
"uri": DEFAULT_URI,
|
|
84
|
+
},
|
|
85
|
+
drop_old=False,
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
self.vectorstore = vectorstore
|
|
89
|
+
|
|
90
|
+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
|
|
91
|
+
"""
|
|
92
|
+
Retrieves similar documents based on the given query using LangChain.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
query (str): The query string to search for similar documents.
|
|
96
|
+
k (int, optional): The number of similar documents to retrieve.
|
|
97
|
+
Defaults to 1.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
list[str]: A list of document contents that are similar to the
|
|
101
|
+
query.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If the vector store is not initialized.
|
|
105
|
+
"""
|
|
106
|
+
if self.vectorstore is None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"Vector store not initialized. Call build_index first.",
|
|
109
|
+
)
|
|
110
|
+
docs = self.vectorstore.similarity_search(query, k=k)
|
|
111
|
+
return [doc.page_content for doc in docs]
|
|
112
|
+
|
|
113
|
+
async def start(self) -> None:
|
|
114
|
+
"""Starts the service."""
|
|
115
|
+
|
|
116
|
+
async def stop(self) -> None:
|
|
117
|
+
"""Stops the service."""
|
|
118
|
+
|
|
119
|
+
async def health(self) -> bool:
|
|
120
|
+
"""Checks the health of the service."""
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class LlamaIndexRAGService(RAGService):
|
|
125
|
+
"""
|
|
126
|
+
RAG Service using LlamaIndex
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
vectorstore=None,
|
|
132
|
+
embedding=None,
|
|
133
|
+
):
|
|
134
|
+
# set default embedding alg.
|
|
135
|
+
if embedding is None:
|
|
136
|
+
from langchain_community.embeddings import DashScopeEmbeddings
|
|
137
|
+
|
|
138
|
+
self.embeddings = DashScopeEmbeddings()
|
|
139
|
+
else:
|
|
140
|
+
self.embeddings = embedding
|
|
141
|
+
|
|
142
|
+
# set default vectorstore.
|
|
143
|
+
if vectorstore is None:
|
|
144
|
+
from llama_index.core import VectorStoreIndex
|
|
145
|
+
from llama_index.core.schema import Document
|
|
146
|
+
from llama_index.vector_stores.milvus import MilvusVectorStore
|
|
147
|
+
|
|
148
|
+
# Create empty documents list for initialization
|
|
149
|
+
documents = [Document(text="")]
|
|
150
|
+
|
|
151
|
+
# Initialize Milvus vector store
|
|
152
|
+
self.vector_store = MilvusVectorStore(
|
|
153
|
+
uri=DEFAULT_URI,
|
|
154
|
+
overwrite=False,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Create index
|
|
158
|
+
self.index = VectorStoreIndex.from_documents(
|
|
159
|
+
documents=documents,
|
|
160
|
+
embed_model=self.embeddings,
|
|
161
|
+
vector_store=self.vector_store,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
self.index = vectorstore
|
|
165
|
+
|
|
166
|
+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
|
|
167
|
+
"""
|
|
168
|
+
Retrieves similar documents based on the given query using LlamaIndex.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
query (str): The query string to search for similar documents.
|
|
172
|
+
k (int, optional): The number of similar documents to retrieve.
|
|
173
|
+
Defaults to 1.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
list[str]: A list of document contents that are similar to the
|
|
177
|
+
query.
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the index is not initialized.
|
|
181
|
+
"""
|
|
182
|
+
if self.index is None:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"Index not initialized.",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Create query engine and query
|
|
188
|
+
query_engine = self.index.as_retriever(similarity_top_k=k)
|
|
189
|
+
response = query_engine.retrieve(query)
|
|
190
|
+
|
|
191
|
+
# Extract text from nodes
|
|
192
|
+
if len(response) > 0:
|
|
193
|
+
return [node.node.get_content() for node in response]
|
|
194
|
+
else:
|
|
195
|
+
return [""]
|