kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,20 +1,8 @@
1
- import os
2
- from dataclasses import dataclass
3
1
  from enum import Enum
4
- from os import getenv
5
- from typing import Any, Dict, List, NoReturn
2
+ from typing import Dict, List
6
3
 
7
- import httpx
8
- import requests
9
- from langchain_aws import ChatBedrockConverse
10
- from langchain_core.language_models.chat_models import BaseChatModel
11
- from langchain_groq import ChatGroq
12
- from langchain_ollama import ChatOllama
13
- from langchain_openai import ChatOpenAI
14
4
  from pydantic import BaseModel
15
5
 
16
- from ..utils.config import Config
17
-
18
6
  """
19
7
  Provides model configuration and management for various LLM providers and models.
20
8
  This module handles the integration with different AI model providers and their respective models,
@@ -32,6 +20,9 @@ class ModelProviderName(str, Enum):
32
20
  amazon_bedrock = "amazon_bedrock"
33
21
  ollama = "ollama"
34
22
  openrouter = "openrouter"
23
+ fireworks_ai = "fireworks_ai"
24
+ kiln_fine_tune = "kiln_fine_tune"
25
+ kiln_custom_registry = "kiln_custom_registry"
35
26
 
36
27
 
37
28
  class ModelFamily(str, Enum):
@@ -46,6 +37,8 @@ class ModelFamily(str, Enum):
46
37
  gemma = "gemma"
47
38
  gemini = "gemini"
48
39
  claude = "claude"
40
+ mixtral = "mixtral"
41
+ qwen = "qwen"
49
42
 
50
43
 
51
44
  # Where models have instruct and raw versions, instruct is default and raw is specified
@@ -58,9 +51,11 @@ class ModelName(str, Enum):
58
51
  llama_3_1_8b = "llama_3_1_8b"
59
52
  llama_3_1_70b = "llama_3_1_70b"
60
53
  llama_3_1_405b = "llama_3_1_405b"
54
+ llama_3_2_1b = "llama_3_2_1b"
61
55
  llama_3_2_3b = "llama_3_2_3b"
62
56
  llama_3_2_11b = "llama_3_2_11b"
63
57
  llama_3_2_90b = "llama_3_2_90b"
58
+ llama_3_3_70b = "llama_3_3_70b"
64
59
  gpt_4o_mini = "gpt_4o_mini"
65
60
  gpt_4o = "gpt_4o"
66
61
  phi_3_5 = "phi_3_5"
@@ -75,6 +70,9 @@ class ModelName(str, Enum):
75
70
  gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
76
71
  gemini_1_5_pro = "gemini_1_5_pro"
77
72
  nemotron_70b = "nemotron_70b"
73
+ mixtral_8x7b = "mixtral_8x7b"
74
+ qwen_2p5_7b = "qwen_2p5_7b"
75
+ qwen_2p5_72b = "qwen_2p5_72b"
78
76
 
79
77
 
80
78
  class KilnModelProvider(BaseModel):
@@ -84,13 +82,20 @@ class KilnModelProvider(BaseModel):
84
82
  Attributes:
85
83
  name: The provider's identifier
86
84
  supports_structured_output: Whether the provider supports structured output formats
85
+ supports_data_gen: Whether the provider supports data generation
86
+ untested_model: Whether the model is untested (typically user added). The supports_ fields are not applicable.
87
+ provider_finetune_id: The finetune ID for the provider, if applicable
87
88
  provider_options: Additional provider-specific configuration options
89
+ adapter_options: Additional options specific to the adapter. Top level key should be adapter ID.
88
90
  """
89
91
 
90
92
  name: ModelProviderName
91
93
  supports_structured_output: bool = True
92
94
  supports_data_gen: bool = True
95
+ untested_model: bool = False
96
+ provider_finetune_id: str | None = None
93
97
  provider_options: Dict = {}
98
+ adapter_options: Dict = {}
94
99
 
95
100
 
96
101
  class KilnModel(BaseModel):
