langchain-githubcopilot-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,923 @@
1
+ """GitHub Copilot Chat model integration via GitHub Models inference API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import (
8
+ Any,
9
+ AsyncIterator,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Sequence,
16
+ Type,
17
+ Union,
18
+ )
19
+
20
+ import httpx
21
+ from langchain_core.callbacks import (
22
+ AsyncCallbackManagerForLLMRun,
23
+ CallbackManagerForLLMRun,
24
+ )
25
+ from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models.base import LangSmithParams
27
+ from langchain_core.messages import (
28
+ AIMessage,
29
+ AIMessageChunk,
30
+ BaseMessage,
31
+ ChatMessage,
32
+ HumanMessage,
33
+ SystemMessage,
34
+ ToolMessage,
35
+ )
36
+ from langchain_core.messages.ai import UsageMetadata
37
+ from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
38
+ from langchain_core.output_parsers.openai_tools import (
39
+ make_invalid_tool_call,
40
+ parse_tool_call,
41
+ )
42
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
43
+ from langchain_core.tools import BaseTool
44
+ from langchain_core.utils.function_calling import convert_to_openai_tool
45
+ from pydantic import Field, SecretStr, model_validator
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Helpers
49
+ # ---------------------------------------------------------------------------
50
+
51
+ _ROLE_MAP = {
52
+ "human": "user",
53
+ "ai": "assistant",
54
+ "system": "system",
55
+ "developer": "developer",
56
+ "tool": "tool",
57
+ }
58
+
59
+ _GITHUB_MODELS_BASE_URL = "https://models.github.ai"
60
+ _INFERENCE_PATH = "/inference/chat/completions"
61
+ _ORG_INFERENCE_PATH = "/orgs/{org}/inference/chat/completions"
62
+ _API_VERSION = "2026-03-10"
63
+
64
+
65
+ def _message_to_dict(message: BaseMessage) -> Dict[str, Any]:
66
+ """Convert a LangChain message to the GitHub Models API message format."""
67
+ if isinstance(message, SystemMessage):
68
+ return {"role": "system", "content": message.content}
69
+ elif isinstance(message, HumanMessage):
70
+ # Support multimodal content (list of content blocks)
71
+ if isinstance(message.content, list):
72
+ parts = []
73
+ for block in message.content:
74
+ if isinstance(block, dict):
75
+ btype = block.get("type", "")
76
+ if btype == "text":
77
+ parts.append({"type": "text", "text": block["text"]})
78
+ elif btype == "image_url":
79
+ parts.append(
80
+ {
81
+ "type": "image_url",
82
+ "image_url": block.get("image_url", {}),
83
+ }
84
+ )
85
+ else:
86
+ parts.append(block)
87
+ else:
88
+ parts.append({"type": "text", "text": str(block)})
89
+ return {"role": "user", "content": parts}
90
+ return {"role": "user", "content": message.content}
91
+ elif isinstance(message, AIMessage):
92
+ msg: Dict[str, Any] = {"role": "assistant", "content": message.content or ""}
93
+ # Attach tool calls if present
94
+ if message.tool_calls:
95
+ msg["tool_calls"] = [
96
+ {
97
+ "id": tc["id"],
98
+ "type": "function",
99
+ "function": {
100
+ "name": tc["name"],
101
+ "arguments": json.dumps(tc["args"]),
102
+ },
103
+ }
104
+ for tc in message.tool_calls
105
+ ]
106
+ elif message.additional_kwargs.get("tool_calls"):
107
+ msg["tool_calls"] = message.additional_kwargs["tool_calls"]
108
+ return msg
109
+ elif isinstance(message, ToolMessage):
110
+ return {
111
+ "role": "tool",
112
+ "tool_call_id": message.tool_call_id,
113
+ "content": message.content,
114
+ }
115
+ elif isinstance(message, ChatMessage):
116
+ role = _ROLE_MAP.get(message.role, message.role)
117
+ return {"role": role, "content": message.content}
118
+ else:
119
+ # Fallback: treat as user message
120
+ return {"role": "user", "content": str(message.content)}
121
+
122
+
123
+ def _format_tools_for_api(
124
+ tools: Sequence[Union[Dict[str, Any], BaseTool, Type]],
125
+ ) -> List[Dict[str, Any]]:
126
+ """Convert LangChain tools into the OpenAI-compatible format
127
+ expected by GitHub Models.
128
+ """
129
+ formatted = []
130
+ for tool in tools:
131
+ if isinstance(tool, dict) and tool.get("type") == "function":
132
+ formatted.append(tool)
133
+ else:
134
+ oai_tool = convert_to_openai_tool(tool) # type: ignore[arg-type]
135
+ formatted.append(oai_tool)
136
+ return formatted
137
+
138
+
139
+ def _parse_tool_calls(
140
+ raw_tool_calls: List[Dict[str, Any]],
141
+ ) -> List[Dict[str, Any]]:
142
+ """Parse raw API tool_calls into LangChain tool_calls format."""
143
+ tool_calls: List[Dict[str, Any]] = []
144
+ for raw in raw_tool_calls:
145
+ try:
146
+ parsed = parse_tool_call(raw, return_id=True)
147
+ if parsed is not None:
148
+ tool_calls.append(parsed)
149
+ except Exception as exc:
150
+ invalid = make_invalid_tool_call(raw, str(exc))
151
+ tool_calls.append(dict(invalid))
152
+ return tool_calls
153
+
154
+
155
+ # GitHub Models API only accepts "auto", "required", or "none" for tool_choice.
156
+ # LangChain internally uses "any" (equivalent to "required") and dict-style
157
+ # {"type": "function", "function": {"name": "..."}} for specific tool forcing.
158
+ _TOOL_CHOICE_MAP: Dict[str, str] = {
159
+ "any": "required",
160
+ }
161
+
162
+
163
+ def _normalize_tool_choice(
164
+ tool_choice: Any,
165
+ ) -> Union[str, Dict[str, Any]]:
166
+ """Normalise a tool_choice value for the GitHub Models API.
167
+
168
+ - ``"any"`` → ``"required"`` (LangChain internal alias)
169
+ - dict ``{"type": "function", "function": {"name": "X"}}`` → kept as-is
170
+ (the API accepts this form for forcing a specific function)
171
+ - any other string is passed through unchanged
172
+ """
173
+ if isinstance(tool_choice, str):
174
+ return _TOOL_CHOICE_MAP.get(tool_choice, tool_choice)
175
+ # dict form — pass through unchanged
176
+ return tool_choice
177
+
178
+
179
+ def _build_ai_message(
180
+ choice: Dict[str, Any], usage: Optional[Dict[str, Any]]
181
+ ) -> AIMessage:
182
+ """Build an AIMessage from a single API response choice."""
183
+ msg = choice.get("message", {})
184
+ content: Union[str, List] = msg.get("content") or ""
185
+ finish_reason = choice.get("finish_reason", "")
186
+
187
+ additional_kwargs: Dict[str, Any] = {}
188
+ tool_calls = []
189
+ raw_tool_calls = msg.get("tool_calls", [])
190
+ if raw_tool_calls:
191
+ additional_kwargs["tool_calls"] = raw_tool_calls
192
+ tool_calls = _parse_tool_calls(raw_tool_calls)
193
+
194
+ usage_metadata: Optional[UsageMetadata] = None
195
+ if usage:
196
+ usage_metadata = UsageMetadata(
197
+ input_tokens=usage.get("prompt_tokens", 0),
198
+ output_tokens=usage.get("completion_tokens", 0),
199
+ total_tokens=usage.get("total_tokens", 0),
200
+ )
201
+
202
+ response_metadata: Dict[str, Any] = {
203
+ "finish_reason": finish_reason,
204
+ }
205
+ if usage:
206
+ response_metadata["usage"] = usage
207
+
208
+ return AIMessage(
209
+ content=content,
210
+ additional_kwargs=additional_kwargs,
211
+ tool_calls=tool_calls,
212
+ response_metadata=response_metadata,
213
+ usage_metadata=usage_metadata,
214
+ )
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # Main class
219
+ # ---------------------------------------------------------------------------
220
+
221
+
222
+ class ChatGithubCopilot(BaseChatModel):
223
+ """GitHub Copilot Chat model integration via the GitHub Models inference API.
224
+
225
+ GitHub Models provides access to top AI models (OpenAI GPT-4.1, DeepSeek,
226
+ Llama, and more) through a unified OpenAI-compatible REST API. This class
227
+ wraps that API so that every model available in the GitHub Models catalog
228
+ can be used as a drop-in LangChain ``BaseChatModel``.
229
+
230
+ Setup:
231
+ Install ``langchain-githubcopilot-chat`` and set the
232
+ ``GITHUB_TOKEN`` environment variable (a classic or fine-grained PAT
233
+ with the ``models: read`` scope, or a GitHub Copilot subscription token).
234
+
235
+ .. code-block:: bash
236
+
237
+ pip install -U langchain-githubcopilot-chat
238
+ export GITHUB_TOKEN="github_pat_..."
239
+
240
+ Key init args — completion params:
241
+ model: str
242
+ Model ID in the ``{publisher}/{model_name}`` format, e.g.
243
+ ``"openai/gpt-4.1"`` or ``"meta/llama-3.3-70b-instruct"``.
244
+ temperature: Optional[float]
245
+ Sampling temperature in ``[0, 1]``. Higher → more creative.
246
+ max_tokens: Optional[int]
247
+ Maximum number of tokens to generate.
248
+ top_p: Optional[float]
249
+ Nucleus sampling probability mass in ``[0, 1]``.
250
+ stop: Optional[List[str]]
251
+ Stop sequences.
252
+ frequency_penalty: Optional[float]
253
+ Frequency penalty in ``[-2, 2]``.
254
+ presence_penalty: Optional[float]
255
+ Presence penalty in ``[-2, 2]``.
256
+ seed: Optional[int]
257
+ Random seed for deterministic sampling (best-effort).
258
+
259
+ Key init args — client params:
260
+ github_token: Optional[SecretStr]
261
+ GitHub token. Falls back to ``GITHUB_TOKEN`` env var.
262
+ base_url: str
263
+ Base URL of the GitHub Models API.
264
+ Defaults to ``"https://models.github.ai"``.
265
+ org: Optional[str]
266
+ Organisation login. When set, every request is attributed to that
267
+ org (uses the ``/orgs/{org}/inference/chat/completions`` endpoint).
268
+ api_version: str
269
+ GitHub Models REST API version header value.
270
+ Defaults to ``"2026-03-10"``.
271
+ timeout: Optional[float]
272
+ HTTP request timeout in seconds.
273
+ max_retries: int
274
+ Number of automatic retries on transient errors (default ``2``).
275
+
276
+ Instantiate:
277
+ .. code-block:: python
278
+
279
+ from langchain_githubcopilot_chat import ChatGithubCopilot
280
+
281
+ llm = ChatGithubCopilot(
282
+ model="openai/gpt-4.1",
283
+ temperature=0,
284
+ max_tokens=1024,
285
+ # github_token="github_pat_...", # or set GITHUB_TOKEN env var
286
+ )
287
+
288
+ Invoke:
289
+ .. code-block:: python
290
+
291
+ messages = [
292
+ ("system", "You are a helpful translator. Translate to French."),
293
+ ("human", "I love programming."),
294
+ ]
295
+ ai_msg = llm.invoke(messages)
296
+ print(ai_msg.content)
297
+ # "J'adore la programmation."
298
+
299
+ Stream:
300
+ .. code-block:: python
301
+
302
+ for chunk in llm.stream(messages):
303
+ print(chunk.content, end="", flush=True)
304
+
305
+ Async:
306
+ .. code-block:: python
307
+
308
+ ai_msg = await llm.ainvoke(messages)
309
+
310
+ async for chunk in llm.astream(messages):
311
+ print(chunk.content, end="", flush=True)
312
+
313
+ Tool calling:
314
+ .. code-block:: python
315
+
316
+ from pydantic import BaseModel, Field
317
+
318
+ class GetWeather(BaseModel):
319
+ '''Get the current weather in a given location.'''
320
+ location: str = Field(
321
+ ..., description="City and state, e.g. Paris, France"
322
+ )
323
+
324
+ llm_with_tools = llm.bind_tools([GetWeather])
325
+ ai_msg = llm_with_tools.invoke("What is the weather like in Paris?")
326
+ print(ai_msg.tool_calls)
327
+ # [{'name': 'GetWeather', 'args': {'location': 'Paris, France'},
328
+ # 'id': '...'}]
329
+
330
+ Structured output:
331
+ .. code-block:: python
332
+
333
+ from typing import Optional
334
+ from pydantic import BaseModel, Field
335
+
336
+ class Joke(BaseModel):
337
+ '''Joke to tell user.'''
338
+ setup: str = Field(description="The setup of the joke")
339
+ punchline: str = Field(description="The punchline to the joke")
340
+ rating: Optional[int] = Field(description="Funniness rating 1-10")
341
+
342
+ structured_llm = llm.with_structured_output(Joke)
343
+ structured_llm.invoke("Tell me a joke about cats")
344
+
345
+ JSON mode:
346
+ .. code-block:: python
347
+
348
+ json_llm = llm.bind(response_format={"type": "json_object"})
349
+ ai_msg = json_llm.invoke(
350
+ "Return a JSON object with key 'numbers' and a list of 5 random ints."
351
+ )
352
+ print(ai_msg.content)
353
+
354
+ Image input:
355
+ .. code-block:: python
356
+
357
+ import base64, httpx
358
+ from langchain_core.messages import HumanMessage
359
+
360
+ image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
361
+ image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
362
+ message = HumanMessage(
363
+ content=[
364
+ {"type": "text", "text": "Describe the weather in this image."},
365
+ {
366
+ "type": "image_url",
367
+ "image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
368
+ },
369
+ ]
370
+ )
371
+ ai_msg = llm.invoke([message])
372
+ print(ai_msg.content)
373
+
374
+ Token usage:
375
+ .. code-block:: python
376
+
377
+ ai_msg = llm.invoke(messages)
378
+ print(ai_msg.usage_metadata)
379
+ # {'input_tokens': 28, 'output_tokens': 18, 'total_tokens': 46}
380
+
381
+ Response metadata:
382
+ .. code-block:: python
383
+
384
+ ai_msg = llm.invoke(messages)
385
+ print(ai_msg.response_metadata)
386
+ # {'finish_reason': 'stop', 'usage': {'prompt_tokens': 28, ...}}
387
+ """
388
+
389
+ # ------------------------------------------------------------------
390
+ # Fields
391
+ # ------------------------------------------------------------------
392
+
393
+ model_name: str = Field(alias="model")
394
+ """Model ID in the ``{publisher}/{model_name}`` format.
395
+
396
+ Examples: ``"openai/gpt-4.1"``, ``"meta/llama-3.3-70b-instruct"``.
397
+ """
398
+
399
+ github_token: Optional[SecretStr] = Field(default=None)
400
+ """GitHub token with ``models: read`` scope.
401
+
402
+ If not provided, the value of the ``GITHUB_TOKEN`` environment variable
403
+ is used.
404
+ """
405
+
406
+ base_url: str = _GITHUB_MODELS_BASE_URL
407
+ """Base URL for the GitHub Models REST API."""
408
+
409
+ org: Optional[str] = None
410
+ """Organisation login for attributed inference requests.
411
+
412
+ When set, requests are sent to
413
+ ``/orgs/{org}/inference/chat/completions`` instead of
414
+ ``/inference/chat/completions``.
415
+ """
416
+
417
+ api_version: str = _API_VERSION
418
+ """GitHub Models API version sent as the ``X-GitHub-Api-Version`` header."""
419
+
420
+ temperature: Optional[float] = None
421
+ """Sampling temperature in ``[0, 1]``."""
422
+
423
+ max_tokens: Optional[int] = None
424
+ """Maximum number of tokens to generate."""
425
+
426
+ top_p: Optional[float] = None
427
+ """Nucleus sampling probability mass in ``[0, 1]``."""
428
+
429
+ stop: Optional[List[str]] = None
430
+ """Stop sequences that terminate generation."""
431
+
432
+ frequency_penalty: Optional[float] = None
433
+ """Frequency penalty in ``[-2, 2]``."""
434
+
435
+ presence_penalty: Optional[float] = None
436
+ """Presence penalty in ``[-2, 2]``."""
437
+
438
+ seed: Optional[int] = None
439
+ """Random seed for (best-effort) deterministic sampling."""
440
+
441
+ timeout: Optional[float] = None
442
+ """HTTP request timeout in seconds."""
443
+
444
+ max_retries: int = 2
445
+ """Number of automatic retries on transient errors."""
446
+
447
+ # ------------------------------------------------------------------
448
+ # Validators / setup
449
+ # ------------------------------------------------------------------
450
+
451
+ @model_validator(mode="before")
452
+ @classmethod
453
+ def _validate_token(cls, values: Dict[str, Any]) -> Dict[str, Any]:
454
+ """Resolve the GitHub token from the environment if not supplied.
455
+
456
+ Priority order:
457
+ 1. Explicitly passed ``github_token``
458
+ 2. Explicitly passed ``api_key`` alias
459
+ 3. ``GITHUB_TOKEN`` environment variable
460
+ """
461
+ token = values.get("github_token") or values.get("api_key")
462
+ if not token:
463
+ token = os.environ.get("GITHUB_TOKEN")
464
+ if token:
465
+ values["github_token"] = token
466
+ return values
467
+
468
+ # ------------------------------------------------------------------
469
+ # Internal helpers
470
+ # ------------------------------------------------------------------
471
+
472
+ @property
473
+ def _token(self) -> str:
474
+ """Return the raw GitHub token string."""
475
+ if self.github_token:
476
+ return self.github_token.get_secret_value()
477
+ env_token = os.environ.get("GITHUB_TOKEN", "")
478
+ if not env_token:
479
+ raise ValueError(
480
+ "A GitHub token is required. Set the GITHUB_TOKEN environment "
481
+ "variable or pass ``github_token`` when instantiating "
482
+ "ChatGithubCopilot."
483
+ )
484
+ return env_token
485
+
486
+ @property
487
+ def _inference_url(self) -> str:
488
+ """Return the full chat-completions endpoint URL."""
489
+ if self.org:
490
+ path = _ORG_INFERENCE_PATH.format(org=self.org)
491
+ else:
492
+ path = _INFERENCE_PATH
493
+ return self.base_url.rstrip("/") + path
494
+
495
+ def _build_headers(self) -> Dict[str, str]:
496
+ return {
497
+ "Authorization": f"Bearer {self._token}",
498
+ "Accept": "application/vnd.github+json",
499
+ "Content-Type": "application/json",
500
+ "X-GitHub-Api-Version": self.api_version,
501
+ }
502
+
503
+ def _build_payload(
504
+ self,
505
+ messages: List[BaseMessage],
506
+ stop: Optional[List[str]] = None,
507
+ stream: bool = False,
508
+ **kwargs: Any,
509
+ ) -> Dict[str, Any]:
510
+ """Assemble the JSON body for the inference API."""
511
+ payload: Dict[str, Any] = {
512
+ "model": self.model_name,
513
+ "messages": [_message_to_dict(m) for m in messages],
514
+ "stream": stream,
515
+ }
516
+ if stream:
517
+ payload["stream_options"] = {"include_usage": True}
518
+
519
+ # Optional sampling params (kwargs override instance-level defaults)
520
+ for field_name, api_key in [
521
+ ("temperature", "temperature"),
522
+ ("max_tokens", "max_tokens"),
523
+ ("top_p", "top_p"),
524
+ ("frequency_penalty", "frequency_penalty"),
525
+ ("presence_penalty", "presence_penalty"),
526
+ ("seed", "seed"),
527
+ ]:
528
+ value = kwargs.pop(api_key, None) or getattr(self, field_name, None)
529
+ if value is not None:
530
+ payload[api_key] = value
531
+
532
+ # Stop sequences
533
+ effective_stop = stop or self.stop
534
+ if effective_stop:
535
+ payload["stop"] = effective_stop
536
+
537
+ # Tools / tool_choice
538
+ tools = kwargs.pop("tools", None)
539
+ if tools:
540
+ payload["tools"] = _format_tools_for_api(tools)
541
+ tool_choice = kwargs.pop("tool_choice", None)
542
+ if tool_choice:
543
+ payload["tool_choice"] = _normalize_tool_choice(tool_choice)
544
+
545
+ # Response format (JSON mode / structured output)
546
+ response_format = kwargs.pop("response_format", None)
547
+ if response_format:
548
+ payload["response_format"] = response_format
549
+
550
+ # Pass through any remaining caller-supplied kwargs
551
+ payload.update(kwargs)
552
+ return payload
553
+
554
+ def _do_request(self, payload: Dict[str, Any]) -> Dict[str, Any]:
555
+ """Perform a synchronous (non-streaming) HTTP POST with retries."""
556
+ headers = self._build_headers()
557
+ last_exc: Optional[Exception] = None
558
+ for attempt in range(self.max_retries + 1):
559
+ try:
560
+ response = httpx.post(
561
+ self._inference_url,
562
+ headers=headers,
563
+ json=payload,
564
+ timeout=self.timeout,
565
+ )
566
+ response.raise_for_status()
567
+ return response.json()
568
+ except (httpx.TimeoutException, httpx.TransportError) as exc:
569
+ last_exc = exc
570
+ if attempt == self.max_retries:
571
+ raise
572
+ except httpx.HTTPStatusError as exc:
573
+ # Don't retry on 4xx client errors
574
+ if exc.response.status_code < 500:
575
+ raise
576
+ last_exc = exc
577
+ if attempt == self.max_retries:
578
+ raise
579
+ raise RuntimeError("Unexpected retry loop exit") from last_exc
580
+
581
+ def _do_stream(self, payload: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
582
+ """Perform a synchronous streaming HTTP POST and yield parsed SSE chunks."""
583
+ headers = self._build_headers()
584
+ with httpx.stream(
585
+ "POST",
586
+ self._inference_url,
587
+ headers=headers,
588
+ json=payload,
589
+ timeout=self.timeout,
590
+ ) as response:
591
+ response.raise_for_status()
592
+ for line in response.iter_lines():
593
+ line = line.strip()
594
+ if not line or line == "data: [DONE]":
595
+ continue
596
+ if line.startswith("data: "):
597
+ line = line[len("data: ") :]
598
+ try:
599
+ yield json.loads(line)
600
+ except json.JSONDecodeError:
601
+ continue
602
+
603
+ async def _do_request_async(self, payload: Dict[str, Any]) -> Dict[str, Any]:
604
+ """Perform an asynchronous (non-streaming) HTTP POST with retries."""
605
+ headers = self._build_headers()
606
+ last_exc: Optional[Exception] = None
607
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
608
+ for attempt in range(self.max_retries + 1):
609
+ try:
610
+ response = await client.post(
611
+ self._inference_url,
612
+ headers=headers,
613
+ json=payload,
614
+ )
615
+ response.raise_for_status()
616
+ return response.json()
617
+ except (httpx.TimeoutException, httpx.TransportError) as exc:
618
+ last_exc = exc
619
+ if attempt == self.max_retries:
620
+ raise
621
+ except httpx.HTTPStatusError as exc:
622
+ if exc.response.status_code < 500:
623
+ raise
624
+ last_exc = exc
625
+ if attempt == self.max_retries:
626
+ raise
627
+ raise RuntimeError("Unexpected retry loop exit") from last_exc
628
+
629
+ async def _do_stream_async(
630
+ self, payload: Dict[str, Any]
631
+ ) -> AsyncIterator[Dict[str, Any]]:
632
+ """Perform an asynchronous streaming HTTP POST and yield parsed SSE chunks."""
633
+ headers = self._build_headers()
634
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
635
+ async with client.stream(
636
+ "POST",
637
+ self._inference_url,
638
+ headers=headers,
639
+ json=payload,
640
+ ) as response:
641
+ response.raise_for_status()
642
+ async for line in response.aiter_lines():
643
+ line = line.strip()
644
+ if not line or line == "data: [DONE]":
645
+ continue
646
+ if line.startswith("data: "):
647
+ line = line[len("data: ") :]
648
+ try:
649
+ yield json.loads(line)
650
+ except json.JSONDecodeError:
651
+ continue
652
+
653
+ # ------------------------------------------------------------------
654
+ # Stream delta → AIMessageChunk helpers
655
+ # ------------------------------------------------------------------
656
+
657
+ @staticmethod
658
+ def _chunk_from_delta(
659
+ delta: Dict[str, Any],
660
+ finish_reason: Optional[str],
661
+ usage: Optional[Dict[str, Any]],
662
+ ) -> AIMessageChunk:
663
+ """Convert a single SSE delta object into an ``AIMessageChunk``."""
664
+ content = delta.get("content") or ""
665
+ additional_kwargs: Dict[str, Any] = {}
666
+ tool_call_chunks = []
667
+
668
+ raw_tool_calls = delta.get("tool_calls") or []
669
+ for raw_tc in raw_tool_calls:
670
+ index = raw_tc.get("index", 0)
671
+ tc_id = raw_tc.get("id")
672
+
673
+ func = raw_tc.get("function", {})
674
+ tool_call_chunks.append(
675
+ create_tool_call_chunk(
676
+ name=func.get("name"),
677
+ args=func.get("arguments"),
678
+ id=tc_id,
679
+ index=index,
680
+ )
681
+ )
682
+
683
+ response_metadata: Dict[str, Any] = {}
684
+ if finish_reason:
685
+ response_metadata["finish_reason"] = finish_reason
686
+
687
+ usage_metadata: Optional[UsageMetadata] = None
688
+ if usage:
689
+ usage_metadata = UsageMetadata(
690
+ input_tokens=usage.get("prompt_tokens", 0),
691
+ output_tokens=usage.get("completion_tokens", 0),
692
+ total_tokens=usage.get("total_tokens", 0),
693
+ )
694
+ response_metadata["usage"] = usage
695
+
696
+ return AIMessageChunk(
697
+ content=content,
698
+ additional_kwargs=additional_kwargs,
699
+ tool_call_chunks=tool_call_chunks,
700
+ response_metadata=response_metadata,
701
+ usage_metadata=usage_metadata,
702
+ )
703
+
704
+ # ------------------------------------------------------------------
705
+ # LangChain BaseChatModel interface
706
+ # ------------------------------------------------------------------
707
+
708
+ @property
709
+ def _llm_type(self) -> str:
710
+ return "chat-github-copilot"
711
+
712
+ @property
713
+ def _identifying_params(self) -> Dict[str, Any]:
714
+ return {
715
+ "model_name": self.model_name,
716
+ "temperature": self.temperature,
717
+ "max_tokens": self.max_tokens,
718
+ }
719
+
720
+ def _get_ls_params(
721
+ self,
722
+ stop: Optional[List[str]] = None,
723
+ **kwargs: Any,
724
+ ) -> LangSmithParams:
725
+ params = self._identifying_params
726
+ return LangSmithParams(
727
+ ls_provider="github-copilot",
728
+ ls_model_name=self.model_name,
729
+ ls_model_type="chat",
730
+ ls_temperature=params.get("temperature"),
731
+ ls_max_tokens=params.get("max_tokens"),
732
+ ls_stop=stop or self.stop or [],
733
+ )
734
+
735
+ def _generate(
736
+ self,
737
+ messages: List[BaseMessage],
738
+ stop: Optional[List[str]] = None,
739
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
740
+ **kwargs: Any,
741
+ ) -> ChatResult:
742
+ """Call the GitHub Models chat completions API and return a ChatResult."""
743
+ payload = self._build_payload(messages, stop=stop, stream=False, **kwargs)
744
+ response_data = self._do_request(payload)
745
+
746
+ choices = response_data.get("choices", [])
747
+ if not choices:
748
+ raise ValueError(
749
+ f"GitHub Models API returned no choices. Response: {response_data}"
750
+ )
751
+
752
+ usage = response_data.get("usage")
753
+ generations = []
754
+ for choice in choices:
755
+ ai_msg = _build_ai_message(choice, usage)
756
+ generations.append(ChatGeneration(message=ai_msg))
757
+
758
+ return ChatResult(generations=generations)
759
+
760
+ def _stream(
761
+ self,
762
+ messages: List[BaseMessage],
763
+ stop: Optional[List[str]] = None,
764
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
765
+ **kwargs: Any,
766
+ ) -> Iterator[ChatGenerationChunk]:
767
+ """Stream token-level chunks from the GitHub Models API."""
768
+ payload = self._build_payload(messages, stop=stop, stream=True, **kwargs)
769
+
770
+ for raw_chunk in self._do_stream(payload):
771
+ choices = raw_chunk.get("choices", [])
772
+ usage = raw_chunk.get(
773
+ "usage"
774
+ ) # present in the final chunk when include_usage=True
775
+
776
+ if not choices and usage:
777
+ # Final usage-only chunk
778
+ chunk = ChatGenerationChunk(
779
+ message=AIMessageChunk(
780
+ content="",
781
+ usage_metadata=UsageMetadata(
782
+ input_tokens=usage.get("prompt_tokens", 0),
783
+ output_tokens=usage.get("completion_tokens", 0),
784
+ total_tokens=usage.get("total_tokens", 0),
785
+ ),
786
+ response_metadata={"usage": usage},
787
+ )
788
+ )
789
+ if run_manager:
790
+ run_manager.on_llm_new_token("", chunk=chunk)
791
+ yield chunk
792
+ continue
793
+
794
+ for choice in choices:
795
+ delta = choice.get("delta", {})
796
+ finish_reason = choice.get("finish_reason")
797
+ ai_chunk = self._chunk_from_delta(delta, finish_reason, usage)
798
+ gen_chunk = ChatGenerationChunk(message=ai_chunk)
799
+
800
+ if run_manager and ai_chunk.content:
801
+ run_manager.on_llm_new_token(str(ai_chunk.content), chunk=gen_chunk)
802
+ yield gen_chunk
803
+
804
+ async def _agenerate(
805
+ self,
806
+ messages: List[BaseMessage],
807
+ stop: Optional[List[str]] = None,
808
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
809
+ **kwargs: Any,
810
+ ) -> ChatResult:
811
+ """Async version of ``_generate``."""
812
+ payload = self._build_payload(messages, stop=stop, stream=False, **kwargs)
813
+ response_data = await self._do_request_async(payload)
814
+
815
+ choices = response_data.get("choices", [])
816
+ if not choices:
817
+ raise ValueError(
818
+ f"GitHub Models API returned no choices. Response: {response_data}"
819
+ )
820
+
821
+ usage = response_data.get("usage")
822
+ generations = []
823
+ for choice in choices:
824
+ ai_msg = _build_ai_message(choice, usage)
825
+ generations.append(ChatGeneration(message=ai_msg))
826
+
827
+ return ChatResult(generations=generations)
828
+
829
+ async def _astream(
830
+ self,
831
+ messages: List[BaseMessage],
832
+ stop: Optional[List[str]] = None,
833
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
834
+ **kwargs: Any,
835
+ ) -> AsyncIterator[ChatGenerationChunk]:
836
+ """Async streaming version of ``_stream``."""
837
+ payload = self._build_payload(messages, stop=stop, stream=True, **kwargs)
838
+
839
+ async for raw_chunk in self._do_stream_async(payload):
840
+ choices = raw_chunk.get("choices", [])
841
+ usage = raw_chunk.get("usage")
842
+
843
+ if not choices and usage:
844
+ chunk = ChatGenerationChunk(
845
+ message=AIMessageChunk(
846
+ content="",
847
+ usage_metadata=UsageMetadata(
848
+ input_tokens=usage.get("prompt_tokens", 0),
849
+ output_tokens=usage.get("completion_tokens", 0),
850
+ total_tokens=usage.get("total_tokens", 0),
851
+ ),
852
+ response_metadata={"usage": usage},
853
+ )
854
+ )
855
+ if run_manager:
856
+ await run_manager.on_llm_new_token("", chunk=chunk)
857
+ yield chunk
858
+ continue
859
+
860
+ for choice in choices:
861
+ delta = choice.get("delta", {})
862
+ finish_reason = choice.get("finish_reason")
863
+ ai_chunk = self._chunk_from_delta(delta, finish_reason, usage)
864
+ gen_chunk = ChatGenerationChunk(message=ai_chunk)
865
+
866
+ if run_manager and ai_chunk.content:
867
+ await run_manager.on_llm_new_token(
868
+ str(ai_chunk.content), chunk=gen_chunk
869
+ )
870
+ yield gen_chunk
871
+
872
+ # ------------------------------------------------------------------
873
+ # Tool calling support
874
+ # ------------------------------------------------------------------
875
+
876
+ def bind_tools(
877
+ self,
878
+ tools: Sequence[Union[Dict[str, Any], BaseTool, Type, Any]],
879
+ *,
880
+ tool_choice: Optional[Union[str, Literal["auto", "required", "none"]]] = None,
881
+ **kwargs: Any,
882
+ ) -> "ChatGithubCopilot":
883
+ """Bind tools to this model, enabling tool calling.
884
+
885
+ Args:
886
+ tools: A list of tools to bind. Accepts LangChain ``BaseTool``
887
+ instances, Pydantic models, or pre-formatted OpenAI tool dicts.
888
+ tool_choice: Controls tool selection. One of ``"auto"``,
889
+ ``"required"``, ``"none"``, or the name of a specific tool.
890
+ Defaults to ``"auto"`` when tools are provided.
891
+
892
+ Returns:
893
+ A new ``ChatGithubCopilot`` instance with ``tools`` bound.
894
+
895
+ Example:
896
+ .. code-block:: python
897
+
898
+ from pydantic import BaseModel, Field
899
+
900
+ class SearchWeb(BaseModel):
901
+ '''Search the web for up-to-date information.'''
902
+ query: str = Field(..., description="The search query")
903
+
904
+ llm_with_tools = llm.bind_tools([SearchWeb])
905
+ ai_msg = llm_with_tools.invoke("Who won the 2024 Olympics 100m sprint?")
906
+ print(ai_msg.tool_calls)
907
+ """
908
+ formatted_tools = _format_tools_for_api(tools)
909
+ tool_choice_param: Optional[str] = tool_choice or (
910
+ "auto" if formatted_tools else None
911
+ )
912
+ return self.bind(
913
+ tools=formatted_tools,
914
+ tool_choice=tool_choice_param,
915
+ **kwargs,
916
+ ) # type: ignore[return-value]
917
+
918
+
919
+ # ---------------------------------------------------------------------------
920
+ # Backwards-compatible alias (matches the generated stub name)
921
+ # ---------------------------------------------------------------------------
922
+
923
+ ChatGithubcopilotChat = ChatGithubCopilot