Compare commits

...

10 Commits

Author SHA1 Message Date
Kevin Liu
b056f31bda disable scrollbar 2025-07-19 19:40:30 -07:00
Michael Bolin
018003e52f feat: leverage elicitations in the MCP server (#1623)
This updates the MCP server so that if it receives an
`ExecApprovalRequest` from the `Codex` session, it in turn sends an [MCP
elicitation](https://modelcontextprotocol.io/specification/draft/client/elicitation)
to the client to ask for the approval decision. Upon getting a response,
it forwards the client's decision via `Op::ExecApproval`.

Admittedly, we should be doing the same thing for
`ApplyPatchApprovalRequest`, but this is our first time experimenting
with elicitations, so I'm inclined to defer wiring that code path up
until we feel good about how this one works.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/1623).
* __->__ #1623
* #1622
* #1621
* #1620
2025-07-19 01:32:03 -04:00
Michael Bolin
11fd3123be chore: introduce OutgoingMessageSender (#1622)
Previous to this change, `MessageProcessor` had a
`tokio::sync::mpsc::Sender<JSONRPCMessage>` as an abstraction for server
code to send a message down to the MCP client. Because `Sender` is cheap
to `clone()`, it was straightforward to make it available to tasks
scheduled with `tokio::task::spawn()`.

This worked well when we were only sending notifications or responses
back down to the client, but we want to add support for sending
elicitations in #1623, which means that we need to be able to send
_requests_ to the client, and now we need a bit of centralization to
ensure all request ids are unique.

To that end, this PR introduces `OutgoingMessageSender`, which houses
the existing `Sender<OutgoingMessage>` as well as an `AtomicI64` to mint
out new, unique request ids. It has methods like `send_request()` and
`send_response()` so that callers do not have to deal with
`JSONRPCMessage` directly, as having to set the `jsonrpc` for each
message was a bit tedious (this cleans up `codex_tool_runner.rs` quite a
bit).

We do not have `OutgoingMessageSender` implement `Clone` because it is
important that the `AtomicI64` is shared across all users of
`OutgoingMessageSender`. As such, `Arc<OutgoingMessageSender>` must be
used instead, as it is frequently shared with new tokio tasks.

As part of this change, we update `message_processor.rs` to embrace
`await`, though we must be careful that no individual handler blocks the
main loop and prevents other messages from being handled.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/1622).
* #1623
* __->__ #1622
* #1621
* #1620
2025-07-19 00:30:56 -04:00
Michael Bolin
e78ec00e73 chore: support MCP schema 2025-06-18 (#1621)
This updates the schema in `generate_mcp_types.py` from `2025-03-26` to
`2025-06-18`, regenerates `mcp-types/src/lib.rs`, and then updates all
the code that uses `mcp-types` to honor the changes.

Ran

```
npx @modelcontextprotocol/inspector just codex mcp
```

and verified that I was able to invoke the `codex` tool, as expected.


---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/1621).
* #1623
* #1622
* __->__ #1621
2025-07-19 00:09:34 -04:00
Michael Bolin
a06d4f58e4 chore: clean up generate_mcp_types.py so codegen matches existing output (#1620) 2025-07-18 21:40:39 -04:00
aibrahim-oai
83eefb55fb Add session loading support to Codex (#1602)
## Summary
- extend rollout format to store all session data in JSON
- add resume/write helpers for rollouts
- track session state after each conversation
- support `LoadSession` op to resume a previous rollout
- allow starting Codex with an existing session via
`experimental_resume` config variable

We need a way later for exploring the available sessions in a user
friendly way.

## Testing
- `cargo test --no-run` *(fails: `cargo: command not found`)*

------
https://chatgpt.com/codex/tasks/task_i_68792a29dd5c832190bf6930d3466fba

This video is outdated. you should use `-c experimental_resume:<full
path>` instead of `--resume <full path>`


https://github.com/user-attachments/assets/7a9975c7-aa04-4f4e-899a-9e87defd947a
2025-07-18 17:04:04 -07:00
aibrahim-oai
9846adeabf Refactor env settings into config (#1601)
## Summary
- add OpenAI retry and timeout fields to Config
- inject these settings in tests instead of mutating env vars
- plumb Config values through client and chat completions logic
- document new configuration options

## Testing
- `cargo test -p codex-core --no-run`

------
https://chatgpt.com/codex/tasks/task_i_68792c5b04cc832195c03050c8b6ea94

---------

Co-authored-by: Michael Bolin <mbolin@openai.com>
2025-07-18 19:12:39 +00:00
aibrahim-oai
d5a2148deb Fix ctrl+c interrupt while streaming (#1617)
Interrupting while streaming now causes is broken because we aren't
clearing the delta buffer.
2025-07-18 12:08:25 -07:00
Michael Bolin
cc874c9205 chore: use AtomicBool instead of Mutex<bool> (#1616) 2025-07-18 11:13:34 -07:00
pakrym-oai
6f2b01bb6b feat: ensure session ID header is sent in Response API request (#1614)
Include the current session id in Responses API requests.
2025-07-18 09:59:07 -07:00
33 changed files with 3899 additions and 549 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -799,6 +799,7 @@ dependencies = [
"schemars 0.8.22",
"serde",
"serde_json",
"shlex",
"tokio",
"toml 0.9.1",
"tracing",

View File

@@ -64,7 +64,11 @@ impl CliConfigOverrides {
// `-c model=o3` without the quotes.
let value: Value = match parse_toml_value(value_str) {
Ok(v) => v,
Err(_) => Value::String(value_str.to_string()),
Err(_) => {
// Strip leading/trailing quotes if present
let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\'');
Value::String(trimmed.to_string())
}
};
Ok((key.to_string(), value))

View File

@@ -92,6 +92,32 @@ http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Features": "EXAMPLE_FEATURES" }
```
### Per-provider network tuning
The following optional settings control retry behaviour and streaming idle timeouts **per model provider**. They must be specified inside the corresponding `[model_providers.<id>]` block in `config.toml`. (Older releases accepted toplevel keys; those are now ignored.)
Example:
```toml
[model_providers.openai]
name = "OpenAI"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
# network tuning overrides (all optional; falls back to builtin defaults)
request_max_retries = 4 # retry failed HTTP requests
stream_max_retries = 10 # retry dropped SSE streams
stream_idle_timeout_ms = 300000 # 5m idle timeout
```
#### request_max_retries
How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`.
#### stream_max_retries
Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`.
#### stream_idle_timeout_ms
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes).
## model_provider
Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. You can override the `base_url` for the built-in `openai` provider via the `OPENAI_BASE_URL` environment variable.
@@ -444,7 +470,7 @@ Currently, `"vscode"` is the default, though Codex does not verify VS Code is in
## hide_agent_reasoning
Codex intermittently emits "reasoning" events that show the models internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
Codex intermittently emits "reasoning" events that show the model's internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
Setting `hide_agent_reasoning` to `true` suppresses these events in **both** the TUI as well as the headless `exec` sub-command:

View File

@@ -21,8 +21,6 @@ use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::models::ContentItem;
use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
@@ -121,6 +119,7 @@ pub(crate) async fn stream_chat_completions(
);
let mut attempt = 0;
let max_retries = provider.request_max_retries();
loop {
attempt += 1;
@@ -136,7 +135,11 @@ pub(crate) async fn stream_chat_completions(
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_chat_sse(stream, tx_event));
tokio::spawn(process_chat_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
@@ -146,7 +149,7 @@ pub(crate) async fn stream_chat_completions(
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(CodexErr::RetryLimit(status));
}
@@ -162,7 +165,7 @@ pub(crate) async fn stream_chat_completions(
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
@@ -175,14 +178,15 @@ pub(crate) async fn stream_chat_completions(
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We

View File

@@ -15,6 +15,7 @@ use tokio_util::io::ReaderStream;
use tracing::debug;
use tracing::trace;
use tracing::warn;
use uuid::Uuid;
use crate::chat_completions::AggregateStreamExt;
use crate::chat_completions::stream_chat_completions;
@@ -29,8 +30,6 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::models::ResponseItem;
@@ -44,6 +43,7 @@ pub struct ModelClient {
config: Arc<Config>,
client: reqwest::Client,
provider: ModelProviderInfo,
session_id: Uuid,
effort: ReasoningEffortConfig,
summary: ReasoningSummaryConfig,
}
@@ -54,11 +54,13 @@ impl ModelClient {
provider: ModelProviderInfo,
effort: ReasoningEffortConfig,
summary: ReasoningSummaryConfig,
session_id: Uuid,
) -> Self {
Self {
config,
client: reqwest::Client::new(),
provider,
session_id,
effort,
summary,
}
@@ -109,7 +111,7 @@ impl ModelClient {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
// short circuit for tests
warn!(path, "Streaming from fixture");
return stream_from_fixture(path).await;
return stream_from_fixture(path, self.provider.clone()).await;
}
let full_instructions = prompt.get_full_instructions(&self.config.model);
@@ -136,6 +138,7 @@ impl ModelClient {
);
let mut attempt = 0;
let max_retries = self.provider.request_max_retries();
loop {
attempt += 1;
@@ -143,6 +146,7 @@ impl ModelClient {
.provider
.create_request_builder(&self.client)?
.header("OpenAI-Beta", "responses=experimental")
.header("session_id", self.session_id.to_string())
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload);
@@ -153,7 +157,11 @@ impl ModelClient {
// spawn task to process SSE
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_sse(stream, tx_event));
tokio::spawn(process_sse(
stream,
tx_event,
self.provider.stream_idle_timeout(),
));
return Ok(ResponseStream { rx_event });
}
@@ -172,7 +180,7 @@ impl ModelClient {
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(CodexErr::RetryLimit(status));
}
@@ -189,7 +197,7 @@ impl ModelClient {
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
@@ -198,6 +206,10 @@ impl ModelClient {
}
}
}
pub fn get_provider(&self) -> ModelProviderInfo {
self.provider.clone()
}
}
#[derive(Debug, Deserialize, Serialize)]
@@ -249,14 +261,16 @@ struct ResponseCompletedOutputTokensDetails {
reasoning_tokens: u64,
}
async fn process_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
async fn process_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
// If the stream stays completely silent for an extended period treat it as disconnected.
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// The response id returned from the "complete" message.
let mut response_completed: Option<ResponseCompleted> = None;
@@ -317,7 +331,7 @@ where
// duplicated `output` array embedded in the `response.completed`
// payload. That produced two concrete issues:
// 1. No realtime streaming the user only saw output after the
// entire turn had finished, which broke the typing UX and
// entire turn had finished, which broke the "typing" UX and
// made longrunning turns look stalled.
// 2. Duplicate `function_call_output` items both the
// individual *and* the completed array were forwarded, which
@@ -390,7 +404,10 @@ where
}
/// used in tests to stream from a text SSE file
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
async fn stream_from_fixture(
path: impl AsRef<Path>,
provider: ModelProviderInfo,
) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let f = std::fs::File::open(path.as_ref())?;
let lines = std::io::BufReader::new(f).lines();
@@ -404,7 +421,11 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let rdr = std::io::Cursor::new(content);
let stream = ReaderStream::new(rdr).map_err(CodexErr::Io);
tokio::spawn(process_sse(stream, tx_event));
tokio::spawn(process_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
));
Ok(ResponseStream { rx_event })
}
@@ -424,7 +445,10 @@ mod tests {
/// Runs the SSE parser on pre-chunked byte slices and returns every event
/// (including any final `Err` from a stream-closure check).
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
async fn collect_events(
chunks: &[&[u8]],
provider: ModelProviderInfo,
) -> Vec<Result<ResponseEvent>> {
let mut builder = IoBuilder::new();
for chunk in chunks {
builder.read(chunk);
@@ -433,7 +457,7 @@ mod tests {
let reader = builder.build();
let stream = ReaderStream::new(reader).map_err(CodexErr::Io);
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(16);
tokio::spawn(process_sse(stream, tx));
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
@@ -444,7 +468,10 @@ mod tests {
/// Builds an in-memory SSE stream from JSON fixtures and returns only the
/// successfully parsed events (panics on internal channel errors).
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
async fn run_sse(
events: Vec<serde_json::Value>,
provider: ModelProviderInfo,
) -> Vec<ResponseEvent> {
let mut body = String::new();
for e in events {
let kind = e
@@ -460,7 +487,7 @@ mod tests {
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io);
tokio::spawn(process_sse(stream, tx));
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
@@ -505,7 +532,25 @@ mod tests {
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let events = collect_events(
&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()],
provider,
)
.await;
assert_eq!(events.len(), 3);
@@ -546,8 +591,21 @@ mod tests {
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let events = collect_events(&[sse1.as_bytes()]).await;
let events = collect_events(&[sse1.as_bytes()], provider).await;
assert_eq!(events.len(), 2);
@@ -635,7 +693,21 @@ mod tests {
let mut evs = vec![case.event];
evs.push(completed.clone());
let out = run_sse(evs).await;
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let out = run_sse(evs, provider).await;
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
assert!(
(case.expect_first)(&out[0]),

View File

@@ -49,7 +49,6 @@ use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::process_exec_tool_call;
use crate::exec_env::create_env;
use crate::flags::OPENAI_STREAM_MAX_RETRIES;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_tool_call::handle_mcp_tool_call;
use crate::models::ContentItem;
@@ -103,6 +102,9 @@ impl Codex {
/// of `Codex` and the ID of the `SessionInitialized` event that was
/// submitted to start the session.
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String)> {
// experimental resume path (undocumented)
let resume_path = config.experimental_resume.clone();
info!("resume_path: {resume_path:?}");
let (tx_sub, rx_sub) = async_channel::bounded(64);
let (tx_event, rx_event) = async_channel::bounded(1600);
@@ -118,6 +120,7 @@ impl Codex {
disable_response_storage: config.disable_response_storage,
notify: config.notify.clone(),
cwd: config.cwd.clone(),
resume_path: resume_path.clone(),
};
let config = Arc::new(config);
@@ -307,24 +310,30 @@ impl Session {
/// transcript, if enabled.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
debug!("Recording items for conversation: {items:?}");
self.record_rollout_items(items).await;
self.record_state_snapshot(items).await;
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
transcript.record_items(items);
}
}
/// Append the given items to the session's rollout transcript (if enabled)
/// and persist them to disk.
async fn record_rollout_items(&self, items: &[ResponseItem]) {
// Clone the recorder outside of the mutex so we don't hold the lock
// across an await point (MutexGuard is not Send).
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
let snapshot = {
let state = self.state.lock().unwrap();
crate::rollout::SessionStateSnapshot {
previous_response_id: state.previous_response_id.clone(),
}
};
let recorder = {
let guard = self.rollout.lock().unwrap();
guard.as_ref().cloned()
};
if let Some(rec) = recorder {
if let Err(e) = rec.record_state(snapshot).await {
error!("failed to record rollout state: {e:#}");
}
if let Err(e) = rec.record_items(items).await {
error!("failed to record rollout items: {e:#}");
}
@@ -518,7 +527,7 @@ async fn submission_loop(
ctrl_c: Arc<Notify>,
) {
// Generate a unique ID for the lifetime of this Codex session.
let session_id = Uuid::new_v4();
let mut session_id = Uuid::new_v4();
let mut sess: Option<Arc<Session>> = None;
// shorthand - send an event when there is no active session
@@ -571,8 +580,11 @@ async fn submission_loop(
disable_response_storage,
notify,
cwd,
resume_path,
} => {
info!("Configuring session: model={model}; provider={provider:?}");
info!(
"Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}"
);
if !cwd.is_absolute() {
let message = format!("cwd is not absolute: {cwd:?}");
error!(message);
@@ -585,12 +597,48 @@ async fn submission_loop(
}
return;
}
// Optionally resume an existing rollout.
let mut restored_items: Option<Vec<ResponseItem>> = None;
let mut restored_prev_id: Option<String> = None;
let rollout_recorder: Option<RolloutRecorder> =
if let Some(path) = resume_path.as_ref() {
match RolloutRecorder::resume(path).await {
Ok((rec, saved)) => {
session_id = saved.session_id;
restored_prev_id = saved.state.previous_response_id;
if !saved.items.is_empty() {
restored_items = Some(saved.items);
}
Some(rec)
}
Err(e) => {
warn!("failed to resume rollout from {path:?}: {e}");
None
}
}
} else {
None
};
let rollout_recorder = match rollout_recorder {
Some(rec) => Some(rec),
None => match RolloutRecorder::new(&config, session_id, instructions.clone())
.await
{
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
}
},
};
let client = ModelClient::new(
config.clone(),
provider.clone(),
model_reasoning_effort,
model_reasoning_summary,
session_id,
);
// abort any current running session and clone its state
@@ -644,21 +692,6 @@ async fn submission_loop(
});
}
}
// Attempt to create a RolloutRecorder *before* moving the
// `instructions` value into the Session struct.
// TODO: if ConfigureSession is sent twice, we will create an
// overlapping rollout file. Consider passing RolloutRecorder
// from above.
let rollout_recorder =
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
}
};
sess = Some(Arc::new(Session {
client,
tx_event: tx_event.clone(),
@@ -676,6 +709,19 @@ async fn submission_loop(
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
}));
// Patch restored state into the newly created session.
if let Some(sess_arc) = &sess {
if restored_prev_id.is_some() || restored_items.is_some() {
let mut st = sess_arc.state.lock().unwrap();
st.previous_response_id = restored_prev_id;
if let (Some(hist), Some(items)) =
(st.zdr_transcript.as_mut(), restored_items.as_ref())
{
hist.record_items(items.iter());
}
}
}
// Gather history metadata for SessionConfiguredEvent.
let (history_log_id, history_entry_count) =
crate::message_history::history_metadata(&config).await;
@@ -744,6 +790,8 @@ async fn submission_loop(
}
}
Op::AddToHistory { text } => {
// TODO: What should we do if we got AddToHistory before ConfigureSession?
// currently, if ConfigureSession has resume path, this history will be ignored
let id = session_id;
let config = config.clone();
tokio::spawn(async move {
@@ -919,15 +967,17 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
) => {
items_to_record_in_conversation_history.push(item);
let (content, success): (String, Option<bool>) = match result {
Ok(CallToolResult { content, is_error }) => {
match serde_json::to_string(content) {
Ok(content) => (content, *is_error),
Err(e) => {
warn!("Failed to serialize MCP tool call output: {e}");
(e.to_string(), Some(true))
}
Ok(CallToolResult {
content,
is_error,
structured_content: _,
}) => match serde_json::to_string(content) {
Ok(content) => (content, *is_error),
Err(e) => {
warn!("Failed to serialize MCP tool call output: {e}");
(e.to_string(), Some(true))
}
}
},
Err(e) => (e.clone(), Some(true)),
};
items_to_record_in_conversation_history.push(
@@ -1026,12 +1076,13 @@ async fn run_turn(
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e) => {
if retries < *OPENAI_STREAM_MAX_RETRIES {
// Use the configured provider-specific stream retry budget.
let max_retries = sess.client.get_provider().stream_max_retries();
if retries < max_retries {
retries += 1;
let delay = backoff(retries);
warn!(
"stream disconnected - retrying turn ({retries}/{} in {delay:?})...",
*OPENAI_STREAM_MAX_RETRIES
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
);
// Surface retry information to any UI/frontend so the
@@ -1040,8 +1091,7 @@ async fn run_turn(
sess.notify_background_event(
&sub_id,
format!(
"stream error: {e}; retrying {retries}/{} in {:?}",
*OPENAI_STREAM_MAX_RETRIES, delay
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}"
),
)
.await;
@@ -1123,7 +1173,28 @@ async fn try_run_turn(
let mut stream = sess.client.clone().stream(&prompt).await?;
let mut output = Vec::new();
while let Some(Ok(event)) = stream.next().await {
loop {
// Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let Some(event) = event else {
// Channel closed without yielding a final Completed event or explicit error.
// Treat as a disconnected stream so the caller can retry.
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
));
};
let event = match event {
Ok(ev) => ev,
Err(e) => {
// Propagate the underlying stream error to the caller (run_turn), which
// will apply the configured `stream_max_retries` policy.
return Err(e);
}
};
match event {
ResponseEvent::Created => {
let mut state = sess.state.lock().unwrap();
@@ -1164,7 +1235,7 @@ async fn try_run_turn(
let mut state = sess.state.lock().unwrap();
state.previous_response_id = Some(response_id);
break;
return Ok(output);
}
ResponseEvent::OutputTextDelta(delta) => {
let event = Event {
@@ -1182,7 +1253,6 @@ async fn try_run_turn(
}
}
}
Ok(output)
}
async fn handle_response_item(
@@ -1285,7 +1355,7 @@ async fn handle_function_call(
let params = match parse_container_exec_arguments(arguments, sess, &call_id) {
Ok(params) => params,
Err(output) => {
return output;
return *output;
}
};
handle_container_exec_with_params(params, sess, sub_id, call_id).await
@@ -1328,7 +1398,7 @@ fn parse_container_exec_arguments(
arguments: String,
sess: &Session,
call_id: &str,
) -> Result<ExecParams, ResponseInputItem> {
) -> Result<ExecParams, Box<ResponseInputItem>> {
// parse command
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)),
@@ -1341,7 +1411,7 @@ fn parse_container_exec_arguments(
success: None,
},
};
Err(output)
Err(Box::new(output))
}
}
}

View File

@@ -137,6 +137,9 @@ pub struct Config {
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
pub chatgpt_base_url: String,
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
pub experimental_resume: Option<PathBuf>,
}
impl Config {
@@ -321,6 +324,9 @@ pub struct ConfigToml {
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
pub chatgpt_base_url: Option<String>,
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
pub experimental_resume: Option<PathBuf>,
}
impl ConfigToml {
@@ -448,6 +454,9 @@ impl Config {
.as_ref()
.map(|info| info.max_output_tokens)
});
let experimental_resume = cfg.experimental_resume;
let config = Self {
model,
model_context_window,
@@ -494,6 +503,8 @@ impl Config {
.chatgpt_base_url
.or(cfg.chatgpt_base_url)
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
experimental_resume,
};
Ok(config)
}
@@ -682,6 +693,9 @@ name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
request_max_retries = 4 # retry failed HTTP requests
stream_max_retries = 10 # retry dropped SSE streams
stream_idle_timeout_ms = 300000 # 5m idle timeout
[profiles.o3]
model = "o3"
@@ -722,6 +736,9 @@ disable_response_storage = true
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(4),
stream_max_retries: Some(10),
stream_idle_timeout_ms: Some(300_000),
};
let model_provider_map = {
let mut model_provider_map = built_in_model_providers();
@@ -800,6 +817,7 @@ disable_response_storage = true
model_reasoning_summary: ReasoningSummary::Detailed,
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
},
o3_profile_config
);
@@ -846,6 +864,7 @@ disable_response_storage = true
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
};
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
@@ -907,6 +926,7 @@ disable_response_storage = true
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
};
assert_eq!(expected_zdr_profile_config, zdr_profile_config);

View File

@@ -11,14 +11,6 @@ env_flags! {
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
value.parse().map(Duration::from_millis)
};
pub OPENAI_REQUEST_MAX_RETRIES: u64 = 4;
pub OPENAI_STREAM_MAX_RETRIES: u64 = 10;
// We generally don't want to disconnect; this updates the timeout to be five minutes
// which matches the upstream typescript codex impl.
pub OPENAI_STREAM_IDLE_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
value.parse().map(Duration::from_millis)
};
/// Fixture path for offline tests (see client.rs).
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;

View File

@@ -18,6 +18,7 @@ use mcp_types::ClientCapabilities;
use mcp_types::Implementation;
use mcp_types::Tool;
use serde_json::json;
use sha1::Digest;
use sha1::Sha1;
use tokio::task::JoinSet;
@@ -135,10 +136,14 @@ impl McpConnectionManager {
experimental: None,
roots: None,
sampling: None,
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
// indicates this should be an empty object.
elicitation: Some(json!({})),
},
client_info: Implementation {
name: "codex-mcp-client".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
title: Some("Codex".into()),
},
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
};
@@ -288,6 +293,8 @@ mod tests {
r#type: "object".to_string(),
},
name: tool_name.to_string(),
output_schema: None,
title: None,
},
}
}

View File

@@ -9,6 +9,7 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::time::Duration;
use crate::error::EnvVarError;
use crate::openai_api_key::get_openai_api_key;
@@ -16,6 +17,9 @@ use crate::openai_api_key::get_openai_api_key;
/// Value for the `OpenAI-Originator` header that is sent with requests to
/// OpenAI.
const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 10;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
@@ -26,7 +30,7 @@ const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The experimental Responses API exposed by OpenAI at `/v1/responses`.
/// The experimental "Responses" API exposed by OpenAI at `/v1/responses`.
Responses,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
@@ -64,6 +68,16 @@ pub struct ModelProviderInfo {
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<u64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<u64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<u64>,
}
impl ModelProviderInfo {
@@ -161,6 +175,25 @@ impl ModelProviderInfo {
None => Ok(None),
}
}
/// Effective maximum number of request retries for this provider.
pub fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
}
/// Effective maximum number of stream reconnection attempts for this provider.
pub fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
}
/// Effective idle timeout for streaming responses.
pub fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
}
/// Built-in default provider list.
@@ -205,6 +238,10 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
.into_iter()
.collect(),
),
// Use global defaults for retry/timeout unless overridden in config.toml.
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
},
),
]
@@ -234,6 +271,9 @@ base_url = "http://localhost:11434/v1"
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
@@ -259,6 +299,9 @@ query_params = { api-version = "2025-04-01-preview" }
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
@@ -287,6 +330,9 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();

View File

@@ -69,6 +69,10 @@ pub enum Op {
/// `ConfigureSession` operation so that the business-logic layer can
/// operate deterministically.
cwd: std::path::PathBuf,
/// Path to a rollout file to resume from.
#[serde(skip_serializing_if = "Option::is_none")]
resume_path: Option<std::path::PathBuf>,
},
/// Abort current task.

View File

@@ -1,33 +1,47 @@
//! Functionality to persist a Codex conversation *rollout* a linear list of
//! [`ResponseItem`] objects exchanged during a session to disk so that
//! sessions can be replayed or inspected later (mirrors the behaviour of the
//! upstream TypeScript implementation).
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
use std::fs::File;
use std::fs::{self};
use std::io::Error as IoError;
use std::path::Path;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use time::OffsetDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{self};
use tracing::info;
use uuid::Uuid;
use crate::config::Config;
use crate::models::ResponseItem;
/// Folder inside `~/.codex` that holds saved rollouts.
const SESSIONS_SUBDIR: &str = "sessions";
#[derive(Serialize)]
struct SessionMeta {
id: String,
timestamp: String,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
#[derive(Serialize, Deserialize, Clone, Default)]
pub struct SessionMeta {
pub id: Uuid,
pub timestamp: String,
pub instructions: Option<String>,
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SessionStateSnapshot {
pub previous_response_id: Option<String>,
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SavedSession {
pub session: SessionMeta,
#[serde(default)]
pub items: Vec<ResponseItem>,
#[serde(default)]
pub state: SessionStateSnapshot,
pub session_id: Uuid,
}
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
@@ -41,7 +55,13 @@ struct SessionMeta {
/// ```
#[derive(Clone)]
pub(crate) struct RolloutRecorder {
tx: Sender<String>,
tx: Sender<RolloutCmd>,
}
#[derive(Clone)]
enum RolloutCmd {
AddItems(Vec<ResponseItem>),
UpdateState(SessionStateSnapshot),
}
impl RolloutRecorder {
@@ -59,7 +79,6 @@ impl RolloutRecorder {
timestamp,
} = create_log_file(config, uuid)?;
// Build the static session metadata JSON first.
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
@@ -69,46 +88,29 @@ impl RolloutRecorder {
let meta = SessionMeta {
timestamp,
id: session_id.to_string(),
id: session_id,
instructions,
};
// A reasonably-sized bounded channel. If the buffer fills up the send
// future will yield, which is fine we only need to ensure we do not
// perform *blocking* I/O on the callers thread.
let (tx, mut rx) = mpsc::channel::<String>(256);
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
// Spawn a Tokio task that owns the file handle and performs async
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
tokio::task::spawn(async move {
let mut file = tokio::fs::File::from_std(file);
tokio::task::spawn(rollout_writer(
tokio::fs::File::from_std(file),
rx,
Some(meta),
));
while let Some(line) = rx.recv().await {
// Write line + newline, then flush to disk.
if let Err(e) = file.write_all(line.as_bytes()).await {
tracing::warn!("rollout writer: failed to write line: {e}");
break;
}
if let Err(e) = file.write_all(b"\n").await {
tracing::warn!("rollout writer: failed to write newline: {e}");
break;
}
if let Err(e) = file.flush().await {
tracing::warn!("rollout writer: failed to flush: {e}");
break;
}
}
});
let recorder = Self { tx };
// Ensure SessionMeta is the first item in the file.
recorder.record_item(&meta).await?;
Ok(recorder)
Ok(Self { tx })
}
/// Append `items` to the rollout file.
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
let mut filtered = Vec::new();
for item in items {
match item {
// Note that function calls may look a bit strange if they are
@@ -117,27 +119,86 @@ impl RolloutRecorder {
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. } => {}
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// These should never be serialized.
continue;
}
}
self.record_item(item).await?;
}
Ok(())
if filtered.is_empty() {
return Ok(());
}
self.tx
.send(RolloutCmd::AddItems(filtered))
.await
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
}
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
// Serialize the item to JSON first so that the writer thread only has
// to perform the actual write.
let json = serde_json::to_string(item)
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
self.tx
.send(json)
.send(RolloutCmd::UpdateState(state))
.await
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
}
pub async fn resume(path: &Path) -> std::io::Result<(Self, SavedSession)> {
info!("Resuming rollout from {path:?}");
let text = tokio::fs::read_to_string(path).await?;
let mut lines = text.lines();
let meta_line = lines
.next()
.ok_or_else(|| IoError::other("empty session file"))?;
let session: SessionMeta = serde_json::from_str(meta_line)
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
let mut items = Vec::new();
let mut state = SessionStateSnapshot::default();
for line in lines {
if line.trim().is_empty() {
continue;
}
let v: Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if v.get("record_type")
.and_then(|rt| rt.as_str())
.map(|s| s == "state")
.unwrap_or(false)
{
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
state = s
}
continue;
}
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
match item {
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
}
}
}
let saved = SavedSession {
session: session.clone(),
items: items.clone(),
state: state.clone(),
session_id: session.id,
};
let file = std::fs::OpenOptions::new()
.append(true)
.read(true)
.open(path)?;
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
tokio::task::spawn(rollout_writer(tokio::fs::File::from_std(file), rx, None));
info!("Resumed rollout successfully from {path:?}");
Ok((Self { tx }, saved))
}
}
@@ -185,3 +246,54 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
timestamp,
})
}
async fn rollout_writer(
mut file: tokio::fs::File,
mut rx: mpsc::Receiver<RolloutCmd>,
meta: Option<SessionMeta>,
) {
if let Some(meta) = meta {
if let Ok(json) = serde_json::to_string(&meta) {
let _ = file.write_all(json.as_bytes()).await;
let _ = file.write_all(b"\n").await;
let _ = file.flush().await;
}
}
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => {
for item in items {
match item {
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. } => {
if let Ok(json) = serde_json::to_string(&item) {
let _ = file.write_all(json.as_bytes()).await;
let _ = file.write_all(b"\n").await;
}
}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
}
}
let _ = file.flush().await;
}
RolloutCmd::UpdateState(state) => {
#[derive(Serialize)]
struct StateLine<'a> {
record_type: &'static str,
#[serde(flatten)]
state: &'a SessionStateSnapshot,
}
if let Ok(json) = serde_json::to_string(&StateLine {
record_type: "state",
state: &state,
}) {
let _ = file.write_all(json.as_bytes()).await;
let _ = file.write_all(b"\n").await;
let _ = file.flush().await;
}
}
}
}
}