@@ -122,6 +127,7 @@ built_in_models: List[KilnModel] = [
122
127
  KilnModelProvider(
123
128
  name=ModelProviderName.openai,
124
129
  provider_options={"model": "gpt-4o-mini"},
130
+ provider_finetune_id="gpt-4o-mini-2024-07-18",
125
131
  ),
126
132
  KilnModelProvider(
127
133
  name=ModelProviderName.openrouter,
@@ -138,6 +144,7 @@ built_in_models: List[KilnModel] = [
138
144
  KilnModelProvider(
139
145
  name=ModelProviderName.openai,
140
146
  provider_options={"model": "gpt-4o"},
147
+ provider_finetune_id="gpt-4o-2024-08-06",
141
148
  ),
142
149
  KilnModelProvider(
143
150
  name=ModelProviderName.openrouter,
@@ -257,6 +264,15 @@ built_in_models: List[KilnModel] = [
257
264
  supports_data_gen=False,
258
265
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
259
266
  ),
267
+ KilnModelProvider(
268
+ name=ModelProviderName.fireworks_ai,
269
+ supports_structured_output=False,
270
+ supports_data_gen=False,
271
+ provider_finetune_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
272
+ provider_options={
273
+ "model": "accounts/fireworks/models/llama-v3p1-8b-instruct"
274
+ },
275
+ ),
260
276
  ],
261
277
  ),
262
278
  # Llama 3.1 70b
@@ -271,7 +287,7 @@ built_in_models: List[KilnModel] = [
271
287
  ),
272
288
  KilnModelProvider(
273
289
  name=ModelProviderName.amazon_bedrock,
274
- # not sure how AWS manages to break this, but it's not working
290
+ # AWS 70b not working as well as the others.
275
291
  supports_structured_output=False,
276
292
  supports_data_gen=False,
277
293
  provider_options={
@@ -287,6 +303,13 @@ built_in_models: List[KilnModel] = [
287
303
  name=ModelProviderName.ollama,
288
304
  provider_options={"model": "llama3.1:70b"},
289
305
  ),
306
+ KilnModelProvider(
307
+ name=ModelProviderName.fireworks_ai,
308
+ provider_finetune_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
309
+ provider_options={
310
+ "model": "accounts/fireworks/models/llama-v3p1-70b-instruct"
311
+ },
312
+ ),
290
313
  ],
291
314
  ),
292
315
  # Llama 3.1 405b
@@ -311,6 +334,13 @@ built_in_models: List[KilnModel] = [
311
334
  name=ModelProviderName.openrouter,
312
335
  provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
313
336
  ),
337
+ KilnModelProvider(
338
+ name=ModelProviderName.fireworks_ai,
339
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
340
+ provider_options={
341
+ "model": "accounts/fireworks/models/llama-v3p1-405b-instruct"
342
+ },
343
+ ),
314
344
  ],
315
345
  ),
316
346
  # Mistral Nemo
@@ -348,6 +378,35 @@ built_in_models: List[KilnModel] = [
348
378
  ),
349
379
  ],
350
380
  ),
381
+ # Llama 3.2 1B
382
+ KilnModel(
383
+ family=ModelFamily.llama,
384
+ name=ModelName.llama_3_2_1b,
385
+ friendly_name="Llama 3.2 1B",
386
+ providers=[
387
+ KilnModelProvider(
388
+ name=ModelProviderName.openrouter,
389
+ supports_structured_output=False,
390
+ supports_data_gen=False,
391
+ provider_options={"model": "meta-llama/llama-3.2-1b-instruct"},
392
+ ),
393
+ KilnModelProvider(
394
+ name=ModelProviderName.ollama,
395
+ supports_structured_output=False,
396
+ supports_data_gen=False,
397
+ provider_options={"model": "llama3.2:1b"},
398
+ ),
399
+ KilnModelProvider(
400
+ name=ModelProviderName.fireworks_ai,
401
+ provider_finetune_id="accounts/fireworks/models/llama-v3p2-1b-instruct",
402
+ supports_structured_output=False,
403
+ supports_data_gen=False,
404
+ provider_options={
405
+ "model": "accounts/fireworks/models/llama-v3p2-1b-instruct"
406
+ },
407
+ ),
408
+ ],
409
+ ),
351
410
  # Llama 3.2 3B
352
411
  KilnModel(
353
412
  family=ModelFamily.llama,
@@ -366,6 +425,15 @@ built_in_models: List[KilnModel] = [
366
425
  supports_data_gen=False,
367
426
  provider_options={"model": "llama3.2"},
368
427
  ),
428
+ KilnModelProvider(
429
+ name=ModelProviderName.fireworks_ai,
430
+ provider_finetune_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
431
+ supports_structured_output=False,
432
+ supports_data_gen=False,
433
+ provider_options={
434
+ "model": "accounts/fireworks/models/llama-v3p2-3b-instruct"
435
+ },
436
+ ),
369
437
  ],
