llm-codegen-research 2.11__tar.gz → 2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/PKG-INFO +1 -1
  2. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/pyproject.toml +6 -0
  3. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/__init__.py +4 -0
  4. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/languages/__init__.py +3 -0
  5. llm_codegen_research-2.13/src/llm_cgr/analyse/languages/rust.py +193 -0
  6. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/__init__.py +4 -0
  7. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/__init__.py +22 -0
  8. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/anthropic.py +8 -6
  9. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/deepseek.py +5 -3
  10. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/nscale.py +5 -3
  11. llm_codegen_research-2.13/src/llm_cgr/llm/clients/openai_tool.py +258 -0
  12. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/together.py +9 -6
  13. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/PKG-INFO +1 -1
  14. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/SOURCES.txt +3 -0
  15. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/tests/test_enums.py +3 -2
  16. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/tests/test_llm_api.py +6 -0
  17. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/tests/test_llm_local.py +26 -0
  18. llm_codegen_research-2.13/tests/test_llm_tool.py +96 -0
  19. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/LICENSE +0 -0
  20. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/README.md +0 -0
  21. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/setup.cfg +0 -0
  22. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/__init__.py +0 -0
  23. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/classes.py +0 -0
  24. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/languages/code_data.py +0 -0
  25. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/languages/javascript.py +0 -0
  26. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/languages/python.py +0 -0
  27. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/analyse/regexes.py +0 -0
  28. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/decorators.py +0 -0
  29. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/defaults.py +0 -0
  30. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/enums.py +0 -0
  31. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/json_utils.py +0 -0
  32. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/base.py +0 -0
  33. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/mistral.py +0 -0
  34. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/openai.py +0 -0
  35. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/clients/protocol.py +0 -0
  36. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/generate.py +0 -0
  37. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/llm/prompts.py +0 -0
  38. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/py.typed +0 -0
  39. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/scripts/test_cuda.py +0 -0
  40. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_cgr/timeout.py +0 -0
  41. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/dependency_links.txt +0 -0
  42. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/entry_points.txt +0 -0
  43. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/requires.txt +0 -0
  44. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/src/llm_codegen_research.egg-info/top_level.txt +0 -0
  45. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/tests/test_json_utils.py +0 -0
  46. {llm_codegen_research-2.11 → llm_codegen_research-2.13}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llm-codegen-research
3
- Version: 2.11
3
+ Version: 2.13
4
4
  Summary: Useful classes and methods for researching code-generation by LLMs.
5
5
  Author-email: Lukas Twist <itsluketwist@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/itsluketwist/llm-codegen-research
@@ -46,5 +46,11 @@ dev = [
46
46
  "uv",
47
47
  ]
48
48
 
49
+ [tool.pytest.ini_options]
50
+ markers = [
51
+ # tests that make real external api calls - excluded from ci runs
52
+ "api: marks tests as making external api calls",
53
+ ]
54
+
49
55
  [project.scripts]
50
56
  test_cuda = "llm_cgr.scripts.test_cuda:main"
@@ -45,7 +45,9 @@ try:
45
45
  GenerationProtocol,
46
46
  Mistral_LLM,
47
47
  OpenAI_LLM,
48
+ OpenAI_Tool_LLM,
48
49
  TogetherAI_LLM,
50
+ Tool,
49
51
  generate,
50
52
  generate_bool,
51
53
  generate_list,
@@ -64,6 +66,8 @@ try:
64
66
  "GenerationProtocol",
65
67
  "Mistral_LLM",
66
68
  "OpenAI_LLM",
69
+ "OpenAI_Tool_LLM",
70
+ "Tool",
67
71
  "TogetherAI_LLM",
68
72
  "generate",
69
73
  "generate_bool",
@@ -1,6 +1,7 @@
1
1
  from llm_cgr.analyse.languages.code_data import CodeData
2
2
  from llm_cgr.analyse.languages.javascript import analyse_javascript_code
3
3
  from llm_cgr.analyse.languages.python import analyse_python_code
4
+ from llm_cgr.analyse.languages.rust import analyse_rust_code
4
5
 
5
6
 
6
7
  def analyse_code(code: str, language: str | None) -> CodeData:
@@ -12,6 +13,8 @@ def analyse_code(code: str, language: str | None) -> CodeData:
12
13
  return analyse_python_code(code=code)
13
14
  elif language == "javascript":
14
15
  return analyse_javascript_code(code=code)
16
+ elif language == "rust":
17
+ return analyse_rust_code(code=code)
15
18
 
