diff --git a/Cargo.toml b/Cargo.toml index d6a809e2..165b46b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deepseek-tui" -version = "0.3.3" +version = "0.3.4" edition = "2024" description = "Unofficial DeepSeek CLI - Just run 'deepseek' to start chatting" license = "MIT" diff --git a/src/client.rs b/src/client.rs index 4db9b13a..fe37963b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -524,6 +524,94 @@ fn build_chat_messages( } } + // Safety net: after compaction, an assistant message may have tool_calls + // whose results were summarized away. The API rejects these, so strip + // the tool_calls (downgrading to a plain assistant message) and remove + // the now-orphaned tool result messages. + let mut i = 0; + while i < out.len() { + let is_assistant_with_tools = out[i].get("role").and_then(Value::as_str) + == Some("assistant") + && out[i].get("tool_calls").is_some(); + + if is_assistant_with_tools { + let expected_ids: HashSet = out[i] + .get("tool_calls") + .and_then(Value::as_array) + .map(|calls| { + calls + .iter() + .filter_map(|c| c.get("id").and_then(Value::as_str).map(String::from)) + .collect() + }) + .unwrap_or_default(); + + // Collect tool result IDs immediately following this assistant message. + let mut found_ids: HashSet = HashSet::new(); + let mut tool_result_end = i + 1; + while tool_result_end < out.len() { + if out[tool_result_end].get("role").and_then(Value::as_str) == Some("tool") { + if let Some(id) = out[tool_result_end] + .get("tool_call_id") + .and_then(Value::as_str) + { + found_ids.insert(id.to_string()); + } + tool_result_end += 1; + } else { + break; + } + } + + // Also scan non-contiguous tool results up to the next assistant message + // in case compaction left gaps. + let mut scan = tool_result_end; + while scan < out.len() { + if out[scan].get("role").and_then(Value::as_str) == Some("assistant") { + break; + } + if out[scan].get("role").and_then(Value::as_str) == Some("tool") { + if let Some(id) = out[scan].get("tool_call_id").and_then(Value::as_str) { + found_ids.insert(id.to_string()); + } + } + scan += 1; + } + + if !expected_ids.is_subset(&found_ids) { + let missing: Vec<_> = expected_ids.difference(&found_ids).collect(); + logging::warn(format!( + "Stripping orphaned tool_calls from assistant message \ + (expected {} tool results, found {}, missing: {:?})", + expected_ids.len(), + found_ids.len(), + missing + )); + if let Some(obj) = out[i].as_object_mut() { + obj.remove("tool_calls"); + } + // Remove contiguous tool results first + if tool_result_end > i + 1 { + out.drain((i + 1)..tool_result_end); + } + // Remove any remaining non-contiguous tool results referencing expected_ids + // (scan backward to avoid index shifting issues) + let mut j = out.len(); + while j > i + 1 { + j -= 1; + if out[j].get("role").and_then(Value::as_str) == Some("tool") { + if let Some(id) = out[j].get("tool_call_id").and_then(Value::as_str) { + if expected_ids.contains(id) { + out.remove(j); + } + } + } + } + } + } + i += 1; + } + out } @@ -893,4 +981,139 @@ mod tests { .expect("assistant message"); assert!(assistant.get("tool_calls").is_some()); } + + #[test] + fn chat_messages_strips_orphaned_tool_calls_after_compaction() { + // Simulates post-compaction state: assistant has tool_calls but the + // tool result messages were summarized away. + let messages = vec![ + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "tool-orphan".to_string(), + name: "read_file".to_string(), + input: json!({"path": "src/main.rs"}), + }], + }, + // No tool result follows — it was removed by compaction. + Message { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: "continue".to_string(), + cache_control: None, + }], + }, + ]; + + let out = build_chat_messages(None, &messages, "deepseek-chat"); + let assistant = out + .iter() + .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) + .expect("assistant message"); + // The safety net should have stripped tool_calls. + assert!( + assistant.get("tool_calls").is_none(), + "orphaned tool_calls should be stripped by safety net" + ); + } + + #[test] + fn chat_messages_keeps_valid_tool_calls_intact() { + // Complete call+result pair should NOT be stripped. + let messages = vec![ + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "tool-ok".to_string(), + name: "list_dir".to_string(), + input: json!({}), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool-ok".to_string(), + content: "files".to_string(), + }], + }, + ]; + + let out = build_chat_messages(None, &messages, "deepseek-chat"); + let assistant = out + .iter() + .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) + .expect("assistant message"); + assert!( + assistant.get("tool_calls").is_some(), + "valid tool_calls should remain intact" + ); + assert!( + out.iter() + .any(|value| value.get("role").and_then(Value::as_str) == Some("tool")), + "tool result should remain" + ); + } + + #[test] + fn chat_messages_strips_partial_tool_results() { + let messages = vec![ + Message { + role: "assistant".to_string(), + content: vec![ + ContentBlock::ToolUse { + id: "t1".to_string(), + name: "read_file".to_string(), + input: json!({"path": "a.rs"}), + }, + ContentBlock::ToolUse { + id: "t2".to_string(), + name: "read_file".to_string(), + input: json!({"path": "b.rs"}), + }, + ContentBlock::ToolUse { + id: "t3".to_string(), + name: "shell".to_string(), + input: json!({"cmd": "ls"}), + }, + ], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "t1".to_string(), + content: "content a".to_string(), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "t2".to_string(), + content: "content b".to_string(), + }], + }, + // No result for t3 + Message { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: "continue".to_string(), + cache_control: None, + }], + }, + ]; + + let out = build_chat_messages(None, &messages, "deepseek-chat"); + let assistant = out + .iter() + .find(|v| v.get("role").and_then(Value::as_str) == Some("assistant")) + .expect("assistant message"); + assert!( + assistant.get("tool_calls").is_none(), + "partial tool_calls should be stripped" + ); + assert!( + !out.iter() + .any(|v| v.get("role").and_then(Value::as_str) == Some("tool")), + "all orphaned tool results should be removed" + ); + } } diff --git a/src/compaction.rs b/src/compaction.rs index c25c8dc2..1f4d934b 100644 --- a/src/compaction.rs +++ b/src/compaction.rs @@ -12,6 +12,7 @@ use std::time::Duration; use crate::client::DeepSeekClient; use crate::llm_client::LlmClient; +use crate::logging; use crate::models::{ CacheControl, ContentBlock, Message, MessageRequest, SystemBlock, SystemPrompt, }; @@ -351,71 +352,96 @@ fn enforce_tool_call_pairs(messages: &[Message], pinned_indices: &mut BTreeSet = HashMap::new(); - let mut tool_result_indices: HashMap = HashMap::new(); + // Build maps: tool_id → message index across ALL messages (not just pinned). + let mut call_id_to_idx: HashMap = HashMap::new(); + let mut result_id_to_idx: HashMap = HashMap::new(); for (idx, msg) in messages.iter().enumerate() { for block in &msg.content { match block { ContentBlock::ToolUse { id, .. } => { - tool_call_indices.insert(id.clone(), idx); + call_id_to_idx.insert(id.clone(), idx); } ContentBlock::ToolResult { tool_use_id, .. } => { - tool_result_indices.insert(tool_use_id.clone(), idx); + result_id_to_idx.insert(tool_use_id.clone(), idx); } _ => {} } } } - let mut to_add = Vec::new(); - let mut to_remove = Vec::new(); + // Fixpoint loop: re-check until stable. + // Newly pinned messages may introduce new pair requirements; + // removed messages may orphan their counterparts. + // Track permanently removed indices so they cannot be re-added + // by a counterpart in a later iteration (prevents oscillation). + let mut permanently_removed: HashSet = HashSet::new(); - // Pass 1: If a tool result is pinned, ensure its tool call is also pinned. - // If the tool call is not found, remove the orphaned result. - for &idx in pinned_indices.iter() { - let msg = &messages[idx]; - let mut tool_ids = Vec::new(); - for block in &msg.content { - if let ContentBlock::ToolResult { tool_use_id, .. } = block { - tool_ids.push(tool_use_id.clone()); - } - } - if tool_ids.is_empty() { - continue; - } + let max_iters = messages.len().max(10); + let mut converged = false; + for _ in 0..max_iters { + let mut to_add = Vec::new(); + let mut to_remove = Vec::new(); - let mut found_any = false; - for tool_id in tool_ids { - if let Some(call_idx) = tool_call_indices.get(&tool_id).copied() { - to_add.push(call_idx); - found_any = true; - } - } - if !found_any { - to_remove.push(idx); - } - } + let snapshot: Vec = pinned_indices.iter().copied().collect(); - // Pass 2: If a tool call is pinned, ensure its tool result is also pinned. - // This prevents "orphaned tool calls" API errors. - for &idx in pinned_indices.iter() { - let msg = &messages[idx]; - for block in &msg.content { - if let ContentBlock::ToolUse { id, .. } = block { - if let Some(result_idx) = tool_result_indices.get(id).copied() { - to_add.push(result_idx); + for idx in snapshot { + let msg = &messages[idx]; + for block in &msg.content { + match block { + // Pinned result → its call must also be pinned (or remove result) + ContentBlock::ToolResult { tool_use_id, .. } => { + match call_id_to_idx.get(tool_use_id) { + Some(&call_idx) if !permanently_removed.contains(&call_idx) => { + to_add.push(call_idx); + } + _ => { + to_remove.push(idx); + } + } + } + // Pinned call → its result must also be pinned (or remove call) + ContentBlock::ToolUse { id, .. } => match result_id_to_idx.get(id) { + Some(&result_idx) if !permanently_removed.contains(&result_idx) => { + to_add.push(result_idx); + } + _ => { + to_remove.push(idx); + } + }, + _ => {} } } } - } - for idx in to_add { - pinned_indices.insert(idx); + // Removals take priority: if a message is both needed and orphaned, + // remove it now; the fixpoint loop will cascade the orphaning. + let remove_set: HashSet = to_remove.iter().copied().collect(); + let mut changed = false; + for idx in to_add { + if !remove_set.contains(&idx) && pinned_indices.insert(idx) { + changed = true; + } + } + for idx in to_remove { + if pinned_indices.remove(&idx) { + permanently_removed.insert(idx); + changed = true; + } + } + + if !changed { + converged = true; + break; + } } - for idx in to_remove { - pinned_indices.remove(&idx); + if !converged { + logging::warn(format!( + "enforce_tool_call_pairs did not converge after {max_iters} iterations \ + ({} messages, {} pinned)", + messages.len(), + pinned_indices.len() + )); } } @@ -1026,4 +1052,231 @@ mod tests { // Pinned recent messages exceed the token budget, so unpinned noise should trigger compaction. assert!(should_compact(&messages, &config, None, None, None)); } + + #[test] + fn enforce_tool_call_pairs_removes_orphaned_tool_call() { + // An assistant message with a tool call but no matching result anywhere + // in the history should be removed from the pinned set. + let messages = vec![ + msg("user", "noise"), + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "orphan-call".to_string(), + name: "read_file".to_string(), + input: json!({"path": "src/main.rs"}), + }], + }, + msg("assistant", "recent"), + ]; + + let mut pinned = BTreeSet::from([0, 1, 2]); + enforce_tool_call_pairs(&messages, &mut pinned); + + // The orphaned tool call message (index 1) should be removed. + assert!( + !pinned.contains(&1), + "orphaned tool call should be removed from pinned set" + ); + // Other messages stay. + assert!(pinned.contains(&0)); + assert!(pinned.contains(&2)); + } + + #[test] + fn enforce_tool_call_pairs_removes_orphaned_tool_result() { + // A tool result whose call doesn't exist anywhere should be removed. + let messages = vec![ + msg("user", "noise"), + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "orphan-result".to_string(), + content: "ok".to_string(), + }], + }, + msg("assistant", "recent"), + ]; + + let mut pinned = BTreeSet::from([0, 1, 2]); + enforce_tool_call_pairs(&messages, &mut pinned); + + assert!( + !pinned.contains(&1), + "orphaned tool result should be removed from pinned set" + ); + assert!(pinned.contains(&0)); + assert!(pinned.contains(&2)); + } + + #[test] + fn enforce_tool_call_pairs_preserves_valid_pairs() { + // A complete call+result pair should remain intact. + let messages = vec![ + msg("user", "do something"), + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "tool-ok".to_string(), + name: "list_dir".to_string(), + input: json!({}), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool-ok".to_string(), + content: "files here".to_string(), + }], + }, + msg("assistant", "done"), + ]; + + let mut pinned = BTreeSet::from([1, 2, 3]); + enforce_tool_call_pairs(&messages, &mut pinned); + + assert!(pinned.contains(&1), "tool call should stay pinned"); + assert!(pinned.contains(&2), "tool result should stay pinned"); + assert!(pinned.contains(&3)); + } + + #[test] + fn enforce_tool_call_pairs_pins_transitive_pairs() { + // If only the result is initially pinned, the call should be pulled in. + // The call message may also contain another tool call whose result should + // then be pulled in transitively. + let messages = vec![ + msg("user", "start"), + Message { + role: "assistant".to_string(), + content: vec![ + ContentBlock::ToolUse { + id: "t1".to_string(), + name: "read_file".to_string(), + input: json!({"path": "a.rs"}), + }, + ContentBlock::ToolUse { + id: "t2".to_string(), + name: "read_file".to_string(), + input: json!({"path": "b.rs"}), + }, + ], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "t1".to_string(), + content: "content of a.rs".to_string(), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "t2".to_string(), + content: "content of b.rs".to_string(), + }], + }, + msg("assistant", "done"), + ]; + + // Only pin the result for t1 initially. + let mut pinned = BTreeSet::from([2, 4]); + enforce_tool_call_pairs(&messages, &mut pinned); + + // The call message (index 1) should be pulled in because t1's result is pinned. + assert!( + pinned.contains(&1), + "call message should be transitively pinned" + ); + // Since the call message also contains t2, t2's result (index 3) should also be pinned. + assert!( + pinned.contains(&3), + "t2 result should be transitively pinned via the call message" + ); + } + + #[test] + fn enforce_tool_call_pairs_cascading_removal() { + // Removing an orphaned call should cascade to remove its result. + // Message 1: assistant with t1 (call) — t1 has a result at index 2 + // Message 2: user with t1 (result) + // Message 3: assistant with t2 (call) — t2 has NO result + // Message 4: user with t2 result referencing the call + // + // If t2 has no result in history, message 3 is removed. That's straightforward. + // Here we test: if a call message is removed because ONE of its calls is orphaned, + // the result for the other call also gets removed in subsequent iterations. + let messages = vec![ + msg("user", "start"), + Message { + role: "assistant".to_string(), + content: vec![ + ContentBlock::ToolUse { + id: "good".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + ContentBlock::ToolUse { + id: "orphan".to_string(), + name: "shell".to_string(), + input: json!({}), + }, + ], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "good".to_string(), + content: "ok".to_string(), + }], + }, + // Note: NO result for "orphan" exists anywhere + msg("assistant", "done"), + ]; + + let mut pinned = BTreeSet::from([1, 2, 3]); + enforce_tool_call_pairs(&messages, &mut pinned); + + // Message 1 has an orphaned tool call ("orphan"), so it's removed. + assert!( + !pinned.contains(&1), + "message with orphaned call should be removed" + ); + // Message 2 (result for "good") now has no matching call pinned, so it's also removed. + assert!( + !pinned.contains(&2), + "result whose call was removed should cascade-remove" + ); + // Message 3 (plain text) stays. + assert!(pinned.contains(&3)); + } + + #[test] + fn enforce_tool_call_pairs_converges_long_chain() { + let mut messages = vec![msg("user", "start")]; + for i in 0..15 { + messages.push(Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: format!("t{i}"), + name: "read_file".to_string(), + input: json!({}), + }], + }); + messages.push(Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: format!("t{i}"), + content: format!("result {i}"), + }], + }); + } + messages.push(msg("assistant", "done")); + + let mut pinned: BTreeSet = (0..messages.len()).collect(); + enforce_tool_call_pairs(&messages, &mut pinned); + + // All pairs should remain intact (no orphans) + assert_eq!(pinned.len(), messages.len()); + } } diff --git a/src/tui/ui.rs b/src/tui/ui.rs index 1cacdc0d..1566d2e8 100644 --- a/src/tui/ui.rs +++ b/src/tui/ui.rs @@ -1129,6 +1129,9 @@ async fn dispatch_user_message( engine_handle: &EngineHandle, message: QueuedMessage, ) -> Result<()> { + // Set immediately to prevent double-dispatch before TurnStarted event arrives. + app.is_loading = true; + let override_query = maybe_auto_switch_to_rlm(app, &message.display); let content = if let Some(query) = override_query.as_deref() { message.content_with_query(query)