370
438
  ),
371
439
  # Llama 3.2 11B
@@ -376,9 +444,12 @@ built_in_models: List[KilnModel] = [
376
444
  providers=[
377
445
  KilnModelProvider(
378
446
  name=ModelProviderName.openrouter,
379
- supports_structured_output=False,
380
- supports_data_gen=False,
381
447
  provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
448
+ adapter_options={
449
+ "langchain": {
450
+ "with_structured_output_options": {"method": "json_mode"}
451
+ }
452
+ },
382
453
  ),
383
454
  KilnModelProvider(
384
455
  name=ModelProviderName.ollama,
@@ -386,6 +457,18 @@ built_in_models: List[KilnModel] = [
386
457
  supports_data_gen=False,
387
458
  provider_options={"model": "llama3.2-vision"},
388
459
  ),
460
+ KilnModelProvider(
461
+ name=ModelProviderName.fireworks_ai,
462
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
463
+ provider_options={
464
+ "model": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct"
465
+ },
466
+ adapter_options={
467
+ "langchain": {
468
+ "with_structured_output_options": {"method": "json_mode"}
469
+ }
470
+ },
471
+ ),
389
472
  ],
390
473
  ),
391
474
  # Llama 3.2 90B
@@ -396,15 +479,60 @@ built_in_models: List[KilnModel] = [
396
479
  providers=[
397
480
  KilnModelProvider(
398
481
  name=ModelProviderName.openrouter,
399
- supports_structured_output=False,
400
- supports_data_gen=False,
401
482
  provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
483
+ adapter_options={
484
+ "langchain": {
485
+ "with_structured_output_options": {"method": "json_mode"}
486
+ }
487
+ },
402
488
  ),
403
489
  KilnModelProvider(
404
490
  name=ModelProviderName.ollama,
491
+ provider_options={"model": "llama3.2-vision:90b"},
492
+ ),
493
+ KilnModelProvider(
494
+ name=ModelProviderName.fireworks_ai,
495
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
496
+ provider_options={
497
+ "model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
498
+ },
499
+ adapter_options={
500
+ "langchain": {
501
+ "with_structured_output_options": {"method": "json_mode"}
502
+ }
503
+ },
504
+ ),
505
+ ],
506
+ ),
507
+ # Llama 3.3 70B
508
+ KilnModel(
509
+ family=ModelFamily.llama,
510
+ name=ModelName.llama_3_3_70b,
511
+ friendly_name="Llama 3.3 70B",
512
+ providers=[
513
+ KilnModelProvider(
514
+ name=ModelProviderName.openrouter,
515
+ provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
516
+ # Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
405
517
  supports_structured_output=False,
406
518
  supports_data_gen=False,
407
- provider_options={"model": "llama3.2-vision:90b"},
519
+ adapter_options={
520
+ "langchain": {
521
+ "with_structured_output_options": {"method": "json_mode"}
522
+ }
523
+ },
524
+ ),
525
+ KilnModelProvider(
526
+ name=ModelProviderName.ollama,
527
+ provider_options={"model": "llama3.3"},
528
+ ),
529
+ KilnModelProvider(
530
+ name=ModelProviderName.fireworks_ai,
531
+ # Finetuning not live yet
532
+ # provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
533
+ provider_options={
534
+ "model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
535
+ },
408
536
  ),
409
537
  ],
410
538
  ),
@@ -427,6 +555,15 @@ built_in_models: List[KilnModel] = [
427
555
  supports_data_gen=False,
428
556
  provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
429
557
  ),
558
+ KilnModelProvider(
559
+ name=ModelProviderName.fireworks_ai,
560
+ supports_structured_output=False,
561
+ supports_data_gen=False,
562
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
563
+ provider_options={
564
+ "model": "accounts/fireworks/models/phi-3-vision-128k-instruct"
565
+ },
566
+ ),
430
567
  ],
431
568
  ),
432
569
  # Gemma 2 2.6b
