xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (54) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/constants.py +1 -0
  5. xinference/core/model.py +10 -3
  6. xinference/core/supervisor.py +8 -2
  7. xinference/core/utils.py +67 -2
  8. xinference/model/audio/model_spec.json +1 -1
  9. xinference/model/image/stable_diffusion/core.py +5 -2
  10. xinference/model/llm/llm_family.json +176 -4
  11. xinference/model/llm/llm_family_modelscope.json +211 -0
  12. xinference/model/llm/mlx/core.py +45 -2
  13. xinference/model/rerank/core.py +11 -4
  14. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  15. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  16. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  17. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  18. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  19. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  20. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  21. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  22. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  23. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  24. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  25. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  26. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  27. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  28. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  29. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  30. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  31. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  32. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  33. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  34. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  35. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  36. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  37. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  38. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
  39. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
  40. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  42. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  46. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  50. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  52. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  53. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  54. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
@@ -8205,6 +8205,16 @@
8205
8205
  ],
8206
8206
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
8207
8207
  "model_specs": [
8208
+ {
8209
+ "model_format": "pytorch",
8210
+ "model_size_in_billions": "0_5",
8211
+ "quantizations": [
8212
+ "4-bit",
8213
+ "8-bit",
8214
+ "none"
8215
+ ],
8216
+ "model_id": "Qwen/Qwen2.5-Coder-0.5B"
8217
+ },
8208
8218
  {
8209
8219
  "model_format": "pytorch",
8210
8220
  "model_size_in_billions": "1_5",
@@ -8213,8 +8223,17 @@
8213
8223
  "8-bit",
8214
8224
  "none"
8215
8225
  ],
8216
- "model_id": "Qwen/Qwen2.5-Coder-1.5B",
8217
- "model_revision": "d3586cfe793730945f8e4d7ef31032a3ee50247d"
8226
+ "model_id": "Qwen/Qwen2.5-Coder-1.5B"
8227
+ },
8228
+ {
8229
+ "model_format": "pytorch",
8230
+ "model_size_in_billions": "3",
8231
+ "quantizations": [
8232
+ "4-bit",
8233
+ "8-bit",
8234
+ "none"
8235
+ ],
8236
+ "model_id": "Qwen/Qwen2.5-Coder-3B"
8218
8237
  },
8219
8238
  {
8220
8239
  "model_format": "pytorch",
@@ -8224,8 +8243,27 @@
8224
8243
  "8-bit",
8225
8244
  "none"
8226
8245
  ],
8227
- "model_id": "Qwen/Qwen2.5-Coder-7B",
8228
- "model_revision": "30b6a7e874a78d46b80fa1db3194ea427dd41b08"
8246
+ "model_id": "Qwen/Qwen2.5-Coder-7B"
8247
+ },
8248
+ {
8249
+ "model_format": "pytorch",
8250
+ "model_size_in_billions": 14,
8251
+ "quantizations": [
8252
+ "4-bit",
8253
+ "8-bit",
8254
+ "none"
8255
+ ],
8256
+ "model_id": "Qwen/Qwen2.5-Coder-14B"
8257
+ },
8258
+ {
8259
+ "model_format": "pytorch",
8260
+ "model_size_in_billions": 32,
8261
+ "quantizations": [
8262
+ "4-bit",
8263
+ "8-bit",
8264
+ "none"
8265
+ ],
8266
+ "model_id": "Qwen/Qwen2.5-Coder-32B"
8229
8267
  }
8230
8268
  ]
8231
8269
  },
@@ -8243,6 +8281,16 @@
8243
8281
  ],
8244
8282
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
8245
8283
  "model_specs": [
8284
+ {
8285
+ "model_format": "pytorch",
8286
+ "model_size_in_billions": "0_5",
8287
+ "quantizations": [
8288
+ "4-bit",
8289
+ "8-bit",
8290
+ "none"
8291
+ ],
8292
+ "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct"
8293
+ },
8246
8294
  {
8247
8295
  "model_format": "pytorch",
8248
8296
  "model_size_in_billions": "1_5",
@@ -8253,6 +8301,16 @@
8253
8301
  ],
8254
8302
  "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct"
8255
8303
  },
