diff --git a/src/client.rs b/src/client.rs index 8f3d4ec9..aeb664c1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ //! Uses the OpenAI Responses API when available, falling back to Chat Completions //! if the Responses endpoint is unsupported by the target base URL. +use std::collections::HashSet; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; @@ -427,6 +428,7 @@ fn build_chat_messages( ) -> Vec { let mut out = Vec::new(); let include_reasoning = requires_reasoning_content(model); + let mut pending_tool_calls: HashSet = HashSet::new(); if let Some(instructions) = system_to_instructions(system.cloned()) { if !instructions.trim().is_empty() { @@ -442,7 +444,8 @@ fn build_chat_messages( let mut text_parts = Vec::new(); let mut thinking_parts = Vec::new(); let mut tool_calls = Vec::new(); - let mut tool_results = Vec::new(); + let mut tool_call_ids = Vec::new(); + let mut tool_results: Vec<(String, Value)> = Vec::new(); for block in &message.content { match block { @@ -458,16 +461,17 @@ fn build_chat_messages( "arguments": args, } })); + tool_call_ids.push(id.clone()); } ContentBlock::ToolResult { tool_use_id, content, } => { - tool_results.push(json!({ + tool_results.push((tool_use_id.clone(), json!({ "role": "tool", "tool_call_id": tool_use_id, "content": content, - })); + }))); } } } @@ -483,6 +487,9 @@ fn build_chat_messages( } if !tool_calls.is_empty() { msg["tool_calls"] = json!(tool_calls); + pending_tool_calls = tool_call_ids.into_iter().collect(); + } else { + pending_tool_calls.clear(); } out.push(msg); } else if role == "user" { @@ -496,7 +503,21 @@ fn build_chat_messages( } if !tool_results.is_empty() { - out.extend(tool_results); + if pending_tool_calls.is_empty() { + logging::warn("Dropping tool results without matching tool_calls".to_string()); + } else { + for (tool_id, tool_msg) in tool_results { + if pending_tool_calls.remove(&tool_id) { + out.push(tool_msg); + } else { + logging::warn(format!( + "Dropping tool result for unknown tool_call_id: {tool_id}" + )); + } + } + } + } else if role != "assistant" { + pending_tool_calls.clear(); } } @@ -778,6 +799,7 @@ where #[cfg(test)] mod tests { use super::*; + use serde_json::json; #[test] fn chat_messages_include_reasoning_content_for_reasoner() { @@ -819,4 +841,51 @@ mod tests { .expect("assistant message"); assert!(assistant.get("reasoning_content").is_none()); } + + #[test] + fn chat_messages_drop_orphan_tool_results() { + let messages = vec![Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool-1".to_string(), + content: "ok".to_string(), + }], + }]; + + let out = build_chat_messages(None, &messages, "deepseek-chat"); + assert!(!out.iter().any(|value| { + value.get("role").and_then(Value::as_str) == Some("tool") + })); + } + + #[test] + fn chat_messages_include_tool_results_when_call_present() { + let messages = vec![ + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "list_dir".to_string(), + input: json!({}), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool-1".to_string(), + content: "ok".to_string(), + }], + }, + ]; + + let out = build_chat_messages(None, &messages, "deepseek-chat"); + assert!(out.iter().any(|value| { + value.get("role").and_then(Value::as_str) == Some("tool") + })); + let assistant = out + .iter() + .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) + .expect("assistant message"); + assert!(assistant.get("tool_calls").is_some()); + } } diff --git a/src/compaction.rs b/src/compaction.rs index 8403d4c5..68fa68b1 100644 --- a/src/compaction.rs +++ b/src/compaction.rs @@ -4,7 +4,7 @@ use anyhow::Result; use regex::Regex; -use std::collections::{BTreeSet, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Write; use std::path::{Path, PathBuf}; use std::sync::OnceLock; @@ -332,6 +332,9 @@ fn plan_compaction( pinned_indices.extend(pins.iter().copied().filter(|idx| *idx < len)); } + // Ensure tool result messages are not kept without their corresponding tool call. + enforce_tool_call_pairs(messages, &mut pinned_indices); + let summarize_indices = (0..len) .filter(|idx| !pinned_indices.contains(idx)) .collect(); @@ -343,6 +346,55 @@ fn plan_compaction( } } +fn enforce_tool_call_pairs(messages: &[Message], pinned_indices: &mut BTreeSet) { + if pinned_indices.is_empty() { + return; + } + + let mut tool_call_indices: HashMap = HashMap::new(); + for (idx, msg) in messages.iter().enumerate() { + for block in &msg.content { + if let ContentBlock::ToolUse { id, .. } = block { + tool_call_indices.insert(id.clone(), idx); + } + } + } + + let mut to_add = Vec::new(); + let mut to_remove = Vec::new(); + + for &idx in pinned_indices.iter() { + let msg = &messages[idx]; + let mut tool_ids = Vec::new(); + for block in &msg.content { + if let ContentBlock::ToolResult { tool_use_id, .. } = block { + tool_ids.push(tool_use_id.clone()); + } + } + if tool_ids.is_empty() { + continue; + } + + let mut found_any = false; + for tool_id in tool_ids { + if let Some(call_idx) = tool_call_indices.get(&tool_id).copied() { + to_add.push(call_idx); + found_any = true; + } + } + if !found_any { + to_remove.push(idx); + } + } + + for idx in to_add { + pinned_indices.insert(idx); + } + for idx in to_remove { + pinned_indices.remove(&idx); + } +} + fn estimate_tokens_for_message(message: &Message) -> usize { message .content @@ -688,6 +740,7 @@ pub fn merge_system_prompts( #[cfg(test)] mod tests { use super::*; + use serde_json::json; fn msg(role: &str, text: &str) -> Message { Message { @@ -869,6 +922,32 @@ mod tests { assert!(plan.pinned_indices.contains(&0)); } + #[test] + fn plan_compaction_pins_tool_calls_for_tool_results() { + let messages = vec![ + msg("user", "noise"), + Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({"path": "src/main.rs"}), + }], + }, + Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool-1".to_string(), + content: "ok src/main.rs".to_string(), + }], + }, + ]; + + let plan = plan_compaction(&messages, None, 1, None, None); + assert!(plan.pinned_indices.contains(&2)); + assert!(plan.pinned_indices.contains(&1)); + } + #[test] fn should_compact_ignores_fully_pinned_context() { let config = CompactionConfig {