@@ -465,6 +602,7 @@ built_in_models: List[KilnModel] = [
465
602
  supports_data_gen=False,
466
603
  provider_options={"model": "google/gemma-2-9b-it"},
467
604
  ),
605
+ # fireworks AI errors - not allowing system role. Exclude until resolved.
468
606
  ],
469
607
  ),
470
608
  # Gemma 2 27b
@@ -488,290 +626,88 @@ built_in_models: List[KilnModel] = [
488
626
  ),
489
627
  ],
490
628
  ),
491
- ]
492
-
493
-
494
- def get_model_and_provider(
495
- model_name: str, provider_name: str
496
- ) -> tuple[KilnModel | None, KilnModelProvider | None]:
497
- model = next(filter(lambda m: m.name == model_name, built_in_models), None)
498
- if model is None:
499
- return None, None
500
- provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
501
- # all or nothing
502
- if provider is None or model is None:
503
- return None, None
504
- return model, provider
505
-
506
-
507
- def provider_name_from_id(id: str) -> str:
508
- """
509
- Converts a provider ID to its human-readable name.
510
-
511
- Args:
512
- id: The provider identifier string
513
-
514
- Returns:
515
- The human-readable name of the provider
516
-
517
- Raises:
518
- ValueError: If the provider ID is invalid or unhandled
519
- """
520
- if id in ModelProviderName.__members__:
521
- enum_id = ModelProviderName(id)
522
- match enum_id:
523
- case ModelProviderName.amazon_bedrock:
524
- return "Amazon Bedrock"
525
- case ModelProviderName.openrouter:
526
- return "OpenRouter"
527
- case ModelProviderName.groq:
528
- return "Groq"
529
- case ModelProviderName.ollama:
530
- return "Ollama"
531
- case ModelProviderName.openai:
532
- return "OpenAI"
533
- case _:
534
- # triggers pyright warning if I miss a case
535
- raise_exhaustive_error(enum_id)
536
-
537
- return "Unknown provider: " + id
538
-
539
-
540
- def raise_exhaustive_error(value: NoReturn) -> NoReturn:
541
- raise ValueError(f"Unhandled enum value: {value}")
542
-
543
-
544
- @dataclass
545
- class ModelProviderWarning:
546
- required_config_keys: List[str]
547
- message: str
548
-
549
-
550
- provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
551
- ModelProviderName.amazon_bedrock: ModelProviderWarning(
552
- required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
553
- message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
554
- ),
555
- ModelProviderName.openrouter: ModelProviderWarning(
556
- required_config_keys=["open_router_api_key"],
557
- message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
629
+ # Mixtral 8x7B
630
+ KilnModel(
631
+ family=ModelFamily.mixtral,
632
+ name=ModelName.mixtral_8x7b,
633
+ friendly_name="Mixtral 8x7B",
634
+ providers=[
635
+ KilnModelProvider(
636
+ name=ModelProviderName.openrouter,
637
+ provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
638
+ adapter_options={
639
+ "langchain": {
640
+ "with_structured_output_options": {"method": "json_mode"}
641
+ }
642
+ },
643
+ ),
644
+ KilnModelProvider(
645
+ name=ModelProviderName.ollama,
646
+ supports_structured_output=False,
647
+ supports_data_gen=False,
648
+ provider_options={"model": "mixtral"},
649
+ ),
650
+ ],
558
651
  ),
559
- ModelProviderName.groq: ModelProviderWarning(
560
- required_config_keys=["groq_api_key"],
561
- message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
652
+ # Qwen 2.5 7B
653
+ KilnModel(
654
+ family=ModelFamily.qwen,
655
+ name=ModelName.qwen_2p5_7b,
656
+ friendly_name="Qwen 2.5 7B",
657
+ providers=[
658
+ KilnModelProvider(
659
+ name=ModelProviderName.openrouter,
660
+ provider_options={"model": "qwen/qwen-2.5-7b-instruct"},
661
+ # Tool calls not supported. JSON doesn't error, but fails.
662
+ supports_structured_output=False,
663
+ supports_data_gen=False,
664
+ adapter_options={
665
+ "langchain": {
666
+ "with_structured_output_options": {"method": "json_mode"}
667
+ }
668
+ },
669
+ ),
670
+ KilnModelProvider(
671
+ name=ModelProviderName.ollama,
672
+ provider_options={"model": "qwen2.5"},
673
+ ),
674
+ ],
562
675
  ),
563
- ModelProviderName.openai: ModelProviderWarning(
564
- required_config_keys=["open_ai_api_key"],
565
- message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
676
+ # Qwen 2.5 72B
677
+ KilnModel(
678
+ family=ModelFamily.qwen,
679
+ name=ModelName.qwen_2p5_72b,
680
+ friendly_name="Qwen 2.5 72B",
681
+ providers=[
682
+ KilnModelProvider(
683
+ name=ModelProviderName.openrouter,
684
+ provider_options={"model": "qwen/qwen-2.5-72b-instruct"},
685
+ # Not consistent with structure data. Works sometimes but not often
686
+ supports_structured_output=False,
687
+ supports_data_gen=False,
688
+ adapter_options={
689
+ "langchain": {
690
+ "with_structured_output_options": {"method": "json_mode"}
691
+ }
692
+ },
693
+ ),
694
+ KilnModelProvider(
695
+ name=ModelProviderName.ollama,
696
+ provider_options={"model": "qwen2.5:72b"},
697
+ ),
698
+ KilnModelProvider(
699
+ name=ModelProviderName.fireworks_ai,
700
+ provider_options={
701
+ "model": "accounts/fireworks/models/qwen2p5-72b-instruct"
702
+ },
703
+ # Fireworks will start tuning, but it never finishes.
704
+ # provider_finetune_id="accounts/fireworks/models/qwen2p5-72b-instruct",
705
+ adapter_options={
706
+ "langchain": {
707
+ "with_structured_output_options": {"method": "json_mode"}
708
+ }
709
+ },
710
+ ),
711
+ ],
566
712
  ),
567
- }
568
-
569
-
570
- def get_config_value(key: str):
571
- try:
572
- return Config.shared().__getattr__(key)
573
- except AttributeError:
574
- return None
575
-
576
-
577
- def check_provider_warnings(provider_name: ModelProviderName):
578
- """
579
- Validates that required configuration is present for a given provider.
580
-
581
- Args:
582
- provider_name: The provider to check
583
-
584
- Raises:
585
- ValueError: If required configuration keys are missing
586
- """
587
- warning_check = provider_warnings.get(provider_name)
588
- if warning_check is None:
589
- return
590
- for key in warning_check.required_config_keys:
591
- if get_config_value(key) is None:
592
- raise ValueError(warning_check.message)
593
-
594
-
595
- async def langchain_model_from(
596
- name: str, provider_name: str | None = None
597
- ) -> BaseChatModel:
598
- """
599
- Creates a LangChain chat model instance for the specified model and provider.
600
-
601
- Args:
602
- name: The name of the model to instantiate
603
- provider_name: Optional specific provider to use (defaults to first available)
604
-
605
- Returns:
606
- A configured LangChain chat model instance
607
-
608
- Raises:
609
- ValueError: If the model/provider combination is invalid or misconfigured
610
- """
611
- if name not in ModelName.__members__:
612
- raise ValueError(f"Invalid name: {name}")
613
-
614
- # Select the model from built_in_models using the name
615
- model = next(filter(lambda m: m.name == name, built_in_models))
616
- if model is None:
617
- raise ValueError(f"Model {name} not found")
618
-
619
- # If a provider is provided, select the provider from the model's provider_config
620
- provider: KilnModelProvider | None = None
621
- if model.providers is None or len(model.providers) == 0:
622
- raise ValueError(f"Model {name} has no providers")
623
- elif provider_name is None:
624
- # TODO: priority order
625
- provider = model.providers[0]
626
- else:
627
- provider = next(
628
- filter(lambda p: p.name == provider_name, model.providers), None
629
- )
630
- if provider is None:
631
- raise ValueError(f"Provider {provider_name} not found for model {name}")
632
-
633
- check_provider_warnings(provider.name)
634
-
635
- if provider.name == ModelProviderName.openai:
636
- api_key = Config.shared().open_ai_api_key
637
- return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
638
- elif provider.name == ModelProviderName.groq:
639
- api_key = Config.shared().groq_api_key
640
- if api_key is None:
641
- raise ValueError(
642
- "Attempted to use Groq without an API key set. "
643
- "Get your API key from https://console.groq.com/keys"
644
- )
645
- return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
646
- elif provider.name == ModelProviderName.amazon_bedrock:
647
- api_key = Config.shared().bedrock_access_key
648
- secret_key = Config.shared().bedrock_secret_key
649
- # langchain doesn't allow passing these, so ugly hack to set env vars
650
- os.environ["AWS_ACCESS_KEY_ID"] = api_key
651
- os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
652
- return ChatBedrockConverse(
653
- **provider.provider_options,
654
- )
655
- elif provider.name == ModelProviderName.ollama:
656
- # Ollama model naming is pretty flexible. We try a few versions of the model name
657
- potential_model_names = []
658
- if "model" in provider.provider_options:
659
- potential_model_names.append(provider.provider_options["model"])
660
- if "model_aliases" in provider.provider_options:
661
- potential_model_names.extend(provider.provider_options["model_aliases"])
662
-
663
- # Get the list of models Ollama supports
664
- ollama_connection = await get_ollama_connection()
665
- if ollama_connection is None:
666
- raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
667
-
668
- for model_name in potential_model_names:
669
- if ollama_model_supported(ollama_connection, model_name):
670
- return ChatOllama(model=model_name, base_url=ollama_base_url())
671
-
672
- raise ValueError(f"Model {name} not installed on Ollama")
673
- elif provider.name == ModelProviderName.openrouter:
674
- api_key = Config.shared().open_router_api_key
675
- base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
676
- return ChatOpenAI(
677
- **provider.provider_options,
678
- openai_api_key=api_key, # type: ignore[arg-type]
679
- openai_api_base=base_url, # type: ignore[arg-type]
680
- default_headers={
681
- "HTTP-Referer": "https://getkiln.ai/openrouter",
682
- "X-Title": "KilnAI",
683
- },
684
- )
685
- else:
686
- raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
687
-
688
-
689
- def ollama_base_url() -> str:
690
- """
691
- Gets the base URL for Ollama API connections.
692
-
693
- Returns:
694
- The base URL to use for Ollama API calls, using environment variable if set
695
- or falling back to localhost default
696
- """
697
- env_base_url = os.getenv("OLLAMA_BASE_URL")
698
- if env_base_url is not None:
699
- return env_base_url
700
- return "http://localhost:11434"
701
-
702
-
703
- async def ollama_online() -> bool:
704
- """
705
- Checks if the Ollama service is available and responding.
706
-
707
- Returns:
708
- True if Ollama is available and responding, False otherwise
709
- """
710
- try:
711
- httpx.get(ollama_base_url() + "/api/tags")
712
- except httpx.RequestError:
713
- return False
714
- return True
715
-
716
-
717
- class OllamaConnection(BaseModel):
718
- message: str
719
- models: List[str]
720
-
721
-
722
- # Parse the Ollama /api/tags response
723
- def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
724
- # Build a list of models we support for Ollama from the built-in model list
725
- supported_ollama_models = [
726
- provider.provider_options["model"]
727
- for model in built_in_models
728
- for provider in model.providers
729
- if provider.name == ModelProviderName.ollama
730
- ]
731
- # Append model_aliases to supported_ollama_models
732
- supported_ollama_models.extend(
733
- [
734
- alias
735
- for model in built_in_models
736
- for provider in model.providers
737
- for alias in provider.provider_options.get("model_aliases", [])
738
- ]
739
- )
740
-
741
- if "models" in tags:
742
- models = tags["models"]
743
- if isinstance(models, list):
744
- model_names = [model["model"] for model in models]
745
- available_supported_models = [
746
- model
747
- for model in model_names
748
- if model in supported_ollama_models
749
- or model in [f"{m}:latest" for m in supported_ollama_models]
750
- ]
751
- if available_supported_models:
752
- return OllamaConnection(
753
- message="Ollama connected",
754
- models=available_supported_models,
755
- )
756
-
757
- return OllamaConnection(
758
- message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
759
- models=[],
760
- )
761
-
762
-
763
- async def get_ollama_connection() -> OllamaConnection | None:
764
- """
765
- Gets the connection status for Ollama.
766
- """
767
- try:
768
- tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
769
-
770
- except Exception:
771
- return None
772
-
773
- return parse_ollama_tags(tags)
774
-
775
-
776
- def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
777
- return model_name in conn.models or f"{model_name}:latest" in conn.models
713
+ ]