draive 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- draive/__init__.py +186 -0
- draive/conversation/__init__.py +18 -0
- draive/conversation/call.py +87 -0
- draive/conversation/completion.py +67 -0
- draive/conversation/lmm.py +159 -0
- draive/conversation/message.py +16 -0
- draive/conversation/state.py +17 -0
- draive/embedding/__init__.py +11 -0
- draive/embedding/call.py +15 -0
- draive/embedding/embedded.py +18 -0
- draive/embedding/embedder.py +23 -0
- draive/embedding/state.py +10 -0
- draive/generation/__init__.py +15 -0
- draive/generation/image/__init__.py +9 -0
- draive/generation/image/call.py +16 -0
- draive/generation/image/generator.py +17 -0
- draive/generation/image/state.py +10 -0
- draive/generation/model/__init__.py +9 -0
- draive/generation/model/call.py +34 -0
- draive/generation/model/generator.py +29 -0
- draive/generation/model/lmm.py +85 -0
- draive/generation/model/state.py +13 -0
- draive/generation/text/__init__.py +9 -0
- draive/generation/text/call.py +26 -0
- draive/generation/text/generator.py +22 -0
- draive/generation/text/lmm.py +63 -0
- draive/generation/text/state.py +13 -0
- draive/helpers/__init__.py +13 -0
- draive/helpers/env.py +139 -0
- draive/helpers/logs.py +59 -0
- draive/helpers/split_sequence.py +20 -0
- draive/lmm/__init__.py +18 -0
- draive/lmm/call.py +73 -0
- draive/lmm/completion.py +64 -0
- draive/lmm/message.py +50 -0
- draive/lmm/state.py +10 -0
- draive/mistral/__init__.py +11 -0
- draive/mistral/chat_response.py +92 -0
- draive/mistral/chat_stream.py +130 -0
- draive/mistral/chat_tools.py +111 -0
- draive/mistral/client.py +112 -0
- draive/mistral/config.py +56 -0
- draive/mistral/errors.py +7 -0
- draive/mistral/lmm.py +213 -0
- draive/openai/__init__.py +23 -0
- draive/openai/chat_response.py +97 -0
- draive/openai/chat_stream.py +120 -0
- draive/openai/chat_tools.py +139 -0
- draive/openai/client.py +212 -0
- draive/openai/config.py +122 -0
- draive/openai/embedding.py +33 -0
- draive/openai/errors.py +7 -0
- draive/openai/images.py +30 -0
- draive/openai/lmm.py +236 -0
- draive/openai/tokenization.py +22 -0
- draive/py.typed +0 -0
- draive/scope/__init__.py +16 -0
- draive/scope/access.py +330 -0
- draive/scope/dependencies.py +63 -0
- draive/scope/errors.py +17 -0
- draive/scope/metrics.py +462 -0
- draive/scope/state.py +60 -0
- draive/similarity/__init__.py +7 -0
- draive/similarity/cosine.py +35 -0
- draive/similarity/mmr.py +67 -0
- draive/similarity/similarity.py +32 -0
- draive/splitters/__init__.py +5 -0
- draive/splitters/basic.py +130 -0
- draive/tokenization/__init__.py +10 -0
- draive/tokenization/call.py +18 -0
- draive/tokenization/state.py +10 -0
- draive/tokenization/text.py +14 -0
- draive/tools/__init__.py +19 -0
- draive/tools/errors.py +7 -0
- draive/tools/state.py +31 -0
- draive/tools/tool.py +184 -0
- draive/tools/toolbox.py +51 -0
- draive/tools/update.py +18 -0
- draive/types/__init__.py +45 -0
- draive/types/images.py +18 -0
- draive/types/memory.py +55 -0
- draive/types/missing.py +28 -0
- draive/types/model.py +50 -0
- draive/types/multimodal.py +8 -0
- draive/types/parameters.py +847 -0
- draive/types/specification.py +394 -0
- draive/types/state.py +16 -0
- draive/types/updates.py +22 -0
- draive/utils/__init__.py +13 -0
- draive/utils/cache.py +177 -0
- draive/utils/early_exit.py +125 -0
- draive/utils/retry.py +167 -0
- draive/utils/stream.py +105 -0
- draive-0.9.1.dist-info/LICENSE +21 -0
- draive-0.9.1.dist-info/METADATA +76 -0
- draive-0.9.1.dist-info/RECORD +98 -0
- draive-0.9.1.dist-info/WHEEL +5 -0
- draive-0.9.1.dist-info/top_level.txt +1 -0
draive/__init__.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from draive.conversation import (
|
|
2
|
+
Conversation,
|
|
3
|
+
ConversationCompletion,
|
|
4
|
+
ConversationCompletionStream,
|
|
5
|
+
ConversationMessage,
|
|
6
|
+
ConversationMessageContent,
|
|
7
|
+
conversation_completion,
|
|
8
|
+
lmm_conversation_completion,
|
|
9
|
+
)
|
|
10
|
+
from draive.embedding import Embedded, Embedder, Embedding, embed_text
|
|
11
|
+
from draive.generation import (
|
|
12
|
+
ImageGeneration,
|
|
13
|
+
ImageGenerator,
|
|
14
|
+
ModelGeneration,
|
|
15
|
+
ModelGenerator,
|
|
16
|
+
TextGeneration,
|
|
17
|
+
TextGenerator,
|
|
18
|
+
generate_image,
|
|
19
|
+
generate_model,
|
|
20
|
+
generate_text,
|
|
21
|
+
)
|
|
22
|
+
from draive.helpers import (
|
|
23
|
+
getenv_bool,
|
|
24
|
+
getenv_float,
|
|
25
|
+
getenv_int,
|
|
26
|
+
getenv_str,
|
|
27
|
+
load_env,
|
|
28
|
+
setup_logging,
|
|
29
|
+
split_sequence,
|
|
30
|
+
)
|
|
31
|
+
from draive.lmm import (
|
|
32
|
+
LMM,
|
|
33
|
+
LMMCompletion,
|
|
34
|
+
LMMCompletionContent,
|
|
35
|
+
LMMCompletionMessage,
|
|
36
|
+
LMMCompletionStream,
|
|
37
|
+
LMMCompletionStreamingUpdate,
|
|
38
|
+
lmm_completion,
|
|
39
|
+
)
|
|
40
|
+
from draive.mistral import (
|
|
41
|
+
MistralChatConfig,
|
|
42
|
+
MistralClient,
|
|
43
|
+
MistralException,
|
|
44
|
+
mistral_lmm_completion,
|
|
45
|
+
)
|
|
46
|
+
from draive.openai import (
|
|
47
|
+
OpenAIChatConfig,
|
|
48
|
+
OpenAIClient,
|
|
49
|
+
OpenAIEmbeddingConfig,
|
|
50
|
+
OpenAIException,
|
|
51
|
+
OpenAIImageGenerationConfig,
|
|
52
|
+
openai_embed_text,
|
|
53
|
+
openai_generate_image,
|
|
54
|
+
openai_lmm_completion,
|
|
55
|
+
openai_tokenize_text,
|
|
56
|
+
)
|
|
57
|
+
from draive.scope import (
|
|
58
|
+
DependenciesScope,
|
|
59
|
+
MetricsScope,
|
|
60
|
+
ScopeDependency,
|
|
61
|
+
ScopeMetric,
|
|
62
|
+
StateScope,
|
|
63
|
+
TokenUsage,
|
|
64
|
+
ctx,
|
|
65
|
+
)
|
|
66
|
+
from draive.similarity import mmr_similarity, similarity
|
|
67
|
+
from draive.splitters import split_text
|
|
68
|
+
from draive.tokenization import TextTokenizer, Tokenization, count_text_tokens, tokenize_text
|
|
69
|
+
from draive.tools import (
|
|
70
|
+
Tool,
|
|
71
|
+
Toolbox,
|
|
72
|
+
ToolCallContext,
|
|
73
|
+
ToolCallStatus,
|
|
74
|
+
ToolCallUpdate,
|
|
75
|
+
ToolException,
|
|
76
|
+
ToolsUpdatesContext,
|
|
77
|
+
tool,
|
|
78
|
+
)
|
|
79
|
+
from draive.types import (
|
|
80
|
+
MISSING,
|
|
81
|
+
Argument,
|
|
82
|
+
Field,
|
|
83
|
+
ImageBase64Content,
|
|
84
|
+
ImageContent,
|
|
85
|
+
ImageURLContent,
|
|
86
|
+
Memory,
|
|
87
|
+
MissingValue,
|
|
88
|
+
Model,
|
|
89
|
+
MultimodalContent,
|
|
90
|
+
ReadOnlyMemory,
|
|
91
|
+
State,
|
|
92
|
+
UpdateSend,
|
|
93
|
+
)
|
|
94
|
+
from draive.utils import allowing_early_exit, auto_retry, cache, with_early_exit
|
|
95
|
+
|
|
96
|
+
__all__ = [
|
|
97
|
+
"Conversation",
|
|
98
|
+
"ImageContent",
|
|
99
|
+
"ImageBase64Content",
|
|
100
|
+
"ImageURLContent",
|
|
101
|
+
"UpdateSend",
|
|
102
|
+
"Embedded",
|
|
103
|
+
"Embedder",
|
|
104
|
+
"Embedding",
|
|
105
|
+
"Model",
|
|
106
|
+
"Memory",
|
|
107
|
+
"ModelGeneration",
|
|
108
|
+
"ModelGenerator",
|
|
109
|
+
"ReadOnlyMemory",
|
|
110
|
+
"DependenciesScope",
|
|
111
|
+
"ScopeDependency",
|
|
112
|
+
"ScopeMetric",
|
|
113
|
+
"MetricsScope",
|
|
114
|
+
"StateScope",
|
|
115
|
+
"State",
|
|
116
|
+
"Field",
|
|
117
|
+
"TextGeneration",
|
|
118
|
+
"TextGenerator",
|
|
119
|
+
"ImageGeneration",
|
|
120
|
+
"ImageGenerator",
|
|
121
|
+
"generate_image",
|
|
122
|
+
"TextTokenizer",
|
|
123
|
+
"Tokenization",
|
|
124
|
+
"TokenUsage",
|
|
125
|
+
"auto_retry",
|
|
126
|
+
"cache",
|
|
127
|
+
"conversation_completion",
|
|
128
|
+
"tokenize_text",
|
|
129
|
+
"count_text_tokens",
|
|
130
|
+
"ctx",
|
|
131
|
+
"embed_text",
|
|
132
|
+
"generate_model",
|
|
133
|
+
"generate_text",
|
|
134
|
+
"load_env",
|
|
135
|
+
"getenv_bool",
|
|
136
|
+
"getenv_float",
|
|
137
|
+
"getenv_int",
|
|
138
|
+
"getenv_str",
|
|
139
|
+
"mmr_similarity",
|
|
140
|
+
"similarity",
|
|
141
|
+
"split_sequence",
|
|
142
|
+
"split_text",
|
|
143
|
+
"tool",
|
|
144
|
+
"Argument",
|
|
145
|
+
"allowing_early_exit",
|
|
146
|
+
"with_early_exit",
|
|
147
|
+
"setup_logging",
|
|
148
|
+
"MissingValue",
|
|
149
|
+
"MISSING",
|
|
150
|
+
"MultimodalContent",
|
|
151
|
+
"Tool",
|
|
152
|
+
"ToolException",
|
|
153
|
+
"Toolbox",
|
|
154
|
+
"ToolCallContext",
|
|
155
|
+
"ToolCallStatus",
|
|
156
|
+
"ToolCallUpdate",
|
|
157
|
+
"ToolException",
|
|
158
|
+
"ToolsUpdatesContext",
|
|
159
|
+
"LMM",
|
|
160
|
+
"LMMCompletion",
|
|
161
|
+
"LMMCompletionContent",
|
|
162
|
+
"LMMCompletionMessage",
|
|
163
|
+
"LMMCompletionStreamingUpdate",
|
|
164
|
+
"LMMCompletionStream",
|
|
165
|
+
"lmm_completion",
|
|
166
|
+
"ConversationMessage",
|
|
167
|
+
"ConversationMessageContent",
|
|
168
|
+
"Conversation",
|
|
169
|
+
"ConversationCompletionStream",
|
|
170
|
+
"ConversationCompletion",
|
|
171
|
+
"conversation_completion",
|
|
172
|
+
"lmm_conversation_completion",
|
|
173
|
+
"OpenAIException",
|
|
174
|
+
"OpenAIChatConfig",
|
|
175
|
+
"OpenAIClient",
|
|
176
|
+
"OpenAIEmbeddingConfig",
|
|
177
|
+
"OpenAIImageGenerationConfig",
|
|
178
|
+
"openai_tokenize_text",
|
|
179
|
+
"openai_embed_text",
|
|
180
|
+
"openai_lmm_completion",
|
|
181
|
+
"openai_generate_image",
|
|
182
|
+
"MistralException",
|
|
183
|
+
"MistralClient",
|
|
184
|
+
"MistralChatConfig",
|
|
185
|
+
"mistral_lmm_completion",
|
|
186
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from draive.conversation.call import conversation_completion
|
|
2
|
+
from draive.conversation.completion import ConversationCompletion, ConversationCompletionStream
|
|
3
|
+
from draive.conversation.lmm import lmm_conversation_completion
|
|
4
|
+
from draive.conversation.message import (
|
|
5
|
+
ConversationMessage,
|
|
6
|
+
ConversationMessageContent,
|
|
7
|
+
)
|
|
8
|
+
from draive.conversation.state import Conversation
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ConversationMessage",
|
|
12
|
+
"ConversationMessageContent",
|
|
13
|
+
"Conversation",
|
|
14
|
+
"ConversationCompletionStream",
|
|
15
|
+
"ConversationCompletion",
|
|
16
|
+
"conversation_completion",
|
|
17
|
+
"lmm_conversation_completion",
|
|
18
|
+
]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Literal, overload
|
|
2
|
+
|
|
3
|
+
from draive.conversation.completion import ConversationCompletionStream
|
|
4
|
+
from draive.conversation.message import (
|
|
5
|
+
ConversationMessage,
|
|
6
|
+
ConversationMessageContent,
|
|
7
|
+
ConversationStreamingUpdate,
|
|
8
|
+
)
|
|
9
|
+
from draive.conversation.state import Conversation
|
|
10
|
+
from draive.scope import ctx
|
|
11
|
+
from draive.tools import Toolbox
|
|
12
|
+
from draive.types import Memory, UpdateSend
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"conversation_completion",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@overload
|
|
20
|
+
async def conversation_completion(
|
|
21
|
+
*,
|
|
22
|
+
instruction: str,
|
|
23
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
24
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
25
|
+
tools: Toolbox | None = None,
|
|
26
|
+
stream: Literal[True],
|
|
27
|
+
) -> ConversationCompletionStream:
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@overload
|
|
32
|
+
async def conversation_completion(
|
|
33
|
+
*,
|
|
34
|
+
instruction: str,
|
|
35
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
36
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
37
|
+
tools: Toolbox | None = None,
|
|
38
|
+
stream: UpdateSend[ConversationStreamingUpdate],
|
|
39
|
+
) -> ConversationMessage:
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
async def conversation_completion(
|
|
45
|
+
*,
|
|
46
|
+
instruction: str,
|
|
47
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
48
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
49
|
+
tools: Toolbox | None = None,
|
|
50
|
+
) -> ConversationMessage:
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def conversation_completion(
|
|
55
|
+
*,
|
|
56
|
+
instruction: str,
|
|
57
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
58
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
59
|
+
tools: Toolbox | None = None,
|
|
60
|
+
stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
|
|
61
|
+
) -> ConversationCompletionStream | ConversationMessage:
|
|
62
|
+
conversation: Conversation = ctx.state(Conversation)
|
|
63
|
+
|
|
64
|
+
match stream:
|
|
65
|
+
case False:
|
|
66
|
+
return await conversation.completion(
|
|
67
|
+
instruction=instruction,
|
|
68
|
+
input=input,
|
|
69
|
+
memory=memory or conversation.memory,
|
|
70
|
+
tools=tools or conversation.tools,
|
|
71
|
+
)
|
|
72
|
+
case True:
|
|
73
|
+
return await conversation.completion(
|
|
74
|
+
instruction=instruction,
|
|
75
|
+
input=input,
|
|
76
|
+
memory=memory or conversation.memory,
|
|
77
|
+
tools=tools or conversation.tools,
|
|
78
|
+
stream=True,
|
|
79
|
+
)
|
|
80
|
+
case progress:
|
|
81
|
+
return await conversation.completion(
|
|
82
|
+
instruction=instruction,
|
|
83
|
+
input=input,
|
|
84
|
+
memory=memory or conversation.memory,
|
|
85
|
+
tools=tools or conversation.tools,
|
|
86
|
+
stream=progress,
|
|
87
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Literal, Protocol, overload, runtime_checkable
|
|
2
|
+
|
|
3
|
+
from draive.conversation.message import (
|
|
4
|
+
ConversationMessage,
|
|
5
|
+
ConversationMessageContent,
|
|
6
|
+
ConversationStreamingUpdate,
|
|
7
|
+
)
|
|
8
|
+
from draive.lmm import LMMCompletionStream
|
|
9
|
+
from draive.tools import Toolbox
|
|
10
|
+
from draive.types import Memory, UpdateSend
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ConversationCompletionStream",
|
|
14
|
+
"ConversationCompletion",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
ConversationCompletionStream = LMMCompletionStream
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@runtime_checkable
|
|
22
|
+
class ConversationCompletion(Protocol):
|
|
23
|
+
@overload
|
|
24
|
+
async def __call__(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
instruction: str,
|
|
28
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
29
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
30
|
+
tools: Toolbox | None = None,
|
|
31
|
+
stream: Literal[True],
|
|
32
|
+
) -> ConversationCompletionStream:
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
async def __call__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
instruction: str,
|
|
40
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
41
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
42
|
+
tools: Toolbox | None = None,
|
|
43
|
+
stream: UpdateSend[ConversationStreamingUpdate],
|
|
44
|
+
) -> ConversationMessage:
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
@overload
|
|
48
|
+
async def __call__(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
instruction: str,
|
|
52
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
53
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
54
|
+
tools: Toolbox | None = None,
|
|
55
|
+
) -> ConversationMessage:
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
async def __call__( # noqa: PLR0913
|
|
59
|
+
self,
|
|
60
|
+
*,
|
|
61
|
+
instruction: str,
|
|
62
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
63
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
64
|
+
tools: Toolbox | None = None,
|
|
65
|
+
stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
|
|
66
|
+
) -> ConversationCompletionStream | ConversationMessage:
|
|
67
|
+
...
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from datetime import UTC, datetime
|
|
2
|
+
from typing import Literal, overload
|
|
3
|
+
|
|
4
|
+
from draive.conversation.completion import ConversationCompletionStream
|
|
5
|
+
from draive.conversation.message import (
|
|
6
|
+
ConversationMessage,
|
|
7
|
+
ConversationMessageContent,
|
|
8
|
+
ConversationStreamingUpdate,
|
|
9
|
+
)
|
|
10
|
+
from draive.lmm import LMMCompletionMessage, lmm_completion
|
|
11
|
+
from draive.tools import Toolbox
|
|
12
|
+
from draive.types import Memory, UpdateSend
|
|
13
|
+
from draive.utils import AsyncStreamTask
|
|
14
|
+
|
|
15
|
+
__all__: list[str] = [
|
|
16
|
+
"lmm_conversation_completion",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@overload
|
|
21
|
+
async def lmm_conversation_completion(
|
|
22
|
+
*,
|
|
23
|
+
instruction: str,
|
|
24
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
25
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
26
|
+
tools: Toolbox | None = None,
|
|
27
|
+
stream: Literal[True],
|
|
28
|
+
) -> ConversationCompletionStream:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
async def lmm_conversation_completion(
|
|
34
|
+
*,
|
|
35
|
+
instruction: str,
|
|
36
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
37
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
38
|
+
tools: Toolbox | None = None,
|
|
39
|
+
stream: UpdateSend[ConversationStreamingUpdate],
|
|
40
|
+
) -> ConversationMessage:
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
async def lmm_conversation_completion(
|
|
46
|
+
*,
|
|
47
|
+
instruction: str,
|
|
48
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
49
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
50
|
+
tools: Toolbox | None = None,
|
|
51
|
+
) -> ConversationMessage:
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def lmm_conversation_completion(
|
|
56
|
+
*,
|
|
57
|
+
instruction: str,
|
|
58
|
+
input: ConversationMessage | ConversationMessageContent, # noqa: A002
|
|
59
|
+
memory: Memory[ConversationMessage] | None = None,
|
|
60
|
+
tools: Toolbox | None = None,
|
|
61
|
+
stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
|
|
62
|
+
) -> ConversationCompletionStream | ConversationMessage:
|
|
63
|
+
system_message: LMMCompletionMessage = LMMCompletionMessage(
|
|
64
|
+
role="system",
|
|
65
|
+
content=instruction,
|
|
66
|
+
)
|
|
67
|
+
user_message: ConversationMessage
|
|
68
|
+
if isinstance(input, ConversationMessage):
|
|
69
|
+
user_message = input
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
user_message = ConversationMessage(
|
|
73
|
+
timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
|
|
74
|
+
role="user",
|
|
75
|
+
content=input,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
context: list[LMMCompletionMessage]
|
|
79
|
+
|
|
80
|
+
if memory:
|
|
81
|
+
context = [
|
|
82
|
+
system_message,
|
|
83
|
+
*await memory.recall(),
|
|
84
|
+
user_message,
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
else:
|
|
88
|
+
context = [
|
|
89
|
+
system_message,
|
|
90
|
+
user_message,
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
match stream:
|
|
94
|
+
case True:
|
|
95
|
+
|
|
96
|
+
async def stream_task(
|
|
97
|
+
update: UpdateSend[ConversationStreamingUpdate],
|
|
98
|
+
) -> None:
|
|
99
|
+
nonlocal memory
|
|
100
|
+
completion: LMMCompletionMessage = await lmm_completion(
|
|
101
|
+
context=context,
|
|
102
|
+
tools=tools,
|
|
103
|
+
stream=update,
|
|
104
|
+
)
|
|
105
|
+
response_message: ConversationMessage = ConversationMessage(
|
|
106
|
+
timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
|
|
107
|
+
role=completion.role,
|
|
108
|
+
content=completion.content,
|
|
109
|
+
)
|
|
110
|
+
if memory := memory:
|
|
111
|
+
await memory.remember(
|
|
112
|
+
[
|
|
113
|
+
user_message,
|
|
114
|
+
response_message,
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return AsyncStreamTask(job=stream_task)
|
|
119
|
+
|
|
120
|
+
case False:
|
|
121
|
+
completion: LMMCompletionMessage = await lmm_completion(
|
|
122
|
+
context=context,
|
|
123
|
+
tools=tools,
|
|
124
|
+
)
|
|
125
|
+
response_message: ConversationMessage = ConversationMessage(
|
|
126
|
+
timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
|
|
127
|
+
role=completion.role,
|
|
128
|
+
content=completion.content,
|
|
129
|
+
)
|
|
130
|
+
if memory := memory:
|
|
131
|
+
await memory.remember(
|
|
132
|
+
[
|
|
133
|
+
user_message,
|
|
134
|
+
response_message,
|
|
135
|
+
],
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return response_message
|
|
139
|
+
|
|
140
|
+
case update:
|
|
141
|
+
completion: LMMCompletionMessage = await lmm_completion(
|
|
142
|
+
context=context,
|
|
143
|
+
tools=tools,
|
|
144
|
+
stream=update,
|
|
145
|
+
)
|
|
146
|
+
response_message: ConversationMessage = ConversationMessage(
|
|
147
|
+
timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
|
|
148
|
+
role=completion.role,
|
|
149
|
+
content=completion.content,
|
|
150
|
+
)
|
|
151
|
+
if memory := memory:
|
|
152
|
+
await memory.remember(
|
|
153
|
+
[
|
|
154
|
+
user_message,
|
|
155
|
+
response_message,
|
|
156
|
+
],
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return response_message
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from draive.lmm import LMMCompletionContent, LMMCompletionMessage, LMMCompletionStreamingUpdate
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"ConversationMessage",
|
|
5
|
+
"ConversationMessageContent",
|
|
6
|
+
"ConversationStreamingUpdate",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConversationMessage(LMMCompletionMessage):
|
|
11
|
+
author: str | None = None
|
|
12
|
+
timestamp: str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
ConversationMessageContent = LMMCompletionContent
|
|
16
|
+
ConversationStreamingUpdate = LMMCompletionStreamingUpdate
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from draive.conversation.completion import ConversationCompletion
|
|
2
|
+
from draive.conversation.lmm import lmm_conversation_completion
|
|
3
|
+
from draive.conversation.message import (
|
|
4
|
+
ConversationMessage,
|
|
5
|
+
)
|
|
6
|
+
from draive.tools import Toolbox
|
|
7
|
+
from draive.types import Memory, State
|
|
8
|
+
|
|
9
|
+
__all__: list[str] = [
|
|
10
|
+
"Conversation",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Conversation(State):
|
|
15
|
+
completion: ConversationCompletion = lmm_conversation_completion
|
|
16
|
+
memory: Memory[ConversationMessage] | None = None
|
|
17
|
+
tools: Toolbox | None = None
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from draive.embedding.call import embed_text
|
|
2
|
+
from draive.embedding.embedded import Embedded
|
|
3
|
+
from draive.embedding.embedder import Embedder
|
|
4
|
+
from draive.embedding.state import Embedding
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"Embedder",
|
|
8
|
+
"Embedding",
|
|
9
|
+
"Embedded",
|
|
10
|
+
"embed_text",
|
|
11
|
+
]
|
draive/embedding/call.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
from draive.embedding.embedded import Embedded
|
|
4
|
+
from draive.embedding.state import Embedding
|
|
5
|
+
from draive.scope import ctx
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"embed_text",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def embed_text(
|
|
13
|
+
values: Iterable[str],
|
|
14
|
+
) -> list[Embedded[str]]:
|
|
15
|
+
return await ctx.state(Embedding).embed_text(values=values)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Generic, TypeVar
|
|
2
|
+
|
|
3
|
+
from draive.types.state import State
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Embedded",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_Embedded = TypeVar(
|
|
11
|
+
"_Embedded",
|
|
12
|
+
bound=object,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Embedded(State, Generic[_Embedded]):
|
|
17
|
+
value: _Embedded
|
|
18
|
+
vector: list[float]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Generic, Protocol, TypeVar, runtime_checkable
|
|
3
|
+
|
|
4
|
+
from draive.embedding.embedded import Embedded
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"Embedder",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_Embeddable = TypeVar(
|
|
12
|
+
"_Embeddable",
|
|
13
|
+
bound=object,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@runtime_checkable
|
|
18
|
+
class Embedder(Protocol, Generic[_Embeddable]):
|
|
19
|
+
async def __call__(
|
|
20
|
+
self,
|
|
21
|
+
values: Iterable[_Embeddable],
|
|
22
|
+
) -> list[Embedded[_Embeddable]]:
|
|
23
|
+
...
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from draive.generation.image import ImageGeneration, ImageGenerator, generate_image
|
|
2
|
+
from draive.generation.model import ModelGeneration, ModelGenerator, generate_model
|
|
3
|
+
from draive.generation.text import TextGeneration, TextGenerator, generate_text
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ModelGenerator",
|
|
7
|
+
"ModelGeneration",
|
|
8
|
+
"generate_model",
|
|
9
|
+
"TextGenerator",
|
|
10
|
+
"TextGeneration",
|
|
11
|
+
"generate_text",
|
|
12
|
+
"ImageGeneration",
|
|
13
|
+
"ImageGenerator",
|
|
14
|
+
"generate_image",
|
|
15
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from draive.generation.image.state import ImageGeneration
|
|
2
|
+
from draive.scope import ctx
|
|
3
|
+
from draive.types import ImageContent
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"generate_image",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def generate_image(
|
|
11
|
+
*,
|
|
12
|
+
instruction: str,
|
|
13
|
+
) -> ImageContent:
|
|
14
|
+
return await ctx.state(ImageGeneration).generate(
|
|
15
|
+
instruction=instruction,
|
|
16
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
from draive.types import ImageContent
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ImageGenerator",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@runtime_checkable
|
|
11
|
+
class ImageGenerator(Protocol):
|
|
12
|
+
async def __call__(
|
|
13
|
+
self,
|
|
14
|
+
*,
|
|
15
|
+
instruction: str,
|
|
16
|
+
) -> ImageContent:
|
|
17
|
+
...
|