16
19
  except Exception as exc:
17
20
  return CodeData(
@@ -0,0 +1,193 @@
1
+ """Utility functions for Rust code analysis."""
2
+
3
+ import re
4
+
5
+ from llm_cgr.analyse.languages.code_data import CodeData
6
+
7
+
8
+ # rust's three standard library crates - everything else comes from crates.io
9
+ RUST_STDLIB = frozenset({"std", "core", "alloc"})
10
+
11
+ # matches extern crate declarations: extern crate serde;
12
+ _EXTERN_CRATE_RE = re.compile(r"^\s*extern\s+crate\s+(\w+)", re.MULTILINE)
13
+
14
+
15
+ def _strip_comments(code: str) -> str:
16
+ """Remove // line comments and /* */ block comments from code."""
17
+ # remove block comments first (they can span multiple lines)
18
+ code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
19
+ # remove line comments
20
+ code = re.sub(r"//[^\n]*", "", code)
21
+ return code
22
+
23
+
24
+ def _extract_use_statements(code: str) -> list[str]:
25
+ """
26
+ Extract raw use paths from all use declarations in the code.
27
+
28
+ Uses a bracket-aware scanner rather than plain regex, because braced imports
29
+ like `use std::{io, fmt};` span multiple tokens and can be multi-line.
30
+ Returns a list of raw path strings, e.g. ["std::collections::HashMap", "tokio"].
31
+ """
32
+ paths = []
33
+ i = 0
34
+ while i < len(code):
35
+ # find the next 'use' keyword (word boundary ensures we skip 'reuse', etc.)
36
+ match = re.search(r"\buse\s+", code[i:])
37
+ if match is None:
38
+ break
39
+
40
+ start = i + match.end()
41
+ depth = 0
42
+ j = start
43
+
44
+ # scan forward until we hit a ';' at brace-depth 0
45
+ while j < len(code):
46
+ c = code[j]
47
+ if c == "{":
48
+ depth += 1
49
+ elif c == "}":
50
+ depth -= 1
51
+ elif c == ";" and depth == 0:
52
+ paths.append(code[start:j].strip())
53
+ break
54
+ j += 1
55
+
56
+ i = start # advance past the 'use' keyword we just processed
57
+
58
+ return paths
59
+
60
+
61
+ def _split_top_level(s: str) -> list[str]:
62
+ """
63
+ Split a comma-separated string into items, respecting brace nesting.
64
+
65
+ Used to split the contents of {...} groups in use paths.
66
+ """
67
+ items = []
68
+ depth = 0
69
+ current: list[str] = []
70
+
71
+ for c in s:
72
+ if c == "{":
73
+ depth += 1
74
+ current.append(c)
75
+ elif c == "}":
76
+ depth -= 1
77
+ current.append(c)
78
+ elif c == "," and depth == 0:
79
+ items.append("".join(current).strip())
80
+ current = []
81
+ else:
82
+ current.append(c)
83
+
84
+ if current:
85
+ items.append("".join(current).strip())
86
+
87
+ return items
88
+
89
+
90
+ def _expand_use_path(path: str, prefix: str = "") -> list[str]:
91
+ """
92
+ Recursively expand a Rust use path into fully-qualified dotted import strings.
93
+
94
+ Handles simple paths, braced groups, wildcards, and 'as' aliases.
95
+ For example:
96
+ "std::collections::HashMap" -> ["std.collections.HashMap"]
97
+ "std::{io::Read, fmt}" -> ["std.io.Read", "std.fmt"]
98
+ "std::io::{self, Write}" -> ["std.io", "std.io.Write"]
99
+ Returns a list of dotted import paths.
100
+ """
101
+ # strip any trailing 'as alias' at this level of the path
102
+ path = re.sub(r"\s+as\s+\w+\s*$", "", path).strip()
103
+
104
+ if not path:
105
+ return []
106
+
107
+ # handle wildcard: use std::collections::*;
108
+ if path.endswith("::*"):
109
+ prefix_part = path[:-3].replace("::", ".")
110
+ full = f"{prefix}.{prefix_part}" if prefix else prefix_part
111
+ return [f"{full}.*"]
112
+
113
+ brace_idx = path.find("{")
114
+
115
+ if brace_idx == -1:
116
+ # simple path, no braces: std::collections::HashMap
117
+ converted = path.replace("::", ".")
118
+ full = f"{prefix}.{converted}" if prefix else converted
119
+ return [full]
120
+
121
+ # braced path: std::{io::Read, fmt} or std::io::{self, Write}
122
+ # everything before the brace is the common prefix
123
+ before_brace = path[:brace_idx].rstrip(":").replace("::", ".")
124
+ new_prefix = (
125
+ f"{prefix}.{before_brace}"
126
+ if (prefix and before_brace)
127
+ else (prefix or before_brace)
128
+ )
129
+
130
+ # extract the content inside the outermost braces
131
+ close_idx = path.rfind("}")
132
+ inner = path[brace_idx + 1 : close_idx]
133
+
134
+ results = []
135
+ for item in _split_top_level(inner):
136
+ item = item.strip()
137
+ if not item:
138
+ continue
139
+
140
+ if item == "self":
141
+ # 'self' refers to the module itself (the prefix path)
142
+ results.append(new_prefix)
143
+ else:
144
+ results.extend(_expand_use_path(item, prefix=new_prefix))
145
+
146
+ return results
147
+
148
+
149
+ def analyse_rust_code(code: str) -> CodeData:
150
+ """
151
+ Analyse Rust code to extract imported crates and their paths.
152
+
153
+ Only extracts use and extern crate declarations; no usage tracking or syntax
154
+ validation is performed (valid is always True).
155
+ Returns a CodeData object with import information.
156
+ """
157
+ std_libs: set[str] = set()
158
+ ext_libs: set[str] = set()
159
+ imports: set[str] = set()
160
+
161
+ # strip comments so 'use' inside comments isn't matched
162
+ clean_code = _strip_comments(code)
163
+
164
+ # process each use declaration
165
+ for raw_path in _extract_use_statements(clean_code):
166
+ for dotted_path in _expand_use_path(raw_path):
167
+ imports.add(dotted_path)
168
+ # the top-level segment is the crate name
169
+ top_level = dotted_path.split(".")[0]
170
+ if top_level in RUST_STDLIB:
171
+ std_libs.add(top_level)
172
+ else:
173
+ ext_libs.add(top_level)
174
+
175
+ # handle extern crate declarations (older rust style, still used occasionally)
176
+ for match in _EXTERN_CRATE_RE.finditer(clean_code):
177
+ crate_name = match.group(1)
178
+ # 'self' and 'std' as extern crate are special cases, skip them
179
+ if crate_name in ("self", "super"):
180
+ continue
181
+ imports.add(crate_name)
182
+ if crate_name in RUST_STDLIB:
183
+ std_libs.add(crate_name)
184
+ else:
185
+ ext_libs.add(crate_name)
186
+
187
+ return CodeData(
188
+ valid=True,
189
+ std_libs=std_libs,
190
+ ext_libs=ext_libs,
191
+ imports=imports,
192
+ lib_usage={},
193
+ )
@@ -5,7 +5,9 @@ from llm_cgr.llm.clients import (
5
5
  GenerationProtocol,
6
6
  Mistral_LLM,
7
7
  OpenAI_LLM,
8
+ OpenAI_Tool_LLM,
8
9
  TogetherAI_LLM,
10
+ Tool,
9
11
  get_llm,
10
12
  )
11
13
  from llm_cgr.llm.generate import generate, generate_bool, generate_list
@@ -24,6 +26,8 @@ __all__ = [
24
26
  "GenerationProtocol",
25
27
  "Mistral_LLM",
26
28
  "OpenAI_LLM",
29
+ "OpenAI_Tool_LLM",
30
+ "Tool",
27
31
  "TogetherAI_LLM",
28
32
  "get_llm",
29
33
  "generate",
@@ -6,6 +6,7 @@ from llm_cgr.llm.clients.deepseek import DeepSeek_LLM
6
6
  from llm_cgr.llm.clients.mistral import Mistral_LLM
7
7
  from llm_cgr.llm.clients.nscale import Nscale_LLM
8
8
  from llm_cgr.llm.clients.openai import OpenAI_LLM
9
+ from llm_cgr.llm.clients.openai_tool import OpenAI_Tool_LLM, Tool
9
10
  from llm_cgr.llm.clients.protocol import GenerationProtocol
10
11
  from llm_cgr.llm.clients.together import TogetherAI_LLM
11
12
 
@@ -27,9 +28,13 @@ def get_llm(
27
28
  top_p: float | None = None,
28
29
  max_tokens: int | None = None,
29
30
  provider: str | None = None,
31
+ tools: list[Tool] | None = None,
30
32
  ) -> GenerationProtocol:
31
33
  """
32
34
  Initialise the correct LLM client for the given model.
35
+
36
+ If tools are provided, returns an OpenAI_Tool_LLM instance. Tool calls
37
+ are currently only supported for OpenAI models.
33
38
  """
34
39
  llm_class: type[Base_LLM]
35
40
  if provider is not None:
@@ -45,6 +50,21 @@ def get_llm(
45
50
  else:
46
51
  llm_class = TogetherAI_LLM
47
52
 
53
+ # if tools are requested, use the tool-enabled subclass (openai only for now)
54
+ if tools is not None:
55
+ if llm_class is not OpenAI_LLM:
56
+ raise NotImplementedError(
57
+ "Tool calls are only supported for OpenAI models."
58
+ )
59
+ return OpenAI_Tool_LLM(
60
+ tools=tools,
61
+ model=model,
62
+ system=system,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ max_tokens=max_tokens,
66
+ )
67
+
48
68
  return llm_class(
49
69
  model=model,
50
70
  system=system,
@@ -60,7 +80,9 @@ __all__ = [
60
80
  "DeepSeek_LLM",
61
81
  "GenerationProtocol",
62
82
  "OpenAI_LLM",
83
+ "OpenAI_Tool_LLM",
63
84
  "TogetherAI_LLM",
64
85
  "Mistral_LLM",
86
+ "Tool",
65
87
  "get_llm",
66
88
  ]
@@ -1,8 +1,9 @@
1
1
  """Classes to access LLMs via the Anthropic Claude API."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, cast
4
4
 
5
5
  import anthropic
6
+ from anthropic.types import MessageParam, TextBlock
6
7
 
7
8
  from llm_cgr.defaults import DEFAULT_MAX_TOKENS
8
9
  from llm_cgr.llm.clients.base import Base_LLM
@@ -69,10 +70,11 @@ class Anthropic_LLM(Base_LLM):
69
70
  """Generate a model response from the Anthropic API."""
70
71
  response = self._client.messages.create(
71
72
  model=model,
72
- system=system or self._system or anthropic.NOT_GIVEN,
73
- messages=input,
74
- temperature=temperature if temperature is not None else anthropic.NOT_GIVEN,
75
- top_p=top_p if top_p is not None else anthropic.NOT_GIVEN,
73
+ system=system or self._system or anthropic.omit,
74
+ messages=cast(list[MessageParam], input),
75
+ temperature=temperature if temperature is not None else anthropic.omit,
76
+ top_p=top_p if top_p is not None else anthropic.omit,
76
77
  max_tokens=max_tokens if max_tokens is not None else DEFAULT_MAX_TOKENS,
77
78
  )
78
- return response.content[0].text
79
+ # cast to TextBlock as non-tool, non-thinking requests always return text
80
+ return cast(TextBlock, response.content[0]).text
@@ -1,9 +1,10 @@
1
1
  """Class to access LLMs via the OpenAI API."""
2
2
 
3
3
  import os
4
- from typing import Any
4
+ from typing import Any, cast
5
5
 
6
6
  import openai
7
+ from openai.types.chat import ChatCompletionMessageParam
7
8
 
8
9
  from llm_cgr.llm.clients.base import Base_LLM
9
10
 
@@ -67,10 +68,11 @@ class DeepSeek_LLM(Base_LLM):
67
68
  ) -> str:
68
69
  """Generate a model response from the OpenAI API."""
69
70
  response = self._client.chat.completions.create(
70
- messages=input,
71
+ messages=cast(list[ChatCompletionMessageParam], input),
71
72
  model=model,
72
73
  temperature=temperature if temperature is not None else openai.omit,
73
74
  top_p=top_p if top_p is not None else openai.omit,
74
75
  max_completion_tokens=max_tokens if max_tokens is not None else openai.omit,
75
76
  )
76
- return response.choices[0].message.content
77
+ # cast to str as text completions always return string content
78
+ return cast(str, response.choices[0].message.content)
@@ -1,9 +1,10 @@
1
1
  """Class to access LLMs via the OpenAI API."""
2
2
 
3
3
  import os
4
- from typing import Any
4
+ from typing import Any, cast
5
5
 
6
6
  import openai
7
+ from openai.types.chat import ChatCompletionMessageParam
7
8
 
8
9
  from llm_cgr.llm.clients.base import Base_LLM
9
10
 
@@ -67,10 +68,11 @@ class Nscale_LLM(Base_LLM):
67
68
  ) -> str:
68
69
  """Generate a model response from the OpenAI API."""
69
70
  response = self._client.chat.completions.create(
70
- messages=input,
71
+ messages=cast(list[ChatCompletionMessageParam], input),
71
72
  model=model,
72
73
  temperature=temperature if temperature is not None else openai.omit,
73
74
  top_p=top_p if top_p is not None else openai.omit,
74
75
  max_completion_tokens=max_tokens if max_tokens is not None else openai.omit,
75
76
  )
76
- return response.choices[0].message.content
77
+ # cast to str as text completions always return string content
78
+ return cast(str, response.choices[0].message.content)
@@ -0,0 +1,258 @@
1
+ """OpenAI client subclass with an agentic tool-call loop."""
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, cast
6
+
7
+ import openai
8
+ from openai.types.responses import ResponseFunctionToolCall, ResponseInputItemParam
9
+
10
+ from llm_cgr.llm.clients.openai import OpenAI_LLM
11
+
12
+
13
+ # maximum number of tool-call iterations per request, to prevent runaway loops
14
+ MAX_TOOL_ITERATIONS: int = 10
15
+
16
+
17
+ @dataclass
18
+ class Tool:
19
+ """
20
+ A tool (function) that the model can call during generation.
21
+
22
+ Attributes:
23
+ name: The function name the model uses to call this tool.
24
+ description: Describes what the tool does; the model uses this
25
+ to decide when to call it.
26
+ parameters: A JSON schema dict describing the function's parameters.
27
+ fn: The Python callable to invoke; must accept kwargs matching the
28
+ schema and return a str result.
29
+ """
30
+
31
+ name: str
32
+ description: str
33
+ parameters: dict[str, Any]
34
+ fn: Callable[..., str]
35
+
36
+
37
+ class OpenAI_Tool_LLM(OpenAI_LLM):
38
+ """OpenAI client with an agentic tool-call loop.
39
+
40
+ Tools are supplied at construction time and used for all subsequent
41
+ generate() and chat() calls. The client handles the full loop internally:
42
+ call the API, execute any tool calls, feed results back, repeat until the
43
+ model produces a final text response.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ tools: list[Tool],
49
+ model: str | None = None,
50
+ system: str | None = None,
51
+ temperature: float | None = None,
52
+ top_p: float | None = None,
53
+ max_tokens: int | None = None,
54
+ ) -> None:
55
+ """
56
+ Initialise the OpenAI tool client.
57
+
58
+ Requires the OPENAI_API_KEY environment variable to be set.
59
+ """
60
+ super().__init__(
61
+ model=model,
62
+ system=system,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ max_tokens=max_tokens,
66
+ )
67
+ self._tools = tools
68
+ # cumulative count of individual tool calls made by this instance
69
+ self._tool_calls: int = 0
70
+
71
+ @property
72
+ def tool_calls(self) -> int:
73
+ """Total number of tool calls made by this client since instantiation.
74
+
75
+ Returns the cumulative count across all generate() and chat() calls.
76
+ Tip: record the value before a call and subtract to get the count for
77
+ that specific call.
78
+ """
79
+ return self._tool_calls
80
+
81
+ def _build_tool_param(
82
+ self,
83
+ tool: Tool,
84
+ ) -> dict[str, Any]:
85
+ """Convert a Tool dataclass to the dict format the OpenAI Responses API expects."""
86
+ return {
87
+ "type": "function",
88
+ "name": tool.name,
89
+ "description": tool.description,
90
+ "parameters": tool.parameters,
91
+ }
92
+
93
+ def _run_tool_loop(
94
+ self,
95
+ messages: list[dict[str, Any]],
96
+ model: str,
97
+ temperature: float | None,
98
+ top_p: float | None,
99
+ max_tokens: int | None,
100
+ ) -> str:
101
+ """Run the agentic tool-call loop for a single turn.
102
+
103
+ Calls the OpenAI API in a loop, executing any tool calls the model
104
+ requests, until the model produces a final text response or the
105
+ MAX_TOOL_ITERATIONS safety limit is reached.
106
+
107
+ Returns the final text response.
108
+ """
109
+ # convert Tool dataclasses to the API's function-tool format
110
+ api_tools = [self._build_tool_param(t) for t in self._tools]
111
+
112
+ # build a name -> Tool lookup map for fast dispatch during the loop
113
+ tool_map = {t.name: t for t in self._tools}
114
+
115
+ # shallow copy so intermediate tool-call scaffolding never mutates the
116
+ # caller's message list (prevents corruption of the chat history).
117
+ # typed as list[Any] so we can freely append both plain message dicts
118
+ # and the richer tool-call dicts without fighting the type checker.
119
+ current_input: list[Any] = list(messages)
120
+
121
+ for _ in range(MAX_TOOL_ITERATIONS):
122
+ response = self._client.responses.create(
123
+ input=cast(list[ResponseInputItemParam], current_input),
124
+ model=model,
125
+ temperature=temperature if temperature is not None else openai.omit,
126
+ top_p=top_p if top_p is not None else openai.omit,
127
+ max_output_tokens=max_tokens if max_tokens is not None else openai.omit,
128
+ tools=cast(Any, api_tools),
129
+ )
130
+
131
+ # collect any function calls the model requested in this response
132
+ function_calls = [
133
+ item for item in response.output if item.type == "function_call"
134
+ ]
135
+
136
+ # no tool calls means the model has produced its final text answer
137
+ if not function_calls:
138
+ return response.output_text
139
+
140
+ # increment the cumulative counter; parallel calls count individually
141
+ self._tool_calls += len(function_calls)
142
+
143
+ # process each tool call: the OpenAI Responses API requires that the
144
+ # function_call item appears in the next input before its matching
145
+ # function_call_output item
146
+ for _call in function_calls:
147
+ # cast to the concrete type so we can access .call_id/.name/.arguments
148
+ call = cast(ResponseFunctionToolCall, _call)
149
+
150
+ # append the function_call itself so the model sees what it called
151
+ current_input.append(
152
+ {
153
+ "type": "function_call",
154
+ "call_id": call.call_id,
155
+ "name": call.name,
156
+ "arguments": call.arguments,
157
+ }
158
+ )
159
+
160
+ # deserialise the model's json argument string and call the local fn
161
+ kwargs = json.loads(call.arguments)
162
+ result = tool_map[call.name].fn(**kwargs)
163
+
164
+ # append the result so the model can read it on the next turn
165
+ current_input.append(
166
+ {
167
+ "type": "function_call_output",
168
+ "call_id": call.call_id,
169
+ "output": result,
170
+ }
171
+ )
172
+
173
+ # loop continues: enriched input is sent back to the model
174
+
175
+ # safety fallback: return whatever text the model produced on the last turn
176
+ return response.output_text
177
+
178
+ def generate(
179
+ self,
180
+ user: str,
181
+ system: str | None = None,
182
+ model: str | None = None,
183
+ samples: int = 1,
184
+ temperature: float | None = None,
185
+ top_p: float | None = None,
186
+ max_tokens: int | None = None,
187
+ ) -> list[str]:
188
+ """Generate model responses via the agentic tool-call loop."""
189
+ _model = model or self._model
190
+ if _model is None:
191
+ raise ValueError("Model must be specified for LLM APIs.")
192
+
193
+ messages = self._build_input(
194
+ user=user,
195
+ system=system or self._system,
196
+ )
197
+
198
+ _generations = []
199
+ for _ in range(samples):
200
+ result = self._run_tool_loop(
201
+ messages=messages,
202
+ model=_model,
203
+ temperature=temperature or self._temperature,
204
+ top_p=top_p or self._top_p,
205
+ max_tokens=max_tokens or self._max_tokens,
206
+ )
207
+ _generations.append(result)
208
+
209
+ return _generations
210
+
211
+ def chat(
212
+ self,
213
+ user: str,
214
+ system: str | None = None,
215
+ model: str | None = None,
216
+ temperature: float | None = None,
217
+ top_p: float | None = None,
218
+ max_tokens: int | None = None,
219
+ ) -> str:
220
+ """Run a chat turn via the agentic tool-call loop.
221
+
222
+ Manages self._history identically to the base class — only the final
223
+ text response is appended, not intermediate tool-call scaffolding.
224
+ """
225
+ _model = model or self._model
226
+ if _model is None:
227
+ raise ValueError("Model must be specified for LLM APIs.")
228
+
229
+ if self._history is None:
230
+ self._history = self._build_input(
231
+ user=user,
232
+ system=system or self._system,
233
+ )
234
+ else:
235
+ self._history.append(
236
+ self._build_message(
237
+ role="user",
238
+ content=user,
239
+ )
240
+ )
241
+
242
+ # _run_tool_loop operates on a shallow copy of self._history, so
243
+ # intermediate tool-call items never appear in the chat history
244
+ response = self._run_tool_loop(
245
+ messages=self._history,
246
+ model=_model,
247
+ temperature=temperature or self._temperature,
248
+ top_p=top_p or self._top_p,
249
+ max_tokens=max_tokens or self._max_tokens,
250
+ )
251
+
252
+ self._history.append(
253
+ self._build_message(
254
+ role="assistant",
255
+ content=response,
256
+ )
257
+ )
258
+ return response
@@ -1,6 +1,6 @@
1
1
  """Class to access LLMs via the TogetherAI API."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, cast
4
4
 
5
5
  import together
6
6
 
@@ -64,9 +64,12 @@ class TogetherAI_LLM(Base_LLM):
64
64
  """Generate a model response from the TogetherAI API."""
65
65
  response = self._client.chat.completions.create(
66
66
  model=model,
67
- messages=input,
68
- temperature=temperature,
69
- top_p=top_p,
70
- max_tokens=max_tokens,
67
+ messages=cast(Any, input),
68
+ temperature=temperature if temperature is not None else together.omit,
69
+ top_p=top_p if top_p is not None else together.omit,
70
+ max_tokens=max_tokens if max_tokens is not None else together.omit,
71
71
  )
72
- return response.choices[0].message.content
72
+ # cast to Any first as together doesn't publicly export the message type,
73
+ # then cast content to str as text completions always have it set
74
+ message = cast(Any, response.choices[0].message)
75
+ return cast(str, message.content)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llm-codegen-research
3
- Version: 2.11
3
+ Version: 2.13
4
4
  Summary: Useful classes and methods for researching code-generation by LLMs.
5
5
  Author-email: Lukas Twist <itsluketwist@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/itsluketwist/llm-codegen-research
@@ -15,6 +15,7 @@ src/llm_cgr/analyse/languages/__init__.py
15
15
  src/llm_cgr/analyse/languages/code_data.py
16
16
  src/llm_cgr/analyse/languages/javascript.py
17
17
  src/llm_cgr/analyse/languages/python.py
18
+ src/llm_cgr/analyse/languages/rust.py
18
19
  src/llm_cgr/llm/__init__.py
19
20
  src/llm_cgr/llm/generate.py
20
21
  src/llm_cgr/llm/prompts.py
@@ -25,6 +26,7 @@ src/llm_cgr/llm/clients/deepseek.py
25
26
  src/llm_cgr/llm/clients/mistral.py
26
27
  src/llm_cgr/llm/clients/nscale.py
27
28
  src/llm_cgr/llm/clients/openai.py
29
+ src/llm_cgr/llm/clients/openai_tool.py
28
30
  src/llm_cgr/llm/clients/protocol.py
29
31
  src/llm_cgr/llm/clients/together.py
30
32
  src/llm_cgr/scripts/test_cuda.py
@@ -38,4 +40,5 @@ tests/test_enums.py
38
40
  tests/test_json_utils.py
39
41
  tests/test_llm_api.py
40
42
  tests/test_llm_local.py
43
+ tests/test_llm_tool.py
41
44
  tests/test_utils.py
@@ -1,4 +1,5 @@
1
1
  from enum import auto
2
+ from typing import Any
2
3
 
3
4
  from llm_cgr import OptionsEnum
4
5
 
@@ -26,8 +27,8 @@ def test_options_enum():
26
27
  assert (TestEnum.ONE != "One") is False
27
28
  assert (TestEnum.ONE != "ONE") is False
28
29
 
29
- # check that the enum can be hashed
30
- test_dict = {TestEnum.TWO: "value"}
30
+ # check that the enum can be hashed and string keys match enum keys at runtime
31
+ test_dict: dict[Any, str] = {TestEnum.TWO: "value"}
31
32
  assert test_dict[TestEnum.TWO] == "value"
32
33
  assert test_dict["two"] == "value"
33
34
  assert {TestEnum.THREE} == {"three"}
@@ -1,8 +1,14 @@
1
1
  """Test our connection and usage of the LLM APIs."""
2
2
 
3
+ import pytest
4
+
3
5
  from llm_cgr import BASE_SYSTEM_PROMPT, generate, generate_bool, generate_list, get_llm
4
6
 
5
7
 
8
+ # mark all tests in this file as api tests, so they can be excluded in ci
9
+ pytestmark = pytest.mark.api
10
+
11
+
6
12
  def test_generate(model):
7
13
  """
8
14
  Test the generate method.
@@ -10,8 +10,10 @@ from llm_cgr import (
10
10
  Mistral_LLM,
11
11
  OpenAI_LLM,
12
12
  TogetherAI_LLM,
13
+ Tool,
13
14
  generate_bool,
14
15
  generate_list,
16
+ get_llm,
15
17
  )
16
18
 
17
19
 
@@ -129,6 +131,30 @@ def test_build_input():
129
131
  ]
130
132
 
131
133
 
134
+ def test_tools_unsupported_provider():
135
+ """
136
+ Test that passing tools to a non-OpenAI provider raises NotImplementedError.
137
+
138
+ No API call is made because the error fires inside get_llm() before any
139
+ network request.
140
+ """
141
+ dummy_tool = Tool(
142
+ name="dummy",
143
+ description="A dummy tool.",
144
+ parameters={"type": "object", "properties": {}},
145
+ fn=lambda: "result",
146
+ )
147
+
148
+ # anthropic is not yet supported for tool calls
149
+ with pytest.raises(
150
+ NotImplementedError, match="Tool calls are only supported for OpenAI models."
151
+ ):
152
+ get_llm(
153
+ model="claude-3-5-haiku-20241022",
154
+ tools=[dummy_tool],
155
+ )
156
+
157
+
132
158
  @pytest.mark.parametrize(
133
159
  "response,error",
134
160
  [
@@ -0,0 +1,96 @@
1
+ """Test the OpenAI_Tool_LLM agentic tool-call loop."""
2
+
3
+ import pytest
4
+
5
+ from llm_cgr import OpenAI_Tool_LLM, Tool
6
+
7
+
8
+ # mark all tests in this file as api tests, so they can be excluded in ci
9
+ pytestmark = pytest.mark.api
10
+
11
+
12
+ def test_tool_call_generate(openai_model):
13
+ """
14
+ Test that the OpenAI tool client runs the agentic loop and returns the
15
+ correct answer via the tool.
16
+
17
+ Uses an addition tool: the model must call it to get the answer, so we can
18
+ verify a real tool call happened (the model cannot guess what our local
19
+ function returns without calling it).
20
+ """
21
+
22
+ def add(a: int, b: int) -> str:
23
+ """Add two integers and return the result as a string."""
24
+ return str(a + b)
25
+
26
+ add_tool = Tool(
27
+ name="add",
28
+ description="Add two integers together and return the result.",
29
+ parameters={
30
+ "type": "object",
31
+ "properties": {
32
+ "a": {"type": "integer", "description": "The first integer."},
33
+ "b": {"type": "integer", "description": "The second integer."},
34
+ },
35
+ "required": ["a", "b"],
36
+ "additionalProperties": False,
37
+ },
38
+ fn=add,
39
+ )
40
+
41
+ llm = OpenAI_Tool_LLM(tools=[add_tool], model=openai_model)
42
+ responses = llm.generate(
43
+ user="Use the add tool to compute 3 + 4. What is the result?"
44
+ )
45
+
46
+ assert isinstance(responses, list)
47
+ assert len(responses) == 1
48
+
49
+ result = responses[0]
50
+ assert isinstance(result, str)
51
+ assert len(result) > 0
52
+ # the correct sum proves the tool was actually called
53
+ assert "7" in result
54
+ assert llm.tool_calls >= 1
55
+
56
+
57
+ def test_tool_call_chat(openai_model):
58
+ """
59
+ Test that tool calls work correctly in a chat session, and that
60
+ intermediate tool-call scaffolding does not corrupt the chat history.
61
+ """
62
+
63
+ def multiply(a: int, b: int) -> str:
64
+ """Multiply two integers and return the result as a string."""
65
+ return str(a * b)
66
+
67
+ multiply_tool = Tool(
68
+ name="multiply",
69
+ description="Multiply two integers together and return the result.",
70
+ parameters={
71
+ "type": "object",
72
+ "properties": {
73
+ "a": {"type": "integer", "description": "The first integer."},
74
+ "b": {"type": "integer", "description": "The second integer."},
75
+ },
76
+ "required": ["a", "b"],
77
+ "additionalProperties": False,
78
+ },
79
+ fn=multiply,
80
+ )
81
+
82
+ llm = OpenAI_Tool_LLM(tools=[multiply_tool], model=openai_model)
83
+ response = llm.chat(
84
+ user="Use the multiply tool to compute 6 * 7. What is the result?"
85
+ )
86
+
87
+ assert isinstance(response, str)
88
+ assert "42" in response
89
+
90
+ # history should only contain the user turn and the final assistant response;
91
+ # no intermediate function_call or function_call_output items should have leaked in
92
+ history = llm.history
93
+ assert len(history) == 2
94
+ assert history[0]["role"] == "user"
95
+ assert history[1]["role"] == "assistant"
96
+ assert llm.tool_calls >= 1