draive 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. draive/__init__.py +186 -0
  2. draive/conversation/__init__.py +18 -0
  3. draive/conversation/call.py +87 -0
  4. draive/conversation/completion.py +67 -0
  5. draive/conversation/lmm.py +159 -0
  6. draive/conversation/message.py +16 -0
  7. draive/conversation/state.py +17 -0
  8. draive/embedding/__init__.py +11 -0
  9. draive/embedding/call.py +15 -0
  10. draive/embedding/embedded.py +18 -0
  11. draive/embedding/embedder.py +23 -0
  12. draive/embedding/state.py +10 -0
  13. draive/generation/__init__.py +15 -0
  14. draive/generation/image/__init__.py +9 -0
  15. draive/generation/image/call.py +16 -0
  16. draive/generation/image/generator.py +17 -0
  17. draive/generation/image/state.py +10 -0
  18. draive/generation/model/__init__.py +9 -0
  19. draive/generation/model/call.py +34 -0
  20. draive/generation/model/generator.py +29 -0
  21. draive/generation/model/lmm.py +85 -0
  22. draive/generation/model/state.py +13 -0
  23. draive/generation/text/__init__.py +9 -0
  24. draive/generation/text/call.py +26 -0
  25. draive/generation/text/generator.py +22 -0
  26. draive/generation/text/lmm.py +63 -0
  27. draive/generation/text/state.py +13 -0
  28. draive/helpers/__init__.py +13 -0
  29. draive/helpers/env.py +139 -0
  30. draive/helpers/logs.py +59 -0
  31. draive/helpers/split_sequence.py +20 -0
  32. draive/lmm/__init__.py +18 -0
  33. draive/lmm/call.py +73 -0
  34. draive/lmm/completion.py +64 -0
  35. draive/lmm/message.py +50 -0
  36. draive/lmm/state.py +10 -0
  37. draive/mistral/__init__.py +11 -0
  38. draive/mistral/chat_response.py +92 -0
  39. draive/mistral/chat_stream.py +130 -0
  40. draive/mistral/chat_tools.py +111 -0
  41. draive/mistral/client.py +112 -0
  42. draive/mistral/config.py +56 -0
  43. draive/mistral/errors.py +7 -0
  44. draive/mistral/lmm.py +213 -0
  45. draive/openai/__init__.py +23 -0
  46. draive/openai/chat_response.py +97 -0
  47. draive/openai/chat_stream.py +120 -0
  48. draive/openai/chat_tools.py +139 -0
  49. draive/openai/client.py +212 -0
  50. draive/openai/config.py +122 -0
  51. draive/openai/embedding.py +33 -0
  52. draive/openai/errors.py +7 -0
  53. draive/openai/images.py +30 -0
  54. draive/openai/lmm.py +236 -0
  55. draive/openai/tokenization.py +22 -0
  56. draive/py.typed +0 -0
  57. draive/scope/__init__.py +16 -0
  58. draive/scope/access.py +330 -0
  59. draive/scope/dependencies.py +63 -0
  60. draive/scope/errors.py +17 -0
  61. draive/scope/metrics.py +462 -0
  62. draive/scope/state.py +60 -0
  63. draive/similarity/__init__.py +7 -0
  64. draive/similarity/cosine.py +35 -0
  65. draive/similarity/mmr.py +67 -0
  66. draive/similarity/similarity.py +32 -0
  67. draive/splitters/__init__.py +5 -0
  68. draive/splitters/basic.py +130 -0
  69. draive/tokenization/__init__.py +10 -0
  70. draive/tokenization/call.py +18 -0
  71. draive/tokenization/state.py +10 -0
  72. draive/tokenization/text.py +14 -0
  73. draive/tools/__init__.py +19 -0
  74. draive/tools/errors.py +7 -0
  75. draive/tools/state.py +31 -0
  76. draive/tools/tool.py +184 -0
  77. draive/tools/toolbox.py +51 -0
  78. draive/tools/update.py +18 -0
  79. draive/types/__init__.py +45 -0
  80. draive/types/images.py +18 -0
  81. draive/types/memory.py +55 -0
  82. draive/types/missing.py +28 -0
  83. draive/types/model.py +50 -0
  84. draive/types/multimodal.py +8 -0
  85. draive/types/parameters.py +847 -0
  86. draive/types/specification.py +394 -0
  87. draive/types/state.py +16 -0
  88. draive/types/updates.py +22 -0
  89. draive/utils/__init__.py +13 -0
  90. draive/utils/cache.py +177 -0
  91. draive/utils/early_exit.py +125 -0
  92. draive/utils/retry.py +167 -0
  93. draive/utils/stream.py +105 -0
  94. draive-0.9.1.dist-info/LICENSE +21 -0
  95. draive-0.9.1.dist-info/METADATA +76 -0
  96. draive-0.9.1.dist-info/RECORD +98 -0
  97. draive-0.9.1.dist-info/WHEEL +5 -0
  98. draive-0.9.1.dist-info/top_level.txt +1 -0
