Fix tool result ordering with compaction

This commit is contained in:
Hunter Bown
2026-01-28 09:39:38 -06:00
parent cbe2a30ea4
commit 9bbe82e7f4
2 changed files with 153 additions and 5 deletions
+73 -4
View File
@@ -3,6 +3,7 @@
//! Uses the OpenAI Responses API when available, falling back to Chat Completions
//! if the Responses endpoint is unsupported by the target base URL.
use std::collections::HashSet;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
@@ -427,6 +428,7 @@ fn build_chat_messages(
) -> Vec<Value> {
let mut out = Vec::new();
let include_reasoning = requires_reasoning_content(model);
let mut pending_tool_calls: HashSet<String> = HashSet::new();
if let Some(instructions) = system_to_instructions(system.cloned()) {
if !instructions.trim().is_empty() {
@@ -442,7 +444,8 @@ fn build_chat_messages(
let mut text_parts = Vec::new();
let mut thinking_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut tool_results = Vec::new();
let mut tool_call_ids = Vec::new();
let mut tool_results: Vec<(String, Value)> = Vec::new();
for block in &message.content {
match block {
@@ -458,16 +461,17 @@ fn build_chat_messages(
"arguments": args,
}
}));
tool_call_ids.push(id.clone());
}
ContentBlock::ToolResult {
tool_use_id,
content,
} => {
tool_results.push(json!({
tool_results.push((tool_use_id.clone(), json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
}));
})));
}
}
}
@@ -483,6 +487,9 @@ fn build_chat_messages(
}
if !tool_calls.is_empty() {
msg["tool_calls"] = json!(tool_calls);
pending_tool_calls = tool_call_ids.into_iter().collect();
} else {
pending_tool_calls.clear();
}
out.push(msg);
} else if role == "user" {
@@ -496,7 +503,21 @@ fn build_chat_messages(
}
if !tool_results.is_empty() {
out.extend(tool_results);
if pending_tool_calls.is_empty() {
logging::warn("Dropping tool results without matching tool_calls".to_string());
} else {
for (tool_id, tool_msg) in tool_results {
if pending_tool_calls.remove(&tool_id) {
out.push(tool_msg);
} else {
logging::warn(format!(
"Dropping tool result for unknown tool_call_id: {tool_id}"
));
}
}
}
} else if role != "assistant" {
pending_tool_calls.clear();
}
}
@@ -778,6 +799,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn chat_messages_include_reasoning_content_for_reasoner() {
@@ -819,4 +841,51 @@ mod tests {
.expect("assistant message");
assert!(assistant.get("reasoning_content").is_none());
}
#[test]
fn chat_messages_drop_orphan_tool_results() {
let messages = vec![Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "tool-1".to_string(),
content: "ok".to_string(),
}],
}];
let out = build_chat_messages(None, &messages, "deepseek-chat");
assert!(!out.iter().any(|value| {
value.get("role").and_then(Value::as_str) == Some("tool")
}));
}
#[test]
fn chat_messages_include_tool_results_when_call_present() {
let messages = vec![
Message {
role: "assistant".to_string(),
content: vec![ContentBlock::ToolUse {
id: "tool-1".to_string(),
name: "list_dir".to_string(),
input: json!({}),
}],
},
Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "tool-1".to_string(),
content: "ok".to_string(),
}],
},
];
let out = build_chat_messages(None, &messages, "deepseek-chat");
assert!(out.iter().any(|value| {
value.get("role").and_then(Value::as_str) == Some("tool")
}));
let assistant = out
.iter()
.find(|value| value.get("role").and_then(Value::as_str) == Some("assistant"))
.expect("assistant message");
assert!(assistant.get("tool_calls").is_some());
}
}
+80 -1
View File
@@ -4,7 +4,7 @@
use anyhow::Result;
use regex::Regex;
use std::collections::{BTreeSet, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Write;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
@@ -332,6 +332,9 @@ fn plan_compaction(
pinned_indices.extend(pins.iter().copied().filter(|idx| *idx < len));
}
// Ensure tool result messages are not kept without their corresponding tool call.
enforce_tool_call_pairs(messages, &mut pinned_indices);
let summarize_indices = (0..len)
.filter(|idx| !pinned_indices.contains(idx))
.collect();
@@ -343,6 +346,55 @@ fn plan_compaction(
}
}
fn enforce_tool_call_pairs(messages: &[Message], pinned_indices: &mut BTreeSet<usize>) {
if pinned_indices.is_empty() {
return;
}
let mut tool_call_indices: HashMap<String, usize> = HashMap::new();
for (idx, msg) in messages.iter().enumerate() {
for block in &msg.content {
if let ContentBlock::ToolUse { id, .. } = block {
tool_call_indices.insert(id.clone(), idx);
}
}
}
let mut to_add = Vec::new();
let mut to_remove = Vec::new();
for &idx in pinned_indices.iter() {
let msg = &messages[idx];
let mut tool_ids = Vec::new();
for block in &msg.content {
if let ContentBlock::ToolResult { tool_use_id, .. } = block {
tool_ids.push(tool_use_id.clone());
}
}
if tool_ids.is_empty() {
continue;
}
let mut found_any = false;
for tool_id in tool_ids {
if let Some(call_idx) = tool_call_indices.get(&tool_id).copied() {
to_add.push(call_idx);
found_any = true;
}
}
if !found_any {
to_remove.push(idx);
}
}
for idx in to_add {
pinned_indices.insert(idx);
}
for idx in to_remove {
pinned_indices.remove(&idx);
}
}
fn estimate_tokens_for_message(message: &Message) -> usize {
message
.content
@@ -688,6 +740,7 @@ pub fn merge_system_prompts(
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn msg(role: &str, text: &str) -> Message {
Message {
@@ -869,6 +922,32 @@ mod tests {
assert!(plan.pinned_indices.contains(&0));
}
#[test]
fn plan_compaction_pins_tool_calls_for_tool_results() {
let messages = vec![
msg("user", "noise"),
Message {
role: "assistant".to_string(),
content: vec![ContentBlock::ToolUse {
id: "tool-1".to_string(),
name: "read_file".to_string(),
input: json!({"path": "src/main.rs"}),
}],
},
Message {
role: "user".to_string(),
content: vec![ContentBlock::ToolResult {
tool_use_id: "tool-1".to_string(),
content: "ok src/main.rs".to_string(),
}],
},
];
let plan = plan_compaction(&messages, None, 1, None, None);
assert!(plan.pinned_indices.contains(&2));
assert!(plan.pinned_indices.contains(&1));
}
#[test]
fn should_compact_ignores_fully_pinned_context() {
let config = CompactionConfig {