8304
+ {
8305
+ "model_format": "pytorch",
8306
+ "model_size_in_billions": "3",
8307
+ "quantizations": [
8308
+ "4-bit",
8309
+ "8-bit",
8310
+ "none"
8311
+ ],
8312
+ "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct"
8313
+ },
8256
8314
  {
8257
8315
  "model_format": "pytorch",
8258
8316
  "model_size_in_billions": 7,
@@ -8263,6 +8321,53 @@
8263
8321
  ],
8264
8322
  "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct"
8265
8323
  },
8324
+ {
8325
+ "model_format": "pytorch",
8326
+ "model_size_in_billions": 14,
8327
+ "quantizations": [
8328
+ "4-bit",
8329
+ "8-bit",
8330
+ "none"
8331
+ ],
8332
+ "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct"
8333
+ },
8334
+ {
8335
+ "model_format": "pytorch",
8336
+ "model_size_in_billions": 32,
8337
+ "quantizations": [
8338
+ "4-bit",
8339
+ "8-bit",
8340
+ "none"
8341
+ ],
8342
+ "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct"
8343
+ },
8344
+ {
8345
+ "model_format": "gptq",
8346
+ "model_size_in_billions": "0_5",
8347
+ "quantizations": [
8348
+ "Int4",
8349
+ "Int8"
8350
+ ],
8351
+ "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-{quantization}"
8352
+ },
8353
+ {
8354
+ "model_format": "gptq",
8355
+ "model_size_in_billions": "1_5",
8356
+ "quantizations": [
8357
+ "Int4",
8358
+ "Int8"
8359
+ ],
8360
+ "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-{quantization}"
8361
+ },
8362
+ {
8363
+ "model_format": "gptq",
8364
+ "model_size_in_billions": "3",
8365
+ "quantizations": [
8366
+ "Int4",
8367
+ "Int8"
8368
+ ],
8369
+ "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-{quantization}"
8370
+ },
8266
8371
  {
8267
8372
  "model_format": "gptq",
8268
8373
  "model_size_in_billions": "7",
@@ -8272,6 +8377,73 @@
8272
8377
  ],
8273
8378
  "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-{quantization}"
8274
8379
  },
8380
+ {
8381
+ "model_format": "gptq",
8382
+ "model_size_in_billions": "14",
8383
+ "quantizations": [
8384
+ "Int4",
8385
+ "Int8"
8386
+ ],
8387
+ "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-{quantization}"
8388
+ },
8389
+ {
8390
+ "model_format": "gptq",
8391
+ "model_size_in_billions": "32",
8392
+ "quantizations": [
8393
+ "Int4",
8394
+ "Int8"
8395
+ ],
8396
+ "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-{quantization}"
8397
+ },
8398
+ {
8399
+ "model_format": "awq",
8400
+ "model_size_in_billions": "0_5",
8401
+ "quantizations": [
8402
+ "Int4"
8403
+ ],
8404
+ "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ"
8405
+ },
8406
+ {
8407
+ "model_format": "awq",
8408
+ "model_size_in_billions": "1_5",
8409
+ "quantizations": [
8410
+ "Int4"
8411
+ ],
8412
+ "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ"
8413
+ },
8414
+ {
8415
+ "model_format": "awq",
8416
+ "model_size_in_billions": "3",
8417
+ "quantizations": [
8418
+ "Int4"
8419
+ ],
8420
+ "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct-AWQ"
8421
+ },
8422
+ {
8423
+ "model_format": "awq",
8424
+ "model_size_in_billions": "7",
8425
+ "quantizations": [
8426
+ "Int4"
8427
+ ],
8428
+ "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct-AWQ"
8429
+ },
8430
+ {
8431
+ "model_format": "awq",
8432
+ "model_size_in_billions": "14",
8433
+ "quantizations": [
8434
+ "Int4"
8435
+ ],
8436
+ "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct-AWQ"
8437
+ },
8438
+ {
8439
+ "model_format": "awq",
8440
+ "model_size_in_billions": "32",
8441
+ "quantizations": [
8442
+ "Int4"
8443
+ ],
8444
+ "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct-AWQ"
8445
+ },
8446
+
8275
8447
  {
8276
8448
  "model_format": "ggufv2",
8277
8449
  "model_size_in_billions": "1_5",
@@ -5907,6 +5907,18 @@
5907
5907
  ],