View File

@@ -2,7 +2,6 @@
use assert_cmd::Command as AssertCommand;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use serde_json::Value;
use std::time::Duration;
use std::time::Instant;
use tempfile::TempDir;
@@ -123,6 +122,7 @@ async fn responses_api_stream_cli() {
assert!(stdout.contains("fixture hello"));
}
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn integration_creates_and_checks_session_file() {
// Honor sandbox network restrictions for CI parity with the other tests.
@@ -170,45 +170,66 @@ async fn integration_creates_and_checks_session_file() {
String::from_utf8_lossy(&output.stderr)
);
// 5. Sessions are written asynchronously; wait briefly for the directory to appear.
// Wait for sessions dir to appear.
let sessions_dir = home.path().join("sessions");
let start = Instant::now();
while !sessions_dir.exists() && start.elapsed() < Duration::from_secs(2) {
let dir_deadline = Instant::now() + Duration::from_secs(5);
while !sessions_dir.exists() && Instant::now() < dir_deadline {
std::thread::sleep(Duration::from_millis(50));
}
assert!(sessions_dir.exists(), "sessions directory never appeared");
// 6. Scan all session files and find the one that contains our marker.
let mut matching_files = vec![];
for entry in WalkDir::new(&sessions_dir) {
let entry = entry.unwrap();
if entry.file_type().is_file() && entry.file_name().to_string_lossy().ends_with(".jsonl") {
// Find the session file that contains `marker`.
let deadline = Instant::now() + Duration::from_secs(10);
let mut matching_path: Option<std::path::PathBuf> = None;
while Instant::now() < deadline && matching_path.is_none() {
for entry in WalkDir::new(&sessions_dir) {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
if !entry.file_type().is_file() {
continue;
}
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
continue;
}
let path = entry.path();
let content = std::fs::read_to_string(path).unwrap();
let Ok(content) = std::fs::read_to_string(path) else {
continue;
};
let mut lines = content.lines();
// Skip SessionMeta (first line)
let _ = lines.next();
if lines.next().is_none() {
continue;
}
for line in lines {
let item: Value = serde_json::from_str(line).unwrap();
if let Some("message") = item.get("type").and_then(|t| t.as_str()) {
if let Some(content) = item.get("content") {
if content.to_string().contains(&marker) {
matching_files.push(path.to_owned());
if line.trim().is_empty() {
continue;
}
let item: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
if let Some(c) = item.get("content") {
if c.to_string().contains(&marker) {
matching_path = Some(path.to_path_buf());
break;
}
}
}
}
}
if matching_path.is_none() {
std::thread::sleep(Duration::from_millis(50));
}
}
assert_eq!(
matching_files.len(),
1,
"Expected exactly one session file containing the marker, found {}",
matching_files.len()
);
let path = &matching_files[0];
// 7. Verify directory structure: sessions/YYYY/MM/DD/filename.jsonl
let path = match matching_path {
Some(p) => p,
None => panic!("No session file containing the marker was found"),
};
// Basic sanity checks on location and metadata.
let rel = match path.strip_prefix(&sessions_dir) {
Ok(r) => r,
Err(_) => panic!("session file should live under sessions/"),
@@ -237,7 +258,6 @@ async fn integration_creates_and_checks_session_file() {
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
"Day dir not zero-padded 2-digit numeric: {day}"
);
// Range checks (best-effort; won't fail on leading zeros)
if let Ok(m) = month.parse::<u8>() {
assert!((1..=12).contains(&m), "Month out of range: {m}");
}
@@ -245,23 +265,32 @@ async fn integration_creates_and_checks_session_file() {
assert!((1..=31).contains(&d), "Day out of range: {d}");
}
// 8. Parse SessionMeta line and basic sanity checks.
let content = std::fs::read_to_string(path).unwrap();
let content =
std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file"));
let mut lines = content.lines();
let meta: Value = serde_json::from_str(lines.next().unwrap()).unwrap();
let meta_line = lines
.next()
.ok_or("missing session meta line")
.unwrap_or_else(|_| panic!("missing session meta line"));
let meta: serde_json::Value = serde_json::from_str(meta_line)
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
assert!(meta.get("id").is_some(), "SessionMeta missing id");
assert!(
meta.get("timestamp").is_some(),
"SessionMeta missing timestamp"
);
// 9. Confirm at least one message contains the marker.
let mut found_message = false;
for line in lines {
let item: Value = serde_json::from_str(line).unwrap();
if item.get("type").map(|t| t == "message").unwrap_or(false) {
if let Some(content) = item.get("content") {
if content.to_string().contains(&marker) {
if line.trim().is_empty() {
continue;
}
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
continue;
};
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
if let Some(c) = item.get("content") {
if c.to_string().contains(&marker) {
found_message = true;
break;
}
@@ -272,4 +301,61 @@ async fn integration_creates_and_checks_session_file() {
found_message,
"No message found in session file containing the marker"
);
// Second run: resume and append.
let orig_len = content.lines().count();
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
let prompt2 = format!("echo {marker2}");
// Crossplatform safe resume override. On Windows, backslashes in a TOML string must be escaped
// or the parse will fail and the raw literal (including quotes) may be preserved all the way down
// to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes
// to sidestep the issue.
let resume_path_str = path.to_string_lossy().replace('\\', "/");
let resume_override = format!("experimental_resume=\"{resume_path_str}\"");
let mut cmd2 = AssertCommand::new("cargo");
cmd2.arg("run")
.arg("-p")
.arg("codex-cli")
.arg("--quiet")
.arg("--")
.arg("exec")
.arg("--skip-git-repo-check")
.arg("-c")
.arg(&resume_override)
.arg("-C")
.arg(env!("CARGO_MANIFEST_DIR"))
.arg(&prompt2);
cmd2.env("CODEX_HOME", home.path())
.env("OPENAI_API_KEY", "dummy")
.env("CODEX_RS_SSE_FIXTURE", &fixture)
.env("OPENAI_BASE_URL", "http://unused.local");
let output2 = cmd2.output().unwrap();
assert!(output2.status.success(), "resume codex-cli run failed");
// The rollout writer runs on a background async task; give it a moment to flush.
let mut new_len = orig_len;
let deadline = Instant::now() + Duration::from_secs(5);
let mut content2 = String::new();
while Instant::now() < deadline {
if let Ok(c) = std::fs::read_to_string(&path) {
let count = c.lines().count();
if count > orig_len {
content2 = c;
new_len = count;
break;
}
}
std::thread::sleep(Duration::from_millis(50));
}
if content2.is_empty() {
// last attempt
content2 = std::fs::read_to_string(&path).unwrap();
new_len = content2.lines().count();
}
assert!(new_len > orig_len, "rollout file did not grow after resume");
assert!(content2.contains(&marker), "rollout lost original marker");
assert!(
content2.contains(&marker2),
"rollout missing resumed marker"
);
}

View File

@@ -0,0 +1,113 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::SessionConfiguredEvent;
mod test_support;
use tempfile::TempDir;
use test_support::load_default_config_for_test;
use test_support::load_sse_fixture_with_id;
use tokio::time::timeout;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
/// Build minimal SSE stream with completed marker using the JSON fixture.
fn sse_completed(id: &str) -> String {
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn includes_session_id_and_model_headers_in_request() {
#![allow(clippy::unwrap_used)]
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
// Mock server
let server = MockServer::start().await;
// First request must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
query_params: None,
http_headers: Some(
[("originator".to_string(), "codex_cli_rs".to_string())]
.into_iter()
.collect(),
),
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: None,
};
// Init session
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.model_provider = model_provider;
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "hello".into(),
}],
})
.await
.unwrap();
let mut current_session_id = None;
// Wait for TaskComplete
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
.unwrap()
.unwrap();
if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg {
current_session_id = Some(session_id.to_string());
}
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
break;
}
}
// get request from the server
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.headers.get("session_id").unwrap();
let originator = request.headers.get("originator").unwrap();
assert!(current_session_id.is_some());
assert_eq!(request_body.to_str().unwrap(), &current_session_id.unwrap());
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
}

View File

@@ -45,22 +45,10 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
"OPENAI_API_KEY must be set for live tests"
);
// Environment tweaks to keep the tests snappy and inexpensive while still
// exercising retry/robustness logic.
//
// NOTE: Starting with the 2024 edition `std::env::set_var` is `unsafe`
// because changing the process environment races with any other threads
// that might be performing environment look-ups at the same time.
// Restrict the unsafety to this tiny block that happens at the very
// beginning of the test, before we spawn any background tasks that could
// observe the environment.
unsafe {
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2");
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2");
}
let codex_home = TempDir::new().unwrap();
let config = load_default_config_for_test(&codex_home);
let mut config = load_default_config_for_test(&codex_home);
config.model_provider.request_max_retries = Some(2);
config.model_provider.stream_max_retries = Some(2);
let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
Ok(agent)
@@ -79,7 +67,7 @@ async fn live_streaming_and_prev_id_reset() {
let codex = spawn_codex().await.unwrap();
// ---------- Task 1 ----------
// ---------- Task 1 ----------
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
@@ -113,7 +101,7 @@ async fn live_streaming_and_prev_id_reset() {
"Agent did not stream any AgentMessage before TaskComplete"
);
// ---------- Task 2 (same session) ----------
// ---------- Task 2 (same session) ----------
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {

View File

@@ -88,13 +88,8 @@ async fn keeps_previous_response_id_between_tasks() {
.mount(&server)
.await;
// Environment
// Update environment `set_var` is `unsafe` starting with the 2024
// edition so we group the calls into a single `unsafe { … }` block.
unsafe {
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
}
// Configure retry behavior explicitly to avoid mutating process-wide
// environment variables.
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
@@ -107,6 +102,10 @@ async fn keeps_previous_response_id_between_tasks() {
query_params: None,
http_headers: None,
env_http_headers: None,
// disable retries so we don't get duplicate calls in this test
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: None,
};
// Init session

View File

@@ -32,8 +32,6 @@ fn sse_completed(id: &str) -> String {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
// this test is flaky (has race conditions), so we ignore it for now
#[ignore]
async fn retries_on_early_close() {
#![allow(clippy::unwrap_used)]
@@ -72,19 +70,8 @@ async fn retries_on_early_close() {
.mount(&server)
.await;
// Environment
//
// As of Rust 2024 `std::env::set_var` has been made `unsafe` because
// mutating the process environment is inherently racy when other threads
// are running. We therefore have to wrap every call in an explicit
// `unsafe` block. These are limited to the test-setup section so the
// scope is very small and clearly delineated.
unsafe {
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
}
// Configure retry behavior explicitly to avoid mutating process-wide
// environment variables.
let model_provider = ModelProviderInfo {
name: "openai".into(),
@@ -98,6 +85,10 @@ async fn retries_on_early_close() {
query_params: None,
http_headers: None,
env_http_headers: None,
// exercise retry path: first attempt yields incomplete stream, so allow 1 retry
request_max_retries: Some(0),
stream_max_retries: Some(1),
stream_idle_timeout_ms: Some(2000),
};
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());

View File

@@ -57,10 +57,12 @@ async fn main() -> Result<()> {
experimental: None,
roots: None,
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: "codex-mcp-client".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
title: Some("Codex".to_string()),
},
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
};

View File

@@ -22,6 +22,7 @@ mcp-types = { path = "../mcp-types" }
schemars = "0.8.22"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
shlex = "1.3.0"
toml = "0.9"
tracing = { version = "0.1.41", features = ["log"] }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }

View File

@@ -108,7 +108,10 @@ pub(crate) fn create_tool_for_codex_tool_call_param() -> Tool {
Tool {
name: "codex".to_string(),
title: Some("Codex".to_string()),
input_schema: tool_input_schema,
// TODO(mbolin): This should be defined.
output_schema: None,
description: Some(
"Run a Codex session. Accepts configuration parameters matching the Codex Config struct.".to_string(),
),
@@ -179,6 +182,7 @@ mod tests {
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
let expected_tool_json = serde_json::json!({
"name": "codex",
"title": "Codex",
"description": "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.",
"inputSchema": {
"type": "object",

View File

@@ -2,33 +2,31 @@
//! Tokio task. Separated from `message_processor.rs` to keep that file small
//! and to make future feature-growth easier to manage.
use std::sync::Arc;
use codex_core::Codex;
use codex_core::codex_wrapper::init_codex;
use codex_core::config::Config as CodexConfig;
use codex_core::protocol::AgentMessageEvent;
use codex_core::protocol::Event;
use codex_core::protocol::EventMsg;
use codex_core::protocol::ExecApprovalRequestEvent;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::ReviewDecision;
use codex_core::protocol::Submission;
use codex_core::protocol::TaskCompleteEvent;
use mcp_types::CallToolResult;
use mcp_types::CallToolResultContent;
use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCMessage;
use mcp_types::JSONRPCResponse;
use mcp_types::ContentBlock;
use mcp_types::ElicitRequest;
use mcp_types::ElicitRequestParamsRequestedSchema;
use mcp_types::ModelContextProtocolRequest;
use mcp_types::RequestId;
use mcp_types::TextContent;
use tokio::sync::mpsc::Sender;
use serde::Deserialize;
use serde_json::json;
use tracing::error;
/// Convert a Codex [`Event`] to an MCP notification.
fn codex_event_to_notification(event: &Event) -> JSONRPCMessage {
#[expect(clippy::expect_used)]
JSONRPCMessage::Notification(mcp_types::JSONRPCNotification {
jsonrpc: JSONRPC_VERSION.into(),
method: "codex/event".into(),
params: Some(serde_json::to_value(event).expect("Event must serialize")),
})
}
use crate::outgoing_message::OutgoingMessageSender;
/// Run a complete Codex session and stream events back to the client.
///
@@ -38,34 +36,28 @@ pub async fn run_codex_tool_session(
id: RequestId,
initial_prompt: String,
config: CodexConfig,
outgoing: Sender<JSONRPCMessage>,
outgoing: Arc<OutgoingMessageSender>,
) {
let (codex, first_event, _ctrl_c) = match init_codex(config).await {
Ok(res) => res,
Err(e) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: format!("Failed to start Codex session: {e}"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
let _ = outgoing
.send(JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id,
result: result.into(),
}))
.await;
outgoing.send_response(id.clone(), result.into()).await;
return;
}
};
let codex = Arc::new(codex);
// Send initial SessionConfigured event.
let _ = outgoing
.send(codex_event_to_notification(&first_event))
.await;
outgoing.send_event_as_notification(&first_event).await;
// Use the original MCP request ID as the `sub_id` for the Codex submission so that
// any events emitted for this tool-call can be correlated with the
@@ -76,7 +68,7 @@ pub async fn run_codex_tool_session(
};
let submission = Submission {
id: sub_id,
id: sub_id.clone(),
op: Op::UserInput {
items: vec![InputItem::Text {
text: initial_prompt.clone(),
@@ -88,84 +80,87 @@ pub async fn run_codex_tool_session(
tracing::error!("Failed to submit initial prompt: {e}");
}
let mut last_agent_message: Option<String> = None;
// Stream events until the task needs to pause for user interaction or
// completes.
loop {
match codex.next_event().await {
Ok(event) => {
let _ = outgoing.send(codex_event_to_notification(&event)).await;
outgoing.send_event_as_notification(&event).await;
match &event.msg {
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
last_agent_message = Some(message.clone());
}
EventMsg::ExecApprovalRequest(_) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
r#type: "text".to_string(),
text: "EXEC_APPROVAL_REQUIRED".to_string(),
annotations: None,
})],
is_error: None,
};
let _ = outgoing
.send(JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id: id.clone(),
result: result.into(),
}))
match event.msg {
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
command,
cwd,
reason: _,
}) => {
let escaped_command = shlex::try_join(command.iter().map(|s| s.as_str()))
.unwrap_or_else(|_| command.join(" "));
let message = format!("Allow Codex to run `{escaped_command}` in {cwd:?}?");
let params = json!({
// These fields are required so that `params`
// conforms to ElicitRequestParams.
"message": message,
"requestedSchema": ElicitRequestParamsRequestedSchema {
r#type: "object".to_string(),
properties: json!({}),
required: None,
},
// These are additional fields the client can use to
// correlate the request with the codex tool call.
"codex_elicitation": "exec-approval",
"codex_mcp_tool_call_id": sub_id,
"codex_event_id": event.id,
"codex_command": command,
// Could convert it to base64 encoded bytes if we
// don't want to use to_string_lossy() here?
"codex_cwd": cwd.to_string_lossy().to_string()
});
let on_response = outgoing
.send_request(ElicitRequest::METHOD, Some(params))
.await;
// Listen for the response on a separate task so we do
// not block the main loop of this function.
{
let codex = codex.clone();
let event_id = event.id.clone();
tokio::spawn(async move {
on_exec_approval_response(event_id, on_response, codex).await;
});
}
break;
}
EventMsg::ApplyPatchApprovalRequest(_) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: "PATCH_APPROVAL_REQUIRED".to_string(),
annotations: None,
})],
is_error: None,
structured_content: None,
};
let _ = outgoing
.send(JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id: id.clone(),
result: result.into(),
}))
.await;
outgoing.send_response(id.clone(), result.into()).await;
break;
}
EventMsg::TaskComplete(TaskCompleteEvent {
last_agent_message: _,
}) => {
let result = if let Some(msg) = last_agent_message {
CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
r#type: "text".to_string(),
text: msg,
annotations: None,
})],
is_error: None,
}
} else {
CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
r#type: "text".to_string(),
text: String::new(),
annotations: None,
})],
is_error: None,
}
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
let text = match last_agent_message {
Some(msg) => msg.clone(),
None => "".to_string(),
};
let _ = outgoing
.send(JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id: id.clone(),
result: result.into(),
}))
.await;
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text,
annotations: None,
})],
is_error: None,
structured_content: None,
};
outgoing.send_response(id.clone(), result.into()).await;
break;
}
EventMsg::SessionConfigured(_) => {
@@ -177,6 +172,9 @@ pub async fn run_codex_tool_session(
EventMsg::AgentReasoningDelta(_) => {
// TODO: think how we want to support this in the MCP
}
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
// TODO: think how we want to support this in the MCP
}
EventMsg::Error(_)
| EventMsg::TaskStarted
| EventMsg::TokenCount(_)
@@ -200,22 +198,58 @@ pub async fn run_codex_tool_session(
}
Err(e) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: format!("Codex runtime error: {e}"),
annotations: None,
})],
is_error: Some(true),
// TODO(mbolin): Could present the error in a more
// structured way.
structured_content: None,
};
let _ = outgoing
.send(JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id: id.clone(),
result: result.into(),
}))
.await;
outgoing.send_response(id.clone(), result.into()).await;
break;
}
}
}
}
async fn on_exec_approval_response(
event_id: String,
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
codex: Arc<Codex>,
) {
let response = receiver.await;
let value = match response {
Ok(value) => value,
Err(err) => {
error!("request failed: {err:?}");
return;
}
};
// Try to deserialize `value` and then make the appropriate call to `codex`.
let response = match serde_json::from_value::<ExecApprovalResponse>(value) {
Ok(response) => response,
Err(err) => {
error!("failed to deserialize ExecApprovalResponse: {err}");
return;
}
};
if let Err(err) = codex
.submit(Op::ExecApproval {
id: event_id,
decision: response.decision,
})
.await
{
error!("failed to submit ExecApproval: {err}");
}
}
#[derive(Debug, Deserialize)]
pub struct ExecApprovalResponse {
pub decision: ReviewDecision,
}

