ommlds 0.0.0.dev456__py3-none-any.whl → 0.0.0.dev485__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. ommlds/.omlish-manifests.json +314 -33
  2. ommlds/__about__.py +15 -9
  3. ommlds/_hacks/__init__.py +4 -0
  4. ommlds/_hacks/funcs.py +110 -0
  5. ommlds/_hacks/names.py +158 -0
  6. ommlds/_hacks/params.py +73 -0
  7. ommlds/_hacks/patches.py +0 -3
  8. ommlds/backends/anthropic/protocol/__init__.py +13 -1
  9. ommlds/backends/anthropic/protocol/_dataclasses.py +1625 -0
  10. ommlds/backends/anthropic/protocol/sse/assemble.py +22 -6
  11. ommlds/backends/anthropic/protocol/sse/events.py +13 -0
  12. ommlds/backends/google/protocol/__init__.py +13 -0
  13. ommlds/backends/google/protocol/_dataclasses.py +5997 -0
  14. ommlds/backends/google/protocol/types.py +5 -1
  15. ommlds/backends/groq/__init__.py +7 -0
  16. ommlds/backends/groq/_dataclasses.py +3901 -0
  17. ommlds/backends/groq/_marshal.py +23 -0
  18. ommlds/backends/groq/protocol.py +249 -0
  19. ommlds/backends/llamacpp/logging.py +4 -1
  20. ommlds/backends/mlx/caching.py +7 -3
  21. ommlds/backends/mlx/cli.py +10 -7
  22. ommlds/backends/mlx/generation.py +18 -16
  23. ommlds/backends/mlx/limits.py +10 -6
  24. ommlds/backends/mlx/loading.py +65 -5
  25. ommlds/backends/ollama/__init__.py +7 -0
  26. ommlds/backends/ollama/_dataclasses.py +3458 -0
  27. ommlds/backends/ollama/protocol.py +170 -0
  28. ommlds/backends/openai/protocol/__init__.py +15 -1
  29. ommlds/backends/openai/protocol/_dataclasses.py +7708 -0
  30. ommlds/backends/tavily/__init__.py +7 -0
  31. ommlds/backends/tavily/_dataclasses.py +1734 -0
  32. ommlds/backends/tavily/protocol.py +301 -0
  33. ommlds/backends/tinygrad/models/llama3/__init__.py +22 -14
  34. ommlds/backends/transformers/__init__.py +14 -0
  35. ommlds/backends/transformers/filecache.py +109 -0
  36. ommlds/backends/transformers/streamers.py +73 -0
  37. ommlds/cli/__init__.py +7 -0
  38. ommlds/cli/_dataclasses.py +2562 -0
  39. ommlds/cli/asyncs.py +30 -0
  40. ommlds/cli/backends/catalog.py +93 -0
  41. ommlds/cli/backends/configs.py +9 -0
  42. ommlds/cli/backends/inject.py +31 -36
  43. ommlds/cli/backends/injection.py +16 -0
  44. ommlds/cli/backends/types.py +46 -0
  45. ommlds/cli/content/messages.py +34 -0
  46. ommlds/cli/content/strings.py +42 -0
  47. ommlds/cli/inject.py +15 -32
  48. ommlds/cli/inputs/__init__.py +0 -0
  49. ommlds/cli/inputs/asyncs.py +32 -0
  50. ommlds/cli/inputs/sync.py +75 -0
  51. ommlds/cli/main.py +267 -128
  52. ommlds/cli/rendering/__init__.py +0 -0
  53. ommlds/cli/rendering/configs.py +9 -0
  54. ommlds/cli/rendering/inject.py +31 -0
  55. ommlds/cli/rendering/markdown.py +52 -0
  56. ommlds/cli/rendering/raw.py +73 -0
  57. ommlds/cli/rendering/types.py +21 -0
  58. ommlds/cli/secrets.py +21 -0
  59. ommlds/cli/sessions/base.py +1 -1
  60. ommlds/cli/sessions/chat/chat/__init__.py +0 -0
  61. ommlds/cli/sessions/chat/chat/ai/__init__.py +0 -0
  62. ommlds/cli/sessions/chat/chat/ai/configs.py +11 -0
  63. ommlds/cli/sessions/chat/chat/ai/inject.py +74 -0
  64. ommlds/cli/sessions/chat/chat/ai/injection.py +14 -0
  65. ommlds/cli/sessions/chat/chat/ai/rendering.py +70 -0
  66. ommlds/cli/sessions/chat/chat/ai/services.py +79 -0
  67. ommlds/cli/sessions/chat/chat/ai/tools.py +44 -0
  68. ommlds/cli/sessions/chat/chat/ai/types.py +28 -0
  69. ommlds/cli/sessions/chat/chat/state/__init__.py +0 -0
  70. ommlds/cli/sessions/chat/chat/state/configs.py +11 -0
  71. ommlds/cli/sessions/chat/chat/state/inject.py +36 -0
  72. ommlds/cli/sessions/chat/chat/state/inmemory.py +33 -0
  73. ommlds/cli/sessions/chat/chat/state/storage.py +52 -0
  74. ommlds/cli/sessions/chat/chat/state/types.py +38 -0
  75. ommlds/cli/sessions/chat/chat/user/__init__.py +0 -0
  76. ommlds/cli/sessions/chat/chat/user/configs.py +17 -0
  77. ommlds/cli/sessions/chat/chat/user/inject.py +62 -0
  78. ommlds/cli/sessions/chat/chat/user/interactive.py +31 -0
  79. ommlds/cli/sessions/chat/chat/user/oneshot.py +25 -0
  80. ommlds/cli/sessions/chat/chat/user/types.py +15 -0
  81. ommlds/cli/sessions/chat/configs.py +27 -0
  82. ommlds/cli/sessions/chat/driver.py +43 -0
  83. ommlds/cli/sessions/chat/inject.py +33 -65
  84. ommlds/cli/sessions/chat/phases/__init__.py +0 -0
  85. ommlds/cli/sessions/chat/phases/inject.py +27 -0
  86. ommlds/cli/sessions/chat/phases/injection.py +14 -0
  87. ommlds/cli/sessions/chat/phases/manager.py +29 -0
  88. ommlds/cli/sessions/chat/phases/types.py +29 -0
  89. ommlds/cli/sessions/chat/session.py +27 -0
  90. ommlds/cli/sessions/chat/tools/__init__.py +0 -0
  91. ommlds/cli/sessions/chat/tools/configs.py +22 -0
  92. ommlds/cli/sessions/chat/tools/confirmation.py +46 -0
  93. ommlds/cli/sessions/chat/tools/execution.py +66 -0
  94. ommlds/cli/sessions/chat/tools/fs/__init__.py +0 -0
  95. ommlds/cli/sessions/chat/tools/fs/configs.py +12 -0
  96. ommlds/cli/sessions/chat/tools/fs/inject.py +35 -0
  97. ommlds/cli/sessions/chat/tools/inject.py +88 -0
  98. ommlds/cli/sessions/chat/tools/injection.py +44 -0
  99. ommlds/cli/sessions/chat/tools/rendering.py +58 -0
  100. ommlds/cli/sessions/chat/tools/todo/__init__.py +0 -0
  101. ommlds/cli/sessions/chat/tools/todo/configs.py +12 -0
  102. ommlds/cli/sessions/chat/tools/todo/inject.py +31 -0
  103. ommlds/cli/sessions/chat/tools/weather/__init__.py +0 -0
  104. ommlds/cli/sessions/chat/tools/weather/configs.py +12 -0
  105. ommlds/cli/sessions/chat/tools/weather/inject.py +22 -0
  106. ommlds/cli/{tools/weather.py → sessions/chat/tools/weather/tools.py} +1 -1
  107. ommlds/cli/sessions/completion/configs.py +21 -0
  108. ommlds/cli/sessions/completion/inject.py +42 -0
  109. ommlds/cli/sessions/completion/session.py +35 -0
  110. ommlds/cli/sessions/embedding/configs.py +21 -0
  111. ommlds/cli/sessions/embedding/inject.py +42 -0
  112. ommlds/cli/sessions/embedding/session.py +33 -0
  113. ommlds/cli/sessions/inject.py +28 -11
  114. ommlds/cli/state/__init__.py +0 -0
  115. ommlds/cli/state/inject.py +28 -0
  116. ommlds/cli/{state.py → state/storage.py} +41 -24
  117. ommlds/minichain/__init__.py +46 -17
  118. ommlds/minichain/_dataclasses.py +15401 -0
  119. ommlds/minichain/backends/catalogs/base.py +20 -1
  120. ommlds/minichain/backends/catalogs/simple.py +2 -2
  121. ommlds/minichain/backends/catalogs/strings.py +10 -8
  122. ommlds/minichain/backends/impls/anthropic/chat.py +31 -65
  123. ommlds/minichain/backends/impls/anthropic/names.py +3 -4
  124. ommlds/minichain/backends/impls/anthropic/protocol.py +109 -0
  125. ommlds/minichain/backends/impls/anthropic/stream.py +53 -31
  126. ommlds/minichain/backends/impls/duckduckgo/search.py +5 -1
  127. ommlds/minichain/backends/impls/dummy/__init__.py +0 -0
  128. ommlds/minichain/backends/impls/dummy/chat.py +69 -0
  129. ommlds/minichain/backends/impls/google/chat.py +9 -2
  130. ommlds/minichain/backends/impls/google/search.py +6 -1
  131. ommlds/minichain/backends/impls/google/stream.py +122 -32
  132. ommlds/minichain/backends/impls/groq/__init__.py +0 -0
  133. ommlds/minichain/backends/impls/groq/chat.py +75 -0
  134. ommlds/minichain/backends/impls/groq/names.py +48 -0
  135. ommlds/minichain/backends/impls/groq/protocol.py +143 -0
  136. ommlds/minichain/backends/impls/groq/stream.py +125 -0
  137. ommlds/minichain/backends/impls/huggingface/repos.py +1 -5
  138. ommlds/minichain/backends/impls/llamacpp/chat.py +15 -3
  139. ommlds/minichain/backends/impls/llamacpp/completion.py +7 -3
  140. ommlds/minichain/backends/impls/llamacpp/stream.py +38 -19
  141. ommlds/minichain/backends/impls/mistral.py +9 -2
  142. ommlds/minichain/backends/impls/mlx/chat.py +100 -23
  143. ommlds/minichain/backends/impls/ollama/__init__.py +0 -0
  144. ommlds/minichain/backends/impls/ollama/chat.py +199 -0
  145. ommlds/minichain/backends/impls/openai/chat.py +14 -7
  146. ommlds/minichain/backends/impls/openai/completion.py +9 -2
  147. ommlds/minichain/backends/impls/openai/embedding.py +9 -2
  148. ommlds/minichain/backends/impls/openai/format.py +115 -109
  149. ommlds/minichain/backends/impls/openai/names.py +31 -5
  150. ommlds/minichain/backends/impls/openai/stream.py +33 -27
  151. ommlds/minichain/backends/impls/sentencepiece/tokens.py +9 -6
  152. ommlds/minichain/backends/impls/tavily.py +66 -0
  153. ommlds/minichain/backends/impls/tinygrad/chat.py +17 -14
  154. ommlds/minichain/backends/impls/tokenizers/tokens.py +9 -6
  155. ommlds/minichain/backends/impls/transformers/sentence.py +5 -2
  156. ommlds/minichain/backends/impls/transformers/tokens.py +10 -7
  157. ommlds/minichain/backends/impls/transformers/transformers.py +139 -20
  158. ommlds/minichain/backends/strings/parsing.py +1 -1
  159. ommlds/minichain/backends/strings/resolving.py +4 -1
  160. ommlds/minichain/chat/choices/stream/__init__.py +0 -0
  161. ommlds/minichain/chat/choices/stream/adapters.py +35 -0
  162. ommlds/minichain/chat/choices/stream/joining.py +31 -0
  163. ommlds/minichain/chat/choices/stream/services.py +45 -0
  164. ommlds/minichain/chat/choices/stream/types.py +43 -0
  165. ommlds/minichain/chat/stream/_marshal.py +4 -4
  166. ommlds/minichain/chat/stream/joining.py +85 -0
  167. ommlds/minichain/chat/stream/services.py +15 -15
  168. ommlds/minichain/chat/stream/types.py +24 -18
  169. ommlds/minichain/llms/types.py +4 -0
  170. ommlds/minichain/registries/globals.py +18 -4
  171. ommlds/minichain/resources.py +28 -3
  172. ommlds/minichain/search.py +1 -1
  173. ommlds/minichain/standard.py +8 -0
  174. ommlds/minichain/stream/services.py +19 -16
  175. ommlds/minichain/tools/reflect.py +5 -1
  176. ommlds/nanochat/LICENSE +21 -0
  177. ommlds/nanochat/__init__.py +0 -0
  178. ommlds/nanochat/rustbpe/LICENSE +21 -0
  179. ommlds/nanochat/tokenizers.py +406 -0
  180. ommlds/specs/__init__.py +0 -0
  181. ommlds/specs/mcp/__init__.py +0 -0
  182. ommlds/specs/mcp/_marshal.py +23 -0
  183. ommlds/specs/mcp/clients.py +146 -0
  184. ommlds/specs/mcp/protocol.py +371 -0
  185. ommlds/tools/git.py +13 -6
  186. ommlds/tools/ocr.py +1 -8
  187. ommlds/wiki/analyze.py +2 -2
  188. ommlds/wiki/text/mfh.py +1 -5
  189. ommlds/wiki/text/wtp.py +1 -3
  190. ommlds/wiki/utils/xml.py +5 -5
  191. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/METADATA +22 -19
  192. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/RECORD +198 -95
  193. ommlds/cli/backends/standard.py +0 -20
  194. ommlds/cli/sessions/chat/base.py +0 -42
  195. ommlds/cli/sessions/chat/code.py +0 -129
  196. ommlds/cli/sessions/chat/interactive.py +0 -71
  197. ommlds/cli/sessions/chat/printing.py +0 -97
  198. ommlds/cli/sessions/chat/prompt.py +0 -151
  199. ommlds/cli/sessions/chat/state.py +0 -110
  200. ommlds/cli/sessions/chat/tools.py +0 -100
  201. ommlds/cli/sessions/completion/completion.py +0 -44
  202. ommlds/cli/sessions/embedding/embedding.py +0 -42
  203. ommlds/cli/tools/config.py +0 -14
  204. ommlds/cli/tools/inject.py +0 -75
  205. ommlds/minichain/backends/impls/openai/format2.py +0 -210
  206. ommlds/minichain/chat/stream/adapters.py +0 -80
  207. /ommlds/{huggingface.py → backends/huggingface.py} +0 -0
  208. /ommlds/cli/{tools → content}/__init__.py +0 -0
  209. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/WHEEL +0 -0
  210. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/entry_points.txt +0 -0
  211. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/licenses/LICENSE +0 -0
  212. {ommlds-0.0.0.dev456.dist-info → ommlds-0.0.0.dev485.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ from omlish import lang
2
+ from omlish import marshal as msh
3
+
4
+ from .protocol import ChatCompletionRequest
5
+
6
+
7
+ ##
8
+
9
+
10
+ @lang.static_init
11
+ def _install_standard_marshaling() -> None:
12
+ for root_cls, tag_field in [
13
+ (ChatCompletionRequest.Message, 'role'),
14
+ ]:
15
+ msh.install_standard_factories(*msh.standard_polymorphism_factories(
16
+ msh.polymorphism_from_subclasses(
17
+ root_cls,
18
+ naming=msh.Naming.SNAKE,
19
+ strip_suffix=msh.AutoStripSuffix,
20
+ ),
21
+ msh.FieldTypeTagging(tag_field),
22
+ unions='partial',
23
+ ))
@@ -0,0 +1,249 @@
1
+ """
2
+ https://console.groq.com/docs/api-reference#chat-create
3
+ """
4
+ import typing as ta
5
+
6
+ from omlish import dataclasses as dc
7
+ from omlish import lang
8
+ from omlish import marshal as msh
9
+
10
+
11
+ ##
12
+
13
+
14
+ def _set_class_marshal_options(cls):
15
+ msh.update_object_metadata(
16
+ cls,
17
+ field_defaults=msh.FieldMetadata(
18
+ options=msh.FieldOptions(
19
+ omit_if=lang.is_none,
20
+ ),
21
+ ),
22
+ )
23
+
24
+ return cls
25
+
26
+
27
+ ##
28
+
29
+
30
+ @dc.dataclass(frozen=True, kw_only=True)
31
+ @_set_class_marshal_options
32
+ class ChatCompletionRequest(lang.Final):
33
+ @dc.dataclass(frozen=True, kw_only=True)
34
+ class Message(lang.Sealed, lang.Abstract):
35
+ pass
36
+
37
+ @dc.dataclass(frozen=True, kw_only=True)
38
+ @_set_class_marshal_options
39
+ class SystemMessage(Message, lang.Final):
40
+ content: str | ta.Sequence[str]
41
+ name: str | None = None
42
+ role: ta.Literal['system'] = 'system'
43
+
44
+ @dc.dataclass(frozen=True, kw_only=True)
45
+ @_set_class_marshal_options
46
+ class UserMessage(Message, lang.Final):
47
+ content: str | ta.Sequence[str]
48
+ name: str | None = None
49
+ role: ta.Literal['user'] = 'user'
50
+
51
+ @dc.dataclass(frozen=True, kw_only=True)
52
+ @_set_class_marshal_options
53
+ class AssistantMessage(Message, lang.Final):
54
+ content: str | ta.Sequence[str] | None = None
55
+ name: str | None = None
56
+ reasoning: str | None = None
57
+ role: ta.Literal['assistant'] = 'assistant'
58
+
59
+ @dc.dataclass(frozen=True, kw_only=True)
60
+ @_set_class_marshal_options
61
+ class ToolCall(lang.Final):
62
+ @dc.dataclass(frozen=True, kw_only=True)
63
+ @_set_class_marshal_options
64
+ class Function(lang.Final):
65
+ arguments: str
66
+ name: str
67
+
68
+ function: Function
69
+ id: str
70
+ type: ta.Literal['function'] = 'function'
71
+
72
+ tool_calls: ta.Sequence[ToolCall] | None = None
73
+
74
+ @dc.dataclass(frozen=True, kw_only=True)
75
+ @_set_class_marshal_options
76
+ class ToolMessage(Message, lang.Final):
77
+ content: str | ta.Sequence[str]
78
+ role: ta.Literal['tool'] = 'tool'
79
+ tool_call_id: str
80
+
81
+ messages: ta.Sequence[Message]
82
+ model: str
83
+ citation_options: ta.Literal['enabled', 'disabled'] | None = None
84
+ compound_custom: ta.Mapping[str, ta.Any] | None = None
85
+ disable_tool_validation: bool | None = None
86
+ documents: ta.Sequence[ta.Mapping[str, ta.Any]] | None = None
87
+ frequency_penalty: float | None = None
88
+ include_reasoning: bool | None = None
89
+ logit_bias: ta.Mapping[str, ta.Any] | None = None
90
+ logprobs: bool | None = None
91
+ max_completion_tokens: int | None = None
92
+ n: int | None = None
93
+ parallel_tool_calls: bool | None = None
94
+ presence_penalty: float | None = None
95
+ reasoning_effort: ta.Literal['none', 'default', 'low', 'medium', 'high'] | None = None
96
+ reasoning_format: ta.Literal['hidden', 'raw', 'parsed'] | None = None
97
+ response_format: ta.Any | None = None
98
+ search_settings: ta.Mapping[str, ta.Any] | None = None
99
+ seed: int | None = None
100
+ service_tier: ta.Literal['auto', 'on_demand', 'flex', 'performance', 'null'] | None = None
101
+ stop: str | ta.Sequence[str] | None = None
102
+ store: bool | None = None
103
+ stream: bool | None = None
104
+ stream_options: ta.Mapping[str, ta.Any] | None = None
105
+ temperature: float | None = None
106
+ ool_choice: str | None = None
107
+
108
+ @dc.dataclass(frozen=True, kw_only=True)
109
+ @_set_class_marshal_options
110
+ class Tool(lang.Final):
111
+ @dc.dataclass(frozen=True, kw_only=True)
112
+ @_set_class_marshal_options
113
+ class Function(lang.Final):
114
+ description: str | None = None
115
+ name: str
116
+ parameters: ta.Mapping[str, ta.Any] | None = None # json schema
117
+ strict: bool | None = None
118
+
119
+ function: Function
120
+ type: ta.Literal['function', 'browser_search', 'code_interpreter'] = 'function'
121
+
122
+ tools: ta.Sequence[Tool] | None = None
123
+
124
+ top_logprobs: int | None = None
125
+ top_p: float | None = None
126
+ user: str | None = None
127
+
128
+
129
+ @dc.dataclass(frozen=True, kw_only=True)
130
+ @_set_class_marshal_options
131
+ class ExecutedTool(lang.Final):
132
+ arguments: str
133
+ index: int
134
+ type: str
135
+ browser_results: ta.Sequence[ta.Any] | None = None
136
+ code_results: ta.Sequence[ta.Any] | None = None
137
+ output: str | None = None
138
+ search_results: ta.Any | None = None
139
+
140
+
141
+ @dc.dataclass(frozen=True, kw_only=True)
142
+ @_set_class_marshal_options
143
+ class ChatCompletionResponse(lang.Final):
144
+ @dc.dataclass(frozen=True, kw_only=True)
145
+ @_set_class_marshal_options
146
+ class Choice(lang.Final):
147
+ finish_reason: ta.Literal['stop', 'length', 'tool_calls', 'function_call']
148
+ index: int
149
+ logprobs: ta.Mapping[str, ta.Any] | None = None
150
+
151
+ @dc.dataclass(frozen=True, kw_only=True)
152
+ @_set_class_marshal_options
153
+ class Message(lang.Final):
154
+ annotations: ta.Sequence[ta.Mapping[str, ta.Any]] | None = None
155
+ content: str | None = None
156
+
157
+ executed_tools: ta.Sequence[ExecutedTool] | None = None
158
+
159
+ reasoning: str | None = None
160
+ role: ta.Literal['assistant'] = 'assistant'
161
+
162
+ @dc.dataclass(frozen=True, kw_only=True)
163
+ @_set_class_marshal_options
164
+ class ToolCall(lang.Final):
165
+ id: str
166
+
167
+ @dc.dataclass(frozen=True, kw_only=True)
168
+ @_set_class_marshal_options
169
+ class Function(lang.Final):
170
+ arguments: str
171
+ name: str
172
+
173
+ function: Function
174
+ type: ta.Literal['function'] = 'function'
175
+
176
+ tool_calls: ta.Sequence[ToolCall] | None = None
177
+
178
+ message: Message
179
+
180
+ choices: ta.Sequence[Choice]
181
+ created: int
182
+ id: str
183
+ model: str
184
+ object: ta.Literal['chat.completion'] = 'chat.completion'
185
+ system_fingerprint: str
186
+ usage: ta.Mapping[str, ta.Any] | None = None
187
+ usage_breakdown: ta.Mapping[str, ta.Any] | None = None
188
+ x_groq: ta.Mapping[str, ta.Any] | None = None
189
+ service_tier: str | None = None
190
+
191
+
192
+ @dc.dataclass(frozen=True, kw_only=True)
193
+ @_set_class_marshal_options
194
+ class ChatCompletionChunk(lang.Final):
195
+ id: str
196
+ object: ta.Literal['chat.completion.chunk'] = 'chat.completion.chunk'
197
+ created: int
198
+ model: str
199
+ system_fingerprint: str
200
+
201
+ @dc.dataclass(frozen=True, kw_only=True)
202
+ @_set_class_marshal_options
203
+ class Choice(lang.Final):
204
+ index: int
205
+
206
+ @dc.dataclass(frozen=True, kw_only=True)
207
+ @_set_class_marshal_options
208
+ class Delta(lang.Final):
209
+ role: str | None = None
210
+ content: str | None = None
211
+
212
+ channel: str | None = None
213
+ reasoning: str | None = None
214
+
215
+ @dc.dataclass(frozen=True, kw_only=True)
216
+ @_set_class_marshal_options
217
+ class ToolCall(lang.Final):
218
+ index: int
219
+ id: str | None = None
220
+
221
+ @dc.dataclass(frozen=True, kw_only=True)
222
+ @_set_class_marshal_options
223
+ class Function(lang.Final):
224
+ arguments: str | None = None
225
+ name: str | None = None
226
+
227
+ function: Function | None = None
228
+
229
+ type: ta.Literal['function'] = 'function'
230
+
231
+ tool_calls: ta.Sequence[ToolCall] | None = None
232
+
233
+ executed_tools: ta.Sequence[ExecutedTool] | None = None
234
+
235
+ delta: Delta
236
+ logprobs: ta.Mapping[str, ta.Any] | None = None
237
+ finish_reason: ta.Literal['stop', 'length', 'tool_calls', 'function_call'] | None = None
238
+
239
+ choices: ta.Sequence[Choice]
240
+
241
+ x_groq: ta.Mapping[str, ta.Any] | None = None
242
+ service_tier: str | None = None
243
+ usage: ta.Mapping[str, ta.Any] | None = None
244
+
245
+
246
+ ##
247
+
248
+
249
+ msh.register_global_module_import('._marshal', __package__)
@@ -1,4 +1,7 @@
1
1
  """
2
+ NOTE: This can't be cleaned up too much - the callback can't be a closure to hide its guts because it needs to be
3
+ picklable for multiprocessing.
4
+
2
5
  FIXME:
3
6
  - it outputs newline-terminated so buffer and chop on newlines - DelimitingBuffer again
4
7
  """
@@ -27,4 +30,4 @@ def llama_log_callback(
27
30
 
28
31
  @lang.cached_function
29
32
  def install_logging_hook() -> None:
30
- llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0))
33
+ llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0)) # noqa
@@ -17,7 +17,11 @@
17
17
  # https://github.com/ml-explore/mlx-lm/blob/ce2358d297af245b002e690623f00195b6507da0/mlx_lm/generate.py