5908
5908
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
5909
5909
  "model_specs": [
5910
+ {
5911
+ "model_format": "pytorch",
5912
+ "model_size_in_billions": "0_5",
5913
+ "quantizations": [
5914
+ "4-bit",
5915
+ "8-bit",
5916
+ "none"
5917
+ ],
5918
+ "model_id": "qwen/Qwen2.5-Coder-0.5B",
5919
+ "model_revision": "master",
5920
+ "model_hub": "modelscope"
5921
+ },
5910
5922
  {
5911
5923
  "model_format": "pytorch",
5912
5924
  "model_size_in_billions": "1_5",
@@ -5919,6 +5931,18 @@
5919
5931
  "model_revision": "master",
5920
5932
  "model_hub": "modelscope"
5921
5933
  },
5934
+ {
5935
+ "model_format": "pytorch",
5936
+ "model_size_in_billions": "3",
5937
+ "quantizations": [
5938
+ "4-bit",
5939
+ "8-bit",
5940
+ "none"
5941
+ ],
5942
+ "model_id": "qwen/Qwen2.5-Coder-3B",
5943
+ "model_revision": "master",
5944
+ "model_hub": "modelscope"
5945
+ },
5922
5946
  {
5923
5947
  "model_format": "pytorch",
5924
5948
  "model_size_in_billions": 7,
@@ -5930,6 +5954,30 @@
5930
5954
  "model_id": "qwen/Qwen2.5-Coder-7B",
5931
5955
  "model_revision": "master",
5932
5956
  "model_hub": "modelscope"
5957
+ },
5958
+ {
5959
+ "model_format": "pytorch",
5960
+ "model_size_in_billions": 14,
5961
+ "quantizations": [
5962
+ "4-bit",
5963
+ "8-bit",
5964
+ "none"
5965
+ ],
5966
+ "model_id": "qwen/Qwen2.5-Coder-14B",
5967
+ "model_revision": "master",
5968
+ "model_hub": "modelscope"
5969
+ },
5970
+ {
5971
+ "model_format": "pytorch",
5972
+ "model_size_in_billions": 32,
5973
+ "quantizations": [
5974
+ "4-bit",
5975
+ "8-bit",
5976
+ "none"
5977
+ ],
5978
+ "model_id": "qwen/Qwen2.5-Coder-32B",
5979
+ "model_revision": "master",
5980
+ "model_hub": "modelscope"
5933
5981
  }
5934
5982
  ]
5935
5983
  },
@@ -5947,6 +5995,18 @@
5947
5995
  ],