View File

@@ -18,8 +18,11 @@ mod codex_tool_config;
mod codex_tool_runner;
mod json_to_toml;
mod message_processor;
mod outgoing_message;
use crate::message_processor::MessageProcessor;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
/// Size of the bounded channels used to communicate between tasks. The value
/// is a balance between throughput and memory usage 128 messages should be
@@ -35,7 +38,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
// Set up channels.
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
// Task: read from stdin, push to `incoming_tx`.
let stdin_reader_handle = tokio::spawn({
@@ -63,16 +66,15 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
// Task: process incoming messages.
let processor_handle = tokio::spawn({
let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe);
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe);
async move {
while let Some(msg) = incoming_rx.recv().await {
match msg {
JSONRPCMessage::Request(r) => processor.process_request(r),
JSONRPCMessage::Response(r) => processor.process_response(r),
JSONRPCMessage::Request(r) => processor.process_request(r).await,
JSONRPCMessage::Response(r) => processor.process_response(r).await,
JSONRPCMessage::Notification(n) => processor.process_notification(n),
JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b),
JSONRPCMessage::Error(e) => processor.process_error(e),
JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b),
}
}
@@ -83,7 +85,8 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
// Task: write outgoing messages to stdout.
let stdout_writer_handle = tokio::spawn(async move {
let mut stdout = io::stdout();
while let Some(msg) = outgoing_rx.recv().await {
while let Some(outgoing_message) = outgoing_rx.recv().await {
let msg: JSONRPCMessage = outgoing_message.into();
match serde_json::to_string(&msg) {
Ok(json) => {
if let Err(e) = stdout.write_all(json.as_bytes()).await {

View File

@@ -1,19 +1,17 @@
use std::path::PathBuf;
use std::sync::Arc;
use crate::codex_tool_config::CodexToolCallParam;
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
use crate::outgoing_message::OutgoingMessageSender;
use codex_core::config::Config as CodexConfig;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::CallToolResultContent;
use mcp_types::ClientRequest;
use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCBatchRequest;
use mcp_types::JSONRPCBatchResponse;
use mcp_types::ContentBlock;
use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError;
use mcp_types::JSONRPCMessage;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::JSONRPCResponse;
@@ -24,11 +22,10 @@ use mcp_types::ServerCapabilitiesTools;
use mcp_types::ServerNotification;
use mcp_types::TextContent;
use serde_json::json;
use tokio::sync::mpsc;
use tokio::task;
pub(crate) struct MessageProcessor {
outgoing: mpsc::Sender<JSONRPCMessage>,
outgoing: Arc<OutgoingMessageSender>,
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
}
@@ -37,17 +34,17 @@ impl MessageProcessor {
/// Create a new `MessageProcessor`, retaining a handle to the outgoing
/// `Sender` so handlers can enqueue messages to be written to stdout.
pub(crate) fn new(
outgoing: mpsc::Sender<JSONRPCMessage>,
outgoing: OutgoingMessageSender,
codex_linux_sandbox_exe: Option<PathBuf>,
) -> Self {
Self {
outgoing,
outgoing: Arc::new(outgoing),
initialized: false,
codex_linux_sandbox_exe,
}
}
pub(crate) fn process_request(&mut self, request: JSONRPCRequest) {
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
// Hold on to the ID so we can respond.
let request_id = request.id.clone();
@@ -62,10 +59,10 @@ impl MessageProcessor {
// Dispatch to a dedicated handler for each request type.
match client_request {
ClientRequest::InitializeRequest(params) => {
self.handle_initialize(request_id, params);
self.handle_initialize(request_id, params).await;
}
ClientRequest::PingRequest(params) => {
self.handle_ping(request_id, params);
self.handle_ping(request_id, params).await;
}
ClientRequest::ListResourcesRequest(params) => {
self.handle_list_resources(params);
@@ -89,10 +86,10 @@ impl MessageProcessor {
self.handle_get_prompt(params);
}
ClientRequest::ListToolsRequest(params) => {
self.handle_list_tools(request_id, params);
self.handle_list_tools(request_id, params).await;
}
ClientRequest::CallToolRequest(params) => {
self.handle_call_tool(request_id, params);
self.handle_call_tool(request_id, params).await;
}
ClientRequest::SetLevelRequest(params) => {
self.handle_set_level(params);
@@ -104,8 +101,10 @@ impl MessageProcessor {
}
/// Handle a standalone JSON-RPC response originating from the peer.
pub(crate) fn process_response(&mut self, response: JSONRPCResponse) {
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
tracing::info!("<- response: {:?}", response);
let JSONRPCResponse { id, result, .. } = response;
self.outgoing.notify_client_response(id, result).await
}
/// Handle a fire-and-forget JSON-RPC notification.
@@ -145,42 +144,12 @@ impl MessageProcessor {
}
}
/// Handle a batch of requests and/or notifications.
pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) {
tracing::info!("<- batch request containing {} item(s)", batch.len());
for item in batch {
match item {
mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => {
self.process_request(req);
}
mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => {
self.process_notification(note);
}
}
}
}
/// Handle an error object received from the peer.
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
tracing::error!("<- error: {:?}", err);
}
/// Handle a batch of responses/errors.
pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) {
tracing::info!("<- batch response containing {} item(s)", batch.len());
for item in batch {
match item {
mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => {
self.process_response(resp);
}
mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => {
self.process_error(err);
}
}
}
}
fn handle_initialize(
async fn handle_initialize(
&mut self,
id: RequestId,
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
@@ -189,19 +158,12 @@ impl MessageProcessor {
if self.initialized {
// Already initialised: send JSON-RPC error response.
let error_msg = JSONRPCMessage::Error(JSONRPCError {
jsonrpc: JSONRPC_VERSION.into(),
id,
error: JSONRPCErrorError {
code: -32600, // Invalid Request
message: "initialize called more than once".to_string(),
data: None,
},
});
if let Err(e) = self.outgoing.try_send(error_msg) {
tracing::error!("Failed to send initialization error: {e}");
}
let error = JSONRPCErrorError {
code: -32600, // Invalid Request
message: "initialize called more than once".to_string(),
data: None,
};
self.outgoing.send_error(id, error).await;
return;
}
@@ -224,37 +186,33 @@ impl MessageProcessor {
server_info: mcp_types::Implementation {
name: "codex-mcp-server".to_string(),
version: mcp_types::MCP_SCHEMA_VERSION.to_string(),
title: Some("Codex".to_string()),
},
};
self.send_response::<mcp_types::InitializeRequest>(id, result);
self.send_response::<mcp_types::InitializeRequest>(id, result)
.await;
}
fn send_response<T>(&self, id: RequestId, result: T::Result)
async fn send_response<T>(&self, id: RequestId, result: T::Result)
where
T: ModelContextProtocolRequest,
{
// result has `Serialized` instance so should never fail
#[expect(clippy::unwrap_used)]
let response = JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id,
result: serde_json::to_value(result).unwrap(),
});
if let Err(e) = self.outgoing.try_send(response) {
tracing::error!("Failed to send response: {e}");
}
let result = serde_json::to_value(result).unwrap();
self.outgoing.send_response(id, result).await;
}
fn handle_ping(
async fn handle_ping(
&self,
id: RequestId,
params: <mcp_types::PingRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("ping -> params: {:?}", params);
let result = json!({});
self.send_response::<mcp_types::PingRequest>(id, result);
self.send_response::<mcp_types::PingRequest>(id, result)
.await;
}
fn handle_list_resources(
@@ -307,7 +265,7 @@ impl MessageProcessor {
tracing::info!("prompts/get -> params: {:?}", params);
}
fn handle_list_tools(
async fn handle_list_tools(
&self,
id: RequestId,
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
@@ -318,10 +276,11 @@ impl MessageProcessor {
next_cursor: None,
};
self.send_response::<mcp_types::ListToolsRequest>(id, result);
self.send_response::<mcp_types::ListToolsRequest>(id, result)
.await;
}
fn handle_call_tool(
async fn handle_call_tool(
&self,
id: RequestId,
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
@@ -333,14 +292,16 @@ impl MessageProcessor {
if name != "codex" {
// Tool not found return error result so the LLM can react.
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: format!("Unknown tool '{name}'"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result);
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
@@ -350,7 +311,7 @@ impl MessageProcessor {
Ok(cfg) => cfg,
Err(e) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_owned(),
text: format!(
"Failed to load Codex configuration from overrides: {e}"
@@ -358,27 +319,31 @@ impl MessageProcessor {
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result);
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
},
Err(e) => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_owned(),
text: format!("Failed to parse configuration for Codex tool: {e}"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result);
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
},
None => {
let result = CallToolResult {
content: vec![CallToolResultContent::TextContent(TextContent {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text:
"Missing arguments for codex tool-call; the `prompt` field is required."
@@ -386,8 +351,10 @@ impl MessageProcessor {
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result);
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
};
@@ -398,7 +365,7 @@ impl MessageProcessor {
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
task::spawn(async move {
// Run the Codex session and stream events back to the client.
// Run the Codex session and stream events Fck to the client.
crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing)
.await;
});

View File

@@ -0,0 +1,165 @@
use std::collections::HashMap;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use codex_core::protocol::Event;
use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError;
use mcp_types::JSONRPCMessage;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::JSONRPCResponse;
use mcp_types::RequestId;
use mcp_types::Result;
use serde::Serialize;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::warn;
pub(crate) struct OutgoingMessageSender {
next_request_id: AtomicI64,
sender: mpsc::Sender<OutgoingMessage>,
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
}
impl OutgoingMessageSender {
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
Self {
next_request_id: AtomicI64::new(0),
sender,
request_id_to_callback: Mutex::new(HashMap::new()),
}
}
pub(crate) async fn send_request(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> oneshot::Receiver<Result> {
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let outgoing_message_id = id.clone();
let (tx_approve, rx_approve) = oneshot::channel();
{
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
request_id_to_callback.insert(id, tx_approve);
}
let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
id: outgoing_message_id,
method: method.to_string(),
params,
});
let _ = self.sender.send(outgoing_message).await;
rx_approve
}
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
let entry = {
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
request_id_to_callback.remove_entry(&id)
};
match entry {
Some((id, sender)) => {
if let Err(err) = sender.send(result) {
warn!("could not notify callback for {id:?} due to: {err:?}");
}
}
None => {
warn!("could not find callback for {id:?}");
}
}
}
pub(crate) async fn send_response(&self, id: RequestId, result: Result) {
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
let _ = self.sender.send(outgoing_message).await;
}
pub(crate) async fn send_event_as_notification(&self, event: &Event) {
#[expect(clippy::expect_used)]
let params = Some(serde_json::to_value(event).expect("Event must serialize"));
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
method: "codex/event".to_string(),
params,
});
let _ = self.sender.send(outgoing_message).await;
}
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
let _ = self.sender.send(outgoing_message).await;
}
}
/// Outgoing message from the server to the client.
pub(crate) enum OutgoingMessage {
Request(OutgoingRequest),
Notification(OutgoingNotification),
Response(OutgoingResponse),
Error(OutgoingError),
}
impl From<OutgoingMessage> for JSONRPCMessage {
fn from(val: OutgoingMessage) -> Self {
use OutgoingMessage::*;
match val {
Request(OutgoingRequest { id, method, params }) => {
JSONRPCMessage::Request(JSONRPCRequest {
jsonrpc: JSONRPC_VERSION.into(),
id,
method,
params,
})
}
Notification(OutgoingNotification { method, params }) => {
JSONRPCMessage::Notification(JSONRPCNotification {
jsonrpc: JSONRPC_VERSION.into(),
method,
params,
})
}
Response(OutgoingResponse { id, result }) => {
JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id,
result,
})
}
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
jsonrpc: JSONRPC_VERSION.into(),
id,
error,
}),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingRequest {
pub id: RequestId,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingNotification {
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingResponse {
pub id: RequestId,
pub result: Result,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingError {
pub error: JSONRPCErrorError,
pub id: RequestId,
}

View File

@@ -2,7 +2,7 @@
Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types.
As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic:
As documented on https://modelcontextprotocol.io/specification/2025-06-18/basic:
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.json

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# flake8: noqa: E501
import argparse
import json
import subprocess
import sys
@@ -13,10 +14,13 @@ from pathlib import Path
# Helper first so it is defined when other functions call it.
from typing import Any, Literal
SCHEMA_VERSION = "2025-03-26"
SCHEMA_VERSION = "2025-06-18"
JSONRPC_VERSION = "2.0"
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n"
STANDARD_HASHABLE_DERIVE = (
"#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n"
)
# Will be populated with the schema's `definitions` map in `main()` so that
# helper functions (for example `define_any_of`) can perform look-ups while
@@ -26,19 +30,27 @@ DEFINITIONS: dict[str, Any] = {}
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
# Concrete *Notification types that make up the ServerNotification enum.
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
# Enum types that will need a `allow(clippy::large_enum_variant)` annotation in
# order to compile without warnings.
LARGE_ENUMS = {"ServerResult"}
def main() -> int:
num_args = len(sys.argv)
if num_args == 1:
schema_file = (
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
)
elif num_args == 2:
schema_file = Path(sys.argv[1])
else:
print("Usage: python3 codegen.py <schema.json>")
return 1
parser = argparse.ArgumentParser(
description="Embed, cluster and analyse text prompts via the OpenAI API.",
)
default_schema_file = (
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
)
parser.add_argument(
"schema_file",
nargs="?",
default=default_schema_file,
help="schema.json file to process",
)
args = parser.parse_args()
schema_file = args.schema_file
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
@@ -197,6 +209,8 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
if name.endswith("Result"):
out.extend(f"impl From<{name}> for serde_json::Value {{\n")
out.append(f" fn from(value: {name}) -> Self {{\n")
out.append(" // Leave this as it should never fail\n")
out.append(" #[expect(clippy::unwrap_used)]\n")
out.append(" serde_json::to_value(value).unwrap()\n")
out.append(" }\n")
out.append("}\n\n")
@@ -211,20 +225,7 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
any_of = definition.get("anyOf", [])
if any_of:
assert isinstance(any_of, list)
if name == "JSONRPCMessage":
# Special case for JSONRPCMessage because its definition in the
# JSON schema does not quite match how we think about this type
# definition in Rust.
deep_copied_any_of = json.loads(json.dumps(any_of))
deep_copied_any_of[2] = {
"$ref": "#/definitions/JSONRPCBatchRequest",
}
deep_copied_any_of[5] = {
"$ref": "#/definitions/JSONRPCBatchResponse",
}
out.extend(define_any_of(name, deep_copied_any_of, description))
else:
out.extend(define_any_of(name, any_of, description))
out.extend(define_any_of(name, any_of, description))
return
type_prop = definition.get("type", None)
@@ -393,7 +394,7 @@ def define_string_enum(
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
out.append(STANDARD_DERIVE)
out.append(STANDARD_HASHABLE_DERIVE)
out.append("#[serde(untagged)]\n")
out.append(f"pub enum {name} {{\n")
for simple_type in type_list:
@@ -439,6 +440,8 @@ def define_any_of(
if serde := get_serde_annotation_for_anyof_type(name):
out.append(serde + "\n")
if name in LARGE_ENUMS:
out.append("#[allow(clippy::large_enum_variant)]\n")
out.append(f"pub enum {name} {{\n")
if name == "ClientRequest":
@@ -596,6 +599,8 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
prop_name = "r#type"
elif name == "ref":
prop_name = "r#ref"
elif name == "enum":
prop_name = "r#enum"
elif snake_case := to_snake_case(name):
prop_name = snake_case
is_rename = True

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,7 @@ use serde::Serialize;
use serde::de::DeserializeOwned;
use std::convert::TryFrom;
pub const MCP_SCHEMA_VERSION: &str = "2025-03-26";
pub const MCP_SCHEMA_VERSION: &str = "2025-06-18";
pub const JSONRPC_VERSION: &str = "2.0";
/// Paired request/response types for the Model Context Protocol (MCP).
@@ -35,6 +35,12 @@ fn default_jsonrpc() -> String {
pub struct Annotations {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audience: Option<Vec<Role>>,
#[serde(
rename = "lastModified",
default,
skip_serializing_if = "Option::is_none"
)]
pub last_modified: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<f64>,
}
@@ -50,6 +56,14 @@ pub struct AudioContent {
pub r#type: String, // &'static str = "audio"
}
/// Base interface for metadata with name (identifier) and title (display name) properties.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct BaseMetadata {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct BlobResourceContents {
pub blob: String,
@@ -58,6 +72,17 @@ pub struct BlobResourceContents {
pub uri: String,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct BooleanSchema {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String, // &'static str = "boolean"
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum CallToolRequest {}
@@ -75,29 +100,17 @@ pub struct CallToolRequestParams {
}
/// The server's response to a tool call.
///
/// Any errors that originate from the tool SHOULD be reported inside the result
/// object, with `isError` set to true, _not_ as an MCP protocol-level error
/// response. Otherwise, the LLM would not be able to see that an error occurred
/// and self-correct.
///
/// However, any errors in _finding_ the tool, an error indicating that the
/// server does not support tool calls, or any other exceptional conditions,
/// should be reported as an MCP error response.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CallToolResult {
pub content: Vec<CallToolResultContent>,
pub content: Vec<ContentBlock>,
#[serde(rename = "isError", default, skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum CallToolResultContent {
TextContent(TextContent),
ImageContent(ImageContent),
AudioContent(AudioContent),
EmbeddedResource(EmbeddedResource),
#[serde(
rename = "structuredContent",
default,
skip_serializing_if = "Option::is_none"
)]
pub structured_content: Option<serde_json::Value>,
}
impl From<CallToolResult> for serde_json::Value {
@@ -127,6 +140,8 @@ pub struct CancelledNotificationParams {
/// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ClientCapabilities {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub elicitation: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub experimental: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
@@ -194,6 +209,7 @@ pub enum ClientResult {
Result(Result),
CreateMessageResult(CreateMessageResult),
ListRootsResult(ListRootsResult),
ElicitResult(ElicitResult),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
@@ -208,9 +224,18 @@ impl ModelContextProtocolRequest for CompleteRequest {
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CompleteRequestParams {
pub argument: CompleteRequestParamsArgument,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context: Option<CompleteRequestParamsContext>,
pub r#ref: CompleteRequestParamsRef,
}
/// Additional, optional context for completions
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CompleteRequestParamsContext {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub arguments: Option<serde_json::Value>,
}
/// The argument's information
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CompleteRequestParamsArgument {
@@ -222,7 +247,7 @@ pub struct CompleteRequestParamsArgument {
#[serde(untagged)]
pub enum CompleteRequestParamsRef {
PromptReference(PromptReference),
ResourceReference(ResourceReference),
ResourceTemplateReference(ResourceTemplateReference),
}
/// The server's response to a completion/complete request
@@ -248,6 +273,16 @@ impl From<CompleteResult> for serde_json::Value {
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ContentBlock {
TextContent(TextContent),
ImageContent(ImageContent),
AudioContent(AudioContent),
ResourceLink(ResourceLink),
EmbeddedResource(EmbeddedResource),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum CreateMessageRequest {}
@@ -325,6 +360,48 @@ impl From<CreateMessageResult> for serde_json::Value {
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Cursor(String);
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum ElicitRequest {}
impl ModelContextProtocolRequest for ElicitRequest {
const METHOD: &'static str = "elicitation/create";
type Params = ElicitRequestParams;
type Result = ElicitResult;
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ElicitRequestParams {
pub message: String,
#[serde(rename = "requestedSchema")]
pub requested_schema: ElicitRequestParamsRequestedSchema,
}
/// A restricted subset of JSON Schema.
/// Only top-level properties are allowed, without nesting.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ElicitRequestParamsRequestedSchema {
pub properties: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
pub r#type: String, // &'static str = "object"
}
/// The client's response to an elicitation request.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ElicitResult {
pub action: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<serde_json::Value>,
}
impl From<ElicitResult> for serde_json::Value {
fn from(value: ElicitResult) -> Self {
// Leave this as it should never fail
#[expect(clippy::unwrap_used)]
serde_json::to_value(value).unwrap()
}
}
/// The contents of a resource, embedded into a prompt or tool call result.
///
/// It is up to the client how best to render embedded resources for the benefit
@@ -346,6 +423,18 @@ pub enum EmbeddedResourceResource {
pub type EmptyResult = Result;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct EnumSchema {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub r#enum: Vec<String>,
#[serde(rename = "enumNames", default, skip_serializing_if = "Option::is_none")]
pub enum_names: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String, // &'static str = "string"
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum GetPromptRequest {}
@@ -389,10 +478,12 @@ pub struct ImageContent {
pub r#type: String, // &'static str = "image"
}
/// Describes the name and version of an MCP implementation.
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Implementation {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub version: String,
}
@@ -442,24 +533,6 @@ impl ModelContextProtocolNotification for InitializedNotification {
type Params = Option<serde_json::Value>;
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum JSONRPCBatchRequestItem {
JSONRPCRequest(JSONRPCRequest),
JSONRPCNotification(JSONRPCNotification),
}
pub type JSONRPCBatchRequest = Vec<JSONRPCBatchRequestItem>;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum JSONRPCBatchResponseItem {
JSONRPCResponse(JSONRPCResponse),
JSONRPCError(JSONRPCError),
}
pub type JSONRPCBatchResponse = Vec<JSONRPCBatchResponseItem>;
/// A response to a request that indicates an error occurred.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct JSONRPCError {
@@ -483,10 +556,8 @@ pub struct JSONRPCErrorError {
pub enum JSONRPCMessage {
Request(JSONRPCRequest),
Notification(JSONRPCNotification),
BatchRequest(JSONRPCBatchRequest),
Response(JSONRPCResponse),
Error(JSONRPCError),
BatchResponse(JSONRPCBatchResponse),
}
/// A notification which does not expect a response.
@@ -777,6 +848,19 @@ pub struct Notification {
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct NumberSchema {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub maximum: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub minimum: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct PaginatedRequest {
pub method: String,
@@ -817,6 +901,17 @@ impl ModelContextProtocolRequest for PingRequest {
type Result = Result;
}
/// Restricted schema definitions that only allow primitive types
/// without nested objects or arrays.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum PrimitiveSchemaDefinition {
StringSchema(StringSchema),
NumberSchema(NumberSchema),
BooleanSchema(BooleanSchema),
EnumSchema(EnumSchema),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum ProgressNotification {}
@@ -836,7 +931,7 @@ pub struct ProgressNotificationParams {
pub total: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
#[serde(untagged)]
pub enum ProgressToken {
String(String),
@@ -851,6 +946,8 @@ pub struct Prompt {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
}
/// Describes an argument that a prompt can accept.
@@ -861,6 +958,8 @@ pub struct PromptArgument {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
@@ -877,23 +976,16 @@ impl ModelContextProtocolNotification for PromptListChangedNotification {
/// resources from the MCP server.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct PromptMessage {
pub content: PromptMessageContent,
pub content: ContentBlock,
pub role: Role,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum PromptMessageContent {
TextContent(TextContent),
ImageContent(ImageContent),
AudioContent(AudioContent),
EmbeddedResource(EmbeddedResource),
}
/// Identifies a prompt.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct PromptReference {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String, // &'static str = "ref/prompt"
}
@@ -939,7 +1031,7 @@ pub struct Request {
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
#[serde(untagged)]
pub enum RequestId {
String(String),
@@ -958,6 +1050,8 @@ pub struct Resource {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub size: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub uri: String,
}
@@ -969,6 +1063,26 @@ pub struct ResourceContents {
pub uri: String,
}
/// A resource that the server is capable of reading, included in a prompt or tool call result.
///
/// Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ResourceLink {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub annotations: Option<Annotations>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub size: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String, // &'static str = "resource_link"
pub uri: String,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum ResourceListChangedNotification {}
@@ -977,13 +1091,6 @@ impl ModelContextProtocolNotification for ResourceListChangedNotification {
type Params = Option<serde_json::Value>;
}
/// A reference to a resource or resource template definition.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ResourceReference {
pub r#type: String, // &'static str = "ref/resource"
pub uri: String,
}
/// A template description for resources available on the server.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ResourceTemplate {
@@ -994,10 +1101,19 @@ pub struct ResourceTemplate {
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(rename = "uriTemplate")]
pub uri_template: String,
}
/// A reference to a resource or resource template definition.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ResourceTemplateReference {
pub r#type: String, // &'static str = "ref/resource"
pub uri: String,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum ResourceUpdatedNotification {}
@@ -1140,6 +1256,7 @@ pub enum ServerRequest {
PingRequest(PingRequest),
CreateMessageRequest(CreateMessageRequest),
ListRootsRequest(ListRootsRequest),
ElicitRequest(ElicitRequest),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
@@ -1172,6 +1289,21 @@ pub struct SetLevelRequestParams {
pub level: LoggingLevel,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct StringSchema {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
#[serde(rename = "maxLength", default, skip_serializing_if = "Option::is_none")]
pub max_length: Option<i64>,
#[serde(rename = "minLength", default, skip_serializing_if = "Option::is_none")]
pub min_length: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub r#type: String, // &'static str = "string"
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum SubscribeRequest {}
@@ -1213,6 +1345,25 @@ pub struct Tool {
#[serde(rename = "inputSchema")]
pub input_schema: ToolInputSchema,
pub name: String,
#[serde(
rename = "outputSchema",
default,
skip_serializing_if = "Option::is_none"
)]
pub output_schema: Option<ToolOutputSchema>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
}
/// An optional JSON Schema object defining the structure of the tool's output returned in
/// the structuredContent field of a CallToolResult.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ToolOutputSchema {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub properties: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
pub r#type: String, // &'static str = "object"
}
/// A JSON Schema object defining the expected parameters for the tool.

View File

@@ -17,8 +17,8 @@ fn deserialize_initialize_request() {
"method": "initialize",
"params": {
"capabilities": {},
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
"protocolVersion": "2025-03-26"
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
"protocolVersion": "2025-06-18"
}
}"#;
@@ -37,8 +37,8 @@ fn deserialize_initialize_request() {
method: "initialize".into(),
params: Some(json!({
"capabilities": {},
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
"protocolVersion": "2025-03-26"
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
"protocolVersion": "2025-06-18"
})),
};
@@ -57,12 +57,14 @@ fn deserialize_initialize_request() {
experimental: None,
roots: None,
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: "acme-client".into(),
title: Some("Acme".to_string()),
version: "1.2.3".into(),
},
protocol_version: "2025-03-26".into(),
protocol_version: "2025-06-18".into(),
}
);
}

View File

@@ -19,7 +19,8 @@ use crossterm::event::MouseEvent;
use crossterm::event::MouseEventKind;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::channel;
use std::thread;
@@ -54,7 +55,7 @@ pub(crate) struct App<'a> {
file_search: FileSearchManager,
/// True when a redraw has been scheduled but not yet executed.
pending_redraw: Arc<Mutex<bool>>,
pending_redraw: Arc<AtomicBool>,
/// Stored parameters needed to instantiate the ChatWidget later, e.g.,
/// after dismissing the Git-repo warning.
@@ -80,7 +81,7 @@ impl App<'_> {
) -> Self {
let (app_event_tx, app_event_rx) = channel();
let app_event_tx = AppEventSender::new(app_event_tx);
let pending_redraw = Arc::new(Mutex::new(false));
let pending_redraw = Arc::new(AtomicBool::new(false));
let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone());
// Spawn a dedicated thread for reading the crossterm event loop and
@@ -177,13 +178,14 @@ impl App<'_> {
/// Schedule a redraw if one is not already pending.
#[allow(clippy::unwrap_used)]
fn schedule_redraw(&self) {
// Attempt to set the flag to `true`. If it was already `true`, another
// redraw is already pending so we can return early.
if self
.pending_redraw
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
#[allow(clippy::unwrap_used)]
let mut flag = self.pending_redraw.lock().unwrap();
if *flag {
return;
}
*flag = true;
return;
}
let tx = self.app_event_tx.clone();
@@ -191,9 +193,7 @@ impl App<'_> {
thread::spawn(move || {
thread::sleep(REDRAW_DEBOUNCE);
tx.send(AppEvent::Redraw);
#[allow(clippy::unwrap_used)]
let mut f = pending_redraw.lock().unwrap();
*f = false;
pending_redraw.store(false, Ordering::SeqCst);
});
}

View File

@@ -464,6 +464,8 @@ impl ChatWidget<'_> {
if self.bottom_pane.is_task_running() {
self.bottom_pane.clear_ctrl_c_quit_hint();
self.submit_op(Op::Interrupt);
self.answer_buffer.clear();
self.reasoning_buffer.clear();
false
} else if self.bottom_pane.ctrl_c_quit_hint_visible() {
true

View File

@@ -380,7 +380,7 @@ impl WidgetRef for ConversationHistoryWidget {
let block = Block::default()
.title(title)
.borders(Borders::ALL)
.borders(Borders::NONE)
.border_type(BorderType::Rounded)
.border_style(border_style);
@@ -391,9 +391,9 @@ impl WidgetRef for ConversationHistoryWidget {
// Cache (and if necessary recalculate) the wrapped line counts for every
// [`HistoryCell`] so that our scrolling math accounts for text
// wrapping. We always reserve one column on the right-hand side for the
// scrollbar so that the content never renders "under" the scrollbar.
let effective_width = inner.width.saturating_sub(1);
// wrapping. The full inner width is now used because the scrollbar has
// been disabled.
let effective_width = inner.width;
if effective_width == 0 {
return; // Nothing to draw avoid division by zero.
@@ -486,48 +486,7 @@ impl WidgetRef for ConversationHistoryWidget {
}
}
// Always render a scrollbar *track* so the reserved column is filled.
let overflow = num_lines.saturating_sub(viewport_height);
let mut scroll_state = ScrollbarState::default()
// The Scrollbar widget expects the *content* height minus the
// viewport height. When there is no overflow we still provide 0
// so that the widget renders only the track without a thumb.
.content_length(overflow)
.position(scroll_pos);
{
// Choose a thumb color that stands out only when this pane has focus so that the
// user's attention is naturally drawn to the active viewport. When unfocused we show
// a low-contrast thumb so the scrollbar fades into the background without becoming
// invisible.
let thumb_style = if self.has_input_focus {
Style::reset().fg(Color::LightYellow)
} else {
Style::reset().fg(Color::Gray)
};
// By default the Scrollbar widget inherits any style that was
// present in the underlying buffer cells. That means if a colored
// line happens to be underneath the scrollbar, the track (and
// potentially the thumb) adopt that color. Explicitly setting the
// track/thumb styles ensures we always draw the scrollbar with a
// consistent palette regardless of what content is behind it.
StatefulWidget::render(
Scrollbar::new(ScrollbarOrientation::VerticalRight)
.begin_symbol(Some(""))
.end_symbol(Some(""))
.begin_style(Style::reset().fg(Color::DarkGray))
.end_style(Style::reset().fg(Color::DarkGray))
.thumb_symbol("")
.thumb_style(thumb_style)
.track_symbol(Some(""))
.track_style(Style::reset().fg(Color::DarkGray)),
inner,
buf,
&mut scroll_state,
);
}
// Scrollbar intentionally disabled: scrolling still functions via key / mouse events.
// Update auxiliary stats that the scroll handlers rely on.
self.num_rendered_lines.set(num_lines);

View File

@@ -17,6 +17,7 @@ use image::GenericImageView;
use image::ImageReader;
use lazy_static::lazy_static;
use mcp_types::EmbeddedResourceResource;
use mcp_types::ResourceLink;
use ratatui::prelude::*;
use ratatui::style::Color;
use ratatui::style::Modifier;
@@ -331,8 +332,7 @@ impl HistoryCell {
) -> Option<Self> {
match result {
Ok(mcp_types::CallToolResult { content, .. }) => {
if let Some(mcp_types::CallToolResultContent::ImageContent(image)) = content.first()
{
if let Some(mcp_types::ContentBlock::ImageContent(image)) = content.first() {
let raw_data =
match base64::engine::general_purpose::STANDARD.decode(&image.data) {
Ok(data) => data,
@@ -405,21 +405,21 @@ impl HistoryCell {
for tool_call_result in content {
let line_text = match tool_call_result {
mcp_types::CallToolResultContent::TextContent(text) => {
mcp_types::ContentBlock::TextContent(text) => {
format_and_truncate_tool_result(
&text.text,
TOOL_CALL_MAX_LINES,
num_cols as usize,
)
}
mcp_types::CallToolResultContent::ImageContent(_) => {
mcp_types::ContentBlock::ImageContent(_) => {
// TODO show images even if they're not the first result, will require a refactor of `CompletedMcpToolCall`
"<image content>".to_string()
}
mcp_types::CallToolResultContent::AudioContent(_) => {
mcp_types::ContentBlock::AudioContent(_) => {
"<audio content>".to_string()
}
mcp_types::CallToolResultContent::EmbeddedResource(resource) => {
mcp_types::ContentBlock::EmbeddedResource(resource) => {
let uri = match resource.resource {
EmbeddedResourceResource::TextResourceContents(text) => {
text.uri
@@ -430,6 +430,9 @@ impl HistoryCell {
};
format!("embedded resource: {uri}")
}
mcp_types::ContentBlock::ResourceLink(ResourceLink { uri, .. }) => {
format!("link: {uri}")
}
};
lines.push(Line::styled(line_text, Style::default().fg(Color::Gray)));
}