draive/__init__.py ADDED
@@ -0,0 +1,186 @@
1
+ from draive.conversation import (
2
+ Conversation,
3
+ ConversationCompletion,
4
+ ConversationCompletionStream,
5
+ ConversationMessage,
6
+ ConversationMessageContent,
7
+ conversation_completion,
8
+ lmm_conversation_completion,
9
+ )
10
+ from draive.embedding import Embedded, Embedder, Embedding, embed_text
11
+ from draive.generation import (
12
+ ImageGeneration,
13
+ ImageGenerator,
14
+ ModelGeneration,
15
+ ModelGenerator,
16
+ TextGeneration,
17
+ TextGenerator,
18
+ generate_image,
19
+ generate_model,
20
+ generate_text,
21
+ )
22
+ from draive.helpers import (
23
+ getenv_bool,
24
+ getenv_float,
25
+ getenv_int,
26
+ getenv_str,
27
+ load_env,
28
+ setup_logging,
29
+ split_sequence,
30
+ )
31
+ from draive.lmm import (
32
+ LMM,
33
+ LMMCompletion,
34
+ LMMCompletionContent,
35
+ LMMCompletionMessage,
36
+ LMMCompletionStream,
37
+ LMMCompletionStreamingUpdate,
38
+ lmm_completion,
39
+ )
40
+ from draive.mistral import (
41
+ MistralChatConfig,
42
+ MistralClient,
43
+ MistralException,
44
+ mistral_lmm_completion,
45
+ )
46
+ from draive.openai import (
47
+ OpenAIChatConfig,
48
+ OpenAIClient,
49
+ OpenAIEmbeddingConfig,
50
+ OpenAIException,
51
+ OpenAIImageGenerationConfig,
52
+ openai_embed_text,
53
+ openai_generate_image,
54
+ openai_lmm_completion,
55
+ openai_tokenize_text,
56
+ )
57
+ from draive.scope import (
58
+ DependenciesScope,
59
+ MetricsScope,
60
+ ScopeDependency,
61
+ ScopeMetric,
62
+ StateScope,
63
+ TokenUsage,
64
+ ctx,
65
+ )
66
+ from draive.similarity import mmr_similarity, similarity
67
+ from draive.splitters import split_text
68
+ from draive.tokenization import TextTokenizer, Tokenization, count_text_tokens, tokenize_text
69
+ from draive.tools import (
70
+ Tool,
71
+ Toolbox,
72
+ ToolCallContext,
73
+ ToolCallStatus,
74
+ ToolCallUpdate,
75
+ ToolException,
76
+ ToolsUpdatesContext,
77
+ tool,
78
+ )
79
+ from draive.types import (
80
+ MISSING,
81
+ Argument,
82
+ Field,
83
+ ImageBase64Content,
84
+ ImageContent,
85
+ ImageURLContent,
86
+ Memory,
87
+ MissingValue,
88
+ Model,
89
+ MultimodalContent,
90
+ ReadOnlyMemory,
91
+ State,
92
+ UpdateSend,
93
+ )
94
+ from draive.utils import allowing_early_exit, auto_retry, cache, with_early_exit
95
+
96
+ __all__ = [
97
+ "Conversation",
98
+ "ImageContent",
99
+ "ImageBase64Content",
100
+ "ImageURLContent",
101
+ "UpdateSend",
102
+ "Embedded",
103
+ "Embedder",
104
+ "Embedding",
105
+ "Model",
106
+ "Memory",
107
+ "ModelGeneration",
108
+ "ModelGenerator",
109
+ "ReadOnlyMemory",
110
+ "DependenciesScope",
111
+ "ScopeDependency",
112
+ "ScopeMetric",
113
+ "MetricsScope",
114
+ "StateScope",
115
+ "State",
116
+ "Field",
117
+ "TextGeneration",
118
+ "TextGenerator",
119
+ "ImageGeneration",
120
+ "ImageGenerator",
121
+ "generate_image",
122
+ "TextTokenizer",
123
+ "Tokenization",
124
+ "TokenUsage",
125
+ "auto_retry",
126
+ "cache",
127
+ "conversation_completion",
128
+ "tokenize_text",
129
+ "count_text_tokens",
130
+ "ctx",
131
+ "embed_text",
132
+ "generate_model",
133
+ "generate_text",
134
+ "load_env",
135
+ "getenv_bool",
136
+ "getenv_float",
137
+ "getenv_int",
138
+ "getenv_str",
139
+ "mmr_similarity",
140
+ "similarity",
141
+ "split_sequence",
142
+ "split_text",
143
+ "tool",
144
+ "Argument",
145
+ "allowing_early_exit",
146
+ "with_early_exit",
147
+ "setup_logging",
148
+ "MissingValue",
149
+ "MISSING",
150
+ "MultimodalContent",
151
+ "Tool",
152
+ "ToolException",
153
+ "Toolbox",
154
+ "ToolCallContext",
155
+ "ToolCallStatus",
156
+ "ToolCallUpdate",
157
+ "ToolException",
158
+ "ToolsUpdatesContext",
159
+ "LMM",
160
+ "LMMCompletion",
161
+ "LMMCompletionContent",
162
+ "LMMCompletionMessage",
163
+ "LMMCompletionStreamingUpdate",
164
+ "LMMCompletionStream",
165
+ "lmm_completion",
166
+ "ConversationMessage",
167
+ "ConversationMessageContent",
168
+ "Conversation",
169
+ "ConversationCompletionStream",
170
+ "ConversationCompletion",
171
+ "conversation_completion",
172
+ "lmm_conversation_completion",
173
+ "OpenAIException",
174
+ "OpenAIChatConfig",
175
+ "OpenAIClient",
176
+ "OpenAIEmbeddingConfig",
177
+ "OpenAIImageGenerationConfig",
178
+ "openai_tokenize_text",
179
+ "openai_embed_text",
180
+ "openai_lmm_completion",
181
+ "openai_generate_image",
182
+ "MistralException",
183
+ "MistralClient",
184
+ "MistralChatConfig",
185
+ "mistral_lmm_completion",
186
+ ]
@@ -0,0 +1,18 @@
1
+ from draive.conversation.call import conversation_completion
2
+ from draive.conversation.completion import ConversationCompletion, ConversationCompletionStream
3
+ from draive.conversation.lmm import lmm_conversation_completion
4
+ from draive.conversation.message import (
5
+ ConversationMessage,
6
+ ConversationMessageContent,
7
+ )
8
+ from draive.conversation.state import Conversation
9
+
10
+ __all__ = [
11
+ "ConversationMessage",
12
+ "ConversationMessageContent",
13
+ "Conversation",
14
+ "ConversationCompletionStream",
15
+ "ConversationCompletion",
16
+ "conversation_completion",
17
+ "lmm_conversation_completion",
18
+ ]
@@ -0,0 +1,87 @@
1
+ from typing import Literal, overload
2
+
3
+ from draive.conversation.completion import ConversationCompletionStream
4
+ from draive.conversation.message import (
5
+ ConversationMessage,
6
+ ConversationMessageContent,
7
+ ConversationStreamingUpdate,
8
+ )
9
+ from draive.conversation.state import Conversation
10
+ from draive.scope import ctx
11
+ from draive.tools import Toolbox
12
+ from draive.types import Memory, UpdateSend
13
+
14
+ __all__ = [
15
+ "conversation_completion",
16
+ ]
17
+
18
+
19
+ @overload
20
+ async def conversation_completion(
21
+ *,
22
+ instruction: str,
23
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
24
+ memory: Memory[ConversationMessage] | None = None,
25
+ tools: Toolbox | None = None,
26
+ stream: Literal[True],
27
+ ) -> ConversationCompletionStream:
28
+ ...
29
+
30
+
31
+ @overload
32
+ async def conversation_completion(
33
+ *,
34
+ instruction: str,
35
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
36
+ memory: Memory[ConversationMessage] | None = None,
37
+ tools: Toolbox | None = None,
38
+ stream: UpdateSend[ConversationStreamingUpdate],
39
+ ) -> ConversationMessage:
40
+ ...
41
+
42
+
43
+ @overload
44
+ async def conversation_completion(
45
+ *,
46
+ instruction: str,
47
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
48
+ memory: Memory[ConversationMessage] | None = None,
49
+ tools: Toolbox | None = None,
50
+ ) -> ConversationMessage:
51
+ ...
52
+
53
+
54
+ async def conversation_completion(
55
+ *,
56
+ instruction: str,
57
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
58
+ memory: Memory[ConversationMessage] | None = None,
59
+ tools: Toolbox | None = None,
60
+ stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
61
+ ) -> ConversationCompletionStream | ConversationMessage:
62
+ conversation: Conversation = ctx.state(Conversation)
63
+
64
+ match stream:
65
+ case False:
66
+ return await conversation.completion(
67
+ instruction=instruction,
68
+ input=input,
69
+ memory=memory or conversation.memory,
70
+ tools=tools or conversation.tools,
71
+ )
72
+ case True:
73
+ return await conversation.completion(
74
+ instruction=instruction,
75
+ input=input,
76
+ memory=memory or conversation.memory,
77
+ tools=tools or conversation.tools,
78
+ stream=True,
79
+ )
80
+ case progress:
81
+ return await conversation.completion(
82
+ instruction=instruction,
83
+ input=input,
84
+ memory=memory or conversation.memory,
85
+ tools=tools or conversation.tools,
86
+ stream=progress,
87
+ )
@@ -0,0 +1,67 @@
1
+ from typing import Literal, Protocol, overload, runtime_checkable
2
+
3
+ from draive.conversation.message import (
4
+ ConversationMessage,
5
+ ConversationMessageContent,
6
+ ConversationStreamingUpdate,
7
+ )
8
+ from draive.lmm import LMMCompletionStream
9
+ from draive.tools import Toolbox
10
+ from draive.types import Memory, UpdateSend
11
+
12
+ __all__ = [
13
+ "ConversationCompletionStream",
14
+ "ConversationCompletion",
15
+ ]
16
+
17
+
18
+ ConversationCompletionStream = LMMCompletionStream
19
+
20
+
21
+ @runtime_checkable
22
+ class ConversationCompletion(Protocol):
23
+ @overload
24
+ async def __call__(
25
+ self,
26
+ *,
27
+ instruction: str,
28
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
29
+ memory: Memory[ConversationMessage] | None = None,
30
+ tools: Toolbox | None = None,
31
+ stream: Literal[True],
32
+ ) -> ConversationCompletionStream:
33
+ ...
34
+
35
+ @overload
36
+ async def __call__(
37
+ self,
38
+ *,
39
+ instruction: str,
40
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
41
+ memory: Memory[ConversationMessage] | None = None,
42
+ tools: Toolbox | None = None,
43
+ stream: UpdateSend[ConversationStreamingUpdate],
44
+ ) -> ConversationMessage:
45
+ ...
46
+
47
+ @overload
48
+ async def __call__(
49
+ self,
50
+ *,
51
+ instruction: str,
52
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
53
+ memory: Memory[ConversationMessage] | None = None,
54
+ tools: Toolbox | None = None,
55
+ ) -> ConversationMessage:
56
+ ...
57
+
58
+ async def __call__( # noqa: PLR0913
59
+ self,
60
+ *,
61
+ instruction: str,
62
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
63
+ memory: Memory[ConversationMessage] | None = None,
64
+ tools: Toolbox | None = None,
65
+ stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
66
+ ) -> ConversationCompletionStream | ConversationMessage:
67
+ ...
@@ -0,0 +1,159 @@
1
+ from datetime import UTC, datetime
2
+ from typing import Literal, overload
3
+
4
+ from draive.conversation.completion import ConversationCompletionStream
5
+ from draive.conversation.message import (
6
+ ConversationMessage,
7
+ ConversationMessageContent,
8
+ ConversationStreamingUpdate,
9
+ )
10
+ from draive.lmm import LMMCompletionMessage, lmm_completion
11
+ from draive.tools import Toolbox
12
+ from draive.types import Memory, UpdateSend
13
+ from draive.utils import AsyncStreamTask
14
+
15
+ __all__: list[str] = [
16
+ "lmm_conversation_completion",
17
+ ]
18
+
19
+
20
+ @overload
21
+ async def lmm_conversation_completion(
22
+ *,
23
+ instruction: str,
24
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
25
+ memory: Memory[ConversationMessage] | None = None,
26
+ tools: Toolbox | None = None,
27
+ stream: Literal[True],
28
+ ) -> ConversationCompletionStream:
29
+ ...
30
+
31
+
32
+ @overload
33
+ async def lmm_conversation_completion(
34
+ *,
35
+ instruction: str,
36
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
37
+ memory: Memory[ConversationMessage] | None = None,
38
+ tools: Toolbox | None = None,
39
+ stream: UpdateSend[ConversationStreamingUpdate],
40
+ ) -> ConversationMessage:
41
+ ...
42
+
43
+
44
+ @overload
45
+ async def lmm_conversation_completion(
46
+ *,
47
+ instruction: str,
48
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
49
+ memory: Memory[ConversationMessage] | None = None,
50
+ tools: Toolbox | None = None,
51
+ ) -> ConversationMessage:
52
+ ...
53
+
54
+
55
+ async def lmm_conversation_completion(
56
+ *,
57
+ instruction: str,
58
+ input: ConversationMessage | ConversationMessageContent, # noqa: A002
59
+ memory: Memory[ConversationMessage] | None = None,
60
+ tools: Toolbox | None = None,
61
+ stream: UpdateSend[ConversationStreamingUpdate] | bool = False,
62
+ ) -> ConversationCompletionStream | ConversationMessage:
63
+ system_message: LMMCompletionMessage = LMMCompletionMessage(
64
+ role="system",
65
+ content=instruction,
66
+ )
67
+ user_message: ConversationMessage
68
+ if isinstance(input, ConversationMessage):
69
+ user_message = input
70
+
71
+ else:
72
+ user_message = ConversationMessage(
73
+ timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
74
+ role="user",
75
+ content=input,
76
+ )
77
+
78
+ context: list[LMMCompletionMessage]
79
+
80
+ if memory:
81
+ context = [
82
+ system_message,
83
+ *await memory.recall(),
84
+ user_message,
85
+ ]
86
+
87
+ else:
88
+ context = [
89
+ system_message,
90
+ user_message,
91
+ ]
92
+
93
+ match stream:
94
+ case True:
95
+
96
+ async def stream_task(
97
+ update: UpdateSend[ConversationStreamingUpdate],
98
+ ) -> None:
99
+ nonlocal memory
100
+ completion: LMMCompletionMessage = await lmm_completion(
101
+ context=context,
102
+ tools=tools,
103
+ stream=update,
104
+ )
105
+ response_message: ConversationMessage = ConversationMessage(
106
+ timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
107
+ role=completion.role,
108
+ content=completion.content,
109
+ )
110
+ if memory := memory:
111
+ await memory.remember(
112
+ [
113
+ user_message,
114
+ response_message,
115
+ ],
116
+ )
117
+
118
+ return AsyncStreamTask(job=stream_task)
119
+
120
+ case False:
121
+ completion: LMMCompletionMessage = await lmm_completion(
122
+ context=context,
123
+ tools=tools,
124
+ )
125
+ response_message: ConversationMessage = ConversationMessage(
126
+ timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
127
+ role=completion.role,
128
+ content=completion.content,
129
+ )
130
+ if memory := memory:
131
+ await memory.remember(
132
+ [
133
+ user_message,
134
+ response_message,
135
+ ],
136
+ )
137
+
138
+ return response_message
139
+
140
+ case update:
141
+ completion: LMMCompletionMessage = await lmm_completion(
142
+ context=context,
143
+ tools=tools,
144
+ stream=update,
145
+ )
146
+ response_message: ConversationMessage = ConversationMessage(
147
+ timestamp=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z"),
148
+ role=completion.role,
149
+ content=completion.content,
150
+ )
151
+ if memory := memory:
152
+ await memory.remember(
153
+ [
154
+ user_message,
155
+ response_message,
156
+ ],
157
+ )
158
+
159
+ return response_message
@@ -0,0 +1,16 @@
1
+ from draive.lmm import LMMCompletionContent, LMMCompletionMessage, LMMCompletionStreamingUpdate
2
+
3
+ __all__ = [
4
+ "ConversationMessage",
5
+ "ConversationMessageContent",
6
+ "ConversationStreamingUpdate",
7
+ ]
8
+
9
+
10
+ class ConversationMessage(LMMCompletionMessage):
11
+ author: str | None = None
12
+ timestamp: str | None = None
13
+
14
+
15
+ ConversationMessageContent = LMMCompletionContent
16
+ ConversationStreamingUpdate = LMMCompletionStreamingUpdate
@@ -0,0 +1,17 @@
1
+ from draive.conversation.completion import ConversationCompletion
2
+ from draive.conversation.lmm import lmm_conversation_completion
3
+ from draive.conversation.message import (
4
+ ConversationMessage,
5
+ )
6
+ from draive.tools import Toolbox
7
+ from draive.types import Memory, State
8
+
9
+ __all__: list[str] = [
10
+ "Conversation",
11
+ ]
12
+
13
+
14
+ class Conversation(State):
15
+ completion: ConversationCompletion = lmm_conversation_completion
16
+ memory: Memory[ConversationMessage] | None = None
17
+ tools: Toolbox | None = None
@@ -0,0 +1,11 @@
1
+ from draive.embedding.call import embed_text
2
+ from draive.embedding.embedded import Embedded
3
+ from draive.embedding.embedder import Embedder
4
+ from draive.embedding.state import Embedding
5
+
6
+ __all__ = [
7
+ "Embedder",
8
+ "Embedding",
9
+ "Embedded",
10
+ "embed_text",
11
+ ]
@@ -0,0 +1,15 @@
1
+ from collections.abc import Iterable
2
+
3
+ from draive.embedding.embedded import Embedded
4
+ from draive.embedding.state import Embedding
5
+ from draive.scope import ctx
6
+
7
+ __all__ = [
8
+ "embed_text",
9
+ ]
10
+
11
+
12
+ async def embed_text(
13
+ values: Iterable[str],
14
+ ) -> list[Embedded[str]]:
15
+ return await ctx.state(Embedding).embed_text(values=values)
@@ -0,0 +1,18 @@
1
+ from typing import Generic, TypeVar
2
+
3
+ from draive.types.state import State
4
+
5
+ __all__ = [
6
+ "Embedded",
7
+ ]
8
+
9
+
10
+ _Embedded = TypeVar(
11
+ "_Embedded",
12
+ bound=object,
13
+ )
14
+
15
+
16
+ class Embedded(State, Generic[_Embedded]):
17
+ value: _Embedded
18
+ vector: list[float]
@@ -0,0 +1,23 @@
1
+ from collections.abc import Iterable
2
+ from typing import Generic, Protocol, TypeVar, runtime_checkable
3
+
4
+ from draive.embedding.embedded import Embedded
5
+
6
+ __all__ = [
7
+ "Embedder",
8
+ ]
9
+
10
+
11
+ _Embeddable = TypeVar(
12
+ "_Embeddable",
13
+ bound=object,
14
+ )
15
+
16
+
17
+ @runtime_checkable
18
+ class Embedder(Protocol, Generic[_Embeddable]):
19
+ async def __call__(
20
+ self,
21
+ values: Iterable[_Embeddable],
22
+ ) -> list[Embedded[_Embeddable]]:
23
+ ...
@@ -0,0 +1,10 @@
1
+ from draive.embedding.embedder import Embedder
2
+ from draive.types import State
3
+
4
+ __all__ = [
5
+ "Embedding",
6
+ ]
7
+
8
+
9
+ class Embedding(State):
10
+ embed_text: Embedder[str]
@@ -0,0 +1,15 @@
1
+ from draive.generation.image import ImageGeneration, ImageGenerator, generate_image
2
+ from draive.generation.model import ModelGeneration, ModelGenerator, generate_model
3
+ from draive.generation.text import TextGeneration, TextGenerator, generate_text
4
+
5
+ __all__ = [
6
+ "ModelGenerator",
7
+ "ModelGeneration",
8
+ "generate_model",
9
+ "TextGenerator",
10
+ "TextGeneration",
11
+ "generate_text",
12
+ "ImageGeneration",
13
+ "ImageGenerator",
14
+ "generate_image",
15
+ ]
@@ -0,0 +1,9 @@
1
+ from draive.generation.image.call import generate_image
2
+ from draive.generation.image.generator import ImageGenerator
3
+ from draive.generation.image.state import ImageGeneration
4
+
5
+ __all__ = [
6
+ "generate_image",
7
+ "ImageGenerator",
8
+ "ImageGeneration",
9
+ ]
@@ -0,0 +1,16 @@
1
+ from draive.generation.image.state import ImageGeneration
2
+ from draive.scope import ctx
3
+ from draive.types import ImageContent
4
+
5
+ __all__ = [
6
+ "generate_image",
7
+ ]
8
+
9
+
10
+ async def generate_image(
11
+ *,
12
+ instruction: str,
13
+ ) -> ImageContent:
14
+ return await ctx.state(ImageGeneration).generate(
15
+ instruction=instruction,
16
+ )
@@ -0,0 +1,17 @@
1
+ from typing import Protocol, runtime_checkable
2
+
3
+ from draive.types import ImageContent
4
+
5
+ __all__ = [
6
+ "ImageGenerator",
7
+ ]
8
+
9
+
10
+ @runtime_checkable
11
+ class ImageGenerator(Protocol):
12
+ async def __call__(
13
+ self,
14
+ *,
15
+ instruction: str,
16
+ ) -> ImageContent:
17
+ ...