5948
5996
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
5949
5997
  "model_specs": [
5998
+ {
5999
+ "model_format": "pytorch",
6000
+ "model_size_in_billions": "0_5",
6001
+ "quantizations": [
6002
+ "4-bit",
6003
+ "8-bit",
6004
+ "none"
6005
+ ],
6006
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct",
6007
+ "model_revision": "master",
6008
+ "model_hub": "modelscope"
6009
+ },
5950
6010
  {
5951
6011
  "model_format": "pytorch",
5952
6012
  "model_size_in_billions": "1_5",
@@ -5958,6 +6018,17 @@
5958
6018
  "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct",
5959
6019
  "model_revision": "master",
5960
6020
  "model_hub": "modelscope"
6021
+ }, {
6022
+ "model_format": "pytorch",
6023
+ "model_size_in_billions": "3",
6024
+ "quantizations": [
6025
+ "4-bit",
6026
+ "8-bit",
6027
+ "none"
6028
+ ],
6029
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct",
6030
+ "model_revision": "master",
6031
+ "model_hub": "modelscope"
5961
6032
  },
5962
6033
  {
5963
6034
  "model_format": "pytorch",
@@ -5971,6 +6042,63 @@
5971
6042
  "model_revision": "master",
5972
6043
  "model_hub": "modelscope"
5973
6044
  },
6045
+ {
6046
+ "model_format": "pytorch",
6047
+ "model_size_in_billions": 14,
6048
+ "quantizations": [
6049
+ "4-bit",
6050
+ "8-bit",
6051
+ "none"
6052
+ ],
6053
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct",
6054
+ "model_revision": "master",
6055
+ "model_hub": "modelscope"
6056
+ },
6057
+ {
6058
+ "model_format": "pytorch",
6059
+ "model_size_in_billions": 32,
6060
+ "quantizations": [
6061
+ "4-bit",
6062
+ "8-bit",
6063
+ "none"
6064
+ ],
6065
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct",
6066
+ "model_revision": "master",
6067
+ "model_hub": "modelscope"
6068
+ },
6069
+ {
6070
+ "model_format": "gptq",
6071
+ "model_size_in_billions": "0_5",
6072
+ "quantizations": [
6073
+ "Int4",
6074
+ "Int8"
6075
+ ],
6076
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-{quantization}",
6077
+ "model_revision": "master",
6078
+ "model_hub": "modelscope"
6079
+ },
6080
+ {
6081
+ "model_format": "gptq",
6082
+ "model_size_in_billions": "1_5",
6083
+ "quantizations": [
6084
+ "Int4",
6085
+ "Int8"
6086
+ ],
6087
+ "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-{quantization}",
6088
+ "model_revision": "master",
6089
+ "model_hub": "modelscope"
6090
+ },
6091
+ {
6092
+ "model_format": "gptq",
6093
+ "model_size_in_billions": 3,
6094
+ "quantizations": [
6095
+ "Int4",
6096
+ "Int8"
6097
+ ],
6098
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-{quantization}",
6099
+ "model_revision": "master",
6100
+ "model_hub": "modelscope"
6101
+ },
5974
6102
  {
5975
6103
  "model_format": "gptq",
5976
6104
  "model_size_in_billions": 7,
@@ -5982,6 +6110,89 @@
5982
6110
  "model_revision": "master",
5983
6111
  "model_hub": "modelscope"
5984
6112
  },
6113
+ {
6114
+ "model_format": "gptq",
6115
+ "model_size_in_billions": 14,
6116
+ "quantizations": [
6117
+ "Int4",
6118
+ "Int8"
6119
+ ],
6120
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-{quantization}",
6121
+ "model_revision": "master",
6122
+ "model_hub": "modelscope"
6123
+ },
6124
+ {
6125
+ "model_format": "gptq",
6126
+ "model_size_in_billions": 32,
6127
+ "quantizations": [
6128
+ "Int4",
6129
+ "Int8"
6130
+ ],
6131
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-{quantization}",
6132
+ "model_revision": "master",
6133
+ "model_hub": "modelscope"
6134
+ },
6135
+ {
6136
+ "model_format": "awq",
6137
+ "model_size_in_billions": "0_5",
6138
+ "quantizations": [
6139
+ "Int4"
6140
+ ],
6141
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ",
6142
+ "model_revision": "master",
6143
+ "model_hub": "modelscope"
6144
+ },
6145
+ {
6146
+ "model_format": "awq",
6147
+ "model_size_in_billions": "1_5",
6148
+ "quantizations": [
6149
+ "Int4"
6150
+ ],
6151
+ "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ",
6152
+ "model_revision": "master",
6153
+ "model_hub": "modelscope"
6154
+ },
6155
+ {
6156
+ "model_format": "awq",
6157
+ "model_size_in_billions": 3,
6158
+ "quantizations": [
6159
+ "Int4"
6160
+ ],
6161
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-AWQ",
6162
+ "model_revision": "master",
6163
+ "model_hub": "modelscope"
6164
+ },
6165
+ {
6166
+ "model_format": "awq",
6167
+ "model_size_in_billions": 7,
6168
+ "quantizations": [
6169
+ "Int4"
6170
+ ],
6171
+ "model_id": "qwen/Qwen2.5-Coder-7B-Instruct-AWQ",
6172
+ "model_revision": "master",
6173
+ "model_hub": "modelscope"
6174
+ },
6175
+ {
6176
+ "model_format": "awq",
6177
+ "model_size_in_billions": 14,
6178
+ "quantizations": [
6179
+ "Int4"
6180
+ ],
6181
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-AWQ",
6182
+ "model_revision": "master",
6183
+ "model_hub": "modelscope"
6184
+ },
6185
+ {
6186
+ "model_format": "awq",
6187
+ "model_size_in_billions": 32,
6188
+ "quantizations": [
6189
+ "Int4"
6190
+ ],
6191
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-AWQ",
6192
+ "model_revision": "master",
6193
+ "model_hub": "modelscope"
6194
+ },
6195
+
5985
6196
  {
5986
6197
  "model_format": "ggufv2",
5987
6198
  "model_size_in_billions": "1_5",
@@ -17,7 +17,8 @@ import platform
17
17
  import sys
18
18
  import time
19
19
  import uuid
20
- from typing import Dict, Iterator, List, Optional, TypedDict, Union
20
+ from dataclasses import dataclass, field
21
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
21
22
 
22
23
  from ....fields import max_tokens_field
23
24
  from ....types import (
@@ -53,6 +54,14 @@ class MLXGenerateConfig(TypedDict, total=False):
53
54
  stream: bool
54
55
  stream_options: Optional[Union[dict, None]]
55
56
  tools: Optional[List[Dict]]
57
+ lora_name: Optional[str]
58
+
59
+
60
+ @dataclass
61
+ class PromptCache:
62
+ cache: List[Any] = field(default_factory=list)
63
+ model_key: Tuple[str, Optional[str]] = ("", None)
64
+ tokens: List[int] = field(default_factory=list)
56
65
 
57
66
 
58
67
  class MLXModel(LLM):
@@ -69,6 +78,8 @@ class MLXModel(LLM):
69
78
  super().__init__(model_uid, model_family, model_spec, quantization, model_path)
70
79
  self._use_fast_tokenizer = True
71
80
  self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
81
+ self._max_kv_size = None
82
+ self._prompt_cache = None
72
83
  if peft_model is not None:
73
84
  raise ValueError("MLX engine has not supported lora yet")
74
85
 
@@ -127,6 +138,9 @@ class MLXModel(LLM):
127
138
  logger.debug(f"Setting cache limit to {cache_limit_gb} GB")
128
139
  mx.metal.set_cache_limit(cache_limit_gb * 1024 * 1024 * 1024)
129
140
 
141
+ self._max_kv_size = kwargs.get("max_kv_size", None)
142
+ self._prompt_cache = PromptCache()
143
+
130
144
  return load(
131
145
  self.model_path,
132
146
  tokenizer_config=tokenizer_config,
@@ -156,6 +170,27 @@ class MLXModel(LLM):
156
170
  return False
157
171
  return True
158
172
 
173
+ def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None):
174
+ from mlx_lm.models.cache import make_prompt_cache
175
+
176
+ assert self._prompt_cache is not None
177
+ cache_len = len(self._prompt_cache.tokens)
178
+ model_key = (self.model_path, lora_name)
179
+ if (
180
+ self._prompt_cache.model_key != model_key
181
+ or cache_len >= len(prompt)
182
+ or self._prompt_cache.tokens != prompt[:cache_len]
183
+ ):
184
+ self._prompt_cache.model_key = model_key
185
+ self._prompt_cache.cache = make_prompt_cache(self._model, self._max_kv_size)
186
+ self._prompt_cache.tokens = []
187
+ logger.debug("Making new prompt cache for %s", self.model_uid)
188
+ else:
189
+ prompt = prompt[cache_len:]
190
+ logger.debug("Cache hit for %s", self.model_uid)
191
+ self._prompt_cache.tokens.extend(prompt)
192
+ return prompt
193
+
159
194
  def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
160
195
  import mlx.core as mx
161
196
  from mlx_lm.utils import generate_step
@@ -167,6 +202,7 @@ class MLXModel(LLM):
167
202
  chunk_id = str(uuid.uuid4())
168
203
  stop_token_ids = kwargs.get("stop_token_ids", [])
169
204
  stream = kwargs.get("stream", False)
205
+ lora_name = kwargs.get("lora_name")
170
206
  stream_options = kwargs.pop("stream_options", None)
171
207
  include_usage = (
172
208
  stream_options["include_usage"]
@@ -174,12 +210,15 @@ class MLXModel(LLM):
174
210
  else False
175
211
  )
176
212
 
177
- prompt_tokens = mx.array(tokenizer.encode(prompt))
213
+ prompt_token_ids = tokenizer.encode(prompt)
214
+ prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name)
215
+ prompt_tokens = mx.array(prompt_token_ids)
178
216
  input_echo_len = len(prompt_tokens)
179
217
 
180
218
  i = 0
181
219
  start = time.time()
182
220
  output = ""
221
+ tokens = []
183
222
  for (token, _), i in zip(
184
223
  generate_step(
185
224
  prompt_tokens,
@@ -189,9 +228,11 @@ class MLXModel(LLM):
189
228
  repetition_context_size=kwargs["repetition_context_size"],
190
229
  top_p=kwargs["top_p"],
191
230
  logit_bias=kwargs["logit_bias"],
231
+ prompt_cache=self._prompt_cache.cache, # type: ignore
192
232
  ),
193
233
  range(max_tokens),
194
234
  ):
235
+ tokens.append(token)
195
236
  if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
196
237
  break
197
238
 
@@ -230,6 +271,8 @@ class MLXModel(LLM):
230
271
  f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
231
272
  )
232
273
 
274
+ self._prompt_cache.tokens.extend(tokens) # type: ignore
275
+
233
276
  if i == max_tokens - 1:
234
277
  finish_reason = "length"
235
278
  else:
@@ -179,6 +179,7 @@ class RerankModel:
179
179
  return rerank_type
180
180
 
181
181
  def load(self):
182
+ logger.info("Loading rerank model: %s", self._model_path)
182
183
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
183
184
  if (
184
185
  self._auto_detect_type(self._model_path) != "normal"
@@ -189,6 +190,7 @@ class RerankModel:
189
190
  "will force set `use_fp16` to True"
190
191
  )
191
192
  self._use_fp16 = True
193
+
192
194
  if self._model_spec.type == "normal":
193
195
  try:
194
196
  import sentence_transformers
@@ -250,22 +252,27 @@ class RerankModel:
250
252
  **kwargs,
251
253
  ) -> Rerank:
252
254
  assert self._model is not None
253
- if kwargs:
254
- raise ValueError("rerank hasn't support extra parameter.")
255
255
  if max_chunks_per_doc is not None:
256
256
  raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
257
+ logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
257
258
  sentence_combinations = [[query, doc] for doc in documents]
258
259
  # reset n tokens
259
260
  self._model.model.n_tokens = 0
260
261
  if self._model_spec.type == "normal":
261
262
  similarity_scores = self._model.predict(
262
- sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
263
+ sentence_combinations,
264
+ convert_to_numpy=False,
265
+ convert_to_tensor=True,
266
+ **kwargs,
263
267
  ).cpu()
264
268
  if similarity_scores.dtype == torch.bfloat16:
265
269
  similarity_scores = similarity_scores.float()
266
270
  else:
267
271
  # Related issue: https://github.com/xorbitsai/inference/issues/1775
268
- similarity_scores = self._model.compute_score(sentence_combinations)
272
+ similarity_scores = self._model.compute_score(
273
+ sentence_combinations, **kwargs
274
+ )
275
+
269
276
  if not isinstance(similarity_scores, Sequence):
270
277
  similarity_scores = [similarity_scores]
271
278
  elif (