fix(undo): sync session context after snapshot restore (#1139) (#1150)

This commit is contained in:
jiaren wang
2026-05-08 16:14:04 +09:00
committed by GitHub
parent f969de91aa
commit 0d7cbe37a8
3 changed files with 581 additions and 6 deletions
+555 -6
View File
@@ -7,7 +7,7 @@ use std::time::Instant;
use super::CommandResult;
use crate::compaction::estimate_input_tokens_conservative;
use crate::localization::{Locale, MessageId, tr};
use crate::models::{SystemPrompt, context_window_for_model};
use crate::models::{ContentBlock, SystemPrompt, context_window_for_model};
use crate::tui::app::{App, AppAction, TurnCacheRecord};
use crate::tui::history::HistoryCell;
@@ -272,6 +272,7 @@ mod tests {
use crate::config::Config;
use crate::models::{ContentBlock, Message, SystemBlock};
use crate::tui::app::{App, TuiOptions};
use crate::tui::history::{GenericToolCell, ToolCell, ToolStatus};
use std::path::PathBuf;
fn create_test_app() -> App {
@@ -627,6 +628,449 @@ mod tests {
assert!(msg.contains("Retrying"));
assert!(msg.contains("..."));
}
#[test]
fn test_patch_undo_requests_session_resync_after_restore() {
use crate::snapshot::SnapshotRepo;
use crate::test_support::lock_test_env;
use std::sync::MutexGuard;
use tempfile::tempdir;
struct HomeGuard {
prev: Option<std::ffi::OsString>,
_lock: MutexGuard<'static, ()>,
}
impl Drop for HomeGuard {
fn drop(&mut self) {
// SAFETY: process-wide lock still held.
unsafe {
match self.prev.take() {
Some(v) => std::env::set_var("HOME", v),
None => std::env::remove_var("HOME"),
}
}
}
}
fn scoped_home(home: &std::path::Path) -> HomeGuard {
let lock = lock_test_env();
let prev = std::env::var_os("HOME");
// SAFETY: serialized by the global env lock.
unsafe {
std::env::set_var("HOME", home);
}
HomeGuard { prev, _lock: lock }
}
let tmp = tempdir().unwrap();
let workspace = tmp.path().join("ws");
std::fs::create_dir_all(&workspace).unwrap();
let _guard = scoped_home(tmp.path());
let repo = SnapshotRepo::open_or_init(&workspace).unwrap();
std::fs::write(workspace.join("a.txt"), b"original").unwrap();
repo.snapshot("pre-turn:1").unwrap();
std::fs::write(workspace.join("a.txt"), b"modified").unwrap();
repo.snapshot("post-turn:1").unwrap();
let mut app = create_test_app();
app.workspace = workspace.clone();
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: "please edit a.txt".to_string(),
cache_control: None,
}],
});
let result = patch_undo(&mut app);
assert!(!result.is_error);
assert!(matches!(
result.action,
Some(AppAction::SyncSession {
ref messages,
ref workspace,
..
}) if messages == &app.api_messages && workspace == &app.workspace
));
}
#[test]
fn test_patch_undo_walks_back_to_older_snapshot_on_repeat() {
use crate::snapshot::SnapshotRepo;
use crate::test_support::lock_test_env;
use std::sync::MutexGuard;
use tempfile::tempdir;
struct HomeGuard {
prev: Option<std::ffi::OsString>,
_lock: MutexGuard<'static, ()>,
}
impl Drop for HomeGuard {
fn drop(&mut self) {
// SAFETY: process-wide lock still held.
unsafe {
match self.prev.take() {
Some(v) => std::env::set_var("HOME", v),
None => std::env::remove_var("HOME"),
}
}
}
}
fn scoped_home(home: &std::path::Path) -> HomeGuard {
let lock = lock_test_env();
let prev = std::env::var_os("HOME");
// SAFETY: serialized by the global env lock.
unsafe {
std::env::set_var("HOME", home);
}
HomeGuard { prev, _lock: lock }
}
let tmp = tempdir().unwrap();
let workspace = tmp.path().join("ws");
std::fs::create_dir_all(&workspace).unwrap();
let _guard = scoped_home(tmp.path());
let repo = SnapshotRepo::open_or_init(&workspace).unwrap();
let file = workspace.join("a.txt");
std::fs::write(&file, b"zero").unwrap();
repo.snapshot("tool:first").unwrap();
std::fs::write(&file, b"one").unwrap();
repo.snapshot("tool:second").unwrap();
std::fs::write(&file, b"two").unwrap();
let mut app = create_test_app();
app.workspace = workspace.clone();
let first = patch_undo(&mut app);
assert!(!first.is_error);
assert_eq!(std::fs::read_to_string(&file).unwrap(), "one");
let second = patch_undo(&mut app);
assert!(!second.is_error);
assert_eq!(std::fs::read_to_string(&file).unwrap(), "zero");
}
#[test]
fn test_patch_undo_prunes_tool_turn_context() {
use crate::snapshot::SnapshotRepo;
use crate::test_support::lock_test_env;
use std::sync::MutexGuard;
use tempfile::tempdir;
struct HomeGuard {
prev: Option<std::ffi::OsString>,
_lock: MutexGuard<'static, ()>,
}
impl Drop for HomeGuard {
fn drop(&mut self) {
// SAFETY: process-wide lock still held.
unsafe {
match self.prev.take() {
Some(v) => std::env::set_var("HOME", v),
None => std::env::remove_var("HOME"),
}
}
}
}
fn scoped_home(home: &std::path::Path) -> HomeGuard {
let lock = lock_test_env();
let prev = std::env::var_os("HOME");
// SAFETY: serialized by the global env lock.
unsafe {
std::env::set_var("HOME", home);
}
HomeGuard { prev, _lock: lock }
}
let tmp = tempdir().unwrap();
let workspace = tmp.path().join("ws");
std::fs::create_dir_all(&workspace).unwrap();
let _guard = scoped_home(tmp.path());
let repo = SnapshotRepo::open_or_init(&workspace).unwrap();
let file = workspace.join("a.txt");
std::fs::write(&file, b"alpha").unwrap();
repo.snapshot("tool:call-1").unwrap();
std::fs::write(&file, b"alpha-fixed").unwrap();
let mut app = create_test_app();
app.workspace = workspace.clone();
app.history.push(HistoryCell::User {
content: "please edit a.txt".to_string(),
});
app.history.push(HistoryCell::Assistant {
content: "I will update the file.".to_string(),
streaming: false,
});
app.history
.push(HistoryCell::Tool(ToolCell::Generic(GenericToolCell {
name: "write_file".to_string(),
status: ToolStatus::Success,
input_summary: Some("a.txt".to_string()),
output: Some("updated".to_string()),
prompts: None,
spillover_path: None,
})));
app.history.push(HistoryCell::Assistant {
content: "Done, file is fixed now.".to_string(),
streaming: false,
});
app.tool_cells.insert("call-1".to_string(), 2);
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: "please edit a.txt".to_string(),
cache_control: None,
}],
});
app.api_messages.push(Message {
role: "assistant".to_string(),
content: vec![
ContentBlock::Text {
text: "I will update the file.".to_string(),
cache_control: None,
},
ContentBlock::ToolUse {
id: "call-1".to_string(),
name: "write_file".to_string(),
input: serde_json::json!({"path": "a.txt"}),
caller: None,
},
],
});
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "call-1".to_string(),
content: "updated".to_string(),
is_error: None,
content_blocks: None,
}],
});
app.api_messages.push(Message {
role: "assistant".to_string(),
content: vec![ContentBlock::Text {
text: "Done, file is fixed now.".to_string(),
cache_control: None,
}],
});
let result = patch_undo(&mut app);
assert!(!result.is_error);
assert_eq!(std::fs::read_to_string(&file).unwrap(), "alpha");
assert_eq!(app.history.len(), 3);
assert!(matches!(
app.history.last(),
Some(HistoryCell::System { content }) if content.contains("/undo reverted workspace")
));
assert_eq!(app.api_messages.len(), 2);
assert!(matches!(
&app.api_messages[0].content[0],
ContentBlock::Text { text, .. } if text == "please edit a.txt"
));
assert_eq!(app.api_messages[1].content.len(), 1);
assert!(matches!(
&app.api_messages[1].content[0],
ContentBlock::Text { text, .. } if text == "I will update the file."
));
}
#[test]
fn test_patch_undo_prunes_pre_turn_context() {
use crate::snapshot::SnapshotRepo;
use crate::test_support::lock_test_env;
use std::sync::MutexGuard;
use tempfile::tempdir;
struct HomeGuard {
prev: Option<std::ffi::OsString>,
_lock: MutexGuard<'static, ()>,
}
impl Drop for HomeGuard {
fn drop(&mut self) {
// SAFETY: process-wide lock still held.
unsafe {
match self.prev.take() {
Some(v) => std::env::set_var("HOME", v),
None => std::env::remove_var("HOME"),
}
}
}
}
fn scoped_home(home: &std::path::Path) -> HomeGuard {
let lock = lock_test_env();
let prev = std::env::var_os("HOME");
// SAFETY: serialized by the global env lock.
unsafe {
std::env::set_var("HOME", home);
}
HomeGuard { prev, _lock: lock }
}
let tmp = tempdir().unwrap();
let workspace = tmp.path().join("ws");
std::fs::create_dir_all(&workspace).unwrap();
let _guard = scoped_home(tmp.path());
let repo = SnapshotRepo::open_or_init(&workspace).unwrap();
let file = workspace.join("a.txt");
std::fs::write(&file, b"alpha").unwrap();
repo.snapshot("pre-turn:1").unwrap();
std::fs::write(&file, b"alpha-fixed").unwrap();
let mut app = create_test_app();
app.workspace = workspace.clone();
app.history.push(HistoryCell::User {
content: "please edit a.txt".to_string(),
});
app.history.push(HistoryCell::Assistant {
content: "Done, file is fixed now.".to_string(),
streaming: false,
});
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: "please edit a.txt".to_string(),
cache_control: None,
}],
});
app.api_messages.push(Message {
role: "assistant".to_string(),
content: vec![ContentBlock::Text {
text: "Done, file is fixed now.".to_string(),
cache_control: None,
}],
});
let result = patch_undo(&mut app);
assert!(!result.is_error);
assert_eq!(std::fs::read_to_string(&file).unwrap(), "alpha");
assert_eq!(app.history.len(), 1);
assert!(matches!(
app.history.last(),
Some(HistoryCell::System { content }) if content.contains("/undo reverted workspace")
));
assert!(app.api_messages.is_empty());
}
#[test]
fn test_prune_undone_tool_context_preserves_prior_tool_pairs() {
let mut app = create_test_app();
app.history.push(HistoryCell::User {
content: "edit two files".to_string(),
});
app.history.push(HistoryCell::Assistant {
content: "I will update both files.".to_string(),
streaming: false,
});
app.history
.push(HistoryCell::Tool(ToolCell::Generic(GenericToolCell {
name: "write_file".to_string(),
status: ToolStatus::Success,
input_summary: Some("a.txt".to_string()),
output: Some("updated a".to_string()),
prompts: None,
spillover_path: None,
})));
app.history
.push(HistoryCell::Tool(ToolCell::Generic(GenericToolCell {
name: "write_file".to_string(),
status: ToolStatus::Success,
input_summary: Some("b.txt".to_string()),
output: Some("updated b".to_string()),
prompts: None,
spillover_path: None,
})));
app.history.push(HistoryCell::Assistant {
content: "Done.".to_string(),
streaming: false,
});
app.tool_cells.insert("call-a".to_string(), 2);
app.tool_cells.insert("call-b".to_string(), 3);
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: "edit two files".to_string(),
cache_control: None,
}],
});
app.api_messages.push(Message {
role: "assistant".to_string(),
content: vec![
ContentBlock::Text {
text: "I will update both files.".to_string(),
cache_control: None,
},
ContentBlock::ToolUse {
id: "call-a".to_string(),
name: "write_file".to_string(),
input: serde_json::json!({"path": "a.txt"}),
caller: None,
},
ContentBlock::ToolUse {
id: "call-b".to_string(),
name: "write_file".to_string(),
input: serde_json::json!({"path": "b.txt"}),
caller: None,
},
],
});
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "call-a".to_string(),
content: "updated a".to_string(),
is_error: None,
content_blocks: None,
}],
});
app.api_messages.push(Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "call-b".to_string(),
content: "updated b".to_string(),
is_error: None,
content_blocks: None,
}],
});
app.api_messages.push(Message {
role: "assistant".to_string(),
content: vec![ContentBlock::Text {
text: "Done.".to_string(),
cache_control: None,
}],
});
prune_undone_tool_context(&mut app, "call-b");
assert_eq!(app.history.len(), 3);
assert_eq!(app.api_messages.len(), 3);
assert!(matches!(
&app.api_messages[1].content[..],
[
ContentBlock::Text { .. },
ContentBlock::ToolUse { id, .. }
] if id == "call-a"
));
assert!(matches!(
&app.api_messages[2].content[0],
ContentBlock::ToolResult { tool_use_id, .. } if tool_use_id == "call-a"
));
}
}
/// Remove last message pair (user + assistant).
@@ -670,6 +1114,89 @@ pub fn undo_conversation(app: &mut App) -> CommandResult {
}
}
fn prune_undone_tool_context(app: &mut App, tool_id: &str) {
if let Some(history_idx) = app.tool_cells.get(tool_id).copied() {
app.truncate_history_to(history_idx);
}
let Some((msg_idx, block_idx)) =
app.api_messages
.iter()
.enumerate()
.find_map(|(msg_idx, msg)| {
msg.content
.iter()
.position(
|block| matches!(block, ContentBlock::ToolUse { id, .. } if id == tool_id),
)
.map(|block_idx| (msg_idx, block_idx))
})
else {
return;
};
let kept_blocks = app.api_messages[msg_idx].content[..block_idx].to_vec();
let kept_tool_ids: std::collections::HashSet<String> = kept_blocks
.iter()
.filter_map(|block| match block {
ContentBlock::ToolUse { id, .. } => Some(id.clone()),
_ => None,
})
.collect();
if kept_blocks.is_empty() {
app.api_messages.truncate(msg_idx);
return;
}
let preserved_tool_results: Vec<_> =
app.api_messages
.iter()
.skip(msg_idx + 1)
.take_while(|msg| {
msg.role == "user"
&& !msg.content.is_empty()
&& msg
.content
.iter()
.all(|block| tool_result_id(block).is_some())
})
.filter(|msg| {
msg.role == "user"
&& !msg.content.is_empty()
&& msg.content.iter().all(|block| {
tool_result_id(block).is_some_and(|id| kept_tool_ids.contains(id))
})
})
.cloned()
.collect();
app.api_messages.truncate(msg_idx + 1);
app.api_messages[msg_idx].content = kept_blocks;
app.api_messages.extend(preserved_tool_results);
}
fn prune_undone_turn_context(app: &mut App) {
if let Some(history_idx) = app
.history
.iter()
.rposition(|cell| matches!(cell, HistoryCell::User { .. }))
{
app.truncate_history_to(history_idx);
}
if let Some(api_idx) = app.api_messages.iter().rposition(|msg| msg.role == "user") {
app.api_messages.truncate(api_idx);
}
}
fn tool_result_id(block: &ContentBlock) -> Option<&String> {
match block {
ContentBlock::ToolResult { tool_use_id, .. }
| ContentBlock::ToolSearchToolResult { tool_use_id, .. }
| ContentBlock::CodeExecutionToolResult { tool_use_id, .. } => Some(tool_use_id),
_ => None,
}
}
/// Revert the most recent write tool (apply_patch/edit_file/write_file) or turn.
///
/// Opens the side-git snapshot repo and finds the most recent snapshot,
@@ -703,20 +1230,34 @@ pub fn patch_undo(app: &mut App) -> CommandResult {
return CommandResult::message("No snapshots found to undo — nothing to revert.");
}
// Prefer the most recent `tool:` snapshot; fall back to `pre-turn:`.
// Prefer the newest revertable `tool:` / `pre-turn:` snapshot whose
// tracked content differs from the current workspace. This lets
// repeated `/undo` walk back through older snapshots instead of
// restoring the same no-op target forever.
let target = snapshots
.iter()
.find(|s| s.label.starts_with("tool:"))
.or_else(|| snapshots.iter().find(|s| s.label.starts_with("pre-turn:")));
.filter(|s| s.label.starts_with("tool:") || s.label.starts_with("pre-turn:"))
.find(|s| match repo.work_tree_matches_snapshot(&s.id) {
Ok(matches) => !matches,
Err(_) => true,
});
let Some(target) = target else {
return CommandResult::message("No tool or pre-turn snapshots found — nothing to revert.");
return CommandResult::message(
"No older tool or pre-turn snapshots differ from the current workspace — nothing to revert.",
);
};
if let Err(e) = repo.restore(&target.id) {
return CommandResult::error(format!("Restore failed: {e}"));
}
if let Some(tool_id) = target.label.strip_prefix("tool:") {
prune_undone_tool_context(app, tool_id);
} else if target.label.starts_with("pre-turn:") {
prune_undone_turn_context(app);
}
// Show diff stat so the user knows what changed.
let diff_stat = std::process::Command::new("git")
.args(["diff", "--stat"])
@@ -752,7 +1293,15 @@ pub fn patch_undo(app: &mut App) -> CommandResult {
),
});
CommandResult::message(summary)
CommandResult::with_message_and_action(
summary,
AppAction::SyncSession {
messages: app.api_messages.clone(),
system_prompt: app.system_prompt.clone(),
model: app.model.clone(),
workspace: app.workspace.clone(),
},
)
}
/// Load the last user message back into the composer for editing.
+18
View File
@@ -294,6 +294,24 @@ impl SnapshotRepo {
Ok(())
}
/// Return whether the current workspace matches the given snapshot's
/// tracked file content.
///
/// This is intentionally narrower than a full "workspace identical"
/// claim: it compares the current working tree against the snapshot's
/// tracked paths via git's diff machinery. That is sufficient for
/// `/undo` cursoring — if the diff is empty, restoring this snapshot
/// again would be a no-op, so the caller should continue scanning
/// older snapshots.
pub fn work_tree_matches_snapshot(&self, id: &SnapshotId) -> io::Result<bool> {
let diff = run_git(
&self.git_dir,
&self.work_tree,
&["diff", "--quiet", id.as_str(), "--", ":/"],
)?;
Ok(diff.status.success())
}
fn tree_paths(&self, treeish: &str) -> io::Result<HashSet<PathBuf>> {
let ls = run_git(
&self.git_dir,
+8
View File
@@ -5748,6 +5748,14 @@ async fn handle_view_events(
ViewEvent::BacktrackConfirm => {
if let Some(depth) = app.backtrack.confirm() {
apply_backtrack(app, depth);
let _ = engine_handle
.send(Op::SyncSession {
messages: app.api_messages.clone(),
system_prompt: app.system_prompt.clone(),
model: app.model.clone(),
workspace: app.workspace.clone(),
})
.await;
}
}
ViewEvent::BacktrackCancel => {