langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base for OpenAI compatible models (including OpenAI)."""
15
+
16
+ from typing import Annotated, Any
17
+
18
+ import langfun.core as lf
19
+ from langfun.core import modalities as lf_modalities
20
+ from langfun.core.llms import rest
21
+ import pyglove as pg
22
+
23
+
24
+ @lf.use_init_args(['api_endpoint', 'model'])
25
+ class OpenAICompatible(rest.REST):
26
+ """Base for OpenAI compatible models."""
27
+
28
+ model: Annotated[
29
+ str, 'The name of the model to use.',
30
+ ] = ''
31
+
32
+ multimodal: Annotated[
33
+ bool, 'Whether this model has multimodal support.'
34
+ ] = False
35
+
36
+ @property
37
+ def headers(self) -> dict[str, Any]:
38
+ return {
39
+ 'Content-Type': 'application/json'
40
+ }
41
+
42
+ def _request_args(
43
+ self, options: lf.LMSamplingOptions) -> dict[str, Any]:
44
+ """Returns a dict as request arguments."""
45
+ # Reference:
46
+ # https://platform.openai.com/docs/api-reference/completions/create
47
+ # NOTE(daiyip): options.top_k is not applicable.
48
+ args = dict(
49
+ n=options.n,
50
+ top_logprobs=options.top_logprobs,
51
+ )
52
+ if self.model:
53
+ args['model'] = self.model
54
+ if options.logprobs:
55
+ args['logprobs'] = options.logprobs
56
+ if options.temperature is not None:
57
+ args['temperature'] = options.temperature
58
+ if options.max_tokens is not None:
59
+ args['max_completion_tokens'] = options.max_tokens
60
+ if options.top_p is not None:
61
+ args['top_p'] = options.top_p
62
+ if options.stop:
63
+ args['stop'] = options.stop
64
+ if options.random_seed is not None:
65
+ args['seed'] = options.random_seed
66
+ return args
67
+
68
+ def _content_from_message(self, message: lf.Message) -> list[dict[str, Any]]:
69
+ """Returns a OpenAI content object from a Langfun message."""
70
+ content = []
71
+ for chunk in message.chunk():
72
+ if isinstance(chunk, str):
73
+ item = dict(type='text', text=chunk)
74
+ elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
75
+ item = dict(type='image_url', image_url=dict(url=chunk.embeddable_uri))
76
+ else:
77
+ raise ValueError(f'Unsupported modality: {chunk!r}.')
78
+ content.append(item)
79
+ return content
80
+
81
+ def request(
82
+ self,
83
+ prompt: lf.Message,
84
+ sampling_options: lf.LMSamplingOptions
85
+ ) -> dict[str, Any]:
86
+ """Returns the JSON input for a message."""
87
+ request_args = self._request_args(sampling_options)
88
+
89
+ # Users could use `metadata_json_schema` to pass additional
90
+ # request arguments.
91
+ json_schema = prompt.metadata.get('json_schema')
92
+ if json_schema is not None:
93
+ if not isinstance(json_schema, dict):
94
+ raise ValueError(
95
+ f'`json_schema` must be a dict, got {json_schema!r}.'
96
+ )
97
+ if 'title' not in json_schema:
98
+ raise ValueError(
99
+ f'The root of `json_schema` must have a `title` field, '
100
+ f'got {json_schema!r}.'
101
+ )
102
+ request_args.update(
103
+ response_format=dict(
104
+ type='json_schema',
105
+ json_schema=dict(
106
+ schema=json_schema,
107
+ name=json_schema['title'],
108
+ strict=True,
109
+ )
110
+ )
111
+ )
112
+ prompt.metadata.formatted_text = (
113
+ prompt.text
114
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
115
+ + pg.to_json_str(request_args['response_format'], json_indent=2)
116
+ )
117
+
118
+ # Prepare messages.
119
+ messages = []
120
+ # Users could use `metadata_system_message` to pass system message.
121
+ system_message = prompt.metadata.get('system_message')
122
+ if system_message:
123
+ system_message = lf.SystemMessage.from_value(system_message)
124
+ messages.append(
125
+ dict(role='system',
126
+ content=self._content_from_message(system_message))
127
+ )
128
+ messages.append(
129
+ dict(role='user', content=self._content_from_message(prompt))
130
+ )
131
+ request = dict()
132
+ request.update(request_args)
133
+ request['messages'] = messages
134
+ return request
135
+
136
+ def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
137
+ # Reference:
138
+ # https://platform.openai.com/docs/api-reference/chat/object
139
+ logprobs = None
140
+ choice_logprobs = choice.get('logprobs')
141
+ if choice_logprobs:
142
+ logprobs = [
143
+ (
144
+ t['token'],
145
+ t['logprob'],
146
+ [(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
147
+ )
148
+ for t in choice_logprobs['content']
149
+ ]
150
+ return lf.LMSample(
151
+ choice['message']['content'],
152
+ score=0.0,
153
+ logprobs=logprobs,
154
+ )
155
+
156
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
157
+ """Returns a LMSamplingResult from a JSON response."""
158
+ usage = json['usage']
159
+ return lf.LMSamplingResult(
160
+ samples=[self._parse_choice(choice) for choice in json['choices']],
161
+ usage=lf.LMSamplingUsage(
162
+ prompt_tokens=usage['prompt_tokens'],
163
+ completion_tokens=usage['completion_tokens'],
164
+ total_tokens=usage['total_tokens'],
165
+ estimated_cost=self.estimate_cost(
166
+ num_input_tokens=usage['prompt_tokens'],
167
+ num_output_tokens=usage['completion_tokens'],
168
+ )
169
+ ),
170
+ )
171
+
172
+ def estimate_cost(
173
+ self,
174
+ num_input_tokens: int,
175
+ num_output_tokens: int
176
+ ) -> float | None:
177
+ """Estimate the cost based on usage."""
178
+ del num_input_tokens, num_output_tokens
179
+ return None
@@ -0,0 +1,495 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for OpenAI models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import openai_compatible
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+ messages = json['messages']
30
+ if len(messages) > 1:
31
+ system_message = f' system={messages[0]["content"]}'
32
+ else:
33
+ system_message = ''
34
+
35
+ if 'response_format' in json:
36
+ response_format = f' format={json["response_format"]["type"]}'
37
+ else:
38
+ response_format = ''
39
+
40
+ choices = []
41
+ for k in range(json['n']):
42
+ if json.get('logprobs'):
43
+ logprobs = dict(
44
+ content=[
45
+ dict(
46
+ token='chosen_token',
47
+ logprob=0.5,
48
+ top_logprobs=[
49
+ dict(
50
+ token=f'alternative_token_{i + 1}',
51
+ logprob=0.1
52
+ ) for i in range(3)
53
+ ]
54
+ )
55
+ ]
56
+ )
57
+ else:
58
+ logprobs = None
59
+
60
+ choices.append(dict(
61
+ message=dict(
62
+ content=(
63
+ f'Sample {k} for message.{system_message}{response_format}'
64
+ )
65
+ ),
66
+ logprobs=logprobs,
67
+ ))
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str(
71
+ dict(
72
+ choices=choices,
73
+ usage=lf.LMSamplingUsage(
74
+ prompt_tokens=100,
75
+ completion_tokens=100,
76
+ total_tokens=200,
77
+ ),
78
+ )
79
+ ).encode()
80
+ return response
81
+
82
+
83
+ def mock_chat_completion_request_vision(
84
+ url: str, json: dict[str, Any], **kwargs
85
+ ):
86
+ del url, kwargs
87
+ choices = []
88
+ urls = [
89
+ c['image_url']['url']
90
+ for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
+ ]
92
+ for k in range(json['n']):
93
+ choices.append(pg.Dict(
94
+ message=pg.Dict(
95
+ content=f'Sample {k} for message: {"".join(urls)}'
96
+ ),
97
+ logprobs=None,
98
+ ))
99
+ response = requests.Response()
100
+ response.status_code = 200
101
+ response._content = pg.to_json_str(
102
+ dict(
103
+ choices=choices,
104
+ usage=lf.LMSamplingUsage(
105
+ prompt_tokens=100,
106
+ completion_tokens=100,
107
+ total_tokens=200,
108
+ ),
109
+ )
110
+ ).encode()
111
+ return response
112
+
113
+
114
+ class OpenAIComptibleTest(unittest.TestCase):
115
+ """Tests for OpenAI compatible language model."""
116
+
117
+ def test_request_args(self):
118
+ self.assertEqual(
119
+ openai_compatible.OpenAICompatible(
120
+ api_endpoint='https://test-server',
121
+ model='test-model'
122
+ )._request_args(
123
+ lf.LMSamplingOptions(
124
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
125
+ )
126
+ ),
127
+ dict(
128
+ model='test-model',
129
+ top_logprobs=None,
130
+ n=1,
131
+ temperature=1.0,
132
+ stop=['\n'],
133
+ seed=123,
134
+ ),
135
+ )
136
+
137
+ def test_call_chat_completion(self):
138
+ with mock.patch('requests.Session.post') as mock_request:
139
+ mock_request.side_effect = mock_chat_completion_request
140
+ lm = openai_compatible.OpenAICompatible(
141
+ api_endpoint='https://test-server', model='test-model',
142
+ )
143
+ self.assertEqual(
144
+ lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
145
+ 'Sample 0 for message.',
146
+ )
147
+
148
+ def test_call_chat_completion_with_logprobs(self):
149
+ with mock.patch('requests.Session.post') as mock_request:
150
+ mock_request.side_effect = mock_chat_completion_request
151
+ lm = openai_compatible.OpenAICompatible(
152
+ api_endpoint='https://test-server', model='test-model',
153
+ )
154
+ results = lm.sample(['hello'], logprobs=True)
155
+ for result in results:
156
+ result.rebind(
157
+ {path: 0 for path in pg.query(result, '.*total_call_interval')},
158
+ skip_notification=True,
159
+ )
160
+ self.assertEqual(len(results), 1)
161
+ self.assertEqual(
162
+ results[0],
163
+ lf.LMSamplingResult(
164
+ [
165
+ lf.LMSample(
166
+ response=lf.AIMessage(
167
+ text='Sample 0 for message.',
168
+ metadata={
169
+ 'score': 0.0,
170
+ 'logprobs': [(
171
+ 'chosen_token',
172
+ 0.5,
173
+ [
174
+ ('alternative_token_1', 0.1),
175
+ ('alternative_token_2', 0.1),
176
+ ('alternative_token_3', 0.1),
177
+ ],
178
+ )],
179
+ 'is_cached': False,
180
+ 'usage': lf.LMSamplingUsage(
181
+ prompt_tokens=100,
182
+ completion_tokens=100,
183
+ total_tokens=200,
184
+ estimated_cost=None,
185
+ ),
186
+ },
187
+ tags=['lm-response'],
188
+ ),
189
+ logprobs=[(
190
+ 'chosen_token',
191
+ 0.5,
192
+ [
193
+ ('alternative_token_1', 0.1),
194
+ ('alternative_token_2', 0.1),
195
+ ('alternative_token_3', 0.1),
196
+ ],
197
+ )],
198
+ )
199
+ ],
200
+ usage=lf.LMSamplingUsage(
201
+ prompt_tokens=100,
202
+ completion_tokens=100,
203
+ total_tokens=200,
204
+ estimated_cost=None,
205
+ ),
206
+ ),
207
+ )
208
+
209
+ def test_call_chat_completion_vision(self):
210
+ with mock.patch('requests.Session.post') as mock_request:
211
+ mock_request.side_effect = mock_chat_completion_request_vision
212
+ lm_1 = openai_compatible.OpenAICompatible(
213
+ api_endpoint='https://test-server',
214
+ model='test-model1',
215
+ multimodal=True
216
+ )
217
+ lm_2 = openai_compatible.OpenAICompatible(
218
+ api_endpoint='https://test-server',
219
+ model='test-model2',
220
+ multimodal=True
221
+ )
222
+ for lm in (lm_1, lm_2):
223
+ self.assertEqual(
224
+ lm(
225
+ lf.UserMessage(
226
+ 'hello <<[[image]]>>',
227
+ image=lf_modalities.Image.from_uri('https://fake/image')
228
+ ),
229
+ sampling_options=lf.LMSamplingOptions(n=2)
230
+ ),
231
+ 'Sample 0 for message: https://fake/image',
232
+ )
233
+ lm_3 = openai_compatible.OpenAICompatible(
234
+ api_endpoint='https://test-server', model='test-model3'
235
+ )
236
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
237
+ lm_3(
238
+ lf.UserMessage(
239
+ 'hello <<[[image]]>>',
240
+ image=lf_modalities.Image.from_uri('https://fake/image')
241
+ ),
242
+ )
243
+
244
+ def test_sample_chat_completion(self):
245
+ with mock.patch('requests.Session.post') as mock_request:
246
+ mock_request.side_effect = mock_chat_completion_request
247
+ lm = openai_compatible.OpenAICompatible(
248
+ api_endpoint='https://test-server', model='test-model'
249
+ )
250
+ results = lm.sample(
251
+ ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
252
+ )
253
+ for result in results:
254
+ result.rebind(
255
+ {path: 0 for path in pg.query(result, '.*total_call_interval')},
256
+ skip_notification=True,
257
+ )
258
+
259
+ self.assertEqual(len(results), 2)
260
+ self.assertEqual(
261
+ results[0],
262
+ lf.LMSamplingResult(
263
+ [
264
+ lf.LMSample(
265
+ lf.AIMessage(
266
+ 'Sample 0 for message.',
267
+ score=0.0,
268
+ logprobs=None,
269
+ is_cached=False,
270
+ usage=lf.LMSamplingUsage(
271
+ prompt_tokens=33,
272
+ completion_tokens=33,
273
+ total_tokens=66,
274
+ estimated_cost=None,
275
+ ),
276
+ tags=[lf.Message.TAG_LM_RESPONSE],
277
+ ),
278
+ score=0.0,
279
+ logprobs=None,
280
+ ),
281
+ lf.LMSample(
282
+ lf.AIMessage(
283
+ 'Sample 1 for message.',
284
+ score=0.0,
285
+ logprobs=None,
286
+ is_cached=False,
287
+ usage=lf.LMSamplingUsage(
288
+ prompt_tokens=33,
289
+ completion_tokens=33,
290
+ total_tokens=66,
291
+ estimated_cost=None,
292
+ ),
293
+ tags=[lf.Message.TAG_LM_RESPONSE],
294
+ ),
295
+ score=0.0,
296
+ logprobs=None,
297
+ ),
298
+ lf.LMSample(
299
+ lf.AIMessage(
300
+ 'Sample 2 for message.',
301
+ score=0.0,
302
+ logprobs=None,
303
+ is_cached=False,
304
+ usage=lf.LMSamplingUsage(
305
+ prompt_tokens=33,
306
+ completion_tokens=33,
307
+ total_tokens=66,
308
+ estimated_cost=None,
309
+ ),
310
+ tags=[lf.Message.TAG_LM_RESPONSE],
311
+ ),
312
+ score=0.0,
313
+ logprobs=None,
314
+ ),
315
+ ],
316
+ usage=lf.LMSamplingUsage(
317
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
318
+ estimated_cost=None,
319
+ ),
320
+ ),
321
+ )
322
+ self.assertEqual(
323
+ results[1],
324
+ lf.LMSamplingResult(
325
+ [
326
+ lf.LMSample(
327
+ lf.AIMessage(
328
+ 'Sample 0 for message.',
329
+ score=0.0,
330
+ logprobs=None,
331
+ is_cached=False,
332
+ usage=lf.LMSamplingUsage(
333
+ prompt_tokens=33,
334
+ completion_tokens=33,
335
+ total_tokens=66,
336
+ estimated_cost=None,
337
+ ),
338
+ tags=[lf.Message.TAG_LM_RESPONSE],
339
+ ),
340
+ score=0.0,
341
+ logprobs=None,
342
+ ),
343
+ lf.LMSample(
344
+ lf.AIMessage(
345
+ 'Sample 1 for message.',
346
+ score=0.0,
347
+ logprobs=None,
348
+ is_cached=False,
349
+ usage=lf.LMSamplingUsage(
350
+ prompt_tokens=33,
351
+ completion_tokens=33,
352
+ total_tokens=66,
353
+ estimated_cost=None,
354
+ ),
355
+ tags=[lf.Message.TAG_LM_RESPONSE],
356
+ ),
357
+ score=0.0,
358
+ logprobs=None,
359
+ ),
360
+ lf.LMSample(
361
+ lf.AIMessage(
362
+ 'Sample 2 for message.',
363
+ score=0.0,
364
+ logprobs=None,
365
+ is_cached=False,
366
+ usage=lf.LMSamplingUsage(
367
+ prompt_tokens=33,
368
+ completion_tokens=33,
369
+ total_tokens=66,
370
+ estimated_cost=None,
371
+ ),
372
+ tags=[lf.Message.TAG_LM_RESPONSE],
373
+ ),
374
+ score=0.0,
375
+ logprobs=None,
376
+ ),
377
+ ],
378
+ usage=lf.LMSamplingUsage(
379
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
380
+ estimated_cost=None,
381
+ ),
382
+ ),
383
+ )
384
+
385
+ def test_sample_with_contextual_options(self):
386
+ with mock.patch('requests.Session.post') as mock_request:
387
+ mock_request.side_effect = mock_chat_completion_request
388
+ lm = openai_compatible.OpenAICompatible(
389
+ api_endpoint='https://test-server', model='test-model'
390
+ )
391
+ with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
392
+ results = lm.sample(['hello'])
393
+ for result in results:
394
+ result.rebind(
395
+ {path: 0 for path in pg.query(result, '.*total_call_interval')},
396
+ skip_notification=True,
397
+ )
398
+
399
+ self.assertEqual(len(results), 1)
400
+ self.assertEqual(
401
+ results[0],
402
+ lf.LMSamplingResult(
403
+ [
404
+ lf.LMSample(
405
+ lf.AIMessage(
406
+ 'Sample 0 for message.',
407
+ score=0.0,
408
+ logprobs=None,
409
+ is_cached=False,
410
+ usage=lf.LMSamplingUsage(
411
+ prompt_tokens=50,
412
+ completion_tokens=50,
413
+ total_tokens=100,
414
+ ),
415
+ tags=[lf.Message.TAG_LM_RESPONSE],
416
+ ),
417
+ score=0.0,
418
+ logprobs=None,
419
+ ),
420
+ lf.LMSample(
421
+ lf.AIMessage(
422
+ 'Sample 1 for message.',
423
+ score=0.0,
424
+ logprobs=None,
425
+ is_cached=False,
426
+ usage=lf.LMSamplingUsage(
427
+ prompt_tokens=50,
428
+ completion_tokens=50,
429
+ total_tokens=100,
430
+ ),
431
+ tags=[lf.Message.TAG_LM_RESPONSE],
432
+ ),
433
+ score=0.0,
434
+ logprobs=None,
435
+ ),
436
+ ],
437
+ usage=lf.LMSamplingUsage(
438
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
439
+ ),
440
+ )
441
+ )
442
+
443
+ def test_call_with_system_message(self):
444
+ with mock.patch('requests.Session.post') as mock_request:
445
+ mock_request.side_effect = mock_chat_completion_request
446
+ lm = openai_compatible.OpenAICompatible(
447
+ api_endpoint='https://test-server', model='test-model'
448
+ )
449
+ self.assertEqual(
450
+ lm(
451
+ lf.UserMessage(
452
+ 'hello',
453
+ system_message='hi',
454
+ ),
455
+ sampling_options=lf.LMSamplingOptions(n=2)
456
+ ),
457
+ '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
458
+ )
459
+
460
+ def test_call_with_json_schema(self):
461
+ with mock.patch('requests.Session.post') as mock_request:
462
+ mock_request.side_effect = mock_chat_completion_request
463
+ lm = openai_compatible.OpenAICompatible(
464
+ api_endpoint='https://test-server', model='test-model'
465
+ )
466
+ self.assertEqual(
467
+ lm(
468
+ lf.UserMessage(
469
+ 'hello',
470
+ json_schema={
471
+ 'type': 'object',
472
+ 'properties': {
473
+ 'name': {'type': 'string'},
474
+ },
475
+ 'required': ['name'],
476
+ 'title': 'Person',
477
+ }
478
+ ),
479
+ sampling_options=lf.LMSamplingOptions(n=2)
480
+ ),
481
+ 'Sample 0 for message. format=json_schema',
482
+ )
483
+
484
+ # Test bad json schema.
485
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
486
+ lm(lf.UserMessage('hello', json_schema='foo'))
487
+
488
+ with self.assertRaisesRegex(
489
+ ValueError, 'The root of `json_schema` must have a `title` field'
490
+ ):
491
+ lm(lf.UserMessage('hello', json_schema={}))
492
+
493
+
494
+ if __name__ == '__main__':
495
+ unittest.main()