18
18
  import typing as ta
19
19
 
20
- import mlx_lm.models.cache
20
+ from omlish import lang
21
+
22
+
23
+ with lang.auto_proxy_import(globals()):
24
+ import mlx_lm.models.cache as mlx_lm_models_cache
21
25
 
22
26
 
23
27
  ##
@@ -32,13 +36,13 @@ def maybe_quantize_kv_cache(
32
36
  ) -> None:
33
37
  if not (
34
38
  kv_bits is not None and
35
- not isinstance(prompt_cache[0], mlx_lm.models.cache.QuantizedKVCache) and
39
+ not isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache) and
36
40
  prompt_cache[0].offset > quantized_kv_start
37
41
  ):
38
42
  return
39
43
 
40
44
  for i in range(len(prompt_cache)):
41
- if isinstance(prompt_cache[i], mlx_lm.models.cache.KVCache):
45
+ if isinstance(prompt_cache[i], mlx_lm_models_cache.KVCache):
42
46
  prompt_cache[i] = prompt_cache[i].to_quantized(
43
47
  bits=kv_bits,
44
48
  group_size=kv_group_size,
@@ -20,16 +20,19 @@ import json
20
20
  import sys
21
21
  import typing as ta
22
22
 
23
- import mlx.core as mx
24
- import mlx_lm.models.cache
25
- import mlx_lm.sample_utils
26
- import mlx_lm.utils
23
+ from omlish import lang
27
24
 
28
25
  from .generation import GenerationParams
29
26
  from .generation import generate
30
27
  from .loading import load_model
31
28
 
32
29
 
30
+ with lang.auto_proxy_import(globals()):
31
+ import mlx.core as mx
32
+ import mlx_lm.models.cache as mlx_lm_models_cache
33
+ import mlx_lm.sample_utils as mlx_lm_sample_utils
34
+
35
+
33
36
  ##
34
37
 
35
38
 
@@ -214,11 +217,11 @@ def _main() -> None:
214
217
  # Load the prompt cache and metadata if a cache file is provided
215
218
  using_cache = args.prompt_cache_file is not None
216
219
  if using_cache:
217
- prompt_cache, metadata = mlx_lm.models.cache.load_prompt_cache(
220
+ prompt_cache, metadata = mlx_lm_models_cache.load_prompt_cache(
218
221
  args.prompt_cache_file,
219
222
  return_metadata=True,
220
223
  )
221
- if isinstance(prompt_cache[0], mlx_lm.models.cache.QuantizedKVCache):
224
+ if isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache):
222
225
  if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
223
226
  raise ValueError('--kv-bits does not match the kv cache loaded from --prompt-cache-file.')
224
227
  if args.kv_group_size != prompt_cache[0].group_size:
@@ -293,7 +296,7 @@ def _main() -> None:
293
296
  else:
294
297
  prompt = tokenizer.encode(prompt)
295
298
 
296
- sampler = mlx_lm.sample_utils.make_sampler(
299
+ sampler = mlx_lm_sample_utils.make_sampler(
297
300
  args.temp,
298
301
  args.top_p,
299
302
  args.min_p,
@@ -21,10 +21,6 @@ import io
21
21
  import sys
22
22
  import typing as ta
23
23
 
24
- import mlx.core as mx
25
- import mlx_lm.models.cache
26
- from mlx import nn
27
-
28
24
  from omlish import check
29
25
  from omlish import lang
30
26
 
@@ -33,6 +29,12 @@ from .limits import wired_limit_context
33
29
  from .tokenization import Tokenization
34
30
 
35
31
 
32
+ with lang.auto_proxy_import(globals()):
33
+ import mlx.core as mx
34
+ import mlx.nn as mlx_nn
35
+ import mlx_lm.models.cache as mlx_lm_models_cache
36
+
37
+
36
38
  ##
37
39
 
38
40
 
@@ -47,9 +49,9 @@ def _generation_stream():
47
49
  class LogitProcessor(ta.Protocol):
48
50
  def __call__(
49
51
  self,
50
- tokens: mx.array,
51
- logits: mx.array,
52
- ) -> mx.array:
52
+ tokens: 'mx.array',
53
+ logits: 'mx.array',
54
+ ) -> 'mx.array':
53
55
  ...
54
56
 
55
57
 
@@ -99,12 +101,12 @@ class GenerationParams:
99
101
 
100
102
  class _GenerationStep(ta.NamedTuple):
101
103
  token: int
102
- logprobs: mx.array
104
+ logprobs: 'mx.array'
103
105
 
104
106
 
105
107
  def _generate_step(
106
- prompt: mx.array,
107
- model: nn.Module,
108
+ prompt: 'mx.array',
109
+ model: 'mlx_nn.Module',
108
110
  params: GenerationParams = GenerationParams(),
109
111
  ) -> ta.Generator[_GenerationStep]:
110
112
  y = prompt
@@ -113,7 +115,7 @@ def _generate_step(
113
115
  # Create the Kv cache for generation
114
116
  prompt_cache = params.prompt_cache
115
117
  if prompt_cache is None:
116
- prompt_cache = mlx_lm.models.cache.make_prompt_cache(
118
+ prompt_cache = mlx_lm_models_cache.make_prompt_cache(
117
119
  model,
118
120
  max_kv_size=params.max_kv_size,
119
121
  )
@@ -221,7 +223,7 @@ class GenerationOutput:
221
223
  token: int
222
224
 
223
225
  # A vector of log probabilities.
224
- logprobs: mx.array
226
+ logprobs: 'mx.array'
225
227
 
226
228
  # The number of tokens in the prompt.
227
229
  prompt_tokens: int
@@ -234,9 +236,9 @@ class GenerationOutput:
234
236
 
235
237
 
236
238
  def stream_generate(
237
- model: nn.Module,
239
+ model: 'mlx_nn.Module',
238
240
  tokenization: Tokenization,
239
- prompt: str | mx.array,
241
+ prompt: ta.Union[str, 'mx.array'],
240
242
  params: GenerationParams = GenerationParams(),
241
243
  ) -> ta.Generator[GenerationOutput]:
242
244
  if not isinstance(prompt, mx.array):
@@ -308,9 +310,9 @@ def stream_generate(
308
310
 
309
311
 
310
312
  def generate(
311
- model: nn.Module,
313
+ model: 'mlx_nn.Module',
312
314
  tokenization: Tokenization,
313
- prompt: str | mx.array,
315
+ prompt: ta.Union[str, 'mx.array'],
314
316
  params: GenerationParams = GenerationParams(),
315
317
  *,
316
318
  verbose: bool = False,
@@ -19,9 +19,13 @@ import contextlib
19
19
  import sys
20
20
  import typing as ta
21
21
 
22
- import mlx.core as mx
23
- import mlx.utils
24
- from mlx import nn
22
+ from omlish import lang
23
+
24
+
25
+ with lang.auto_proxy_import(globals()):
26
+ import mlx.core as mx
27
+ import mlx.nn as mlx_nn
28
+ import mlx.utils as mlx_utils
25
29
 
26
30
 
27
31
  ##
@@ -29,8 +33,8 @@ from mlx import nn
29
33
 
30
34
  @contextlib.contextmanager
31
35
  def wired_limit_context(
32
- model: nn.Module,
33
- streams: ta.Iterable[mx.Stream] | None = None,
36
+ model: 'mlx_nn.Module',
37
+ streams: ta.Iterable['mx.Stream'] | None = None,
34
38
  ) -> ta.Generator[None]:
35
39
  """
36
40
  A context manager to temporarily change the wired limit.
@@ -43,7 +47,7 @@ def wired_limit_context(
43
47
  yield
44
48
  return
45
49
 
46
- model_bytes = mlx.utils.tree_reduce(
50
+ model_bytes = mlx_utils.tree_reduce(
47
51
  lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc,
48
52
  model,
49
53
  0,
@@ -1,10 +1,8 @@
1
+ # ruff: noqa: TC002
1
2
  import dataclasses as dc
2
3
  import pathlib
3
4
  import typing as ta
4
5
 
5
- import mlx_lm.utils
6
- from mlx import nn
7
-
8
6
  from omlish import check
9
7
  from omlish import lang
10
8
 
@@ -12,6 +10,68 @@ from .tokenization import Tokenization
12
10
  from .tokenization import load_tokenization
13
11
 
14
12
 
13
+ with lang.auto_proxy_import(globals()):
14
+ import mlx.nn as mlx_nn
15
+ import mlx_lm.utils
16
+
17
+
18
+ ##
19
+
20
+
21
+ def get_model_path(
22
+ path_or_hf_repo: str,
23
+ revision: str | None = None,
24
+ ) -> tuple[pathlib.Path, str | None]:
25
+ """
26
+ Ensures the model is available locally. If the path does not exist locally,
27
+ it is downloaded from the Hugging Face Hub.
28
+
29
+ Args:
30
+ path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
31
+ revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
32
+
33
+ Returns:
34
+ Tuple[Path, str]: A tuple containing the local file path and the Hugging Face repo ID.
35
+ """
36
+
37
+ model_path = pathlib.Path(path_or_hf_repo)
38
+
39
+ if not model_path.exists():
40
+ from huggingface_hub import snapshot_download
41
+ hf_path = path_or_hf_repo
42
+ model_path = pathlib.Path(
43
+ snapshot_download(
44
+ path_or_hf_repo,
45
+ revision=revision,
46
+ allow_patterns=[
47
+ '*.jinja',
48
+ '*.json',
49
+ '*.jsonl',
50
+ '*.py',
51
+ '*.txt',
52
+
53
+ 'model*.safetensors',
54
+
55
+ '*.tiktoken',
56
+ 'tiktoken.model',
57
+ 'tokenizer.model',
58
+ ],
59
+ ),
60
+ )
61
+
62
+ else:
63
+ from huggingface_hub import ModelCard
64
+
65
+ card_path = model_path / 'README.md'
66
+ if card_path.is_file():
67
+ card = ModelCard.load(card_path)
68
+ hf_path = card.data.base_model
69
+ else:
70
+ hf_path = None
71
+
72
+ return model_path, hf_path
73
+
74
+
15
75
  ##
16
76
 
17
77
 
@@ -19,7 +79,7 @@ from .tokenization import load_tokenization
19
79
  class LoadedModel:
20
80
  path: pathlib.Path
21
81
 
22
- model: nn.Module
82
+ model: 'mlx_nn.Module'
23
83
  config: dict
24
84
 
25
85
  #
@@ -46,7 +106,7 @@ def load_model(
46
106
  ) -> LoadedModel:
47
107
  # FIXME: get_model_path return annotation is wrong:
48
108
  # https://github.com/ml-explore/mlx-lm/blob/9ee2b7358f5e258af7b31a8561acfbbe56ad5085/mlx_lm/utils.py#L82
49
- model_path_res = ta.cast(ta.Any, mlx_lm.utils.get_model_path(path_or_hf_repo))
109
+ model_path_res = ta.cast(ta.Any, get_model_path(path_or_hf_repo))
50
110
  if isinstance(model_path_res, tuple):
51
111
  model_path = check.isinstance(model_path_res[0], pathlib.Path)
52
112
  else:
@@ -0,0 +1,7 @@
1
+ from omlish import dataclasses as _dc # noqa
2
+
3
+
4
+ _dc.init_package(
5
+ globals(),
6
+ codegen=True,
7
+ )