ommlds 0.0.0.dev480__py3-none-any.whl → 0.0.0.dev503__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. ommlds/.omlish-manifests.json +100 -33
  2. ommlds/README.md +11 -0
  3. ommlds/__about__.py +9 -6
  4. ommlds/backends/anthropic/protocol/__init__.py +13 -1
  5. ommlds/backends/anthropic/protocol/_dataclasses.py +1625 -0
  6. ommlds/backends/anthropic/protocol/sse/events.py +2 -0
  7. ommlds/backends/cerebras/__init__.py +7 -0
  8. ommlds/backends/cerebras/_dataclasses.py +4254 -0
  9. ommlds/backends/cerebras/_marshal.py +24 -0
  10. ommlds/backends/cerebras/protocol.py +312 -0
  11. ommlds/backends/google/protocol/__init__.py +13 -0
  12. ommlds/backends/google/protocol/_dataclasses.py +5997 -0
  13. ommlds/backends/groq/__init__.py +7 -0
  14. ommlds/backends/groq/_dataclasses.py +3901 -0
  15. ommlds/backends/groq/clients.py +9 -0
  16. ommlds/backends/llamacpp/logging.py +4 -1
  17. ommlds/backends/mlx/caching.py +7 -3
  18. ommlds/backends/mlx/cli.py +10 -7
  19. ommlds/backends/mlx/generation.py +18 -16
  20. ommlds/backends/mlx/limits.py +10 -6
  21. ommlds/backends/mlx/loading.py +7 -4
  22. ommlds/backends/ollama/__init__.py +7 -0
  23. ommlds/backends/ollama/_dataclasses.py +3488 -0
  24. ommlds/backends/ollama/protocol.py +3 -0
  25. ommlds/backends/openai/protocol/__init__.py +15 -1
  26. ommlds/backends/openai/protocol/_dataclasses.py +7708 -0
  27. ommlds/backends/tavily/__init__.py +7 -0
  28. ommlds/backends/tavily/_dataclasses.py +1734 -0
  29. ommlds/backends/transformers/__init__.py +14 -0
  30. ommlds/cli/__init__.py +7 -0
  31. ommlds/cli/_dataclasses.py +3515 -0
  32. ommlds/cli/backends/catalog.py +0 -5
  33. ommlds/cli/backends/inject.py +70 -7
  34. ommlds/cli/backends/meta.py +82 -0
  35. ommlds/cli/content/messages.py +1 -1
  36. ommlds/cli/inject.py +11 -3
  37. ommlds/cli/main.py +137 -68
  38. ommlds/cli/rendering/types.py +6 -0
  39. ommlds/cli/secrets.py +2 -1
  40. ommlds/cli/sessions/base.py +1 -10
  41. ommlds/cli/sessions/chat/configs.py +9 -17
  42. ommlds/cli/sessions/chat/{chat → drivers}/ai/configs.py +3 -1
  43. ommlds/cli/sessions/chat/drivers/ai/events.py +57 -0
  44. ommlds/cli/sessions/chat/{chat → drivers}/ai/inject.py +10 -3
  45. ommlds/cli/sessions/chat/{chat → drivers}/ai/rendering.py +1 -1
  46. ommlds/cli/sessions/chat/{chat → drivers}/ai/services.py +1 -1
  47. ommlds/cli/sessions/chat/{chat → drivers}/ai/tools.py +4 -8
  48. ommlds/cli/sessions/chat/{chat → drivers}/ai/types.py +9 -0
  49. ommlds/cli/sessions/chat/drivers/configs.py +25 -0
  50. ommlds/cli/sessions/chat/drivers/events/inject.py +27 -0
  51. ommlds/cli/sessions/chat/drivers/events/injection.py +14 -0
  52. ommlds/cli/sessions/chat/drivers/events/manager.py +16 -0
  53. ommlds/cli/sessions/chat/drivers/events/types.py +38 -0
  54. ommlds/cli/sessions/chat/drivers/impl.py +50 -0
  55. ommlds/cli/sessions/chat/drivers/inject.py +70 -0
  56. ommlds/cli/sessions/chat/{chat → drivers}/state/configs.py +2 -0
  57. ommlds/cli/sessions/chat/drivers/state/ids.py +25 -0
  58. ommlds/cli/sessions/chat/drivers/state/inject.py +83 -0
  59. ommlds/cli/sessions/chat/{chat → drivers}/state/inmemory.py +0 -4
  60. ommlds/cli/sessions/chat/{chat → drivers}/state/storage.py +17 -10
  61. ommlds/cli/sessions/chat/{chat → drivers}/state/types.py +10 -5
  62. ommlds/cli/sessions/chat/{tools → drivers/tools}/configs.py +2 -2
  63. ommlds/cli/sessions/chat/drivers/tools/confirmation.py +44 -0
  64. ommlds/cli/sessions/chat/drivers/tools/errorhandling.py +39 -0
  65. ommlds/cli/sessions/chat/{tools → drivers/tools}/execution.py +3 -4
  66. ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/inject.py +3 -3
  67. ommlds/cli/sessions/chat/{tools → drivers/tools}/inject.py +7 -12
  68. ommlds/cli/sessions/chat/{tools → drivers/tools}/injection.py +5 -5
  69. ommlds/cli/sessions/chat/{tools → drivers/tools}/rendering.py +3 -3
  70. ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/inject.py +3 -3
  71. ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/tools.py +1 -1
  72. ommlds/cli/sessions/chat/drivers/types.py +31 -0
  73. ommlds/cli/sessions/chat/{chat → drivers}/user/configs.py +0 -3
  74. ommlds/cli/sessions/chat/drivers/user/inject.py +41 -0
  75. ommlds/cli/sessions/chat/facades/__init__.py +0 -0
  76. ommlds/cli/sessions/chat/facades/commands/__init__.py +0 -0
  77. ommlds/cli/sessions/chat/facades/commands/base.py +83 -0
  78. ommlds/cli/sessions/chat/facades/commands/configs.py +9 -0
  79. ommlds/cli/sessions/chat/facades/commands/inject.py +41 -0
  80. ommlds/cli/sessions/chat/facades/commands/injection.py +15 -0
  81. ommlds/cli/sessions/chat/facades/commands/manager.py +59 -0
  82. ommlds/cli/sessions/chat/facades/commands/simple.py +34 -0
  83. ommlds/cli/sessions/chat/facades/commands/types.py +13 -0
  84. ommlds/cli/sessions/chat/facades/configs.py +11 -0
  85. ommlds/cli/sessions/chat/facades/facade.py +26 -0
  86. ommlds/cli/sessions/chat/facades/inject.py +35 -0
  87. ommlds/cli/sessions/chat/facades/ui.py +34 -0
  88. ommlds/cli/sessions/chat/inject.py +8 -31
  89. ommlds/cli/sessions/chat/interfaces/__init__.py +0 -0
  90. ommlds/cli/sessions/chat/interfaces/bare/__init__.py +0 -0
  91. ommlds/cli/sessions/chat/interfaces/bare/configs.py +15 -0
  92. ommlds/cli/sessions/chat/interfaces/bare/inject.py +69 -0
  93. ommlds/cli/sessions/chat/interfaces/bare/interactive.py +49 -0
  94. ommlds/cli/sessions/chat/interfaces/bare/oneshot.py +21 -0
  95. ommlds/cli/sessions/chat/{tools/confirmation.py → interfaces/bare/tools.py} +3 -22
  96. ommlds/cli/sessions/chat/interfaces/base.py +13 -0
  97. ommlds/cli/sessions/chat/interfaces/configs.py +11 -0
  98. ommlds/cli/sessions/chat/interfaces/inject.py +29 -0
  99. ommlds/cli/sessions/chat/interfaces/textual/__init__.py +0 -0
  100. ommlds/cli/sessions/chat/interfaces/textual/app.py +310 -0
  101. ommlds/cli/sessions/chat/interfaces/textual/configs.py +11 -0
  102. ommlds/cli/sessions/chat/interfaces/textual/facades.py +19 -0
  103. ommlds/cli/sessions/chat/interfaces/textual/inject.py +97 -0
  104. ommlds/cli/sessions/chat/interfaces/textual/interface.py +24 -0
  105. ommlds/cli/sessions/chat/interfaces/textual/styles/__init__.py +29 -0
  106. ommlds/cli/sessions/chat/interfaces/textual/styles/input.tcss +53 -0
  107. ommlds/cli/sessions/chat/interfaces/textual/styles/markdown.tcss +7 -0
  108. ommlds/cli/sessions/chat/interfaces/textual/styles/messages.tcss +157 -0
  109. ommlds/cli/sessions/chat/interfaces/textual/tools.py +38 -0
  110. ommlds/cli/sessions/chat/interfaces/textual/widgets/__init__.py +0 -0
  111. ommlds/cli/sessions/chat/interfaces/textual/widgets/input.py +36 -0
  112. ommlds/cli/sessions/chat/interfaces/textual/widgets/messages.py +197 -0
  113. ommlds/cli/sessions/chat/session.py +8 -13
  114. ommlds/cli/sessions/completion/configs.py +3 -4
  115. ommlds/cli/sessions/completion/inject.py +1 -2
  116. ommlds/cli/sessions/completion/session.py +4 -8
  117. ommlds/cli/sessions/configs.py +10 -0
  118. ommlds/cli/sessions/embedding/configs.py +3 -4
  119. ommlds/cli/sessions/embedding/inject.py +1 -2
  120. ommlds/cli/sessions/embedding/session.py +4 -8
  121. ommlds/cli/sessions/inject.py +15 -15
  122. ommlds/cli/state/storage.py +7 -1
  123. ommlds/minichain/__init__.py +161 -38
  124. ommlds/minichain/_dataclasses.py +20452 -0
  125. ommlds/minichain/_typedvalues.py +11 -4
  126. ommlds/minichain/backends/impls/anthropic/names.py +3 -3
  127. ommlds/minichain/backends/impls/anthropic/protocol.py +2 -2
  128. ommlds/minichain/backends/impls/anthropic/stream.py +1 -1
  129. ommlds/minichain/backends/impls/cerebras/__init__.py +0 -0
  130. ommlds/minichain/backends/impls/cerebras/chat.py +80 -0
  131. ommlds/minichain/backends/impls/cerebras/names.py +45 -0
  132. ommlds/minichain/backends/impls/cerebras/protocol.py +143 -0
  133. ommlds/minichain/backends/impls/cerebras/stream.py +125 -0
  134. ommlds/minichain/backends/impls/duckduckgo/search.py +5 -1
  135. ommlds/minichain/backends/impls/google/names.py +6 -0
  136. ommlds/minichain/backends/impls/google/stream.py +1 -1
  137. ommlds/minichain/backends/impls/google/tools.py +2 -2
  138. ommlds/minichain/backends/impls/groq/chat.py +2 -0
  139. ommlds/minichain/backends/impls/groq/protocol.py +2 -2
  140. ommlds/minichain/backends/impls/groq/stream.py +3 -1
  141. ommlds/minichain/backends/impls/huggingface/repos.py +1 -5
  142. ommlds/minichain/backends/impls/llamacpp/chat.py +6 -3
  143. ommlds/minichain/backends/impls/llamacpp/completion.py +7 -3
  144. ommlds/minichain/backends/impls/llamacpp/stream.py +6 -3
  145. ommlds/minichain/backends/impls/mlx/chat.py +6 -3
  146. ommlds/minichain/backends/impls/ollama/chat.py +51 -57
  147. ommlds/minichain/backends/impls/ollama/protocol.py +144 -0
  148. ommlds/minichain/backends/impls/openai/format.py +4 -3
  149. ommlds/minichain/backends/impls/openai/names.py +3 -1
  150. ommlds/minichain/backends/impls/openai/stream.py +33 -1
  151. ommlds/minichain/backends/impls/sentencepiece/tokens.py +9 -6
  152. ommlds/minichain/backends/impls/tinygrad/chat.py +7 -4
  153. ommlds/minichain/backends/impls/tokenizers/tokens.py +9 -6
  154. ommlds/minichain/backends/impls/transformers/sentence.py +5 -2
  155. ommlds/minichain/backends/impls/transformers/tokens.py +9 -6
  156. ommlds/minichain/backends/impls/transformers/transformers.py +10 -8
  157. ommlds/minichain/backends/strings/resolving.py +1 -1
  158. ommlds/minichain/chat/content.py +42 -0
  159. ommlds/minichain/chat/messages.py +43 -39
  160. ommlds/minichain/chat/stream/joining.py +36 -12
  161. ommlds/minichain/chat/stream/types.py +1 -1
  162. ommlds/minichain/chat/templating.py +3 -3
  163. ommlds/minichain/content/__init__.py +19 -3
  164. ommlds/minichain/content/_marshal.py +181 -55
  165. ommlds/minichain/content/code.py +26 -0
  166. ommlds/minichain/content/composite.py +28 -0
  167. ommlds/minichain/content/content.py +27 -0
  168. ommlds/minichain/content/dynamic.py +12 -0
  169. ommlds/minichain/content/emphasis.py +27 -0
  170. ommlds/minichain/content/images.py +2 -2
  171. ommlds/minichain/content/json.py +2 -2
  172. ommlds/minichain/content/link.py +13 -0
  173. ommlds/minichain/content/markdown.py +12 -0
  174. ommlds/minichain/content/metadata.py +10 -0
  175. ommlds/minichain/content/namespaces.py +8 -0
  176. ommlds/minichain/content/placeholders.py +10 -9
  177. ommlds/minichain/content/quote.py +26 -0
  178. ommlds/minichain/content/raw.py +49 -0
  179. ommlds/minichain/content/recursive.py +12 -0
  180. ommlds/minichain/content/section.py +26 -0
  181. ommlds/minichain/content/sequence.py +17 -3
  182. ommlds/minichain/content/standard.py +32 -0
  183. ommlds/minichain/content/tag.py +28 -0
  184. ommlds/minichain/content/templates.py +13 -0
  185. ommlds/minichain/content/text.py +2 -2
  186. ommlds/minichain/content/transform/__init__.py +0 -0
  187. ommlds/minichain/content/transform/json.py +55 -0
  188. ommlds/minichain/content/transform/markdown.py +8 -0
  189. ommlds/minichain/content/transform/materialize.py +51 -0
  190. ommlds/minichain/content/transform/metadata.py +16 -0
  191. ommlds/minichain/content/{prepare.py → transform/prepare.py} +10 -15
  192. ommlds/minichain/content/transform/recursive.py +97 -0
  193. ommlds/minichain/content/transform/standard.py +43 -0
  194. ommlds/minichain/content/{transforms → transform}/stringify.py +1 -7
  195. ommlds/minichain/content/transform/strings.py +33 -0
  196. ommlds/minichain/content/transform/templates.py +25 -0
  197. ommlds/minichain/content/visitors.py +231 -0
  198. ommlds/minichain/lib/fs/tools/read.py +1 -1
  199. ommlds/minichain/lib/fs/tools/recursivels/rendering.py +1 -1
  200. ommlds/minichain/lib/fs/tools/recursivels/running.py +1 -1
  201. ommlds/minichain/lib/todo/tools/write.py +2 -1
  202. ommlds/minichain/lib/todo/types.py +1 -1
  203. ommlds/minichain/metadata.py +56 -2
  204. ommlds/minichain/resources.py +22 -1
  205. ommlds/minichain/services/README.md +154 -0
  206. ommlds/minichain/services/__init__.py +6 -2
  207. ommlds/minichain/services/_marshal.py +46 -10
  208. ommlds/minichain/services/_origclasses.py +11 -0
  209. ommlds/minichain/services/_typedvalues.py +8 -3
  210. ommlds/minichain/services/requests.py +73 -3
  211. ommlds/minichain/services/responses.py +73 -3
  212. ommlds/minichain/services/services.py +9 -0
  213. ommlds/minichain/stream/services.py +24 -1
  214. ommlds/minichain/text/applypatch.py +2 -1
  215. ommlds/minichain/text/toolparsing/llamacpp/types.py +1 -1
  216. ommlds/minichain/tokens/specials.py +1 -1
  217. ommlds/minichain/tools/execution/catalog.py +1 -1
  218. ommlds/minichain/tools/execution/errorhandling.py +36 -0
  219. ommlds/minichain/tools/execution/errors.py +2 -2
  220. ommlds/minichain/tools/execution/executors.py +1 -1
  221. ommlds/minichain/tools/fns.py +1 -1
  222. ommlds/minichain/tools/jsonschema.py +2 -2
  223. ommlds/minichain/tools/reflect.py +6 -6
  224. ommlds/minichain/tools/types.py +12 -15
  225. ommlds/minichain/vectors/_marshal.py +1 -1
  226. ommlds/minichain/vectors/embeddings.py +1 -1
  227. ommlds/minichain/wrappers/__init__.py +7 -0
  228. ommlds/minichain/wrappers/firstinwins.py +144 -0
  229. ommlds/minichain/wrappers/instrument.py +146 -0
  230. ommlds/minichain/wrappers/retry.py +168 -0
  231. ommlds/minichain/wrappers/services.py +98 -0
  232. ommlds/minichain/wrappers/stream.py +57 -0
  233. ommlds/nanochat/rustbpe/README.md +9 -0
  234. ommlds/nanochat/tokenizers.py +40 -6
  235. ommlds/specs/mcp/clients.py +146 -0
  236. ommlds/specs/mcp/protocol.py +123 -18
  237. ommlds/tools/git.py +82 -65
  238. {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/METADATA +13 -11
  239. ommlds-0.0.0.dev503.dist-info/RECORD +520 -0
  240. ommlds/cli/sessions/chat/chat/state/inject.py +0 -36
  241. ommlds/cli/sessions/chat/chat/user/inject.py +0 -62
  242. ommlds/cli/sessions/chat/chat/user/interactive.py +0 -31
  243. ommlds/cli/sessions/chat/chat/user/oneshot.py +0 -25
  244. ommlds/cli/sessions/chat/chat/user/types.py +0 -15
  245. ommlds/cli/sessions/chat/driver.py +0 -43
  246. ommlds/minichain/content/materialize.py +0 -196
  247. ommlds/minichain/content/simple.py +0 -47
  248. ommlds/minichain/content/transforms/base.py +0 -46
  249. ommlds/minichain/content/transforms/interleave.py +0 -70
  250. ommlds/minichain/content/transforms/squeeze.py +0 -72
  251. ommlds/minichain/content/transforms/strings.py +0 -24
  252. ommlds/minichain/content/types.py +0 -43
  253. ommlds/minichain/stream/wrap.py +0 -62
  254. ommlds-0.0.0.dev480.dist-info/RECORD +0 -427
  255. /ommlds/cli/sessions/chat/{chat → drivers}/__init__.py +0 -0
  256. /ommlds/cli/sessions/chat/{chat → drivers}/ai/__init__.py +0 -0
  257. /ommlds/cli/sessions/chat/{chat → drivers}/ai/injection.py +0 -0
  258. /ommlds/cli/sessions/chat/{chat/state → drivers/events}/__init__.py +0 -0
  259. /ommlds/cli/sessions/chat/{chat/user → drivers/phases}/__init__.py +0 -0
  260. /ommlds/cli/sessions/chat/{phases → drivers/phases}/inject.py +0 -0
  261. /ommlds/cli/sessions/chat/{phases → drivers/phases}/injection.py +0 -0
  262. /ommlds/cli/sessions/chat/{phases → drivers/phases}/manager.py +0 -0
  263. /ommlds/cli/sessions/chat/{phases → drivers/phases}/types.py +0 -0
  264. /ommlds/cli/sessions/chat/{phases → drivers/state}/__init__.py +0 -0
  265. /ommlds/cli/sessions/chat/{tools → drivers/tools}/__init__.py +0 -0
  266. /ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/__init__.py +0 -0
  267. /ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/configs.py +0 -0
  268. /ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/__init__.py +0 -0
  269. /ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/configs.py +0 -0
  270. /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/__init__.py +0 -0
  271. /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/configs.py +0 -0
  272. /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/inject.py +0 -0
  273. /ommlds/{minichain/content/transforms → cli/sessions/chat/drivers/user}/__init__.py +0 -0
  274. {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/WHEEL +0 -0
  275. {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/entry_points.txt +0 -0
  276. {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/licenses/LICENSE +0 -0
  277. {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,9 @@
1
+ import typing as ta
2
+
3
+
4
+ ##
5
+
6
+
7
+ REQUIRED_HTTP_HEADERS: ta.Mapping[bytes, bytes] = {
8
+ b'User-Agent': b'python-httpx/0.28.1', # required or it 403's lol
9
+ }
@@ -1,4 +1,7 @@
1
1
  """
2
+ NOTE: This can't be cleaned up too much - the callback can't be a closure to hide its guts because it needs to be
3
+ picklable for multiprocessing.
4
+
2
5
  FIXME:
3
6
  - it outputs newline-terminated so buffer and chop on newlines - DelimitingBuffer again
4
7
  """
@@ -27,4 +30,4 @@ def llama_log_callback(
27
30
 
28
31
  @lang.cached_function
29
32
  def install_logging_hook() -> None:
30
- llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0))
33
+ llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0)) # noqa
@@ -17,7 +17,11 @@
17
17
  # https://github.com/ml-explore/mlx-lm/blob/ce2358d297af245b002e690623f00195b6507da0/mlx_lm/generate.py
18
18
  import typing as ta
19
19
 
20
- import mlx_lm.models.cache
20
+ from omlish import lang
21
+
22
+
23
+ with lang.auto_proxy_import(globals()):
24
+ import mlx_lm.models.cache as mlx_lm_models_cache
21
25
 
22
26
 
23
27
  ##
@@ -32,13 +36,13 @@ def maybe_quantize_kv_cache(
32
36
  ) -> None:
33
37
  if not (
34
38
  kv_bits is not None and
35
- not isinstance(prompt_cache[0], mlx_lm.models.cache.QuantizedKVCache) and
39
+ not isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache) and
36
40
  prompt_cache[0].offset > quantized_kv_start
37
41
  ):
38
42
  return
39
43
 
40
44
  for i in range(len(prompt_cache)):
41
- if isinstance(prompt_cache[i], mlx_lm.models.cache.KVCache):
45
+ if isinstance(prompt_cache[i], mlx_lm_models_cache.KVCache):
42
46
  prompt_cache[i] = prompt_cache[i].to_quantized(
43
47
  bits=kv_bits,
44
48
  group_size=kv_group_size,
@@ -20,16 +20,19 @@ import json
20
20
  import sys
21
21
  import typing as ta
22
22
 
23
- import mlx.core as mx
24
- import mlx_lm.models.cache
25
- import mlx_lm.sample_utils
26
- import mlx_lm.utils
23
+ from omlish import lang
27
24
 
28
25
  from .generation import GenerationParams
29
26
  from .generation import generate
30
27
  from .loading import load_model
31
28
 
32
29
 
30
+ with lang.auto_proxy_import(globals()):
31
+ import mlx.core as mx
32
+ import mlx_lm.models.cache as mlx_lm_models_cache
33
+ import mlx_lm.sample_utils as mlx_lm_sample_utils
34
+
35
+
33
36
  ##
34
37
 
35
38
 
@@ -214,11 +217,11 @@ def _main() -> None:
214
217
  # Load the prompt cache and metadata if a cache file is provided
215
218
  using_cache = args.prompt_cache_file is not None
216
219
  if using_cache:
217
- prompt_cache, metadata = mlx_lm.models.cache.load_prompt_cache(
220
+ prompt_cache, metadata = mlx_lm_models_cache.load_prompt_cache(
218
221
  args.prompt_cache_file,
219
222
  return_metadata=True,
220
223
  )
221
- if isinstance(prompt_cache[0], mlx_lm.models.cache.QuantizedKVCache):
224
+ if isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache):
222
225
  if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
223
226
  raise ValueError('--kv-bits does not match the kv cache loaded from --prompt-cache-file.')
224
227
  if args.kv_group_size != prompt_cache[0].group_size:
@@ -293,7 +296,7 @@ def _main() -> None:
293
296
  else:
294
297
  prompt = tokenizer.encode(prompt)
295
298
 
296
- sampler = mlx_lm.sample_utils.make_sampler(
299
+ sampler = mlx_lm_sample_utils.make_sampler(
297
300
  args.temp,
298
301
  args.top_p,
299
302
  args.min_p,
@@ -21,10 +21,6 @@ import io
21
21
  import sys
22
22
  import typing as ta
23
23
 
24
- import mlx.core as mx
25
- import mlx_lm.models.cache
26
- from mlx import nn
27
-
28
24
  from omlish import check
29
25
  from omlish import lang
30
26
 
@@ -33,6 +29,12 @@ from .limits import wired_limit_context
33
29
  from .tokenization import Tokenization
34
30
 
35
31
 
32
+ with lang.auto_proxy_import(globals()):
33
+ import mlx.core as mx
34
+ import mlx.nn as mlx_nn
35
+ import mlx_lm.models.cache as mlx_lm_models_cache
36
+
37
+
36
38
  ##
37
39
 
38
40
 
@@ -47,9 +49,9 @@ def _generation_stream():
47
49
  class LogitProcessor(ta.Protocol):
48
50
  def __call__(
49
51
  self,
50
- tokens: mx.array,
51
- logits: mx.array,
52
- ) -> mx.array:
52
+ tokens: 'mx.array',
53
+ logits: 'mx.array',
54
+ ) -> 'mx.array':
53
55
  ...
54
56
 
55
57
 
@@ -99,12 +101,12 @@ class GenerationParams:
99
101
 
100
102
  class _GenerationStep(ta.NamedTuple):
101
103
  token: int
102
- logprobs: mx.array
104
+ logprobs: 'mx.array'
103
105
 
104
106
 
105
107
  def _generate_step(
106
- prompt: mx.array,
107
- model: nn.Module,
108
+ prompt: 'mx.array',
109
+ model: 'mlx_nn.Module',
108
110
  params: GenerationParams = GenerationParams(),
109
111
  ) -> ta.Generator[_GenerationStep]:
110
112
  y = prompt
@@ -113,7 +115,7 @@ def _generate_step(
113
115
  # Create the Kv cache for generation
114
116
  prompt_cache = params.prompt_cache
115
117
  if prompt_cache is None:
116
- prompt_cache = mlx_lm.models.cache.make_prompt_cache(
118
+ prompt_cache = mlx_lm_models_cache.make_prompt_cache(
117
119
  model,
118
120
  max_kv_size=params.max_kv_size,
119
121
  )
@@ -221,7 +223,7 @@ class GenerationOutput:
221
223
  token: int
222
224
 
223
225
  # A vector of log probabilities.
224
- logprobs: mx.array
226
+ logprobs: 'mx.array'
225
227
 
226
228
  # The number of tokens in the prompt.
227
229
  prompt_tokens: int
@@ -234,9 +236,9 @@ class GenerationOutput:
234
236
 
235
237
 
236
238
  def stream_generate(
237
- model: nn.Module,
239
+ model: 'mlx_nn.Module',
238
240
  tokenization: Tokenization,
239
- prompt: str | mx.array,
241
+ prompt: ta.Union[str, 'mx.array'],
240
242
  params: GenerationParams = GenerationParams(),
241
243
  ) -> ta.Generator[GenerationOutput]:
242
244
  if not isinstance(prompt, mx.array):
@@ -308,9 +310,9 @@ def stream_generate(
308
310
 
309
311
 
310
312
  def generate(
311
- model: nn.Module,
313
+ model: 'mlx_nn.Module',
312
314
  tokenization: Tokenization,
313
- prompt: str | mx.array,
315
+ prompt: ta.Union[str, 'mx.array'],
314
316
  params: GenerationParams = GenerationParams(),
315
317
  *,
316
318
  verbose: bool = False,
@@ -19,9 +19,13 @@ import contextlib
19
19
  import sys
20
20
  import typing as ta
21
21
 
22
- import mlx.core as mx
23
- import mlx.utils
24
- from mlx import nn
22
+ from omlish import lang
23
+
24
+
25
+ with lang.auto_proxy_import(globals()):
26
+ import mlx.core as mx
27
+ import mlx.nn as mlx_nn
28
+ import mlx.utils as mlx_utils
25
29
 
26
30
 
27
31
  ##
@@ -29,8 +33,8 @@ from mlx import nn
29
33
 
30
34
  @contextlib.contextmanager
31
35
  def wired_limit_context(
32
- model: nn.Module,
33
- streams: ta.Iterable[mx.Stream] | None = None,
36
+ model: 'mlx_nn.Module',
37
+ streams: ta.Iterable['mx.Stream'] | None = None,
34
38
  ) -> ta.Generator[None]:
35
39
  """
36
40
  A context manager to temporarily change the wired limit.
@@ -43,7 +47,7 @@ def wired_limit_context(
43
47
  yield
44
48
  return
45
49
 
46
- model_bytes = mlx.utils.tree_reduce(
50
+ model_bytes = mlx_utils.tree_reduce(
47
51
  lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc,
48
52
  model,
49
53
  0,
@@ -1,10 +1,8 @@
1
+ # ruff: noqa: TC002
1
2
  import dataclasses as dc
2
3
  import pathlib
3
4
  import typing as ta
4
5
 
5
- import mlx_lm.utils
6
- from mlx import nn
7
-
8
6
  from omlish import check
9
7
  from omlish import lang
10
8
 
@@ -12,6 +10,11 @@ from .tokenization import Tokenization
12
10
  from .tokenization import load_tokenization
13
11
 
14
12
 
13
+ with lang.auto_proxy_import(globals()):
14
+ import mlx.nn as mlx_nn
15
+ import mlx_lm.utils
16
+
17
+
15
18
  ##
16
19
 
17
20
 
@@ -76,7 +79,7 @@ def get_model_path(
76
79
  class LoadedModel:
77
80
  path: pathlib.Path
78
81
 
79
- model: nn.Module
82
+ model: 'mlx_nn.Module'
80
83
  config: dict
81
84
 
82
85
  #
@@ -0,0 +1,7 @@
1
+ from omlish import dataclasses as _dc # noqa
2
+
3
+
4
+ _dc.init_package(
5
+ globals(),
6
+ codegen=True,
7
+ )