Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/copilot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
PermissionHandler,
PermissionRequest,
PermissionRequestResult,
PingResponse,
ProviderConfig,
ResumeSessionConfig,
SessionConfig,
SessionEvent,
SessionMetadata,
StopError,
Tool,
ToolHandler,
ToolInvocation,
Expand Down Expand Up @@ -56,11 +58,13 @@
"PermissionHandler",
"PermissionRequest",
"PermissionRequestResult",
"PingResponse",
"ProviderConfig",
"ResumeSessionConfig",
"SessionConfig",
"SessionEvent",
"SessionMetadata",
"StopError",
"Tool",
"ToolHandler",
"ToolInvocation",
Expand Down
64 changes: 35 additions & 29 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import subprocess
import threading
from dataclasses import asdict, is_dataclass
from typing import Any, Optional, cast
from typing import Any, Optional

from .generated.session_events import session_event_from_dict
from .jsonrpc import JsonRpcClient
Expand All @@ -32,10 +32,12 @@
GetAuthStatusResponse,
GetStatusResponse,
ModelInfo,
PingResponse,
ProviderConfig,
ResumeSessionConfig,
SessionConfig,
SessionMetadata,
StopError,
ToolHandler,
ToolInvocation,
ToolResult,
Expand Down Expand Up @@ -220,7 +222,7 @@ async def start(self) -> None:
self._state = "error"
raise

async def stop(self) -> list[dict[str, str]]:
async def stop(self) -> list["StopError"]:
"""
Stop the CLI server and close all active sessions.

Expand All @@ -230,16 +232,16 @@ async def stop(self) -> list[dict[str, str]]:
3. Terminates the CLI server process (if spawned by this client)

Returns:
A list of errors that occurred during cleanup, each as a dict with
a 'message' key. An empty list indicates all cleanup succeeded.
A list of StopError objects containing error messages that occurred
during cleanup. An empty list indicates all cleanup succeeded.

Example:
>>> errors = await client.stop()
>>> if errors:
... for error in errors:
... print(f"Cleanup error: {error['message']}")
... print(f"Cleanup error: {error.message}")
"""
errors: list[dict[str, str]] = []
errors: list[StopError] = []

# Atomically take ownership of all sessions and clear the dict
# so no other thread can access them
Expand All @@ -251,7 +253,9 @@ async def stop(self) -> list[dict[str, str]]:
try:
await session.destroy()
except Exception as e:
errors.append({"message": f"Failed to destroy session {session.session_id}: {e}"})
errors.append(
StopError(message=f"Failed to destroy session {session.session_id}: {e}")
)

# Close client
if self._client:
Expand Down Expand Up @@ -570,67 +574,69 @@ def get_state(self) -> ConnectionState:
"""
return self._state

async def ping(self, message: Optional[str] = None) -> dict:
async def ping(self, message: Optional[str] = None) -> "PingResponse":
"""
Send a ping request to the server to verify connectivity.

Args:
message: Optional message to include in the ping.

Returns:
A dict containing the ping response with 'message', 'timestamp',
and 'protocolVersion' keys.
A PingResponse object containing the ping response.

Raises:
RuntimeError: If the client is not connected.

Example:
>>> response = await client.ping("health check")
>>> print(f"Server responded at {response['timestamp']}")
>>> print(f"Server responded at {response.timestamp}")
"""
if not self._client:
raise RuntimeError("Client not connected")

return await self._client.request("ping", {"message": message})
result = await self._client.request("ping", {"message": message})
return PingResponse.from_dict(result)

async def get_status(self) -> "GetStatusResponse":
"""
Get CLI status including version and protocol information.

Returns:
A GetStatusResponse containing version and protocolVersion.
A GetStatusResponse object containing version and protocolVersion.

Raises:
RuntimeError: If the client is not connected.

Example:
>>> status = await client.get_status()
>>> print(f"CLI version: {status['version']}")
>>> print(f"CLI version: {status.version}")
"""
if not self._client:
raise RuntimeError("Client not connected")

return await self._client.request("status.get", {})
result = await self._client.request("status.get", {})
return GetStatusResponse.from_dict(result)

async def get_auth_status(self) -> "GetAuthStatusResponse":
"""
Get current authentication status.

Returns:
A GetAuthStatusResponse containing authentication state.
A GetAuthStatusResponse object containing authentication state.

Raises:
RuntimeError: If the client is not connected.

Example:
>>> auth = await client.get_auth_status()
>>> if auth['isAuthenticated']:
... print(f"Logged in as {auth.get('login')}")
>>> if auth.isAuthenticated:
... print(f"Logged in as {auth.login}")
"""
if not self._client:
raise RuntimeError("Client not connected")

return await self._client.request("auth.getStatus", {})
result = await self._client.request("auth.getStatus", {})
return GetAuthStatusResponse.from_dict(result)

async def list_models(self) -> list["ModelInfo"]:
"""
Expand All @@ -646,13 +652,14 @@ async def list_models(self) -> list["ModelInfo"]:
Example:
>>> models = await client.list_models()
>>> for model in models:
... print(f"{model['id']}: {model['name']}")
... print(f"{model.id}: {model.name}")
"""
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("models.list", {})
return response.get("models", [])
models_data = response.get("models", [])
return [ModelInfo.from_dict(model) for model in models_data]

async def list_sessions(self) -> list["SessionMetadata"]:
"""
Expand All @@ -661,23 +668,22 @@ async def list_sessions(self) -> list["SessionMetadata"]:
Returns metadata about each session including ID, timestamps, and summary.

Returns:
A list of session metadata dictionaries with keys: sessionId (str),
startTime (str), modifiedTime (str), summary (str, optional),
and isRemote (bool).
A list of SessionMetadata objects.

Raises:
RuntimeError: If the client is not connected.

Example:
>>> sessions = await client.list_sessions()
>>> for session in sessions:
... print(f"Session: {session['sessionId']}")
... print(f"Session: {session.sessionId}")
"""
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("session.list", {})
return response.get("sessions", [])
sessions_data = response.get("sessions", [])
return [SessionMetadata.from_dict(session) for session in sessions_data]

async def delete_session(self, session_id: str) -> None:
"""
Expand Down Expand Up @@ -714,7 +720,7 @@ async def _verify_protocol_version(self) -> None:
"""Verify that the server's protocol version matches the SDK's expected version."""
expected_version = get_sdk_protocol_version()
ping_result = await self.ping()
server_version = ping_result.get("protocolVersion")
server_version = ping_result.protocolVersion

if server_version is None:
raise RuntimeError(
Expand Down Expand Up @@ -845,11 +851,11 @@ async def read_port():
if not process or not process.stdout:
raise RuntimeError("Process not started or stdout not available")
while True:
line = cast(bytes, await loop.run_in_executor(None, process.stdout.readline))
line = await loop.run_in_executor(None, process.stdout.readline)
if not line:
raise RuntimeError("CLI process exited before announcing port")

line_str = line.decode()
line_str = line.decode() if isinstance(line, bytes) else line
match = re.search(r"listening on port (\d+)", line_str, re.IGNORECASE)
if match:
self._actual_port = int(match.group(1))
Expand Down
2 changes: 1 addition & 1 deletion python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _normalize_result(result: Any) -> ToolResult:

# ToolResult passes through directly
if isinstance(result, dict) and "resultType" in result and "textResultForLlm" in result:
return result # type: ignore
return result

# Strings pass through directly
if isinstance(result, str):
Expand Down
Loading
Loading