mirror of
https://github.com/openai/codex.git
synced 2026-04-22 07:21:46 +03:00
Compare commits
5 Commits
ollama
...
planning-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c274cf279 | ||
|
|
97581a86a3 | ||
|
|
658f2d677f | ||
|
|
18619dbbc1 | ||
|
|
c66fd9d59a |
@@ -147,8 +147,4 @@ const READ_ONLY_SEATBELT_POLICY = `
|
||||
(sysctl-name "kern.version")
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)`.trim();
|
||||
)`.trim();
|
||||
|
||||
65
codex-rs/Cargo.lock
generated
65
codex-rs/Cargo.lock
generated
@@ -671,7 +671,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-channel",
|
||||
"async-stream",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -700,7 +699,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"shlex",
|
||||
"similar",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
@@ -709,7 +707,6 @@ dependencies = [
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml 0.9.2",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
@@ -862,7 +859,6 @@ dependencies = [
|
||||
"mcp-types",
|
||||
"path-clean",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
"ratatui",
|
||||
"ratatui-image",
|
||||
"regex-lite",
|
||||
@@ -872,14 +868,13 @@ dependencies = [
|
||||
"shlex",
|
||||
"strum 0.27.2",
|
||||
"strum_macros 0.27.2",
|
||||
"supports-color",
|
||||
"textwrap 0.16.2",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tui-input",
|
||||
"tui-markdown",
|
||||
"tui-textarea",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
"uuid",
|
||||
@@ -2341,12 +2336,6 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_ci"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7655c9839580ee829dfacba1d1278c2b7883e50a277ff7541299489d6bdfdc45"
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
@@ -4184,12 +4173,6 @@ version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "smawk"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.10"
|
||||
@@ -4252,7 +4235,7 @@ dependencies = [
|
||||
"starlark_syntax",
|
||||
"static_assertions",
|
||||
"strsim 0.10.0",
|
||||
"textwrap 0.11.0",
|
||||
"textwrap",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
@@ -4388,15 +4371,6 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||
|
||||
[[package]]
|
||||
name = "supports-color"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c64fc7232dd8d2e4ac5ce4ef302b1d81e0b80d055b9d77c7c4f51f6aa4c867d6"
|
||||
dependencies = [
|
||||
"is_ci",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
@@ -4550,17 +4524,6 @@ dependencies = [
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
|
||||
dependencies = [
|
||||
"smawk",
|
||||
"unicode-linebreak",
|
||||
"unicode-width 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@@ -4816,7 +4779,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_spanned 0.6.9",
|
||||
"toml_datetime 0.6.11",
|
||||
"toml_write",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
@@ -4829,12 +4791,6 @@ dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_write"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.2"
|
||||
@@ -5032,6 +4988,17 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tui-textarea"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a5318dd619ed73c52a9417ad19046724effc1287fb75cdcc4eca1d6ac1acbae"
|
||||
dependencies = [
|
||||
"crossterm",
|
||||
"ratatui",
|
||||
"unicode-width 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
@@ -5050,12 +5017,6 @@ version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-linebreak"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
Codex supports several mechanisms for setting config values:
|
||||
|
||||
- Config-specific command-line flags, such as `--model o3` (highest precedence).
|
||||
- Convenience provider flags, such as `--ollama` (equivalent to `-c model_provider=ollama`).
|
||||
- A generic `-c`/`--config` flag that takes a `key=value` pair, such as `--config model="o3"`.
|
||||
- The key can contain dots to set a value deeper than the root, e.g. `--config model_providers.openai.wire_api="chat"`.
|
||||
- Values can contain objects, such as `--config shell_environment_policy.include_only=["PATH", "HOME", "USER"]`.
|
||||
@@ -57,13 +56,6 @@ name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
```
|
||||
|
||||
Alternatively, you can pass `--ollama` on the CLI, which is equivalent to `-c model_provider=ollama`.
|
||||
When using `--ollama`, Codex will verify that an Ollama server is running locally and
|
||||
will create a `[model_providers.ollama]` entry in your `config.toml` with sensible defaults
|
||||
(`base_url = "http://localhost:11434/v1"`, `wire_api = "chat"`) if one does not already exist.
|
||||
If no running Ollama server is detected, Codex will print instructions to install/start Ollama
|
||||
and exit: https://github.com/ollama/ollama?tab=readme-ov-file#ollama
|
||||
|
||||
Or a third-party provider (using a distinct environment variable for the API key):
|
||||
|
||||
```toml
|
||||
@@ -251,6 +243,21 @@ By default, `reasoning` is only set on requests to OpenAI models that are known
|
||||
model_supports_reasoning_summaries = true
|
||||
```
|
||||
|
||||
## experimental_include_plan_tool
|
||||
|
||||
Controls whether to expose the experimental plan tool (named `update_plan`) to the model and include the corresponding guidance in the system prompt.
|
||||
|
||||
Default behavior:
|
||||
- For known models (anything hardcoded in the models list), this is disabled by default.
|
||||
- For unknown models whose name starts with "gpt-", this is enabled by default so new GPT-family models get the feature without a CLI update.
|
||||
|
||||
When enabled, the model can call `update_plan` to keep an up-to-date, step-by-step plan for the task and Codex will render plan updates in the UI. When disabled, the tool is not advertised to the model and the “Plan updates” section is omitted from the prompt; any unsolicited `update_plan` calls will be treated as unsupported.
|
||||
|
||||
```toml
|
||||
# Enable the experimental plan tool and prompt instructions
|
||||
experimental_include_plan_tool = true
|
||||
```
|
||||
|
||||
## sandbox_mode
|
||||
|
||||
Codex executes model-generated shell commands inside an OS-level sandbox.
|
||||
|
||||
@@ -13,7 +13,6 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3"
|
||||
base64 = "0.22"
|
||||
bytes = "1.10.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
@@ -31,11 +30,10 @@ mime_guess = "2.0"
|
||||
rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
serde_bytes = "0.11"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
strum_macros = "0.27.2"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
@@ -48,7 +46,6 @@ tokio = { version = "1", features = [
|
||||
] }
|
||||
tokio-util = "0.7.14"
|
||||
toml = "0.9.2"
|
||||
toml_edit = "0.22"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
|
||||
@@ -97,6 +97,7 @@ You can invoke apply_patch like:
|
||||
shell {"command":["apply_patch","*** Begin Patch\n*** Add File: hello.txt\n+Hello, world!\n*** End Patch\n"]}
|
||||
```
|
||||
|
||||
<!-- PLAN_TOOL:START -->
|
||||
Plan updates
|
||||
|
||||
A tool named `update_plan` is available. Use it to keep an up‑to‑date, step‑by‑step plan for the task so you can follow your progress. When making your plans, keep in mind that you are a deployed coding agent - `update_plan` calls should not involve doing anything that you aren't capable of doing. For example, `update_plan` calls should NEVER contain tasks to merge your own pull requests. Only stop to ask the user if you genuinely need their feedback on a change.
|
||||
@@ -105,3 +106,4 @@ A tool named `update_plan` is available. Use it to keep an up‑to‑date, step
|
||||
- Whenever you finish a step, call `update_plan` again, marking the finished step as `completed` and the next step as `in_progress`.
|
||||
- If your plan needs to change, call `update_plan` with the revised steps and include an `explanation` describing the change.
|
||||
- When all steps are complete, make a final `update_plan` call with all steps marked `completed`.
|
||||
<!-- PLAN_TOOL:END -->
|
||||
|
||||
@@ -37,7 +37,8 @@ pub(crate) async fn stream_chat_completions(
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model);
|
||||
let instr_cfg = crate::client_common::InstructionsConfig::for_model(model, include_plan_tool);
|
||||
let full_instructions = prompt.get_full_instructions(&instr_cfg);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
if let Some(instr) = &prompt.user_instructions {
|
||||
|
||||
@@ -62,14 +62,10 @@ impl ModelClient {
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
config,
|
||||
auth,
|
||||
client,
|
||||
client: reqwest::Client::new(),
|
||||
provider,
|
||||
session_id,
|
||||
effort,
|
||||
@@ -145,7 +141,11 @@ impl ModelClient {
|
||||
|
||||
let token = auth.get_token().await?;
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
let instr_cfg = crate::client_common::InstructionsConfig::for_model(
|
||||
&self.config.model,
|
||||
self.config.include_plan_tool,
|
||||
);
|
||||
let full_instructions = prompt.get_full_instructions(&instr_cfg);
|
||||
let tools_json = create_tools_json_for_responses_api(
|
||||
prompt,
|
||||
&self.config.model,
|
||||
|
||||
@@ -37,15 +37,60 @@ pub struct Prompt {
|
||||
pub base_instructions_override: Option<String>,
|
||||
}
|
||||
|
||||
/// Options that influence how the full instructions are composed for a request.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct InstructionsConfig {
|
||||
pub include_plan_tool: bool,
|
||||
pub extra_sections: Vec<&'static str>,
|
||||
}
|
||||
|
||||
impl InstructionsConfig {
|
||||
pub fn for_model(model: &str, include_plan_tool: bool) -> Self {
|
||||
let mut extra_sections = Vec::new();
|
||||
if model.starts_with("gpt-4.1") {
|
||||
extra_sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS);
|
||||
}
|
||||
Self {
|
||||
include_plan_tool,
|
||||
extra_sections,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
|
||||
let base = self
|
||||
pub(crate) fn get_full_instructions(&self, cfg: &InstructionsConfig) -> Cow<'_, str> {
|
||||
let mut base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(BASE_INSTRUCTIONS);
|
||||
let mut sections: Vec<&str> = vec![base];
|
||||
if model.starts_with("gpt-4.1") {
|
||||
sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS);
|
||||
.unwrap_or(BASE_INSTRUCTIONS)
|
||||
.to_string();
|
||||
|
||||
if !cfg.include_plan_tool {
|
||||
// Remove the plan-tool section if present. Prefer explicit markers
|
||||
// for robustness, but fall back to trimming from the "Plan updates"
|
||||
// heading if markers are missing.
|
||||
let start_marker = "<!-- PLAN_TOOL:START -->";
|
||||
let end_marker = "<!-- PLAN_TOOL:END -->";
|
||||
if let (Some(start), Some(end)) = (base.find(start_marker), base.find(end_marker)) {
|
||||
if end > start {
|
||||
let mut edited = String::with_capacity(base.len());
|
||||
edited.push_str(&base[..start]);
|
||||
edited.push_str(&base[end + end_marker.len()..]);
|
||||
base = edited;
|
||||
}
|
||||
} else if let Some(idx) = base
|
||||
.find("\n\nPlan updates")
|
||||
.or_else(|| base.find("\nPlan updates"))
|
||||
.or_else(|| base.find("Plan updates"))
|
||||
{
|
||||
base.truncate(idx);
|
||||
}
|
||||
base = base.trim_end().to_string();
|
||||
}
|
||||
|
||||
let mut sections: Vec<&str> = vec![&base];
|
||||
for s in &cfg.extra_sections {
|
||||
sections.push(s);
|
||||
}
|
||||
Cow::Owned(sections.join("\n"))
|
||||
}
|
||||
@@ -197,7 +242,18 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}");
|
||||
let full = prompt.get_full_instructions("gpt-4.1");
|
||||
let cfg = InstructionsConfig::for_model("gpt-4.1", true);
|
||||
let full = prompt.get_full_instructions(&cfg);
|
||||
assert_eq!(full, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_section_removed_when_disabled() {
|
||||
let prompt = Prompt::default();
|
||||
let cfg = InstructionsConfig::for_model("gpt-4.1", false);
|
||||
let full = prompt.get_full_instructions(&cfg);
|
||||
assert!(!full.contains("Plan updates"));
|
||||
assert!(!full.contains("update_plan"));
|
||||
assert!(full.contains(APPLY_PATCH_TOOL_INSTRUCTIONS));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,13 +85,11 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::protocol::Submission;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
|
||||
@@ -227,6 +225,9 @@ pub(crate) struct Session {
|
||||
state: Mutex<State>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
user_shell: shell::Shell,
|
||||
|
||||
/// Whether the experimental plan tool is enabled for this session.
|
||||
include_plan_tool: bool,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -364,11 +365,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
async fn notify_exec_command_begin(&self, exec_command_context: ExecCommandContext) {
|
||||
let ExecCommandContext {
|
||||
sub_id,
|
||||
call_id,
|
||||
@@ -380,15 +377,11 @@ impl Session {
|
||||
Some(ApplyPatchCommandContext {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => {
|
||||
turn_diff_tracker.on_patch_begin(&changes);
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
})
|
||||
}
|
||||
}) => EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}),
|
||||
None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command: command_for_display.clone(),
|
||||
@@ -402,21 +395,15 @@ impl Session {
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn on_exec_command_end(
|
||||
async fn notify_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
output: &ExecToolCallOutput,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
exit_code: i32,
|
||||
is_apply_patch: bool,
|
||||
) {
|
||||
let ExecToolCallOutput {
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
exit_code,
|
||||
} = output;
|
||||
// Because stdout and stderr could each be up to 100 KiB, we send
|
||||
// truncated versions.
|
||||
const MAX_STREAM_OUTPUT: usize = 5 * 1024; // 5KiB
|
||||
@@ -428,15 +415,14 @@ impl Session {
|
||||
call_id: call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
success: *exit_code == 0,
|
||||
success: exit_code == 0,
|
||||
})
|
||||
} else {
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
duration: *duration,
|
||||
exit_code: *exit_code,
|
||||
exit_code,
|
||||
})
|
||||
};
|
||||
|
||||
@@ -445,20 +431,6 @@ impl Session {
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.into(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
||||
@@ -822,6 +794,7 @@ async fn submission_loop(
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
disable_response_storage,
|
||||
user_shell: default_shell,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
}));
|
||||
|
||||
// Patch restored state into the newly created session.
|
||||
@@ -1032,10 +1005,6 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
.await;
|
||||
|
||||
let last_agent_message: Option<String>;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
@@ -1067,7 +1036,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
|
||||
match run_turn(&sess, sub_id.clone(), turn_input).await {
|
||||
Ok(turn_output) => {
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
@@ -1193,7 +1162,6 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
@@ -1208,7 +1176,7 @@ async fn run_turn(
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
|
||||
match try_run_turn(sess, &sub_id, &prompt).await {
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
@@ -1254,7 +1222,6 @@ struct ProcessedResponseItem {
|
||||
|
||||
async fn try_run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
@@ -1342,8 +1309,7 @@ async fn try_run_turn(
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response =
|
||||
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
@@ -1361,16 +1327,6 @@ async fn try_run_turn(
|
||||
.ok();
|
||||
}
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = sess.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
return Ok(output);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
@@ -1421,7 +1377,7 @@ async fn run_compact_task(
|
||||
let mut retries = 0;
|
||||
|
||||
loop {
|
||||
let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await;
|
||||
let attempt_result = drain_to_completed(&sess, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => break,
|
||||
@@ -1475,7 +1431,6 @@ async fn run_compact_task(
|
||||
|
||||
async fn handle_response_item(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
@@ -1513,17 +1468,7 @@ async fn handle_response_item(
|
||||
..
|
||||
} => {
|
||||
info!("FunctionCall: {arguments}");
|
||||
Some(
|
||||
handle_function_call(
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
@@ -1558,7 +1503,6 @@ async fn handle_response_item(
|
||||
handle_container_exec_with_params(
|
||||
exec_params,
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
effective_call_id,
|
||||
)
|
||||
@@ -1576,7 +1520,6 @@ async fn handle_response_item(
|
||||
|
||||
async fn handle_function_call(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
@@ -1590,10 +1533,21 @@ async fn handle_function_call(
|
||||
return *output;
|
||||
}
|
||||
};
|
||||
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
|
||||
.await
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
}
|
||||
"update_plan" => {
|
||||
if sess.include_plan_tool {
|
||||
handle_update_plan(sess, arguments, sub_id, call_id).await
|
||||
} else {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("unsupported call: {name}"),
|
||||
success: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
||||
_ => {
|
||||
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
||||
Some((server, tool_name)) => {
|
||||
@@ -1665,7 +1619,6 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
|
||||
async fn handle_container_exec_with_params(
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> ResponseInputItem {
|
||||
@@ -1813,7 +1766,7 @@ async fn handle_container_exec_with_params(
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context.clone())
|
||||
sess.notify_exec_command_begin(exec_command_context.clone())
|
||||
.await;
|
||||
|
||||
let params = maybe_run_with_user_profile(params, sess);
|
||||
@@ -1838,22 +1791,23 @@ async fn handle_container_exec_with_params(
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
} = &output;
|
||||
} = output;
|
||||
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
sess.notify_exec_command_end(
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&output,
|
||||
&stdout,
|
||||
&stderr,
|
||||
exit_code,
|
||||
exec_command_context.apply_patch.is_some(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let is_success = *exit_code == 0;
|
||||
let is_success = exit_code == 0;
|
||||
let content = format_exec_output(
|
||||
if is_success { stdout } else { stderr },
|
||||
*exit_code,
|
||||
*duration,
|
||||
if is_success { &stdout } else { &stderr },
|
||||
exit_code,
|
||||
duration,
|
||||
);
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
@@ -1865,15 +1819,7 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
}
|
||||
Err(CodexErr::Sandbox(error)) => {
|
||||
handle_sandbox_error(
|
||||
turn_diff_tracker,
|
||||
params,
|
||||
exec_command_context,
|
||||
error,
|
||||
sandbox_type,
|
||||
sess,
|
||||
)
|
||||
.await
|
||||
handle_sandbox_error(params, exec_command_context, error, sandbox_type, sess).await
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle non-sandbox errors
|
||||
@@ -1889,7 +1835,6 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
|
||||
async fn handle_sandbox_error(
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
params: ExecParams,
|
||||
exec_command_context: ExecCommandContext,
|
||||
error: SandboxErr,
|
||||
@@ -1946,8 +1891,7 @@ async fn handle_sandbox_error(
|
||||
sess.notify_background_event(&sub_id, "retrying command without sandbox")
|
||||
.await;
|
||||
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context)
|
||||
.await;
|
||||
sess.notify_exec_command_begin(exec_command_context).await;
|
||||
|
||||
// This is an escalated retry; the policy will not be
|
||||
// examined and the sandbox has been set to `None`.
|
||||
@@ -1972,22 +1916,23 @@ async fn handle_sandbox_error(
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
} = &retry_output;
|
||||
} = retry_output;
|
||||
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
sess.notify_exec_command_end(
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&retry_output,
|
||||
&stdout,
|
||||
&stderr,
|
||||
exit_code,
|
||||
is_apply_patch,
|
||||
)
|
||||
.await;
|
||||
|
||||
let is_success = *exit_code == 0;
|
||||
let is_success = exit_code == 0;
|
||||
let content = format_exec_output(
|
||||
if is_success { stdout } else { stderr },
|
||||
*exit_code,
|
||||
*duration,
|
||||
if is_success { &stdout } else { &stderr },
|
||||
exit_code,
|
||||
duration,
|
||||
);
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
@@ -2072,7 +2017,7 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
})
|
||||
}
|
||||
|
||||
async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> CodexResult<()> {
|
||||
async fn drain_to_completed(sess: &Session, prompt: &Prompt) -> CodexResult<()> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
@@ -2082,32 +2027,7 @@ async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> Co
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
// Record only to in-memory conversation history; avoid state snapshot.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
}
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
}) => {
|
||||
let token_usage = match token_usage {
|
||||
Some(usage) => usage,
|
||||
None => {
|
||||
return Err(CodexErr::Stream(
|
||||
"token_usage was None in ResponseEvent::Completed".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
sess.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(token_usage),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return Ok(());
|
||||
}
|
||||
Ok(ResponseEvent::Completed { .. }) => return Ok(()),
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
@@ -342,6 +342,10 @@ pub struct ConfigToml {
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
|
||||
/// Include an experimental plan tool that the model can use to update its current plan and status of each step.
|
||||
/// This is experimental and may be removed in the future.
|
||||
pub experimental_include_plan_tool: Option<bool>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -428,11 +432,6 @@ impl Config {
|
||||
.or(config_profile.model_provider)
|
||||
.or(cfg.model_provider)
|
||||
.unwrap_or_else(|| "openai".to_string());
|
||||
// Do not implicitly inject an Ollama provider when selected via
|
||||
// `-c model_provider=ollama`. Only the `--ollama` flag path sets up the
|
||||
// provider entry and performs discovery. This ensures parity with
|
||||
// other providers: if a provider is not defined in config.toml, we
|
||||
// return a clear error below.
|
||||
let model_provider = model_providers
|
||||
.get(&model_provider_id)
|
||||
.ok_or_else(|| {
|
||||
@@ -481,16 +480,15 @@ impl Config {
|
||||
});
|
||||
|
||||
let experimental_resume = cfg.experimental_resume;
|
||||
let is_unknown_gpt = openai_model_info.is_none() && model.starts_with("gpt-");
|
||||
|
||||
// Load base instructions override from a file if specified. If the
|
||||
// path is relative, resolve it against the effective cwd so the
|
||||
// behaviour matches other path-like config values.
|
||||
let experimental_instructions_path = config_profile
|
||||
.experimental_instructions_file
|
||||
.as_ref()
|
||||
.or(cfg.experimental_instructions_file.as_ref());
|
||||
let file_base_instructions =
|
||||
Self::get_base_instructions(experimental_instructions_path, &resolved_cwd)?;
|
||||
let file_base_instructions = Self::get_base_instructions(
|
||||
cfg.experimental_instructions_file.as_ref(),
|
||||
&resolved_cwd,
|
||||
)?;
|
||||
let base_instructions = base_instructions.or(file_base_instructions);
|
||||
|
||||
let config = Self {
|
||||
@@ -534,7 +532,7 @@ impl Config {
|
||||
|
||||
model_supports_reasoning_summaries: cfg
|
||||
.model_supports_reasoning_summaries
|
||||
.unwrap_or(false),
|
||||
.unwrap_or(is_unknown_gpt),
|
||||
|
||||
chatgpt_base_url: config_profile
|
||||
.chatgpt_base_url
|
||||
@@ -542,7 +540,9 @@ impl Config {
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
|
||||
experimental_resume,
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_plan_tool: include_plan_tool
|
||||
.or(cfg.experimental_include_plan_tool)
|
||||
.unwrap_or(is_unknown_gpt),
|
||||
internal_originator: cfg.internal_originator,
|
||||
};
|
||||
Ok(config)
|
||||
@@ -656,6 +656,73 @@ mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_plan_and_reasoning_defaults_known_vs_unknown() -> std::io::Result<()> {
|
||||
let fixture = create_test_fixture()?;
|
||||
|
||||
// Unknown GPT-like model -> defaults ON for plan tool and reasoning summaries override.
|
||||
let unknown_gpt_overrides = ConfigOverrides {
|
||||
model: Some("gpt-unknown-2025".to_string()),
|
||||
cwd: Some(fixture.cwd()),
|
||||
..Default::default()
|
||||
};
|
||||
let unknown_gpt_cfg = Config::load_from_base_config_with_overrides(
|
||||
fixture.cfg.clone(),
|
||||
unknown_gpt_overrides,
|
||||
fixture.codex_home(),
|
||||
)?;
|
||||
assert!(
|
||||
unknown_gpt_cfg.include_plan_tool,
|
||||
"plan tool should default to ON for unknown GPT-like models"
|
||||
);
|
||||
assert!(
|
||||
unknown_gpt_cfg.model_supports_reasoning_summaries,
|
||||
"reasoning summaries should default to ON for unknown GPT-like models"
|
||||
);
|
||||
|
||||
// Unknown non-GPT model -> defaults OFF for both.
|
||||
let unknown_non_gpt_overrides = ConfigOverrides {
|
||||
model: Some("my-new-model".to_string()),
|
||||
cwd: Some(fixture.cwd()),
|
||||
..Default::default()
|
||||
};
|
||||
let unknown_non_gpt_cfg = Config::load_from_base_config_with_overrides(
|
||||
fixture.cfg.clone(),
|
||||
unknown_non_gpt_overrides,
|
||||
fixture.codex_home(),
|
||||
)?;
|
||||
assert!(
|
||||
!unknown_non_gpt_cfg.include_plan_tool,
|
||||
"plan tool should default to OFF for unknown non-GPT models"
|
||||
);
|
||||
assert!(
|
||||
!unknown_non_gpt_cfg.model_supports_reasoning_summaries,
|
||||
"reasoning summaries should default to OFF for unknown non-GPT models"
|
||||
);
|
||||
|
||||
// Known model -> defaults OFF for plan tool and reasoning summaries override value.
|
||||
let known_overrides = ConfigOverrides {
|
||||
model: Some("gpt-3.5-turbo".to_string()),
|
||||
cwd: Some(fixture.cwd()),
|
||||
..Default::default()
|
||||
};
|
||||
let known_cfg = Config::load_from_base_config_with_overrides(
|
||||
fixture.cfg.clone(),
|
||||
known_overrides,
|
||||
fixture.codex_home(),
|
||||
)?;
|
||||
assert!(
|
||||
!known_cfg.include_plan_tool,
|
||||
"plan tool should default to OFF for known models"
|
||||
);
|
||||
assert!(
|
||||
!known_cfg.model_supports_reasoning_summaries,
|
||||
"reasoning summaries override should default to OFF for known models"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_toml_parsing() {
|
||||
let history_with_persistence = r#"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config_types::ReasoningEffort;
|
||||
use crate::config_types::ReasoningSummary;
|
||||
@@ -18,5 +17,4 @@ pub struct ConfigProfile {
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
@@ -102,20 +102,6 @@ pub enum CodexErr {
|
||||
|
||||
#[error("{0}")]
|
||||
EnvVar(EnvVarError),
|
||||
|
||||
// ------------------------------
|
||||
// Ollama‑specific errors
|
||||
// ------------------------------
|
||||
#[error(
|
||||
"No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama"
|
||||
)]
|
||||
OllamaServerUnreachable,
|
||||
|
||||
#[error("ollama model not found: {0}")]
|
||||
OllamaModelNotFound(String),
|
||||
|
||||
#[error("ollama pull failed: {0}")]
|
||||
OllamaPullFailed(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -140,7 +140,11 @@ pub async fn process_exec_tool_call(
|
||||
|
||||
let exit_code = raw_output.exit_status.code().unwrap_or(-1);
|
||||
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
// NOTE(ragona): This is much less restrictive than the previous check. If we exec
|
||||
// a command, and it returns anything other than success, we assume that it may have
|
||||
// been a sandboxing error and allow the user to retry. (The user of course may choose
|
||||
// not to retry, or in a non-interactive mode, would automatically reject the approval.)
|
||||
if exit_code != 0 && sandbox_type != SandboxType::None {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied(
|
||||
exit_code, stdout, stderr,
|
||||
)));
|
||||
@@ -219,26 +223,6 @@ fn create_linux_sandbox_command_args(
|
||||
linux_cmd
|
||||
}
|
||||
|
||||
/// We don't have a fully deterministic way to tell if our command failed
|
||||
/// because of the sandbox - a command in the user's zshrc file might hit an
|
||||
/// error, but the command itself might fail or succeed for other reasons.
|
||||
/// For now, we conservatively check for 'command not found' (exit code 127),
|
||||
/// and can add additional cases as necessary.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exit_code: i32) -> bool {
|
||||
if sandbox_type == SandboxType::None {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick rejects: well-known non-sandbox shell exit codes
|
||||
// 127: command not found, 2: misuse of shell builtins
|
||||
if exit_code == 127 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For all other cases, we assume the sandbox is the cause
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RawExecToolCallOutput {
|
||||
pub exit_status: ExitStatus,
|
||||
|
||||
@@ -37,16 +37,13 @@ mod openai_tools;
|
||||
pub mod plan_tool;
|
||||
mod project_doc;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
mod rollout;
|
||||
pub(crate) mod safety;
|
||||
mod safety;
|
||||
pub mod seatbelt;
|
||||
pub mod shell;
|
||||
pub mod spawn;
|
||||
pub mod turn_diff_tracker;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use client_common::model_supports_reasoning_summaries;
|
||||
pub use safety::get_platform_sandbox;
|
||||
|
||||
@@ -387,8 +387,6 @@ pub enum EventMsg {
|
||||
/// Notification that a patch application has finished.
|
||||
PatchApplyEnd(PatchApplyEndEvent),
|
||||
|
||||
TurnDiff(TurnDiffEvent),
|
||||
|
||||
/// Response to GetHistoryEntryRequest.
|
||||
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),
|
||||
|
||||
@@ -525,8 +523,6 @@ pub struct ExecCommandEndEvent {
|
||||
pub stderr: String,
|
||||
/// The command's exit code.
|
||||
pub exit_code: i32,
|
||||
/// The duration of the command execution.
|
||||
pub duration: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -600,11 +596,6 @@ pub struct PatchApplyEndEvent {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TurnDiffEvent {
|
||||
pub unified_diff: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GetHistoryEntryResponseEvent {
|
||||
pub offset: usize,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub mod ollama;
|
||||
@@ -1,259 +0,0 @@
|
||||
use bytes::BytesMut;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::BoxStream;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
|
||||
use super::DEFAULT_BASE_URL;
|
||||
use super::PullEvent;
|
||||
use super::PullProgressReporter;
|
||||
use super::parser::pull_events_from_value;
|
||||
use super::url::base_url_to_host_root;
|
||||
use super::url::is_openai_compatible_base_url;
|
||||
|
||||
/// Client for interacting with a local Ollama instance.
|
||||
pub struct OllamaClient {
|
||||
client: reqwest::Client,
|
||||
host_root: String,
|
||||
uses_openai_compat: bool,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Build a client from a provider definition. Falls back to the default
|
||||
/// local URL if no base_url is configured.
|
||||
pub fn from_provider(provider: &ModelProviderInfo) -> Self {
|
||||
let base_url = provider
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
|
||||
let uses_openai_compat = is_openai_compatible_base_url(&base_url)
|
||||
|| matches!(provider.wire_api, WireApi::Chat)
|
||||
&& is_openai_compatible_base_url(&base_url);
|
||||
let host_root = base_url_to_host_root(&base_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
client,
|
||||
host_root,
|
||||
uses_openai_compat,
|
||||
}
|
||||
}
|
||||
|
||||
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
||||
pub fn from_host_root(host_root: impl Into<String>) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
client,
|
||||
host_root: host_root.into(),
|
||||
uses_openai_compat: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
||||
pub async fn probe_server(&self) -> io::Result<bool> {
|
||||
let url = if self.uses_openai_compat {
|
||||
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
||||
};
|
||||
let resp = self.client.get(url).send().await;
|
||||
Ok(matches!(resp, Ok(r) if r.status().is_success()))
|
||||
}
|
||||
|
||||
/// Return the list of model names known to the local Ollama instance.
|
||||
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
||||
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.get(tags_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
|
||||
let names = val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Ok(names)
|
||||
}
|
||||
|
||||
/// Start a model pull and emit streaming events. The returned stream ends when
|
||||
/// a Success event is observed or the server closes the connection.
|
||||
pub async fn pull_model_stream(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> io::Result<BoxStream<'static, PullEvent>> {
|
||||
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(io::Error::other(format!(
|
||||
"failed to start pull: HTTP {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut buf = BytesMut::new();
|
||||
let _pending: VecDeque<PullEvent> = VecDeque::new();
|
||||
|
||||
// Using an async stream adaptor backed by unfold-like manual loop.
|
||||
let s = async_stream::stream! {
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
buf.extend_from_slice(&bytes);
|
||||
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||||
let line = buf.split_to(pos + 1);
|
||||
if let Ok(text) = std::str::from_utf8(&line) {
|
||||
let text = text.trim();
|
||||
if text.is_empty() { continue; }
|
||||
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
||||
for ev in pull_events_from_value(&value) { yield ev; }
|
||||
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
||||
yield PullEvent::Status(format!("error: {err_msg}"));
|
||||
return;
|
||||
}
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
if status == "success" { yield PullEvent::Success; return; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Connection error: end the stream.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(s))
|
||||
}
|
||||
|
||||
/// High-level helper to pull a model and drive a progress reporter.
|
||||
pub async fn pull_with_reporter(
|
||||
&self,
|
||||
model: &str,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> io::Result<()> {
|
||||
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
|
||||
let mut stream = self.pull_model_stream(model).await?;
|
||||
while let Some(event) = stream.next().await {
|
||||
reporter.on_event(&event)?;
|
||||
if matches!(event, PullEvent::Success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
|
||||
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
||||
#[tokio::test]
|
||||
async fn test_fetch_models_happy_path() {
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} is set; skipping test_fetch_models_happy_path",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(
|
||||
wiremock::ResponseTemplate::new(200).set_body_raw(
|
||||
serde_json::json!({
|
||||
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
|
||||
})
|
||||
.to_string(),
|
||||
"application/json",
|
||||
),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = OllamaClient::from_host_root(server.uri());
|
||||
let models = client.fetch_models().await.expect("fetch models");
|
||||
assert!(models.contains(&"llama3.2:3b".to_string()));
|
||||
assert!(models.contains(&"mistral".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_probe_server_happy_path_openai_compat_and_native() {
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
|
||||
// Native endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let native = OllamaClient::from_host_root(server.uri());
|
||||
assert!(native.probe_server().await.expect("probe native"));
|
||||
|
||||
// OpenAI compatibility endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/v1/models"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let provider = ModelProviderInfo {
|
||||
name: "Ollama".to_string(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: false,
|
||||
};
|
||||
let compat = OllamaClient::from_provider(&provider);
|
||||
assert!(compat.probe_server().await.expect("probe compat"));
|
||||
}
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
use toml_edit::DocumentMut as Document;
|
||||
use toml_edit::Item;
|
||||
use toml_edit::Table;
|
||||
use toml_edit::Value as TomlValueEdit;
|
||||
|
||||
use super::DEFAULT_BASE_URL;
|
||||
|
||||
/// Read the list of models recorded under [model_providers.ollama].models.
|
||||
pub fn read_ollama_models_list(config_path: &Path) -> Vec<String> {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
|
||||
{
|
||||
Some(toml::Value::Table(root)) => root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("models"))
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience wrapper that returns the models list as an io::Result for callers
|
||||
/// that want a uniform Result-based API.
|
||||
pub fn read_config_models(config_path: &Path) -> io::Result<Vec<String>> {
|
||||
Ok(read_ollama_models_list(config_path))
|
||||
}
|
||||
|
||||
/// Overwrite the recorded models list under [model_providers.ollama].models using toml_edit.
|
||||
pub fn write_ollama_models_list(config_path: &Path, models: &[String]) -> io::Result<()> {
|
||||
let mut doc = read_document(config_path)?;
|
||||
{
|
||||
let tbl = upsert_provider_ollama(&mut doc);
|
||||
let mut arr = toml_edit::Array::new();
|
||||
for m in models {
|
||||
arr.push(TomlValueEdit::from(m.clone()));
|
||||
}
|
||||
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
|
||||
}
|
||||
write_document(config_path, &doc)
|
||||
}
|
||||
|
||||
/// Write models list via a uniform name expected by higher layers.
|
||||
pub fn write_config_models(config_path: &Path, models: &[String]) -> io::Result<()> {
|
||||
write_ollama_models_list(config_path, models)
|
||||
}
|
||||
|
||||
/// Ensure `[model_providers.ollama]` exists with sensible defaults on disk.
|
||||
/// Returns true if it created/updated the entry.
|
||||
pub fn ensure_ollama_provider_entry(codex_home: &Path) -> io::Result<bool> {
|
||||
let config_path = codex_home.join("config.toml");
|
||||
let mut doc = read_document(&config_path)?;
|
||||
let before = doc.to_string();
|
||||
let _tbl = upsert_provider_ollama(&mut doc);
|
||||
let after = doc.to_string();
|
||||
if before != after {
|
||||
write_document(&config_path, &doc)?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Alias name mirroring the refactor plan wording.
|
||||
pub fn ensure_provider_entry_and_defaults(codex_home: &Path) -> io::Result<bool> {
|
||||
ensure_ollama_provider_entry(codex_home)
|
||||
}
|
||||
|
||||
/// Read whether the provider exists and how many models are recorded under it.
|
||||
pub fn read_provider_state(config_path: &Path) -> (bool, usize) {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
|
||||
{
|
||||
Some(toml::Value::Table(root)) => {
|
||||
let provider_present = root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.map(|_| true)
|
||||
.unwrap_or(false);
|
||||
let models_count = root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("models"))
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.len())
|
||||
.unwrap_or(0);
|
||||
(provider_present, models_count)
|
||||
}
|
||||
_ => (false, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- toml_edit helpers ----------
|
||||
|
||||
fn read_document(path: &Path) -> io::Result<Document> {
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(s) => Document::from_str(&s).map_err(io::Error::other),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Document::new()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
std::fs::write(path, doc.to_string())
|
||||
}
|
||||
|
||||
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
|
||||
// Ensure "model_providers" exists and is a table.
|
||||
let needs_init = match doc.get("model_providers") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_init {
|
||||
doc.as_table_mut()
|
||||
.insert("model_providers", Item::Table(Table::new()));
|
||||
}
|
||||
|
||||
// Now, get a mutable reference to the "model_providers" table without `expect`/`unwrap`.
|
||||
let providers: &mut Table = {
|
||||
// Insert if missing.
|
||||
if doc.as_table().get("model_providers").is_none() {
|
||||
doc.as_table_mut()
|
||||
.insert("model_providers", Item::Table(Table::new()));
|
||||
}
|
||||
match doc.as_table_mut().get_mut("model_providers") {
|
||||
Some(item) => {
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(Table::new());
|
||||
}
|
||||
match item.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => unreachable!("model_providers was set to a table"),
|
||||
}
|
||||
}
|
||||
None => unreachable!("model_providers should exist after insertion"),
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure "ollama" exists and is a table.
|
||||
let needs_ollama_init = match providers.get("ollama") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_ollama_init {
|
||||
providers.insert("ollama", Item::Table(Table::new()));
|
||||
}
|
||||
|
||||
// Get a mutable reference to the "ollama" table without `expect`/`unwrap`.
|
||||
let tbl: &mut Table = {
|
||||
let needs_set = match providers.get("ollama") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_set {
|
||||
providers.insert("ollama", Item::Table(Table::new()));
|
||||
}
|
||||
match providers.get_mut("ollama") {
|
||||
Some(item) => {
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(Table::new());
|
||||
}
|
||||
match item.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => unreachable!("ollama was set to a table"),
|
||||
}
|
||||
}
|
||||
None => unreachable!("ollama should exist after insertion"),
|
||||
}
|
||||
};
|
||||
|
||||
if !tbl.contains_key("name") {
|
||||
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
|
||||
}
|
||||
if !tbl.contains_key("base_url") {
|
||||
tbl["base_url"] = Item::Value(TomlValueEdit::from(DEFAULT_BASE_URL));
|
||||
}
|
||||
if !tbl.contains_key("wire_api") {
|
||||
tbl["wire_api"] = Item::Value(TomlValueEdit::from("chat"));
|
||||
}
|
||||
tbl
|
||||
}
|
||||
|
||||
pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
|
||||
let tbl = upsert_provider_ollama(doc);
|
||||
let mut arr = toml_edit::Array::new();
|
||||
for m in models {
|
||||
arr.push(TomlValueEdit::from(m.clone()));
|
||||
}
|
||||
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use toml_edit::DocumentMut as Document;
|
||||
|
||||
#[test]
|
||||
fn test_upsert_provider_and_models() {
|
||||
let mut doc = Document::new();
|
||||
let tbl = upsert_provider_ollama(&mut doc);
|
||||
assert!(tbl.contains_key("name"));
|
||||
assert!(tbl.contains_key("base_url"));
|
||||
assert!(tbl.contains_key("wire_api"));
|
||||
set_ollama_models(&mut doc, &[String::from("llama3.2:3b")]);
|
||||
let root = doc.as_table();
|
||||
let mp = match root.get("model_providers").and_then(|i| i.as_table()) {
|
||||
Some(t) => t,
|
||||
None => panic!("model_providers"),
|
||||
};
|
||||
let ollama = match mp.get("ollama").and_then(|i| i.as_table()) {
|
||||
Some(t) => t,
|
||||
None => panic!("ollama"),
|
||||
};
|
||||
let arr = match ollama.get("models") {
|
||||
Some(v) => v,
|
||||
None => panic!("models array"),
|
||||
};
|
||||
assert!(arr.is_array(), "models should be an array");
|
||||
let s = doc.to_string();
|
||||
assert!(s.contains("model_providers"));
|
||||
assert!(s.contains("ollama"));
|
||||
assert!(s.contains("models"));
|
||||
}
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CoreResult;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
pub const DEFAULT_BASE_URL: &str = "http://localhost:11434/v1";
|
||||
pub const DEFAULT_WIRE_API: crate::model_provider_info::WireApi =
|
||||
crate::model_provider_info::WireApi::Chat;
|
||||
pub const DEFAULT_PULL_ALLOWLIST: &[&str] = &["llama3.2:3b"];
|
||||
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod parser;
|
||||
pub mod url;
|
||||
|
||||
pub use client::OllamaClient;
|
||||
pub use config::read_config_models;
|
||||
pub use config::read_provider_state;
|
||||
pub use config::write_config_models;
|
||||
pub use url::base_url_to_host_root;
|
||||
pub use url::base_url_to_host_root_with_wire;
|
||||
pub use url::probe_ollama_server;
|
||||
pub use url::probe_url_for_base;
|
||||
/// Coordinator wrapper used by frontends when responding to `--ollama`.
|
||||
///
|
||||
/// - Probes the server using the configured base_url when present, otherwise
|
||||
/// falls back to DEFAULT_BASE_URL.
|
||||
/// - If the server is reachable, ensures an `[model_providers.ollama]` entry
|
||||
/// exists in `config.toml` with sensible defaults.
|
||||
/// - If no server is reachable, returns an error.
|
||||
pub async fn ensure_configured_and_running() -> CoreResult<()> {
|
||||
use crate::config::find_codex_home;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
let codex_home = find_codex_home()?;
|
||||
let config_path = codex_home.join("config.toml");
|
||||
// Try to read a configured base_url if present.
|
||||
let base_url = match std::fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(TomlValue::Table(root)) => root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("base_url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(DEFAULT_BASE_URL)
|
||||
.to_string(),
|
||||
_ => DEFAULT_BASE_URL.to_string(),
|
||||
},
|
||||
Err(_) => DEFAULT_BASE_URL.to_string(),
|
||||
};
|
||||
|
||||
// Probe reachability; map any probe error to a friendly unreachable message.
|
||||
let ok: bool = url::probe_ollama_server(&base_url)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if !ok {
|
||||
return Err(CodexErr::OllamaServerUnreachable);
|
||||
}
|
||||
|
||||
// Ensure provider entry exists with defaults.
|
||||
let _ = config::ensure_ollama_provider_entry(&codex_home)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod ensure_tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_configured_returns_friendly_error_when_unreachable() {
|
||||
// Skip in CI sandbox environments without network to avoid false negatives.
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} is set; skipping test_ensure_configured_returns_friendly_error_when_unreachable",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let tmpdir = tempfile::TempDir::new().expect("tempdir");
|
||||
let config_path = tmpdir.path().join("config.toml");
|
||||
std::fs::create_dir_all(tmpdir.path()).unwrap();
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"[model_providers.ollama]
|
||||
name = "Ollama"
|
||||
base_url = "http://127.0.0.1:1/v1"
|
||||
wire_api = "chat"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", tmpdir.path());
|
||||
}
|
||||
|
||||
let err = ensure_configured_and_running()
|
||||
.await
|
||||
.expect_err("should report unreachable server as friendly error");
|
||||
assert!(matches!(err, CodexErr::OllamaServerUnreachable));
|
||||
}
|
||||
}
|
||||
|
||||
/// Events emitted while pulling a model from Ollama.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PullEvent {
|
||||
/// A human-readable status message (e.g., "verifying", "writing").
|
||||
Status(String),
|
||||
/// Byte-level progress update for a specific layer digest.
|
||||
ChunkProgress {
|
||||
digest: String,
|
||||
total: Option<u64>,
|
||||
completed: Option<u64>,
|
||||
},
|
||||
/// The pull finished successfully.
|
||||
Success,
|
||||
}
|
||||
|
||||
/// A simple observer for pull progress events. Implementations decide how to
|
||||
/// render progress (CLI, TUI, logs, ...).
|
||||
pub trait PullProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A minimal CLI reporter that writes inline progress to stderr.
|
||||
pub struct CliProgressReporter {
|
||||
printed_header: bool,
|
||||
last_line_len: usize,
|
||||
last_completed_sum: u64,
|
||||
last_instant: std::time::Instant,
|
||||
totals_by_digest: HashMap<String, (u64, u64)>,
|
||||
}
|
||||
|
||||
impl Default for CliProgressReporter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CliProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
printed_header: false,
|
||||
last_line_len: 0,
|
||||
last_completed_sum: 0,
|
||||
last_instant: std::time::Instant::now(),
|
||||
totals_by_digest: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PullProgressReporter for CliProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
let mut out = std::io::stderr();
|
||||
match event {
|
||||
PullEvent::Status(status) => {
|
||||
// Avoid noisy manifest messages; otherwise show status inline.
|
||||
if status.eq_ignore_ascii_case("pulling manifest") {
|
||||
return Ok(());
|
||||
}
|
||||
let pad = self.last_line_len.saturating_sub(status.len());
|
||||
let line = format!("\r{status}{}", " ".repeat(pad));
|
||||
self.last_line_len = status.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
}
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
if let Some(t) = *total {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.0 = t;
|
||||
}
|
||||
if let Some(c) = *completed {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.1 = c;
|
||||
}
|
||||
|
||||
let (sum_total, sum_completed) = self
|
||||
.totals_by_digest
|
||||
.values()
|
||||
.fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
|
||||
if sum_total > 0 {
|
||||
if !self.printed_header {
|
||||
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let header = format!("Downloading model: total {gb:.2} GB\n");
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(header.as_bytes())?;
|
||||
self.printed_header = true;
|
||||
}
|
||||
let now = std::time::Instant::now();
|
||||
let dt = now
|
||||
.duration_since(self.last_instant)
|
||||
.as_secs_f64()
|
||||
.max(0.001);
|
||||
let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
|
||||
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
|
||||
self.last_completed_sum = sum_completed;
|
||||
self.last_instant = now;
|
||||
|
||||
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
|
||||
let text =
|
||||
format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
|
||||
let pad = self.last_line_len.saturating_sub(text.len());
|
||||
let line = format!("\r{text}{}", " ".repeat(pad));
|
||||
self.last_line_len = text.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
PullEvent::Success => {
|
||||
out.write_all(b"\n")?;
|
||||
out.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
|
||||
/// CLI behavior aligned until a dedicated TUI integration is implemented.
|
||||
#[derive(Default)]
|
||||
pub struct TuiProgressReporter(CliProgressReporter);
|
||||
impl TuiProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
impl PullProgressReporter for TuiProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
self.0.on_event(event)
|
||||
}
|
||||
}
|
||||
/// Ensure a model is available locally.
|
||||
///
|
||||
/// - If the model is already present, ensure it is recorded in config.toml.
|
||||
/// - If missing and in the default allowlist, pull it with streaming progress
|
||||
/// and record it in config.toml after success.
|
||||
/// - If missing and not allowlisted, return an error.
|
||||
pub async fn ensure_model_available(
|
||||
model: &str,
|
||||
client: &OllamaClient,
|
||||
config_path: &Path,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> CoreResult<()> {
|
||||
let mut listed = config::read_ollama_models_list(config_path);
|
||||
let available = client.fetch_models().await.unwrap_or_default();
|
||||
if available.iter().any(|m| m == model) {
|
||||
if !listed.iter().any(|m| m == model) {
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = config::write_ollama_models_list(config_path, &listed);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !DEFAULT_PULL_ALLOWLIST.contains(&model) {
|
||||
return Err(CodexErr::OllamaModelNotFound(model.to_string()));
|
||||
}
|
||||
|
||||
loop {
|
||||
let _ = client.pull_with_reporter(model, reporter).await;
|
||||
// After the stream completes (success or early EOF), check again.
|
||||
let available = client.fetch_models().await.unwrap_or_default();
|
||||
if available.iter().any(|m| m == model) {
|
||||
break;
|
||||
}
|
||||
// Keep waiting for the model to finish downloading.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = config::write_ollama_models_list(config_path, &listed);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use super::PullEvent;
|
||||
|
||||
// Convert a single JSON object representing a pull update into one or more events.
|
||||
pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
events.push(PullEvent::Status(status.to_string()));
|
||||
if status == "success" {
|
||||
events.push(PullEvent::Success);
|
||||
}
|
||||
}
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if total.is_some() || completed.is_some() {
|
||||
events.push(PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
});
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_status_and_success() {
|
||||
let v: JsonValue = serde_json::json!({"status":"verifying"});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"status":"success"});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 2);
|
||||
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
|
||||
assert!(matches!(events2[1], PullEvent::Success));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_progress() {
|
||||
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:abc");
|
||||
assert_eq!(*total, Some(100));
|
||||
assert_eq!(*completed, None);
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 1);
|
||||
match &events2[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:def");
|
||||
assert_eq!(*total, None);
|
||||
assert_eq!(*completed, Some(42));
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
use crate::error::Result as CoreResult;
|
||||
|
||||
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
|
||||
pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool {
|
||||
base_url.trim_end_matches('/').ends_with("/v1")
|
||||
}
|
||||
|
||||
/// Convert a provider base_url into the native Ollama host root.
|
||||
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
|
||||
pub fn base_url_to_host_root(base_url: &str) -> String {
|
||||
let trimmed = base_url.trim_end_matches('/');
|
||||
if trimmed.ends_with("/v1") {
|
||||
trimmed
|
||||
.trim_end_matches("/v1")
|
||||
.trim_end_matches('/')
|
||||
.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Variant that considers an explicit WireApi value; provided to centralize
|
||||
/// host root computation in one place for future extension.
|
||||
pub fn base_url_to_host_root_with_wire(
|
||||
base_url: &str,
|
||||
_wire_api: crate::model_provider_info::WireApi,
|
||||
) -> String {
|
||||
base_url_to_host_root(base_url)
|
||||
}
|
||||
|
||||
/// Compute the probe URL to verify if an Ollama server is reachable.
|
||||
/// If the configured base is OpenAI-compatible (/v1), probe "models", otherwise
|
||||
/// fall back to the native "/api/tags" endpoint.
|
||||
pub fn probe_url_for_base(base_url: &str) -> String {
|
||||
if is_openai_compatible_base_url(base_url) {
|
||||
format!("{}/models", base_url.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("{}/api/tags", base_url.trim_end_matches('/'))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience helper to probe an Ollama server given a provider style base URL.
|
||||
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
|
||||
let url = probe_url_for_base(base_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()?;
|
||||
let resp = client.get(url).send().await?;
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_base_url_to_host_root() {
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/v1"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_probe_url_for_base() {
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434/v1"),
|
||||
"http://localhost:11434/v1/models"
|
||||
);
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434"),
|
||||
"http://localhost:11434/api/tags"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,3 @@
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
@@ -1,887 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use sha1::digest::Output;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::protocol::FileChange;
|
||||
|
||||
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
|
||||
const DEV_NULL: &str = "/dev/null";
|
||||
|
||||
struct BaselineFileInfo {
|
||||
path: PathBuf,
|
||||
content: Vec<u8>,
|
||||
mode: FileMode,
|
||||
oid: String,
|
||||
}
|
||||
|
||||
/// Tracks sets of changes to files and exposes the overall unified diff.
|
||||
/// Internally, the way this works is now:
|
||||
/// 1. Maintain an in-memory baseline snapshot of files when they are first seen.
|
||||
/// For new additions, do not create a baseline so that diffs are shown as proper additions (using /dev/null).
|
||||
/// 2. Keep a stable internal filename (uuid) per external path for rename tracking.
|
||||
/// 3. To compute the aggregated unified diff, compare each baseline snapshot to the current file on disk entirely in-memory
|
||||
/// using the `similar` crate and emit unified diffs with rewritten external paths.
|
||||
#[derive(Default)]
|
||||
pub struct TurnDiffTracker {
|
||||
/// Map external path -> internal filename (uuid).
|
||||
external_to_temp_name: HashMap<PathBuf, String>,
|
||||
/// Internal filename -> baseline file info.
|
||||
baseline_file_info: HashMap<String, BaselineFileInfo>,
|
||||
/// Internal filename -> external path as of current accumulated state (after applying all changes).
|
||||
/// This is where renames are tracked.
|
||||
temp_name_to_current_path: HashMap<String, PathBuf>,
|
||||
/// Cache of known git worktree roots to avoid repeated filesystem walks.
|
||||
git_root_cache: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl TurnDiffTracker {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Front-run apply patch calls to track the starting contents of any modified files.
|
||||
/// - Creates an in-memory baseline snapshot for files that already exist on disk when first seen.
|
||||
/// - For additions, we intentionally do not create a baseline snapshot so that diffs are proper additions.
|
||||
/// - Also updates internal mappings for move/rename events.
|
||||
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
|
||||
for (path, change) in changes.iter() {
|
||||
// Ensure a stable internal filename exists for this external path.
|
||||
if !self.external_to_temp_name.contains_key(path) {
|
||||
let internal = Uuid::new_v4().to_string();
|
||||
self.external_to_temp_name
|
||||
.insert(path.clone(), internal.clone());
|
||||
self.temp_name_to_current_path
|
||||
.insert(internal.clone(), path.clone());
|
||||
|
||||
// If the file exists on disk now, snapshot as baseline; else leave missing to represent /dev/null.
|
||||
let baseline_file_info = if path.exists() {
|
||||
let mode = file_mode_for_path(path);
|
||||
let mode_val = mode.unwrap_or(FileMode::Regular);
|
||||
let content = blob_bytes(path, &mode_val).unwrap_or_default();
|
||||
let oid = if mode == Some(FileMode::Symlink) {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(&content))
|
||||
} else {
|
||||
self.git_blob_oid_for_path(path)
|
||||
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(&content)))
|
||||
};
|
||||
Some(BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content,
|
||||
mode: mode_val,
|
||||
oid,
|
||||
})
|
||||
} else {
|
||||
Some(BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content: vec![],
|
||||
mode: FileMode::Regular,
|
||||
oid: ZERO_OID.to_string(),
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(baseline_file_info) = baseline_file_info {
|
||||
self.baseline_file_info
|
||||
.insert(internal.clone(), baseline_file_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Track rename/move in current mapping if provided in an Update.
|
||||
if let FileChange::Update {
|
||||
move_path: Some(dest),
|
||||
..
|
||||
} = change
|
||||
{
|
||||
let uuid_filename = match self.external_to_temp_name.get(path) {
|
||||
Some(i) => i.clone(),
|
||||
None => {
|
||||
// This should be rare, but if we haven't mapped the source, create it with no baseline.
|
||||
let i = Uuid::new_v4().to_string();
|
||||
self.baseline_file_info.insert(
|
||||
i.clone(),
|
||||
BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content: vec![],
|
||||
mode: FileMode::Regular,
|
||||
oid: ZERO_OID.to_string(),
|
||||
},
|
||||
);
|
||||
i
|
||||
}
|
||||
};
|
||||
// Update current external mapping for temp file name.
|
||||
self.temp_name_to_current_path
|
||||
.insert(uuid_filename.clone(), dest.clone());
|
||||
// Update forward file_mapping: external current -> internal name.
|
||||
self.external_to_temp_name.remove(path);
|
||||
self.external_to_temp_name
|
||||
.insert(dest.clone(), uuid_filename);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn get_path_for_internal(&self, internal: &str) -> Option<PathBuf> {
|
||||
self.temp_name_to_current_path
|
||||
.get(internal)
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
self.baseline_file_info
|
||||
.get(internal)
|
||||
.map(|info| info.path.clone())
|
||||
})
|
||||
}
|
||||
|
||||
/// Find the git worktree root for a file/directory by walking up to the first ancestor containing a `.git` entry.
|
||||
/// Uses a simple cache of known roots and avoids negative-result caching for simplicity.
|
||||
fn find_git_root_cached(&mut self, start: &Path) -> Option<PathBuf> {
|
||||
let dir = if start.is_dir() {
|
||||
start
|
||||
} else {
|
||||
start.parent()?
|
||||
};
|
||||
|
||||
// Fast path: if any cached root is an ancestor of this path, use it.
|
||||
if let Some(root) = self
|
||||
.git_root_cache
|
||||
.iter()
|
||||
.find(|r| dir.starts_with(r))
|
||||
.cloned()
|
||||
{
|
||||
return Some(root);
|
||||
}
|
||||
|
||||
// Walk up to find a `.git` marker.
|
||||
let mut cur = dir.to_path_buf();
|
||||
loop {
|
||||
let git_marker = cur.join(".git");
|
||||
if git_marker.is_dir() || git_marker.is_file() {
|
||||
if !self.git_root_cache.iter().any(|r| r == &cur) {
|
||||
self.git_root_cache.push(cur.clone());
|
||||
}
|
||||
return Some(cur);
|
||||
}
|
||||
|
||||
// On Windows, avoid walking above the drive or UNC share root.
|
||||
#[cfg(windows)]
|
||||
{
|
||||
if is_windows_drive_or_unc_root(&cur) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(parent) = cur.parent() {
|
||||
cur = parent.to_path_buf();
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a display string for `path` relative to its git root if found, else absolute.
|
||||
fn relative_to_git_root_str(&mut self, path: &Path) -> String {
|
||||
let s = if let Some(root) = self.find_git_root_cached(path) {
|
||||
if let Ok(rel) = path.strip_prefix(&root) {
|
||||
rel.display().to_string()
|
||||
} else {
|
||||
path.display().to_string()
|
||||
}
|
||||
} else {
|
||||
path.display().to_string()
|
||||
};
|
||||
s.replace('\\', "/")
|
||||
}
|
||||
|
||||
/// Ask git to compute the blob SHA-1 for the file at `path` within its repository.
|
||||
/// Returns None if no repository is found or git invocation fails.
|
||||
fn git_blob_oid_for_path(&mut self, path: &Path) -> Option<String> {
|
||||
let root = self.find_git_root_cached(path)?;
|
||||
// Compute a path relative to the repo root for better portability across platforms.
|
||||
let rel = path.strip_prefix(&root).unwrap_or(path);
|
||||
let output = Command::new("git")
|
||||
.arg("-C")
|
||||
.arg(&root)
|
||||
.arg("hash-object")
|
||||
.arg("--")
|
||||
.arg(rel)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let s = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if s.len() == 40 { Some(s) } else { None }
|
||||
}
|
||||
|
||||
/// Recompute the aggregated unified diff by comparing all of the in-memory snapshots that were
|
||||
/// collected before the first time they were touched by apply_patch during this turn with
|
||||
/// the current repo state.
|
||||
pub fn get_unified_diff(&mut self) -> Result<Option<String>> {
|
||||
let mut aggregated = String::new();
|
||||
|
||||
// Compute diffs per tracked internal file in a stable order by external path.
|
||||
let mut baseline_file_names: Vec<String> =
|
||||
self.baseline_file_info.keys().cloned().collect();
|
||||
// Sort lexicographically by full repo-relative path to match git behavior.
|
||||
baseline_file_names.sort_by_key(|internal| {
|
||||
self.get_path_for_internal(internal)
|
||||
.map(|p| self.relative_to_git_root_str(&p))
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
for internal in baseline_file_names {
|
||||
aggregated.push_str(self.get_file_diff(&internal).as_str());
|
||||
if !aggregated.ends_with('\n') {
|
||||
aggregated.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
if aggregated.trim().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(aggregated))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
|
||||
let mut aggregated = String::new();
|
||||
|
||||
// Snapshot lightweight fields only.
|
||||
let (baseline_external_path, baseline_mode, left_oid) = {
|
||||
if let Some(info) = self.baseline_file_info.get(internal_file_name) {
|
||||
(info.path.clone(), info.mode, info.oid.clone())
|
||||
} else {
|
||||
(PathBuf::new(), FileMode::Regular, ZERO_OID.to_string())
|
||||
}
|
||||
};
|
||||
let current_external_path = match self.get_path_for_internal(internal_file_name) {
|
||||
Some(p) => p,
|
||||
None => return aggregated,
|
||||
};
|
||||
|
||||
let current_mode = file_mode_for_path(¤t_external_path).unwrap_or(FileMode::Regular);
|
||||
let right_bytes = blob_bytes(¤t_external_path, ¤t_mode);
|
||||
|
||||
// Compute displays with &mut self before borrowing any baseline content.
|
||||
let left_display = self.relative_to_git_root_str(&baseline_external_path);
|
||||
let right_display = self.relative_to_git_root_str(¤t_external_path);
|
||||
|
||||
// Compute right oid before borrowing baseline content.
|
||||
let right_oid = if let Some(b) = right_bytes.as_ref() {
|
||||
if current_mode == FileMode::Symlink {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(b))
|
||||
} else {
|
||||
self.git_blob_oid_for_path(¤t_external_path)
|
||||
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(b)))
|
||||
}
|
||||
} else {
|
||||
ZERO_OID.to_string()
|
||||
};
|
||||
|
||||
// Borrow baseline content only after all &mut self uses are done.
|
||||
let left_present = left_oid.as_str() != ZERO_OID;
|
||||
let left_bytes: Option<&[u8]> = if left_present {
|
||||
self.baseline_file_info
|
||||
.get(internal_file_name)
|
||||
.map(|i| i.content.as_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Fast path: identical bytes or both missing.
|
||||
if left_bytes == right_bytes.as_deref() {
|
||||
return aggregated;
|
||||
}
|
||||
|
||||
aggregated.push_str(&format!("diff --git a/{left_display} b/{right_display}\n"));
|
||||
|
||||
let is_add = !left_present && right_bytes.is_some();
|
||||
let is_delete = left_present && right_bytes.is_none();
|
||||
|
||||
if is_add {
|
||||
aggregated.push_str(&format!("new file mode {current_mode}\n"));
|
||||
} else if is_delete {
|
||||
aggregated.push_str(&format!("deleted file mode {baseline_mode}\n"));
|
||||
} else if baseline_mode != current_mode {
|
||||
aggregated.push_str(&format!("old mode {baseline_mode}\n"));
|
||||
aggregated.push_str(&format!("new mode {current_mode}\n"));
|
||||
}
|
||||
|
||||
let left_text = left_bytes.and_then(|b| std::str::from_utf8(b).ok());
|
||||
let right_text = right_bytes
|
||||
.as_deref()
|
||||
.and_then(|b| std::str::from_utf8(b).ok());
|
||||
|
||||
let can_text_diff = matches!(
|
||||
(left_text, right_text, is_add, is_delete),
|
||||
(Some(_), Some(_), _, _) | (_, Some(_), true, _) | (Some(_), _, _, true)
|
||||
);
|
||||
|
||||
if can_text_diff {
|
||||
let l = left_text.unwrap_or("");
|
||||
let r = right_text.unwrap_or("");
|
||||
|
||||
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
|
||||
|
||||
let old_header = if left_present {
|
||||
format!("a/{left_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
let new_header = if right_bytes.is_some() {
|
||||
format!("b/{right_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
|
||||
let diff = similar::TextDiff::from_lines(l, r);
|
||||
let unified = diff
|
||||
.unified_diff()
|
||||
.context_radius(3)
|
||||
.header(&old_header, &new_header)
|
||||
.to_string();
|
||||
|
||||
aggregated.push_str(&unified);
|
||||
} else {
|
||||
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
|
||||
let old_header = if left_present {
|
||||
format!("a/{left_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
let new_header = if right_bytes.is_some() {
|
||||
format!("b/{right_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
aggregated.push_str(&format!("--- {old_header}\n"));
|
||||
aggregated.push_str(&format!("+++ {new_header}\n"));
|
||||
aggregated.push_str("Binary files differ\n");
|
||||
}
|
||||
aggregated
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the Git SHA-1 blob object ID for the given content (bytes).
|
||||
fn git_blob_sha1_hex_bytes(data: &[u8]) -> Output<sha1::Sha1> {
|
||||
// Git blob hash is sha1 of: "blob <len>\0<data>"
|
||||
let header = format!("blob {}\0", data.len());
|
||||
use sha1::Digest;
|
||||
let mut hasher = sha1::Sha1::new();
|
||||
hasher.update(header.as_bytes());
|
||||
hasher.update(data);
|
||||
hasher.finalize()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum FileMode {
|
||||
Regular,
|
||||
#[cfg(unix)]
|
||||
Executable,
|
||||
Symlink,
|
||||
}
|
||||
|
||||
impl FileMode {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
FileMode::Regular => "100644",
|
||||
#[cfg(unix)]
|
||||
FileMode::Executable => "100755",
|
||||
FileMode::Symlink => "120000",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FileMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn file_mode_for_path(path: &Path) -> Option<FileMode> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let meta = fs::symlink_metadata(path).ok()?;
|
||||
let ft = meta.file_type();
|
||||
if ft.is_symlink() {
|
||||
return Some(FileMode::Symlink);
|
||||
}
|
||||
let mode = meta.permissions().mode();
|
||||
let is_exec = (mode & 0o111) != 0;
|
||||
Some(if is_exec {
|
||||
FileMode::Executable
|
||||
} else {
|
||||
FileMode::Regular
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn file_mode_for_path(_path: &Path) -> Option<FileMode> {
|
||||
// Default to non-executable on non-unix.
|
||||
Some(FileMode::Regular)
|
||||
}
|
||||
|
||||
fn blob_bytes(path: &Path, mode: &FileMode) -> Option<Vec<u8>> {
|
||||
if path.exists() {
|
||||
let contents = if *mode == FileMode::Symlink {
|
||||
symlink_blob_bytes(path)
|
||||
.ok_or_else(|| anyhow!("failed to read symlink target for {}", path.display()))
|
||||
} else {
|
||||
fs::read(path)
|
||||
.with_context(|| format!("failed to read current file for diff {}", path.display()))
|
||||
};
|
||||
contents.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn symlink_blob_bytes(path: &Path) -> Option<Vec<u8>> {
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
let target = std::fs::read_link(path).ok()?;
|
||||
Some(target.as_os_str().as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn symlink_blob_bytes(_path: &Path) -> Option<Vec<u8>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn is_windows_drive_or_unc_root(p: &std::path::Path) -> bool {
|
||||
use std::path::Component;
|
||||
let mut comps = p.components();
|
||||
matches!(
|
||||
(comps.next(), comps.next(), comps.next()),
|
||||
(Some(Component::Prefix(_)), Some(Component::RootDir), None)
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Compute the Git SHA-1 blob object ID for the given content (string).
|
||||
/// This delegates to the bytes version to avoid UTF-8 lossy conversions here.
|
||||
fn git_blob_sha1_hex(data: &str) -> String {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes()))
|
||||
}
|
||||
|
||||
fn normalize_diff_for_test(input: &str, root: &Path) -> String {
|
||||
let root_str = root.display().to_string().replace('\\', "/");
|
||||
let replaced = input.replace(&root_str, "<TMP>");
|
||||
// Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin
|
||||
let mut blocks: Vec<String> = Vec::new();
|
||||
let mut current = String::new();
|
||||
for line in replaced.lines() {
|
||||
if line.starts_with("diff --git ") && !current.is_empty() {
|
||||
blocks.push(current);
|
||||
current = String::new();
|
||||
}
|
||||
if !current.is_empty() {
|
||||
current.push('\n');
|
||||
}
|
||||
current.push_str(line);
|
||||
}
|
||||
if !current.is_empty() {
|
||||
blocks.push(current);
|
||||
}
|
||||
blocks.sort();
|
||||
let mut out = blocks.join("\n");
|
||||
if !out.ends_with('\n') {
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_add_and_update() {
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("a.txt");
|
||||
|
||||
// First patch: add file (baseline should be /dev/null).
|
||||
let add_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Add {
|
||||
content: "foo\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&add_changes);
|
||||
|
||||
// Simulate apply: create the file on disk.
|
||||
fs::write(&file, "foo\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -0,0 +1 @@
|
||||
+foo
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Second patch: update the file on disk.
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Simulate apply: append a new line.
|
||||
fs::write(&file, "foo\nbar\n").unwrap();
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected_combined = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -0,0 +1,2 @@
|
||||
+foo
|
||||
+bar
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected_combined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_delete() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("b.txt");
|
||||
fs::write(&file, "x\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let del_changes = HashMap::from([(file.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_changes);
|
||||
|
||||
// Simulate apply: delete the file from disk.
|
||||
let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
fs::remove_file(&file).unwrap();
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let left_oid = git_blob_sha1_hex("x\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/b.txt b/<TMP>/b.txt
|
||||
deleted file mode {baseline_mode}
|
||||
index {left_oid}..{ZERO_OID}
|
||||
--- a/<TMP>/b.txt
|
||||
+++ {DEV_NULL}
|
||||
@@ -1 +0,0 @@
|
||||
-x
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_move_and_update() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("src.txt");
|
||||
let dest = dir.path().join("dst.txt");
|
||||
fs::write(&src, "line\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv_changes = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv_changes);
|
||||
|
||||
// Simulate apply: move and update content.
|
||||
fs::rename(&src, &dest).unwrap();
|
||||
fs::write(&dest, "line2\n").unwrap();
|
||||
|
||||
let out = acc.get_unified_diff().unwrap().unwrap();
|
||||
let out = normalize_diff_for_test(&out, dir.path());
|
||||
let expected = {
|
||||
let left_oid = git_blob_sha1_hex("line\n");
|
||||
let right_oid = git_blob_sha1_hex("line2\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/src.txt b/<TMP>/dst.txt
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/src.txt
|
||||
+++ b/<TMP>/dst.txt
|
||||
@@ -1 +1 @@
|
||||
-line
|
||||
+line2
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(out, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_without_1change_yields_no_diff() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("moved.txt");
|
||||
let dest = dir.path().join("renamed.txt");
|
||||
fs::write(&src, "same\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv_changes = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv_changes);
|
||||
|
||||
// Simulate apply: move only, no content change.
|
||||
fs::rename(&src, &dest).unwrap();
|
||||
|
||||
let diff = acc.get_unified_diff().unwrap();
|
||||
assert_eq!(diff, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_declared_but_file_only_appears_at_dest_is_add() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("src.txt");
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv);
|
||||
// No file existed initially; create only dest
|
||||
fs::write(&dest, "hello\n").unwrap();
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("hello\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/src.txt b/<TMP>/dest.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/dest.txt
|
||||
@@ -0,0 +1 @@
|
||||
+hello
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_persists_across_new_baseline_for_new_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let a = dir.path().join("a.txt");
|
||||
let b = dir.path().join("b.txt");
|
||||
fs::write(&a, "foo\n").unwrap();
|
||||
fs::write(&b, "z\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
// First: update existing a.txt (baseline snapshot is created for a).
|
||||
let update_a = HashMap::from([(
|
||||
a.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_a);
|
||||
// Simulate apply: modify a.txt on disk.
|
||||
fs::write(&a, "foo\nbar\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let left_oid = git_blob_sha1_hex("foo\n");
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/a.txt
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -1 +1,2 @@
|
||||
foo
|
||||
+bar
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
|
||||
let del_b = HashMap::from([(b.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_b);
|
||||
// Simulate apply: delete b.txt.
|
||||
let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular);
|
||||
fs::remove_file(&b).unwrap();
|
||||
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected = {
|
||||
let left_oid_a = git_blob_sha1_hex("foo\n");
|
||||
let right_oid_a = git_blob_sha1_hex("foo\nbar\n");
|
||||
let left_oid_b = git_blob_sha1_hex("z\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
index {left_oid_a}..{right_oid_a}
|
||||
--- a/<TMP>/a.txt
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -1 +1,2 @@
|
||||
foo
|
||||
+bar
|
||||
diff --git a/<TMP>/b.txt b/<TMP>/b.txt
|
||||
deleted file mode {baseline_mode}
|
||||
index {left_oid_b}..{ZERO_OID}
|
||||
--- a/<TMP>/b.txt
|
||||
+++ {DEV_NULL}
|
||||
@@ -1 +0,0 @@
|
||||
-z
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_files_differ_update() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("bin.dat");
|
||||
|
||||
// Initial non-UTF8 bytes
|
||||
let left_bytes: Vec<u8> = vec![0xff, 0xfe, 0xfd, 0x00];
|
||||
// Updated non-UTF8 bytes
|
||||
let right_bytes: Vec<u8> = vec![0x01, 0x02, 0x03, 0x00];
|
||||
|
||||
fs::write(&file, &left_bytes).unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Apply update on disk
|
||||
fs::write(&file, &right_bytes).unwrap();
|
||||
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes));
|
||||
let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes));
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/bin.dat b/<TMP>/bin.dat
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/bin.dat
|
||||
+++ b/<TMP>/bin.dat
|
||||
Binary files differ
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filenames_with_spaces_add_and_update() {
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("name with spaces.txt");
|
||||
|
||||
// First patch: add file (baseline should be /dev/null).
|
||||
let add_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Add {
|
||||
content: "foo\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&add_changes);
|
||||
|
||||
// Simulate apply: create the file on disk.
|
||||
fs::write(&file, "foo\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/name with spaces.txt
|
||||
@@ -0,0 +1 @@
|
||||
+foo
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Second patch: update the file on disk.
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Simulate apply: append a new line with a space.
|
||||
fs::write(&file, "foo\nbar baz\n").unwrap();
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected_combined = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar baz\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/name with spaces.txt
|
||||
@@ -0,0 +1,2 @@
|
||||
+foo
|
||||
+bar baz
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected_combined);
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
#![cfg(target_os = "macos")]
|
||||
#![expect(clippy::expect_used)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use codex_core::get_platform_sandbox;
|
||||
|
||||
async fn run_test_cmd(tmp: TempDir, cmd: Vec<&str>, should_be_ok: bool) {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
let sandbox_type = get_platform_sandbox().expect("should be able to get sandbox type");
|
||||
assert_eq!(sandbox_type, SandboxType::MacosSeatbelt);
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd.iter().map(|s| s.to_string()).collect(),
|
||||
cwd: tmp.path().to_path_buf(),
|
||||
timeout_ms: Some(1000),
|
||||
env: HashMap::new(),
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(params, sandbox_type, ctrl_c, &policy, &None, None).await;
|
||||
|
||||
assert!(result.is_ok() == should_be_ok);
|
||||
}
|
||||
|
||||
/// Command succeeds with exit code 0 normally
|
||||
#[tokio::test]
|
||||
async fn exit_code_0_succeeds() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let cmd = vec!["echo", "hello"];
|
||||
|
||||
run_test_cmd(tmp, cmd, true).await
|
||||
}
|
||||
|
||||
/// Command not found returns exit code 127, this is not considered a sandbox error
|
||||
#[tokio::test]
|
||||
async fn exit_command_not_found_is_ok() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let cmd = vec!["/bin/bash", "-c", "nonexistent_command_12345"];
|
||||
run_test_cmd(tmp, cmd, true).await
|
||||
}
|
||||
|
||||
/// Writing a file fails and should be considered a sandbox error
|
||||
#[tokio::test]
|
||||
async fn write_file_fails_as_sandbox_error() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let path = tmp.path().join("test.txt");
|
||||
let cmd = vec![
|
||||
"/user/bin/touch",
|
||||
path.to_str().expect("should be able to get path"),
|
||||
];
|
||||
|
||||
run_test_cmd(tmp, cmd, false).await;
|
||||
}
|
||||
@@ -177,7 +177,8 @@ async fn live_shell_function_call() {
|
||||
match ev.msg {
|
||||
EventMsg::ExecCommandBegin(codex_core::protocol::ExecCommandBeginEvent {
|
||||
command,
|
||||
..
|
||||
call_id: _,
|
||||
cwd: _,
|
||||
}) => {
|
||||
assert_eq!(command, vec!["echo", MARKER]);
|
||||
saw_begin = true;
|
||||
@@ -185,7 +186,8 @@ async fn live_shell_function_call() {
|
||||
EventMsg::ExecCommandEnd(codex_core::protocol::ExecCommandEndEvent {
|
||||
stdout,
|
||||
exit_code,
|
||||
..
|
||||
call_id: _,
|
||||
stderr: _,
|
||||
}) => {
|
||||
assert_eq!(exit_code, 0, "echo returned non‑zero exit code");
|
||||
assert!(stdout.contains(MARKER));
|
||||
|
||||
@@ -14,12 +14,6 @@ pub struct Cli {
|
||||
#[arg(long, short = 'm')]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Convenience flag to select the local Ollama provider.
|
||||
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
|
||||
/// creates a model_providers.ollama entry in config.toml if missing.
|
||||
#[arg(long = "ollama", default_value_t = false)]
|
||||
pub ollama: bool,
|
||||
|
||||
/// Select the sandbox policy to use when executing model-generated shell
|
||||
/// commands.
|
||||
#[arg(long = "sandbox", short = 's')]
|
||||
|
||||
@@ -20,7 +20,6 @@ use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
@@ -107,6 +106,7 @@ impl EventProcessorWithHumanOutput {
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
@@ -228,6 +228,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
@@ -243,14 +244,16 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin { command, .. }) = exec_command
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_duration(duration)),
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
@@ -400,7 +403,6 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
..
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
@@ -430,10 +432,6 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::TurnDiff(TurnDiffEvent { unified_diff }) => {
|
||||
ts_println!(self, "{}", "turn diff:".style(self.magenta));
|
||||
println!("{unified_diff}");
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
|
||||
@@ -31,17 +31,10 @@ use tracing_subscriber::EnvFilter;
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
// Shared Ollama helpers are centralized in codex_core::providers::ollama.
|
||||
use codex_core::providers::ollama::CliProgressReporter;
|
||||
use codex_core::providers::ollama::OllamaClient;
|
||||
use codex_core::providers::ollama::ensure_configured_and_running;
|
||||
use codex_core::providers::ollama::ensure_model_available;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
model,
|
||||
ollama,
|
||||
config_profile,
|
||||
full_auto,
|
||||
dangerously_bypass_approvals_and_sandbox,
|
||||
@@ -55,9 +48,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
config_overrides,
|
||||
} = cli;
|
||||
|
||||
// Track whether the user explicitly provided a model via --model.
|
||||
let user_specified_model = model.is_some();
|
||||
|
||||
// Determine the prompt based on CLI arg and/or stdin.
|
||||
let prompt = match prompt {
|
||||
Some(p) if p != "-" => p,
|
||||
@@ -124,16 +114,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
sandbox_mode_cli_arg.map(Into::<SandboxMode>::into)
|
||||
};
|
||||
|
||||
// When the user opts into the Ollama provider via `--ollama`, ensure we
|
||||
// have a configured provider entry and that a local server is running.
|
||||
if ollama {
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Load configuration and determine approval policy
|
||||
let overrides = ConfigOverrides {
|
||||
model,
|
||||
@@ -143,11 +123,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox_mode,
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: if ollama {
|
||||
Some("ollama".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
include_plan_tool: None,
|
||||
@@ -162,22 +138,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
|
||||
// If the user passed both --ollama and --model, ensure the requested model
|
||||
// is present locally or pull it automatically (subject to allowlist).
|
||||
if ollama && user_specified_model {
|
||||
let model_name = config.model.clone();
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
let mut reporter = CliProgressReporter::new();
|
||||
if let Err(e) =
|
||||
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
|
||||
} else {
|
||||
|
||||
@@ -263,7 +263,6 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
|
||||
@@ -97,7 +97,6 @@ pub async fn run_conversation_loop(
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`ElicitRequest`].
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
|
||||
@@ -89,18 +89,14 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
// This is the first request from the server, so the id should be 0 given
|
||||
// how things are currently implemented.
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
let params = serde_json::from_value::<ExecApprovalElicitRequestParams>(
|
||||
elicitation_request
|
||||
.params
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("elicitation_request.params must be set"))?,
|
||||
)?;
|
||||
let expected_elicitation_request = create_expected_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
shell_command.clone(),
|
||||
workdir_for_shell_function_call.path(),
|
||||
codex_request_id.to_string(),
|
||||
params.codex_event_id.clone(),
|
||||
// Internal Codex id: empirically it is 1, but this is
|
||||
// admittedly an internal detail that could change.
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
|
||||
@@ -48,8 +48,6 @@ serde_json = { version = "1", features = ["preserve_order"] }
|
||||
shlex = "1.3.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
supports-color = "3.0.2"
|
||||
textwrap = "0.16.2"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -62,13 +60,13 @@ tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
tui-input = "0.14.0"
|
||||
tui-markdown = "0.3.3"
|
||||
tui-textarea = "0.7.0"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
uuid = "1"
|
||||
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "1.43.1"
|
||||
pretty_assertions = "1"
|
||||
rand = "0.8"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
@@ -43,7 +43,6 @@ enum AppState<'a> {
|
||||
},
|
||||
/// The start-up warning that recommends running codex inside a Git repo.
|
||||
GitWarning { screen: GitWarningScreen },
|
||||
// (no additional states)
|
||||
}
|
||||
|
||||
pub(crate) struct App<'a> {
|
||||
@@ -439,15 +438,14 @@ impl App<'_> {
|
||||
);
|
||||
self.pending_history_lines.clear();
|
||||
}
|
||||
terminal.draw(|frame| match &mut self.app_state {
|
||||
match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
if let Some((x, y)) = widget.cursor_pos(frame.area()) {
|
||||
frame.set_cursor_position((x, y));
|
||||
}
|
||||
frame.render_widget_ref(&**widget, frame.area())
|
||||
terminal.draw(|frame| frame.render_widget_ref(&**widget, frame.area()))?;
|
||||
}
|
||||
AppState::GitWarning { screen } => frame.render_widget_ref(&*screen, frame.area()),
|
||||
})?;
|
||||
AppState::GitWarning { screen } => {
|
||||
terminal.draw(|frame| frame.render_widget_ref(&*screen, frame.area()))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Constraint;
|
||||
use ratatui::layout::Layout;
|
||||
use ratatui::layout::Margin;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::style::Color;
|
||||
use ratatui::style::Style;
|
||||
@@ -13,11 +8,13 @@ use ratatui::style::Styled;
|
||||
use ratatui::style::Stylize;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::text::Span;
|
||||
use ratatui::widgets::Block;
|
||||
use ratatui::widgets::BorderType;
|
||||
use ratatui::widgets::Borders;
|
||||
use ratatui::widgets::StatefulWidgetRef;
|
||||
use ratatui::widgets::Widget;
|
||||
use ratatui::widgets::WidgetRef;
|
||||
use tui_textarea::Input;
|
||||
use tui_textarea::Key;
|
||||
use tui_textarea::TextArea;
|
||||
|
||||
use super::chat_composer_history::ChatComposerHistory;
|
||||
use super::command_popup::CommandPopup;
|
||||
@@ -25,10 +22,7 @@ use super::file_search_popup::FileSearchPopup;
|
||||
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use crate::bottom_pane::textarea::TextArea;
|
||||
use crate::bottom_pane::textarea::TextAreaState;
|
||||
use codex_file_search::FileMatch;
|
||||
use std::cell::RefCell;
|
||||
|
||||
const BASE_PLACEHOLDER_TEXT: &str = "...";
|
||||
/// If the pasted content exceeds this number of characters, replace it with a
|
||||
@@ -41,14 +35,8 @@ pub enum InputResult {
|
||||
None,
|
||||
}
|
||||
|
||||
struct TokenUsageInfo {
|
||||
token_usage: TokenUsage,
|
||||
model_context_window: Option<u64>,
|
||||
}
|
||||
|
||||
pub(crate) struct ChatComposer {
|
||||
textarea: TextArea,
|
||||
textarea_state: RefCell<TextAreaState>,
|
||||
pub(crate) struct ChatComposer<'a> {
|
||||
textarea: TextArea<'a>,
|
||||
active_popup: ActivePopup,
|
||||
app_event_tx: AppEventSender,
|
||||
history: ChatComposerHistory,
|
||||
@@ -57,8 +45,6 @@ pub(crate) struct ChatComposer {
|
||||
dismissed_file_popup_token: Option<String>,
|
||||
current_file_query: Option<String>,
|
||||
pending_pastes: Vec<(String, String)>,
|
||||
token_usage_info: Option<TokenUsageInfo>,
|
||||
has_focus: bool,
|
||||
}
|
||||
|
||||
/// Popup state – at most one can be visible at any time.
|
||||
@@ -68,17 +54,20 @@ enum ActivePopup {
|
||||
File(FileSearchPopup),
|
||||
}
|
||||
|
||||
impl ChatComposer {
|
||||
impl ChatComposer<'_> {
|
||||
pub fn new(
|
||||
has_input_focus: bool,
|
||||
app_event_tx: AppEventSender,
|
||||
enhanced_keys_supported: bool,
|
||||
) -> Self {
|
||||
let mut textarea = TextArea::default();
|
||||
textarea.set_placeholder_text(BASE_PLACEHOLDER_TEXT);
|
||||
textarea.set_cursor_line_style(ratatui::style::Style::default());
|
||||
|
||||
let use_shift_enter_hint = enhanced_keys_supported;
|
||||
|
||||
Self {
|
||||
textarea: TextArea::new(),
|
||||
textarea_state: RefCell::new(TextAreaState::default()),
|
||||
let mut this = Self {
|
||||
textarea,
|
||||
active_popup: ActivePopup::None,
|
||||
app_event_tx,
|
||||
history: ChatComposerHistory::new(),
|
||||
@@ -87,13 +76,13 @@ impl ChatComposer {
|
||||
dismissed_file_popup_token: None,
|
||||
current_file_query: None,
|
||||
pending_pastes: Vec::new(),
|
||||
token_usage_info: None,
|
||||
has_focus: has_input_focus,
|
||||
}
|
||||
};
|
||||
this.update_border(has_input_focus);
|
||||
this
|
||||
}
|
||||
|
||||
pub fn desired_height(&self, width: u16) -> u16 {
|
||||
self.textarea.desired_height(width - 1)
|
||||
pub fn desired_height(&self) -> u16 {
|
||||
self.textarea.lines().len().max(1) as u16
|
||||
+ match &self.active_popup {
|
||||
ActivePopup::None => 1u16,
|
||||
ActivePopup::Command(c) => c.calculate_required_height(),
|
||||
@@ -101,21 +90,6 @@ impl ChatComposer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cursor_pos(&self, area: Rect) -> Option<(u16, u16)> {
|
||||
let popup_height = match &self.active_popup {
|
||||
ActivePopup::Command(popup) => popup.calculate_required_height(),
|
||||
ActivePopup::File(popup) => popup.calculate_required_height(),
|
||||
ActivePopup::None => 1,
|
||||
};
|
||||
let [textarea_rect, _] =
|
||||
Layout::vertical([Constraint::Min(0), Constraint::Max(popup_height)]).areas(area);
|
||||
let mut textarea_rect = textarea_rect;
|
||||
textarea_rect.width = textarea_rect.width.saturating_sub(1);
|
||||
textarea_rect.x += 1;
|
||||
let state = self.textarea_state.borrow();
|
||||
self.textarea.cursor_pos_with_state(textarea_rect, &state)
|
||||
}
|
||||
|
||||
/// Returns true if the composer currently contains no user input.
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.textarea.is_empty()
|
||||
@@ -129,10 +103,28 @@ impl ChatComposer {
|
||||
token_usage: TokenUsage,
|
||||
model_context_window: Option<u64>,
|
||||
) {
|
||||
self.token_usage_info = Some(TokenUsageInfo {
|
||||
token_usage,
|
||||
model_context_window,
|
||||
});
|
||||
let placeholder = match (token_usage.total_tokens, model_context_window) {
|
||||
(total_tokens, Some(context_window)) => {
|
||||
let percent_remaining: u8 = if context_window > 0 {
|
||||
// Calculate the percentage of context left.
|
||||
let percent = 100.0 - (total_tokens as f32 / context_window as f32 * 100.0);
|
||||
percent.clamp(0.0, 100.0) as u8
|
||||
} else {
|
||||
// If we don't have a context window, we cannot compute the
|
||||
// percentage.
|
||||
100
|
||||
};
|
||||
// When https://github.com/openai/codex/issues/1257 is resolved,
|
||||
// check if `percent_remaining < 25`, and if so, recommend
|
||||
// /compact.
|
||||
format!("{BASE_PLACEHOLDER_TEXT} — {percent_remaining}% context left")
|
||||
}
|
||||
(total_tokens, None) => {
|
||||
format!("{BASE_PLACEHOLDER_TEXT} — {total_tokens} tokens used")
|
||||
}
|
||||
};
|
||||
|
||||
self.textarea.set_placeholder_text(placeholder);
|
||||
}
|
||||
|
||||
/// Record the history metadata advertised by `SessionConfiguredEvent` so
|
||||
@@ -150,12 +142,8 @@ impl ChatComposer {
|
||||
offset: usize,
|
||||
entry: Option<String>,
|
||||
) -> bool {
|
||||
let Some(text) = self.history.on_entry_response(log_id, offset, entry) else {
|
||||
return false;
|
||||
};
|
||||
self.textarea.set_text(&text);
|
||||
self.textarea.set_cursor(0);
|
||||
true
|
||||
self.history
|
||||
.on_entry_response(log_id, offset, entry, &mut self.textarea)
|
||||
}
|
||||
|
||||
pub fn handle_paste(&mut self, pasted: String) -> bool {
|
||||
@@ -191,7 +179,7 @@ impl ChatComposer {
|
||||
|
||||
pub fn set_ctrl_c_quit_hint(&mut self, show: bool, has_focus: bool) {
|
||||
self.ctrl_c_quit_hint = show;
|
||||
self.set_has_focus(has_focus);
|
||||
self.update_border(has_focus);
|
||||
}
|
||||
|
||||
/// Handle a key event coming from the main UI.
|
||||
@@ -219,47 +207,49 @@ impl ChatComposer {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
match key_event {
|
||||
KeyEvent {
|
||||
code: KeyCode::Up, ..
|
||||
} => {
|
||||
match key_event.into() {
|
||||
Input { key: Key::Up, .. } => {
|
||||
popup.move_up();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Down,
|
||||
..
|
||||
} => {
|
||||
Input { key: Key::Down, .. } => {
|
||||
popup.move_down();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Tab, ..
|
||||
} => {
|
||||
Input { key: Key::Tab, .. } => {
|
||||
if let Some(cmd) = popup.selected_command() {
|
||||
let first_line = self.textarea.text().lines().next().unwrap_or("");
|
||||
let first_line = self
|
||||
.textarea
|
||||
.lines()
|
||||
.first()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let starts_with_cmd = first_line
|
||||
.trim_start()
|
||||
.starts_with(&format!("/{}", cmd.command()));
|
||||
|
||||
if !starts_with_cmd {
|
||||
self.textarea.set_text(&format!("/{} ", cmd.command()));
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
let _ = self.textarea.insert_str(format!("/{} ", cmd.command()));
|
||||
}
|
||||
}
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
modifiers: KeyModifiers::NONE,
|
||||
..
|
||||
Input {
|
||||
key: Key::Enter,
|
||||
shift: false,
|
||||
alt: false,
|
||||
ctrl: false,
|
||||
} => {
|
||||
if let Some(cmd) = popup.selected_command() {
|
||||
// Send command to the app layer.
|
||||
self.app_event_tx.send(AppEvent::DispatchCommand(*cmd));
|
||||
|
||||
// Clear textarea so no residual text remains.
|
||||
self.textarea.set_text("");
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
|
||||
// Hide popup since the command has been dispatched.
|
||||
self.active_popup = ActivePopup::None;
|
||||
@@ -278,23 +268,16 @@ impl ChatComposer {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
match key_event {
|
||||
KeyEvent {
|
||||
code: KeyCode::Up, ..
|
||||
} => {
|
||||
match key_event.into() {
|
||||
Input { key: Key::Up, .. } => {
|
||||
popup.move_up();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Down,
|
||||
..
|
||||
} => {
|
||||
Input { key: Key::Down, .. } => {
|
||||
popup.move_down();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Esc, ..
|
||||
} => {
|
||||
Input { key: Key::Esc, .. } => {
|
||||
// Hide popup without modifying text, remember token to avoid immediate reopen.
|
||||
if let Some(tok) = Self::current_at_token(&self.textarea) {
|
||||
self.dismissed_file_popup_token = Some(tok.to_string());
|
||||
@@ -302,13 +285,12 @@ impl ChatComposer {
|
||||
self.active_popup = ActivePopup::None;
|
||||
(InputResult::None, true)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Tab, ..
|
||||
}
|
||||
| KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
modifiers: KeyModifiers::NONE,
|
||||
..
|
||||
Input { key: Key::Tab, .. }
|
||||
| Input {
|
||||
key: Key::Enter,
|
||||
ctrl: false,
|
||||
alt: false,
|
||||
shift: false,
|
||||
} => {
|
||||
if let Some(sel) = popup.selected_match() {
|
||||
let sel_path = sel.to_string();
|
||||
@@ -333,89 +315,46 @@ impl ChatComposer {
|
||||
/// - A token is delimited by ASCII whitespace (space, tab, newline).
|
||||
/// - If the token under the cursor starts with `@` and contains at least
|
||||
/// one additional character, that token (without `@`) is returned.
|
||||
fn current_at_token(textarea: &TextArea) -> Option<String> {
|
||||
let cursor_offset = textarea.cursor();
|
||||
let text = textarea.text();
|
||||
fn current_at_token(textarea: &tui_textarea::TextArea) -> Option<String> {
|
||||
let (row, col) = textarea.cursor();
|
||||
|
||||
// Adjust the provided byte offset to the nearest valid char boundary at or before it.
|
||||
let mut safe_cursor = cursor_offset.min(text.len());
|
||||
// If we're not on a char boundary, move back to the start of the current char.
|
||||
if safe_cursor < text.len() && !text.is_char_boundary(safe_cursor) {
|
||||
// Find the last valid boundary <= cursor_offset.
|
||||
safe_cursor = text
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= cursor_offset)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
}
|
||||
// Guard against out-of-bounds rows.
|
||||
let line = textarea.lines().get(row)?.as_str();
|
||||
|
||||
// Split the line around the (now safe) cursor position.
|
||||
let before_cursor = &text[..safe_cursor];
|
||||
let after_cursor = &text[safe_cursor..];
|
||||
// Calculate byte offset for cursor position
|
||||
let cursor_byte_offset = line.chars().take(col).map(|c| c.len_utf8()).sum::<usize>();
|
||||
|
||||
// Detect whether we're on whitespace at the cursor boundary.
|
||||
let at_whitespace = if safe_cursor < text.len() {
|
||||
text[safe_cursor..]
|
||||
.chars()
|
||||
.next()
|
||||
.map(|c| c.is_whitespace())
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
// Split the line at the cursor position so we can search for word
|
||||
// boundaries on both sides.
|
||||
let before_cursor = &line[..cursor_byte_offset];
|
||||
let after_cursor = &line[cursor_byte_offset..];
|
||||
|
||||
// Left candidate: token containing the cursor position.
|
||||
let start_left = before_cursor
|
||||
// Find start index (first character **after** the previous multi-byte whitespace).
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
let end_left_rel = after_cursor
|
||||
|
||||
// Find end index (first multi-byte whitespace **after** the cursor position).
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_left = safe_cursor + end_left_rel;
|
||||
let token_left = if start_left < end_left {
|
||||
Some(&text[start_left..end_left])
|
||||
let end_idx = cursor_byte_offset + end_rel_idx;
|
||||
|
||||
if start_idx >= end_idx {
|
||||
return None;
|
||||
}
|
||||
|
||||
let token = &line[start_idx..end_idx];
|
||||
|
||||
if token.starts_with('@') && token.len() > 1 {
|
||||
Some(token[1..].to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Right candidate: token immediately after any whitespace from the cursor.
|
||||
let ws_len_right: usize = after_cursor
|
||||
.chars()
|
||||
.take_while(|c| c.is_whitespace())
|
||||
.map(|c| c.len_utf8())
|
||||
.sum();
|
||||
let start_right = safe_cursor + ws_len_right;
|
||||
let end_right_rel = text[start_right..]
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(text.len() - start_right);
|
||||
let end_right = start_right + end_right_rel;
|
||||
let token_right = if start_right < end_right {
|
||||
Some(&text[start_right..end_right])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let left_at = token_left
|
||||
.filter(|t| t.starts_with('@') && t.len() > 1)
|
||||
.map(|t| t[1..].to_string());
|
||||
let right_at = token_right
|
||||
.filter(|t| t.starts_with('@') && t.len() > 1)
|
||||
.map(|t| t[1..].to_string());
|
||||
|
||||
if at_whitespace {
|
||||
return right_at.or(left_at);
|
||||
}
|
||||
if after_cursor.starts_with('@') {
|
||||
return right_at.or(left_at);
|
||||
}
|
||||
left_at.or(right_at)
|
||||
}
|
||||
|
||||
/// Replace the active `@token` (the one under the cursor) with `path`.
|
||||
@@ -424,73 +363,94 @@ impl ChatComposer {
|
||||
/// where the cursor is within the token and regardless of how many
|
||||
/// `@tokens` exist in the line.
|
||||
fn insert_selected_path(&mut self, path: &str) {
|
||||
let cursor_offset = self.textarea.cursor();
|
||||
let text = self.textarea.text();
|
||||
let (row, col) = self.textarea.cursor();
|
||||
|
||||
let before_cursor = &text[..cursor_offset];
|
||||
let after_cursor = &text[cursor_offset..];
|
||||
// Materialize the textarea lines so we can mutate them easily.
|
||||
let mut lines: Vec<String> = self.textarea.lines().to_vec();
|
||||
|
||||
// Determine token boundaries.
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
if let Some(line) = lines.get_mut(row) {
|
||||
// Calculate byte offset for cursor position
|
||||
let cursor_byte_offset = line.chars().take(col).map(|c| c.len_utf8()).sum::<usize>();
|
||||
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_idx = cursor_offset + end_rel_idx;
|
||||
let before_cursor = &line[..cursor_byte_offset];
|
||||
let after_cursor = &line[cursor_byte_offset..];
|
||||
|
||||
// Replace the slice `[start_idx, end_idx)` with the chosen path and a trailing space.
|
||||
let mut new_text =
|
||||
String::with_capacity(text.len() - (end_idx - start_idx) + path.len() + 1);
|
||||
new_text.push_str(&text[..start_idx]);
|
||||
new_text.push_str(path);
|
||||
new_text.push(' ');
|
||||
new_text.push_str(&text[end_idx..]);
|
||||
// Determine token boundaries.
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
|
||||
self.textarea.set_text(&new_text);
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_idx = cursor_byte_offset + end_rel_idx;
|
||||
|
||||
// Replace the slice `[start_idx, end_idx)` with the chosen path and a trailing space.
|
||||
let mut new_line =
|
||||
String::with_capacity(line.len() - (end_idx - start_idx) + path.len() + 1);
|
||||
new_line.push_str(&line[..start_idx]);
|
||||
new_line.push_str(path);
|
||||
new_line.push(' ');
|
||||
new_line.push_str(&line[end_idx..]);
|
||||
|
||||
*line = new_line;
|
||||
|
||||
// Re-populate the textarea.
|
||||
let new_text = lines.join("\n");
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
let _ = self.textarea.insert_str(new_text);
|
||||
|
||||
// Note: tui-textarea currently exposes only relative cursor
|
||||
// movements. Leaving the cursor position unchanged is acceptable
|
||||
// as subsequent typing will move the cursor naturally.
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle key event when no popup is visible.
|
||||
fn handle_key_event_without_popup(&mut self, key_event: KeyEvent) -> (InputResult, bool) {
|
||||
match key_event {
|
||||
let input: Input = key_event.into();
|
||||
match input {
|
||||
// -------------------------------------------------------------
|
||||
// History navigation (Up / Down) – only when the composer is not
|
||||
// empty or when the cursor is at the correct position, to avoid
|
||||
// interfering with normal cursor movement.
|
||||
// -------------------------------------------------------------
|
||||
KeyEvent {
|
||||
code: KeyCode::Up | KeyCode::Down,
|
||||
..
|
||||
} => {
|
||||
if self
|
||||
.history
|
||||
.should_handle_navigation(self.textarea.text(), self.textarea.cursor())
|
||||
{
|
||||
let replace_text = match key_event.code {
|
||||
KeyCode::Up => self.history.navigate_up(&self.app_event_tx),
|
||||
KeyCode::Down => self.history.navigate_down(&self.app_event_tx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if let Some(text) = replace_text {
|
||||
self.textarea.set_text(&text);
|
||||
self.textarea.set_cursor(0);
|
||||
Input { key: Key::Up, .. } => {
|
||||
if self.history.should_handle_navigation(&self.textarea) {
|
||||
let consumed = self
|
||||
.history
|
||||
.navigate_up(&mut self.textarea, &self.app_event_tx);
|
||||
if consumed {
|
||||
return (InputResult::None, true);
|
||||
}
|
||||
}
|
||||
self.handle_input_basic(key_event)
|
||||
self.handle_input_basic(input)
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
modifiers: KeyModifiers::NONE,
|
||||
..
|
||||
Input { key: Key::Down, .. } => {
|
||||
if self.history.should_handle_navigation(&self.textarea) {
|
||||
let consumed = self
|
||||
.history
|
||||
.navigate_down(&mut self.textarea, &self.app_event_tx);
|
||||
if consumed {
|
||||
return (InputResult::None, true);
|
||||
}
|
||||
}
|
||||
self.handle_input_basic(input)
|
||||
}
|
||||
Input {
|
||||
key: Key::Enter,
|
||||
shift: false,
|
||||
alt: false,
|
||||
ctrl: false,
|
||||
} => {
|
||||
let mut text = self.textarea.text().to_string();
|
||||
self.textarea.set_text("");
|
||||
let mut text = self.textarea.lines().join("\n");
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
|
||||
// Replace all pending pastes in the text
|
||||
for (placeholder, actual) in &self.pending_pastes {
|
||||
@@ -507,15 +467,41 @@ impl ChatComposer {
|
||||
(InputResult::Submitted(text), true)
|
||||
}
|
||||
}
|
||||
Input {
|
||||
key: Key::Enter, ..
|
||||
}
|
||||
| Input {
|
||||
key: Key::Char('j'),
|
||||
ctrl: true,
|
||||
alt: false,
|
||||
shift: false,
|
||||
} => {
|
||||
self.textarea.insert_newline();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
Input {
|
||||
key: Key::Char('d'),
|
||||
ctrl: true,
|
||||
alt: false,
|
||||
shift: false,
|
||||
} => {
|
||||
self.textarea.input(Input {
|
||||
key: Key::Delete,
|
||||
ctrl: false,
|
||||
alt: false,
|
||||
shift: false,
|
||||
});
|
||||
(InputResult::None, true)
|
||||
}
|
||||
input => self.handle_input_basic(input),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle generic Input events that modify the textarea content.
|
||||
fn handle_input_basic(&mut self, input: KeyEvent) -> (InputResult, bool) {
|
||||
fn handle_input_basic(&mut self, input: Input) -> (InputResult, bool) {
|
||||
// Special handling for backspace on placeholders
|
||||
if let KeyEvent {
|
||||
code: KeyCode::Backspace,
|
||||
if let Input {
|
||||
key: Key::Backspace,
|
||||
..
|
||||
} = input
|
||||
{
|
||||
@@ -524,9 +510,20 @@ impl ChatComposer {
|
||||
}
|
||||
}
|
||||
|
||||
if let Input {
|
||||
key: Key::Char('u'),
|
||||
ctrl: true,
|
||||
alt: false,
|
||||
..
|
||||
} = input
|
||||
{
|
||||
self.textarea.delete_line_by_head();
|
||||
return (InputResult::None, true);
|
||||
}
|
||||
|
||||
// Normal input handling
|
||||
self.textarea.input(input);
|
||||
let text_after = self.textarea.text();
|
||||
let text_after = self.textarea.lines().join("\n");
|
||||
|
||||
// Check if any placeholders were removed and remove their corresponding pending pastes
|
||||
self.pending_pastes
|
||||
@@ -538,16 +535,21 @@ impl ChatComposer {
|
||||
/// Attempts to remove a placeholder if the cursor is at the end of one.
|
||||
/// Returns true if a placeholder was removed.
|
||||
fn try_remove_placeholder_at_cursor(&mut self) -> bool {
|
||||
let p = self.textarea.cursor();
|
||||
let text = self.textarea.text();
|
||||
let (row, col) = self.textarea.cursor();
|
||||
let line = self
|
||||
.textarea
|
||||
.lines()
|
||||
.get(row)
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Find any placeholder that ends at the cursor position
|
||||
let placeholder_to_remove = self.pending_pastes.iter().find_map(|(ph, _)| {
|
||||
if p < ph.len() {
|
||||
if col < ph.len() {
|
||||
return None;
|
||||
}
|
||||
let potential_ph_start = p - ph.len();
|
||||
if text[potential_ph_start..p] == *ph {
|
||||
let potential_ph_start = col - ph.len();
|
||||
if line[potential_ph_start..col] == *ph {
|
||||
Some(ph.clone())
|
||||
} else {
|
||||
None
|
||||
@@ -555,7 +557,17 @@ impl ChatComposer {
|
||||
});
|
||||
|
||||
if let Some(placeholder) = placeholder_to_remove {
|
||||
self.textarea.replace_range(p - placeholder.len()..p, "");
|
||||
// Remove the entire placeholder from the text
|
||||
let placeholder_len = placeholder.len();
|
||||
for _ in 0..placeholder_len {
|
||||
self.textarea.input(Input {
|
||||
key: Key::Backspace,
|
||||
ctrl: false,
|
||||
alt: false,
|
||||
shift: false,
|
||||
});
|
||||
}
|
||||
// Remove from pending pastes
|
||||
self.pending_pastes.retain(|(ph, _)| ph != &placeholder);
|
||||
true
|
||||
} else {
|
||||
@@ -567,7 +579,16 @@ impl ChatComposer {
|
||||
/// textarea. This must be called after every modification that can change
|
||||
/// the text so the popup is shown/updated/hidden as appropriate.
|
||||
fn sync_command_popup(&mut self) {
|
||||
let first_line = self.textarea.text().lines().next().unwrap_or("");
|
||||
// Inspect only the first line to decide whether to show the popup. In
|
||||
// the common case (no leading slash) we avoid copying the entire
|
||||
// textarea contents.
|
||||
let first_line = self
|
||||
.textarea
|
||||
.lines()
|
||||
.first()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let input_starts_with_slash = first_line.starts_with('/');
|
||||
match &mut self.active_popup {
|
||||
ActivePopup::Command(popup) => {
|
||||
@@ -623,29 +644,74 @@ impl ChatComposer {
|
||||
self.dismissed_file_popup_token = None;
|
||||
}
|
||||
|
||||
fn set_has_focus(&mut self, has_focus: bool) {
|
||||
self.has_focus = has_focus;
|
||||
fn update_border(&mut self, has_focus: bool) {
|
||||
let border_style = if has_focus {
|
||||
Style::default().fg(Color::Cyan)
|
||||
} else {
|
||||
Style::default().dim()
|
||||
};
|
||||
|
||||
self.textarea.set_block(
|
||||
ratatui::widgets::Block::default()
|
||||
.borders(Borders::LEFT)
|
||||
.border_type(BorderType::QuadrantOutside)
|
||||
.border_style(border_style),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatComposer {
|
||||
impl WidgetRef for &ChatComposer<'_> {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
let popup_height = match &self.active_popup {
|
||||
ActivePopup::Command(popup) => popup.calculate_required_height(),
|
||||
ActivePopup::File(popup) => popup.calculate_required_height(),
|
||||
ActivePopup::None => 1,
|
||||
};
|
||||
let [textarea_rect, popup_rect] =
|
||||
Layout::vertical([Constraint::Min(0), Constraint::Max(popup_height)]).areas(area);
|
||||
match &self.active_popup {
|
||||
ActivePopup::Command(popup) => {
|
||||
popup.render_ref(popup_rect, buf);
|
||||
let popup_height = popup.calculate_required_height();
|
||||
|
||||
// Split the provided rect so that the popup is rendered at the
|
||||
// **bottom** and the textarea occupies the remaining space above.
|
||||
let popup_height = popup_height.min(area.height);
|
||||
let textarea_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: area.height.saturating_sub(popup_height),
|
||||
};
|
||||
let popup_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + textarea_rect.height,
|
||||
width: area.width,
|
||||
height: popup_height,
|
||||
};
|
||||
|
||||
popup.render(popup_rect, buf);
|
||||
self.textarea.render(textarea_rect, buf);
|
||||
}
|
||||
ActivePopup::File(popup) => {
|
||||
popup.render_ref(popup_rect, buf);
|
||||
let popup_height = popup.calculate_required_height();
|
||||
|
||||
let popup_height = popup_height.min(area.height);
|
||||
let textarea_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: area.height.saturating_sub(popup_height),
|
||||
};
|
||||
let popup_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + textarea_rect.height,
|
||||
width: area.width,
|
||||
height: popup_height,
|
||||
};
|
||||
|
||||
popup.render(popup_rect, buf);
|
||||
self.textarea.render(textarea_rect, buf);
|
||||
}
|
||||
ActivePopup::None => {
|
||||
let bottom_line_rect = popup_rect;
|
||||
let mut textarea_rect = area;
|
||||
textarea_rect.height = textarea_rect.height.saturating_sub(1);
|
||||
self.textarea.render(textarea_rect, buf);
|
||||
let mut bottom_line_rect = area;
|
||||
bottom_line_rect.y += textarea_rect.height;
|
||||
bottom_line_rect.height = 1;
|
||||
let key_hint_style = Style::default().fg(Color::Cyan);
|
||||
let hint = if self.ctrl_c_quit_hint {
|
||||
vec![
|
||||
@@ -674,56 +740,6 @@ impl WidgetRef for &ChatComposer {
|
||||
.render_ref(bottom_line_rect, buf);
|
||||
}
|
||||
}
|
||||
Block::default()
|
||||
.border_style(Style::default().dim())
|
||||
.borders(Borders::LEFT)
|
||||
.border_type(BorderType::QuadrantOutside)
|
||||
.border_style(Style::default().fg(if self.has_focus {
|
||||
Color::Cyan
|
||||
} else {
|
||||
Color::Gray
|
||||
}))
|
||||
.render_ref(
|
||||
Rect::new(textarea_rect.x, textarea_rect.y, 1, textarea_rect.height),
|
||||
buf,
|
||||
);
|
||||
let mut textarea_rect = textarea_rect;
|
||||
textarea_rect.width = textarea_rect.width.saturating_sub(1);
|
||||
textarea_rect.x += 1;
|
||||
let mut state = self.textarea_state.borrow_mut();
|
||||
StatefulWidgetRef::render_ref(&(&self.textarea), textarea_rect, buf, &mut state);
|
||||
if self.textarea.text().is_empty() {
|
||||
let placeholder = if let Some(token_usage_info) = &self.token_usage_info {
|
||||
let token_usage = &token_usage_info.token_usage;
|
||||
let model_context_window = token_usage_info.model_context_window;
|
||||
match (token_usage.total_tokens, model_context_window) {
|
||||
(total_tokens, Some(context_window)) => {
|
||||
let percent_remaining: u8 = if context_window > 0 {
|
||||
// Calculate the percentage of context left.
|
||||
let percent =
|
||||
100.0 - (total_tokens as f32 / context_window as f32 * 100.0);
|
||||
percent.clamp(0.0, 100.0) as u8
|
||||
} else {
|
||||
// If we don't have a context window, we cannot compute the
|
||||
// percentage.
|
||||
100
|
||||
};
|
||||
// When https://github.com/openai/codex/issues/1257 is resolved,
|
||||
// check if `percent_remaining < 25`, and if so, recommend
|
||||
// /compact.
|
||||
format!("{BASE_PLACEHOLDER_TEXT} — {percent_remaining}% context left")
|
||||
}
|
||||
(total_tokens, None) => {
|
||||
format!("{BASE_PLACEHOLDER_TEXT} — {total_tokens} tokens used")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
BASE_PLACEHOLDER_TEXT.to_string()
|
||||
};
|
||||
Line::from(placeholder)
|
||||
.style(Style::default().dim())
|
||||
.render_ref(textarea_rect.inner(Margin::new(1, 0)), buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -733,7 +749,7 @@ mod tests {
|
||||
use crate::bottom_pane::ChatComposer;
|
||||
use crate::bottom_pane::InputResult;
|
||||
use crate::bottom_pane::chat_composer::LARGE_PASTE_CHAR_THRESHOLD;
|
||||
use crate::bottom_pane::textarea::TextArea;
|
||||
use tui_textarea::TextArea;
|
||||
|
||||
#[test]
|
||||
fn test_current_at_token_basic_cases() {
|
||||
@@ -776,9 +792,9 @@ mod tests {
|
||||
];
|
||||
|
||||
for (input, cursor_pos, expected, description) in test_cases {
|
||||
let mut textarea = TextArea::new();
|
||||
let mut textarea = TextArea::default();
|
||||
textarea.insert_str(input);
|
||||
textarea.set_cursor(cursor_pos);
|
||||
textarea.move_cursor(tui_textarea::CursorMove::Jump(0, cursor_pos));
|
||||
|
||||
let result = ChatComposer::current_at_token(&textarea);
|
||||
assert_eq!(
|
||||
@@ -810,9 +826,9 @@ mod tests {
|
||||
];
|
||||
|
||||
for (input, cursor_pos, expected, description) in test_cases {
|
||||
let mut textarea = TextArea::new();
|
||||
let mut textarea = TextArea::default();
|
||||
textarea.insert_str(input);
|
||||
textarea.set_cursor(cursor_pos);
|
||||
textarea.move_cursor(tui_textarea::CursorMove::Jump(0, cursor_pos));
|
||||
|
||||
let result = ChatComposer::current_at_token(&textarea);
|
||||
assert_eq!(
|
||||
@@ -847,13 +863,13 @@ mod tests {
|
||||
// Full-width space boundaries
|
||||
(
|
||||
"test @İstanbul",
|
||||
8,
|
||||
6,
|
||||
Some("İstanbul".to_string()),
|
||||
"@ token after full-width space",
|
||||
),
|
||||
(
|
||||
"@ЙЦУ @诶",
|
||||
10,
|
||||
6,
|
||||
Some("诶".to_string()),
|
||||
"Full-width space between Unicode tokens",
|
||||
),
|
||||
@@ -867,9 +883,9 @@ mod tests {
|
||||
];
|
||||
|
||||
for (input, cursor_pos, expected, description) in test_cases {
|
||||
let mut textarea = TextArea::new();
|
||||
let mut textarea = TextArea::default();
|
||||
textarea.insert_str(input);
|
||||
textarea.set_cursor(cursor_pos);
|
||||
textarea.move_cursor(tui_textarea::CursorMove::Jump(0, cursor_pos));
|
||||
|
||||
let result = ChatComposer::current_at_token(&textarea);
|
||||
assert_eq!(
|
||||
@@ -891,7 +907,7 @@ mod tests {
|
||||
|
||||
let needs_redraw = composer.handle_paste("hello".to_string());
|
||||
assert!(needs_redraw);
|
||||
assert_eq!(composer.textarea.text(), "hello");
|
||||
assert_eq!(composer.textarea.lines(), ["hello"]);
|
||||
assert!(composer.pending_pastes.is_empty());
|
||||
|
||||
let (result, _) =
|
||||
@@ -916,7 +932,7 @@ mod tests {
|
||||
let needs_redraw = composer.handle_paste(large.clone());
|
||||
assert!(needs_redraw);
|
||||
let placeholder = format!("[Pasted Content {} chars]", large.chars().count());
|
||||
assert_eq!(composer.textarea.text(), placeholder);
|
||||
assert_eq!(composer.textarea.lines(), [placeholder.as_str()]);
|
||||
assert_eq!(composer.pending_pastes.len(), 1);
|
||||
assert_eq!(composer.pending_pastes[0].0, placeholder);
|
||||
assert_eq!(composer.pending_pastes[0].1, large);
|
||||
@@ -992,7 +1008,7 @@ mod tests {
|
||||
composer.handle_paste("b".repeat(LARGE_PASTE_CHAR_THRESHOLD + 4));
|
||||
composer.handle_paste("c".repeat(LARGE_PASTE_CHAR_THRESHOLD + 6));
|
||||
// Move cursor to end and press backspace
|
||||
composer.textarea.set_cursor(composer.textarea.text().len());
|
||||
composer.textarea.move_cursor(tui_textarea::CursorMove::End);
|
||||
composer.handle_key_event(KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE));
|
||||
}
|
||||
|
||||
@@ -1107,7 +1123,7 @@ mod tests {
|
||||
current_pos += content.len();
|
||||
}
|
||||
(
|
||||
composer.textarea.text().to_string(),
|
||||
composer.textarea.lines().join("\n"),
|
||||
composer.pending_pastes.len(),
|
||||
current_pos,
|
||||
)
|
||||
@@ -1118,18 +1134,25 @@ mod tests {
|
||||
let mut deletion_states = vec![];
|
||||
|
||||
// First deletion
|
||||
composer.textarea.set_cursor(states[0].2);
|
||||
composer
|
||||
.textarea
|
||||
.move_cursor(tui_textarea::CursorMove::Jump(0, states[0].2 as u16));
|
||||
composer.handle_key_event(KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE));
|
||||
deletion_states.push((
|
||||
composer.textarea.text().to_string(),
|
||||
composer.textarea.lines().join("\n"),
|
||||
composer.pending_pastes.len(),
|
||||
));
|
||||
|
||||
// Second deletion
|
||||
composer.textarea.set_cursor(composer.textarea.text().len());
|
||||
composer
|
||||
.textarea
|
||||
.move_cursor(tui_textarea::CursorMove::Jump(
|
||||
0,
|
||||
composer.textarea.lines().join("\n").len() as u16,
|
||||
));
|
||||
composer.handle_key_event(KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE));
|
||||
deletion_states.push((
|
||||
composer.textarea.text().to_string(),
|
||||
composer.textarea.lines().join("\n"),
|
||||
composer.pending_pastes.len(),
|
||||
));
|
||||
|
||||
@@ -1168,13 +1191,17 @@ mod tests {
|
||||
composer.handle_paste(paste.clone());
|
||||
composer
|
||||
.textarea
|
||||
.set_cursor((placeholder.len() - pos_from_end) as usize);
|
||||
.move_cursor(tui_textarea::CursorMove::Jump(
|
||||
0,
|
||||
(placeholder.len() - pos_from_end) as u16,
|
||||
));
|
||||
composer.handle_key_event(KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE));
|
||||
let result = (
|
||||
composer.textarea.text().contains(&placeholder),
|
||||
composer.textarea.lines().join("\n").contains(&placeholder),
|
||||
composer.pending_pastes.len(),
|
||||
);
|
||||
composer.textarea.set_text("");
|
||||
composer.textarea.select_all();
|
||||
composer.textarea.cut();
|
||||
result
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tui_textarea::CursorMove;
|
||||
use tui_textarea::TextArea;
|
||||
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -64,52 +67,59 @@ impl ChatComposerHistory {
|
||||
|
||||
/// Should Up/Down key presses be interpreted as history navigation given
|
||||
/// the current content and cursor position of `textarea`?
|
||||
pub fn should_handle_navigation(&self, text: &str, cursor: usize) -> bool {
|
||||
pub fn should_handle_navigation(&self, textarea: &TextArea) -> bool {
|
||||
if self.history_entry_count == 0 && self.local_history.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if text.is_empty() {
|
||||
if textarea.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Textarea is not empty – only navigate when cursor is at start and
|
||||
// text matches last recalled history entry so regular editing is not
|
||||
// hijacked.
|
||||
if cursor != 0 {
|
||||
let (row, col) = textarea.cursor();
|
||||
if row != 0 || col != 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
matches!(&self.last_history_text, Some(prev) if prev == text)
|
||||
let lines = textarea.lines();
|
||||
matches!(&self.last_history_text, Some(prev) if prev == &lines.join("\n"))
|
||||
}
|
||||
|
||||
/// Handle <Up>. Returns true when the key was consumed and the caller
|
||||
/// should request a redraw.
|
||||
pub fn navigate_up(&mut self, app_event_tx: &AppEventSender) -> Option<String> {
|
||||
pub fn navigate_up(&mut self, textarea: &mut TextArea, app_event_tx: &AppEventSender) -> bool {
|
||||
let total_entries = self.history_entry_count + self.local_history.len();
|
||||
if total_entries == 0 {
|
||||
return None;
|
||||
return false;
|
||||
}
|
||||
|
||||
let next_idx = match self.history_cursor {
|
||||
None => (total_entries as isize) - 1,
|
||||
Some(0) => return None, // already at oldest
|
||||
Some(0) => return true, // already at oldest
|
||||
Some(idx) => idx - 1,
|
||||
};
|
||||
|
||||
self.history_cursor = Some(next_idx);
|
||||
self.populate_history_at_index(next_idx as usize, app_event_tx)
|
||||
self.populate_history_at_index(next_idx as usize, textarea, app_event_tx);
|
||||
true
|
||||
}
|
||||
|
||||
/// Handle <Down>.
|
||||
pub fn navigate_down(&mut self, app_event_tx: &AppEventSender) -> Option<String> {
|
||||
pub fn navigate_down(
|
||||
&mut self,
|
||||
textarea: &mut TextArea,
|
||||
app_event_tx: &AppEventSender,
|
||||
) -> bool {
|
||||
let total_entries = self.history_entry_count + self.local_history.len();
|
||||
if total_entries == 0 {
|
||||
return None;
|
||||
return false;
|
||||
}
|
||||
|
||||
let next_idx_opt = match self.history_cursor {
|
||||
None => return None, // not browsing
|
||||
None => return false, // not browsing
|
||||
Some(idx) if (idx as usize) + 1 >= total_entries => None,
|
||||
Some(idx) => Some(idx + 1),
|
||||
};
|
||||
@@ -117,15 +127,16 @@ impl ChatComposerHistory {
|
||||
match next_idx_opt {
|
||||
Some(idx) => {
|
||||
self.history_cursor = Some(idx);
|
||||
self.populate_history_at_index(idx as usize, app_event_tx)
|
||||
self.populate_history_at_index(idx as usize, textarea, app_event_tx);
|
||||
}
|
||||
None => {
|
||||
// Past newest – clear and exit browsing mode.
|
||||
self.history_cursor = None;
|
||||
self.last_history_text = None;
|
||||
Some(String::new())
|
||||
self.replace_textarea_content(textarea, "");
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Integrate a GetHistoryEntryResponse event.
|
||||
@@ -134,18 +145,19 @@ impl ChatComposerHistory {
|
||||
log_id: u64,
|
||||
offset: usize,
|
||||
entry: Option<String>,
|
||||
) -> Option<String> {
|
||||
textarea: &mut TextArea,
|
||||
) -> bool {
|
||||
if self.history_log_id != Some(log_id) {
|
||||
return None;
|
||||
return false;
|
||||
}
|
||||
let text = entry?;
|
||||
let Some(text) = entry else { return false };
|
||||
self.fetched_history.insert(offset, text.clone());
|
||||
|
||||
if self.history_cursor == Some(offset as isize) {
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text);
|
||||
self.replace_textarea_content(textarea, &text);
|
||||
return true;
|
||||
}
|
||||
None
|
||||
false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
@@ -155,20 +167,21 @@ impl ChatComposerHistory {
|
||||
fn populate_history_at_index(
|
||||
&mut self,
|
||||
global_idx: usize,
|
||||
textarea: &mut TextArea,
|
||||
app_event_tx: &AppEventSender,
|
||||
) -> Option<String> {
|
||||
) {
|
||||
if global_idx >= self.history_entry_count {
|
||||
// Local entry.
|
||||
if let Some(text) = self
|
||||
.local_history
|
||||
.get(global_idx - self.history_entry_count)
|
||||
{
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text.clone());
|
||||
let t = text.clone();
|
||||
self.replace_textarea_content(textarea, &t);
|
||||
}
|
||||
} else if let Some(text) = self.fetched_history.get(&global_idx) {
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text.clone());
|
||||
let t = text.clone();
|
||||
self.replace_textarea_content(textarea, &t);
|
||||
} else if let Some(log_id) = self.history_log_id {
|
||||
let op = Op::GetHistoryEntryRequest {
|
||||
offset: global_idx,
|
||||
@@ -176,7 +189,14 @@ impl ChatComposerHistory {
|
||||
};
|
||||
app_event_tx.send(AppEvent::CodexOp(op));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn replace_textarea_content(&mut self, textarea: &mut TextArea, text: &str) {
|
||||
textarea.select_all();
|
||||
textarea.cut();
|
||||
let _ = textarea.insert_str(text);
|
||||
textarea.move_cursor(CursorMove::Jump(0, 0));
|
||||
self.last_history_text = Some(text.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,9 +217,11 @@ mod tests {
|
||||
// Pretend there are 3 persistent entries.
|
||||
history.set_metadata(1, 3);
|
||||
|
||||
let mut textarea = TextArea::default();
|
||||
|
||||
// First Up should request offset 2 (latest) and await async data.
|
||||
assert!(history.should_handle_navigation("", 0));
|
||||
assert!(history.navigate_up(&tx).is_none()); // don't replace the text yet
|
||||
assert!(history.should_handle_navigation(&textarea));
|
||||
assert!(history.navigate_up(&mut textarea, &tx));
|
||||
|
||||
// Verify that an AppEvent::CodexOp with the correct GetHistoryEntryRequest was sent.
|
||||
let event = rx.try_recv().expect("expected AppEvent to be sent");
|
||||
@@ -213,15 +235,14 @@ mod tests {
|
||||
},
|
||||
history_request1
|
||||
);
|
||||
assert_eq!(textarea.lines().join("\n"), ""); // still empty
|
||||
|
||||
// Inject the async response.
|
||||
assert_eq!(
|
||||
Some("latest".into()),
|
||||
history.on_entry_response(1, 2, Some("latest".into()))
|
||||
);
|
||||
assert!(history.on_entry_response(1, 2, Some("latest".into()), &mut textarea));
|
||||
assert_eq!(textarea.lines().join("\n"), "latest");
|
||||
|
||||
// Next Up should move to offset 1.
|
||||
assert!(history.navigate_up(&tx).is_none()); // don't replace the text yet
|
||||
assert!(history.navigate_up(&mut textarea, &tx));
|
||||
|
||||
// Verify second CodexOp event for offset 1.
|
||||
let event2 = rx.try_recv().expect("expected second event");
|
||||
@@ -236,9 +257,7 @@ mod tests {
|
||||
history_request_2
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some("older".into()),
|
||||
history.on_entry_response(1, 1, Some("older".into()))
|
||||
);
|
||||
history.on_entry_response(1, 1, Some("older".into()), &mut textarea);
|
||||
assert_eq!(textarea.lines().join("\n"), "older");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ mod chat_composer_history;
|
||||
mod command_popup;
|
||||
mod file_search_popup;
|
||||
mod status_indicator_view;
|
||||
mod textarea;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancellationEvent {
|
||||
@@ -37,7 +36,7 @@ use status_indicator_view::StatusIndicatorView;
|
||||
pub(crate) struct BottomPane<'a> {
|
||||
/// Composer is retained even when a BottomPaneView is displayed so the
|
||||
/// input state is retained when the view is closed.
|
||||
composer: ChatComposer,
|
||||
composer: ChatComposer<'a>,
|
||||
|
||||
/// If present, this is displayed instead of the `composer`.
|
||||
active_view: Option<Box<dyn BottomPaneView<'a> + 'a>>,
|
||||
@@ -75,19 +74,7 @@ impl BottomPane<'_> {
|
||||
self.active_view
|
||||
.as_ref()
|
||||
.map(|v| v.desired_height(width))
|
||||
.unwrap_or(self.composer.desired_height(width))
|
||||
}
|
||||
|
||||
pub fn cursor_pos(&self, area: Rect) -> Option<(u16, u16)> {
|
||||
// Hide the cursor whenever an overlay view is active (e.g. the
|
||||
// status indicator shown while a task is running, or approval modal).
|
||||
// In these states the textarea is not interactable, so we should not
|
||||
// show its caret.
|
||||
if self.active_view.is_some() {
|
||||
None
|
||||
} else {
|
||||
self.composer.cursor_pos(area)
|
||||
}
|
||||
.unwrap_or(self.composer.desired_height())
|
||||
}
|
||||
|
||||
/// Forward a key event to the active view or the composer.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::codex_wrapper::CodexConversation;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
@@ -389,7 +390,6 @@ impl ChatWidget<'_> {
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
exit_code,
|
||||
duration,
|
||||
stdout,
|
||||
stderr,
|
||||
}) => {
|
||||
@@ -400,7 +400,7 @@ impl ChatWidget<'_> {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
duration: Duration::from_secs(0),
|
||||
},
|
||||
));
|
||||
}
|
||||
@@ -509,10 +509,6 @@ impl ChatWidget<'_> {
|
||||
self.bottom_pane
|
||||
.set_token_usage(self.token_usage.clone(), self.config.model_context_window);
|
||||
}
|
||||
|
||||
pub fn cursor_pos(&self, area: Rect) -> Option<(u16, u16)> {
|
||||
self.bottom_pane.cursor_pos(area)
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatWidget<'_> {
|
||||
|
||||
@@ -17,12 +17,6 @@ pub struct Cli {
|
||||
#[arg(long, short = 'm')]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Convenience flag to select the local Ollama provider.
|
||||
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
|
||||
/// creates a model_providers.ollama entry in config.toml if missing.
|
||||
#[arg(long = "ollama", default_value_t = false)]
|
||||
pub ollama: bool,
|
||||
|
||||
/// Configuration profile from config.toml to specify default options.
|
||||
#[arg(long = "profile", short = 'p')]
|
||||
pub config_profile: Option<String>,
|
||||
|
||||
@@ -9,17 +9,9 @@ use codex_core::config_types::SandboxMode;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_login::load_auth;
|
||||
use crossterm::event::Event as CEvent;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
use crossterm::event::{self};
|
||||
use crossterm::terminal::disable_raw_mode;
|
||||
use crossterm::terminal::enable_raw_mode;
|
||||
use log_layer::TuiLogLayer;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::path::PathBuf;
|
||||
use tracing::error;
|
||||
use tracing_appender::non_blocking;
|
||||
@@ -55,253 +47,6 @@ mod updates;
|
||||
use color_eyre::owo_colors::OwoColorize;
|
||||
|
||||
pub use cli::Cli;
|
||||
// Centralized Ollama helpers from core
|
||||
use codex_core::providers::ollama::OllamaClient;
|
||||
use codex_core::providers::ollama::TuiProgressReporter;
|
||||
use codex_core::providers::ollama::ensure_configured_and_running;
|
||||
use codex_core::providers::ollama::ensure_model_available;
|
||||
use codex_core::providers::ollama::read_config_models;
|
||||
use codex_core::providers::ollama::read_provider_state;
|
||||
use codex_core::providers::ollama::write_config_models;
|
||||
|
||||
fn print_inline_message_no_models(
|
||||
host_root: &str,
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
let path = config_path.display().to_string();
|
||||
// green bold helper
|
||||
let b = |s: &str| format!("\x1b[1m{s}\x1b[0m");
|
||||
// Ensure we start clean at column 0.
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(
|
||||
format!(
|
||||
"{}\n\n",
|
||||
b("we've discovered no models on your local Ollama instance.")
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
out.write_all(format!("\rendpoint: {host_root}\n").as_bytes())?;
|
||||
if provider_was_present_before {
|
||||
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
|
||||
)?;
|
||||
}
|
||||
out.write_all(
|
||||
b"\rmodels: none recorded in config (pull models with `ollama pull <model>`).\n\n",
|
||||
)?;
|
||||
out.flush()
|
||||
}
|
||||
|
||||
fn run_inline_models_picker(
|
||||
host_root: &str,
|
||||
available: &[String],
|
||||
preselected: &[String],
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
let mut selected: Vec<bool> = available
|
||||
.iter()
|
||||
.map(|m| preselected.iter().any(|x| x == m))
|
||||
.collect();
|
||||
let mut cursor: usize = 0;
|
||||
|
||||
let mut first = true;
|
||||
let mut lines_printed: usize = 0;
|
||||
|
||||
enable_raw_mode()?;
|
||||
|
||||
loop {
|
||||
// Render block
|
||||
render_inline_picker(
|
||||
&mut out,
|
||||
host_root,
|
||||
available,
|
||||
&selected,
|
||||
cursor,
|
||||
&mut first,
|
||||
&mut lines_printed,
|
||||
)?;
|
||||
|
||||
// Wait for key
|
||||
match event::read()? {
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Up, ..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('k'),
|
||||
..
|
||||
}) => {
|
||||
cursor = cursor.saturating_sub(1);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Down,
|
||||
..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('j'),
|
||||
..
|
||||
}) => {
|
||||
if cursor + 1 < available.len() {
|
||||
cursor += 1;
|
||||
}
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char(' '),
|
||||
..
|
||||
}) => {
|
||||
if let Some(s) = selected.get_mut(cursor) {
|
||||
*s = !*s;
|
||||
}
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('a'),
|
||||
..
|
||||
}) => {
|
||||
let all_sel = selected.iter().all(|s| *s);
|
||||
selected.fill(!all_sel);
|
||||
}
|
||||
// Allow quitting the entire app from the inline picker with Ctrl+C or Ctrl+D.
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('c'),
|
||||
modifiers,
|
||||
..
|
||||
}) if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Restore terminal state and exit with SIGINT-like code.
|
||||
disable_raw_mode()?;
|
||||
// Start on a clean line before exiting.
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
std::process::exit(130);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('d'),
|
||||
modifiers,
|
||||
..
|
||||
}) if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Restore terminal state and exit cleanly.
|
||||
disable_raw_mode()?;
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
std::process::exit(0);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
..
|
||||
}) => {
|
||||
break;
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('q'),
|
||||
..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Esc, ..
|
||||
}) => {
|
||||
// Skip saving – print summary and continue.
|
||||
disable_raw_mode()?;
|
||||
print_config_summary_after_save(config_path, provider_was_present_before, None)?;
|
||||
return Ok(());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
disable_raw_mode()?;
|
||||
// Ensure the summary starts on a clean, left‑aligned new line.
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
|
||||
// Compute chosen
|
||||
let chosen: Vec<String> = available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(selected.iter())
|
||||
.filter_map(|(name, sel)| if *sel { Some(name) } else { None })
|
||||
.collect();
|
||||
|
||||
let _ = write_config_models(config_path, &chosen);
|
||||
print_config_summary_after_save(config_path, provider_was_present_before, Some(chosen.len()))
|
||||
}
|
||||
|
||||
fn render_inline_picker(
|
||||
out: &mut std::io::Stdout,
|
||||
host_root: &str,
|
||||
items: &[String],
|
||||
selected: &[bool],
|
||||
cursor: usize,
|
||||
first: &mut bool,
|
||||
lines_printed: &mut usize,
|
||||
) -> io::Result<()> {
|
||||
// If not first render, move to the start of the block. We will clear each line as we redraw.
|
||||
if !*first {
|
||||
out.write_all(format!("\x1b[{}A", *lines_printed).as_bytes())?; // up N lines
|
||||
// Ensure we start at column 1 for a clean redraw.
|
||||
out.write_all(b"\r")?;
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
let bold = |s: &str| format!("\x1b[1m{s}\x1b[0m");
|
||||
lines.push(bold(&format!("discovered models on ollama ({host_root}):")));
|
||||
lines
|
||||
.push("↑/↓ move, space to toggle, 'a' (un)select all, enter confirm, 'q' skip".to_string());
|
||||
lines.push(String::new());
|
||||
for (i, name) in items.iter().enumerate() {
|
||||
let mark = if selected.get(i).copied().unwrap_or(false) {
|
||||
"\x1b[32m[x]\x1b[0m" // green
|
||||
} else {
|
||||
"[ ]"
|
||||
};
|
||||
let mut line = format!("{mark} {name}");
|
||||
if i == cursor {
|
||||
line = format!("\x1b[7m{line}\x1b[0m"); // reverse video for current row
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
|
||||
for l in &lines {
|
||||
// Move to column 0 and clear the entire line before writing.
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(l.as_bytes())?;
|
||||
out.write_all(b"\n")?;
|
||||
}
|
||||
out.flush()?;
|
||||
*first = false;
|
||||
*lines_printed = lines.len();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_config_summary_after_save(
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
models_count_after: Option<usize>,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
// Start clean and at column 0
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
let path = config_path.display().to_string();
|
||||
if provider_was_present_before {
|
||||
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
|
||||
)?;
|
||||
}
|
||||
if let Some(after) = models_count_after {
|
||||
let names = read_config_models(config_path).unwrap_or_default();
|
||||
if names.is_empty() {
|
||||
out.write_all(format!("\rmodels: recorded {after}\n\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rmodels: recorded {} ({})\n\n", after, names.join(", ")).as_bytes(),
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
out.write_all(b"\rmodels: no changes recorded\n\n")?;
|
||||
}
|
||||
out.flush()
|
||||
}
|
||||
|
||||
pub async fn run_main(
|
||||
cli: Cli,
|
||||
@@ -324,46 +69,18 @@ pub async fn run_main(
|
||||
)
|
||||
};
|
||||
|
||||
// Track config.toml state for messaging before launching TUI.
|
||||
let provider_was_present_before = if cli.ollama {
|
||||
let codex_home = codex_core::config::find_codex_home()?;
|
||||
let config_path = codex_home.join("config.toml");
|
||||
let (p, _m) = read_provider_state(&config_path);
|
||||
p
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let config = {
|
||||
// If the user selected the Ollama provider via `--ollama`, verify a
|
||||
// local server is reachable and ensure a provider entry exists in
|
||||
// config.toml. Exit early with a helpful message otherwise.
|
||||
if cli.ollama {
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
#[allow(clippy::print_stderr)]
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
}
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Load configuration and support CLI overrides.
|
||||
let overrides = ConfigOverrides {
|
||||
model: cli.model.clone(),
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: if cli.ollama {
|
||||
Some("ollama".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
model_provider: None,
|
||||
config_profile: cli.config_profile.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
include_plan_tool: Some(true),
|
||||
include_plan_tool: None,
|
||||
};
|
||||
// Parse `-c` overrides from the CLI.
|
||||
let cli_kv_overrides = match cli.config_overrides.parse_overrides() {
|
||||
@@ -384,73 +101,6 @@ pub async fn run_main(
|
||||
}
|
||||
}
|
||||
};
|
||||
// If the user passed --ollama, either ensure an explicitly requested model is
|
||||
// available (automatic pull if allowlisted) or offer an inline picker when no
|
||||
// specific model was provided.
|
||||
if cli.ollama {
|
||||
// Determine host root for the Ollama native API (e.g. http://localhost:11434).
|
||||
let base_url = config
|
||||
.model_provider
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
|
||||
let host_root = base_url
|
||||
.trim_end_matches('/')
|
||||
.trim_end_matches("/v1")
|
||||
.to_string();
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
|
||||
if let Some(ref model_name) = cli.model {
|
||||
// Explicit model requested: ensure it is available locally without prompting.
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let mut reporter = TuiProgressReporter::new();
|
||||
if let Err(e) =
|
||||
ensure_model_available(model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
let mut out = std::io::stderr();
|
||||
tracing::error!("{e}");
|
||||
let _ = out.write_all(format!("{e}\n").as_bytes());
|
||||
let _ = out.flush();
|
||||
std::process::exit(1);
|
||||
}
|
||||
} else {
|
||||
// No specific model was requested: fetch available models from the local instance
|
||||
// and, if they differ from what is listed in config.toml, display a minimal
|
||||
// inline selection UI before launching the TUI.
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let available_models: Vec<String> = client.fetch_models().await.unwrap_or_default();
|
||||
|
||||
// Read existing models in config.
|
||||
let existing_models: Vec<String> = read_config_models(&config_path).unwrap_or_default();
|
||||
|
||||
if available_models.is_empty() {
|
||||
// Inform the user and continue launching the TUI.
|
||||
print_inline_message_no_models(
|
||||
&host_root,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
)?;
|
||||
} else {
|
||||
// Compare sets to decide whether to show the prompt.
|
||||
let set_eq = {
|
||||
use std::collections::HashSet;
|
||||
let a: HashSet<_> = available_models.iter().collect();
|
||||
let b: HashSet<_> = existing_models.iter().collect();
|
||||
a == b
|
||||
};
|
||||
|
||||
if !set_eq {
|
||||
run_inline_models_picker(
|
||||
&host_root,
|
||||
&available_models,
|
||||
&existing_models,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let log_dir = codex_core::config::log_dir(&config)?;
|
||||
std::fs::create_dir_all(&log_dir)?;
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
pub mod ollama_model_picker;
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
|
||||
use ratatui::style::{Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::widgets::{Block, BorderType, Borders, Paragraph, WidgetRef};
|
||||
use ratatui::prelude::Widget;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub enum PickerOutcome {
|
||||
Submit(Vec<String>),
|
||||
Cancel,
|
||||
None,
|
||||
}
|
||||
|
||||
pub struct OllamaModelPickerScreen {
|
||||
pub host_root: String,
|
||||
pub config_path: PathBuf,
|
||||
available: Vec<String>,
|
||||
selected: Vec<bool>,
|
||||
cursor: usize,
|
||||
pub loading: bool,
|
||||
}
|
||||
|
||||
impl OllamaModelPickerScreen {
|
||||
pub fn new(host_root: String, config_path: PathBuf, preselected: Vec<String>) -> Self {
|
||||
Self {
|
||||
host_root,
|
||||
config_path,
|
||||
available: Vec::new(),
|
||||
selected: preselected.into_iter().map(|_| false).collect(),
|
||||
cursor: 0,
|
||||
loading: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn desired_height(&self, _width: u16) -> u16 {
|
||||
18u16
|
||||
}
|
||||
|
||||
pub fn update_available(&mut self, available: Vec<String>) {
|
||||
// Build selection state using existing selected names where possible.
|
||||
let prev_selected_names: Vec<String> = self
|
||||
.available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.selected.iter().cloned())
|
||||
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
|
||||
.collect();
|
||||
|
||||
self.available = available.clone();
|
||||
self.selected = available
|
||||
.iter()
|
||||
.map(|n| prev_selected_names.iter().any(|p| p == n))
|
||||
.collect();
|
||||
if self.cursor >= self.available.len() {
|
||||
self.cursor = self.available.len().saturating_sub(1);
|
||||
}
|
||||
self.loading = false;
|
||||
}
|
||||
|
||||
pub fn handle_key_event(&mut self, key: KeyEvent) -> PickerOutcome {
|
||||
match key.code {
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
if self.cursor > 0 { self.cursor -= 1; }
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
if self.cursor + 1 < self.available.len() { self.cursor += 1; }
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Char(' ') => {
|
||||
if let Some(s) = self.selected.get_mut(self.cursor) {
|
||||
*s = !*s;
|
||||
}
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Char('a') => {
|
||||
let all = self.selected.iter().all(|s| *s);
|
||||
self.selected.fill(!all);
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let chosen: Vec<String> = self
|
||||
.available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.selected.iter().cloned())
|
||||
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
|
||||
.collect();
|
||||
PickerOutcome::Submit(chosen)
|
||||
}
|
||||
KeyCode::Esc | KeyCode::Char('q') => PickerOutcome::Cancel,
|
||||
_ => PickerOutcome::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &OllamaModelPickerScreen {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
const MIN_WIDTH: u16 = 40;
|
||||
const MIN_HEIGHT: u16 = 15;
|
||||
let popup_width = std::cmp::max(MIN_WIDTH, (area.width as f32 * 0.7) as u16);
|
||||
let popup_height = std::cmp::max(MIN_HEIGHT, (area.height as f32 * 0.6) as u16);
|
||||
let popup_x = area.x + (area.width.saturating_sub(popup_width)) / 2;
|
||||
let popup_y = area.y + (area.height.saturating_sub(popup_height)) / 2;
|
||||
let popup_area = Rect::new(popup_x, popup_y, popup_width, popup_height);
|
||||
|
||||
let popup_block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_type(BorderType::Plain)
|
||||
.title(Span::styled(
|
||||
"Select Ollama models",
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
let inner = popup_block.inner(popup_area);
|
||||
popup_block.render(popup_area, buf);
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([Constraint::Length(3), Constraint::Min(3), Constraint::Length(3)])
|
||||
.split(inner);
|
||||
|
||||
// Header
|
||||
let header = format!("endpoint: {}\n↑/↓ move, space toggle, 'a' (un)select all, enter confirm, 'q' skip", self.host_root);
|
||||
Paragraph::new(header).alignment(Alignment::Left).render(chunks[0], buf);
|
||||
|
||||
// Body: list of models or a loading message
|
||||
if self.loading {
|
||||
Paragraph::new("discovering models...").alignment(Alignment::Center).render(chunks[1], buf);
|
||||
} else if self.available.is_empty() {
|
||||
Paragraph::new("No models discovered on the local Ollama instance.")
|
||||
.alignment(Alignment::Center)
|
||||
.render(chunks[1], buf);
|
||||
} else {
|
||||
// Render each line manually with highlight for cursor.
|
||||
let mut lines: Vec<Line> = Vec::with_capacity(self.available.len());
|
||||
for (i, name) in self.available.iter().enumerate() {
|
||||
let mark = if self.selected.get(i).copied().unwrap_or(false) { "[x]" } else { "[ ]" };
|
||||
let content = format!("{mark} {name}");
|
||||
if i == self.cursor {
|
||||
lines.push(Line::from(content).style(Style::default().add_modifier(Modifier::REVERSED)));
|
||||
} else {
|
||||
lines.push(Line::from(content));
|
||||
}
|
||||
}
|
||||
Paragraph::new(lines).render(chunks[1], buf);
|
||||
}
|
||||
|
||||
// Footer/help
|
||||
Paragraph::new("press Enter to save, 'q' to continue without changes")
|
||||
.alignment(Alignment::Center)
|
||||
.render(chunks[2], buf);
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ impl StatusIndicatorWidget {
|
||||
thread::spawn(move || {
|
||||
let mut counter = 0usize;
|
||||
while running_clone.load(Ordering::Relaxed) {
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
counter = counter.wrapping_add(1);
|
||||
frame_idx_clone.store(counter, Ordering::Relaxed);
|
||||
app_event_tx_clone.send(AppEvent::RequestRedraw);
|
||||
@@ -98,51 +98,46 @@ impl WidgetRef for StatusIndicatorWidget {
|
||||
.borders(Borders::LEFT)
|
||||
.border_type(BorderType::QuadrantOutside)
|
||||
.border_style(widget_style.dim());
|
||||
// Animated 3‑dot pattern inside brackets. The *active* dot is bold
|
||||
// white, the others are dim.
|
||||
const DOT_COUNT: usize = 3;
|
||||
let idx = self.frame_idx.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let header_text = "Working";
|
||||
let header_chars: Vec<char> = header_text.chars().collect();
|
||||
|
||||
let padding = 4usize; // virtual padding around the word for smoother loop
|
||||
let period = header_chars.len() + padding * 2;
|
||||
let pos = idx % period;
|
||||
|
||||
let has_true_color = supports_color::on_cached(supports_color::Stream::Stdout)
|
||||
.map(|level| level.has_16m)
|
||||
.unwrap_or(false);
|
||||
|
||||
// Width of the bright band (in characters).
|
||||
let band_half_width = 2.0;
|
||||
let phase = idx % (DOT_COUNT * 2 - 2);
|
||||
let active = if phase < DOT_COUNT {
|
||||
phase
|
||||
} else {
|
||||
(DOT_COUNT * 2 - 2) - phase
|
||||
};
|
||||
|
||||
let mut header_spans: Vec<Span<'static>> = Vec::new();
|
||||
for (i, ch) in header_chars.iter().enumerate() {
|
||||
let i_pos = i as isize + padding as isize;
|
||||
let pos = pos as isize;
|
||||
let dist = (i_pos - pos).abs() as f32;
|
||||
|
||||
let t = if dist <= band_half_width {
|
||||
let x = std::f32::consts::PI * (dist / band_half_width);
|
||||
0.5 * (1.0 + x.cos())
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
header_spans.push(Span::styled(
|
||||
"Working ",
|
||||
Style::default()
|
||||
.fg(Color::White)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
|
||||
let brightness = 0.4 + 0.6 * t;
|
||||
let level = (brightness * 255.0).clamp(0.0, 255.0) as u8;
|
||||
let style = if has_true_color {
|
||||
header_spans.push(Span::styled(
|
||||
"[",
|
||||
Style::default()
|
||||
.fg(Color::White)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
|
||||
for i in 0..DOT_COUNT {
|
||||
let style = if i == active {
|
||||
Style::default()
|
||||
.fg(Color::Rgb(level, level, level))
|
||||
.fg(Color::White)
|
||||
.add_modifier(Modifier::BOLD)
|
||||
} else {
|
||||
// Bold makes dark gray and gray look the same, so don't use it
|
||||
// when true color is not supported.
|
||||
Style::default().fg(color_for_level(level))
|
||||
Style::default().dim()
|
||||
};
|
||||
|
||||
header_spans.push(Span::styled(ch.to_string(), style));
|
||||
header_spans.push(Span::styled(".", style));
|
||||
}
|
||||
|
||||
header_spans.push(Span::styled(
|
||||
" ",
|
||||
"] ",
|
||||
Style::default()
|
||||
.fg(Color::White)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
@@ -194,13 +189,3 @@ impl WidgetRef for StatusIndicatorWidget {
|
||||
paragraph.render_ref(area, buf);
|
||||
}
|
||||
}
|
||||
|
||||
fn color_for_level(level: u8) -> Color {
|
||||
if level < 128 {
|
||||
Color::DarkGray
|
||||
} else if level < 192 {
|
||||
Color::Gray
|
||||
} else {
|
||||
Color::White
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user