Fix tool result ordering with compaction
This commit is contained in:
+73
-4
@@ -3,6 +3,7 @@
|
||||
//! Uses the OpenAI Responses API when available, falling back to Chat Completions
|
||||
//! if the Responses endpoint is unsupported by the target base URL.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
@@ -427,6 +428,7 @@ fn build_chat_messages(
|
||||
) -> Vec<Value> {
|
||||
let mut out = Vec::new();
|
||||
let include_reasoning = requires_reasoning_content(model);
|
||||
let mut pending_tool_calls: HashSet<String> = HashSet::new();
|
||||
|
||||
if let Some(instructions) = system_to_instructions(system.cloned()) {
|
||||
if !instructions.trim().is_empty() {
|
||||
@@ -442,7 +444,8 @@ fn build_chat_messages(
|
||||
let mut text_parts = Vec::new();
|
||||
let mut thinking_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
let mut tool_call_ids = Vec::new();
|
||||
let mut tool_results: Vec<(String, Value)> = Vec::new();
|
||||
|
||||
for block in &message.content {
|
||||
match block {
|
||||
@@ -458,16 +461,17 @@ fn build_chat_messages(
|
||||
"arguments": args,
|
||||
}
|
||||
}));
|
||||
tool_call_ids.push(id.clone());
|
||||
}
|
||||
ContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
} => {
|
||||
tool_results.push(json!({
|
||||
tool_results.push((tool_use_id.clone(), json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_use_id,
|
||||
"content": content,
|
||||
}));
|
||||
})));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -483,6 +487,9 @@ fn build_chat_messages(
|
||||
}
|
||||
if !tool_calls.is_empty() {
|
||||
msg["tool_calls"] = json!(tool_calls);
|
||||
pending_tool_calls = tool_call_ids.into_iter().collect();
|
||||
} else {
|
||||
pending_tool_calls.clear();
|
||||
}
|
||||
out.push(msg);
|
||||
} else if role == "user" {
|
||||
@@ -496,7 +503,21 @@ fn build_chat_messages(
|
||||
}
|
||||
|
||||
if !tool_results.is_empty() {
|
||||
out.extend(tool_results);
|
||||
if pending_tool_calls.is_empty() {
|
||||
logging::warn("Dropping tool results without matching tool_calls".to_string());
|
||||
} else {
|
||||
for (tool_id, tool_msg) in tool_results {
|
||||
if pending_tool_calls.remove(&tool_id) {
|
||||
out.push(tool_msg);
|
||||
} else {
|
||||
logging::warn(format!(
|
||||
"Dropping tool result for unknown tool_call_id: {tool_id}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role != "assistant" {
|
||||
pending_tool_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -778,6 +799,7 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn chat_messages_include_reasoning_content_for_reasoner() {
|
||||
@@ -819,4 +841,51 @@ mod tests {
|
||||
.expect("assistant message");
|
||||
assert!(assistant.get("reasoning_content").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_messages_drop_orphan_tool_results() {
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::ToolResult {
|
||||
tool_use_id: "tool-1".to_string(),
|
||||
content: "ok".to_string(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let out = build_chat_messages(None, &messages, "deepseek-chat");
|
||||
assert!(!out.iter().any(|value| {
|
||||
value.get("role").and_then(Value::as_str) == Some("tool")
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_messages_include_tool_results_when_call_present() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::ToolUse {
|
||||
id: "tool-1".to_string(),
|
||||
name: "list_dir".to_string(),
|
||||
input: json!({}),
|
||||
}],
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::ToolResult {
|
||||
tool_use_id: "tool-1".to_string(),
|
||||
content: "ok".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let out = build_chat_messages(None, &messages, "deepseek-chat");
|
||||
assert!(out.iter().any(|value| {
|
||||
value.get("role").and_then(Value::as_str) == Some("tool")
|
||||
}));
|
||||
let assistant = out
|
||||
.iter()
|
||||
.find(|value| value.get("role").and_then(Value::as_str) == Some("assistant"))
|
||||
.expect("assistant message");
|
||||
assert!(assistant.get("tool_calls").is_some());
|
||||
}
|
||||
}
|
||||
|
||||
+80
-1
@@ -4,7 +4,7 @@
|
||||
|
||||
use anyhow::Result;
|
||||
use regex::Regex;
|
||||
use std::collections::{BTreeSet, HashSet};
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::fmt::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::OnceLock;
|
||||
@@ -332,6 +332,9 @@ fn plan_compaction(
|
||||
pinned_indices.extend(pins.iter().copied().filter(|idx| *idx < len));
|
||||
}
|
||||
|
||||
// Ensure tool result messages are not kept without their corresponding tool call.
|
||||
enforce_tool_call_pairs(messages, &mut pinned_indices);
|
||||
|
||||
let summarize_indices = (0..len)
|
||||
.filter(|idx| !pinned_indices.contains(idx))
|
||||
.collect();
|
||||
@@ -343,6 +346,55 @@ fn plan_compaction(
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_tool_call_pairs(messages: &[Message], pinned_indices: &mut BTreeSet<usize>) {
|
||||
if pinned_indices.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut tool_call_indices: HashMap<String, usize> = HashMap::new();
|
||||
for (idx, msg) in messages.iter().enumerate() {
|
||||
for block in &msg.content {
|
||||
if let ContentBlock::ToolUse { id, .. } = block {
|
||||
tool_call_indices.insert(id.clone(), idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut to_add = Vec::new();
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
for &idx in pinned_indices.iter() {
|
||||
let msg = &messages[idx];
|
||||
let mut tool_ids = Vec::new();
|
||||
for block in &msg.content {
|
||||
if let ContentBlock::ToolResult { tool_use_id, .. } = block {
|
||||
tool_ids.push(tool_use_id.clone());
|
||||
}
|
||||
}
|
||||
if tool_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut found_any = false;
|
||||
for tool_id in tool_ids {
|
||||
if let Some(call_idx) = tool_call_indices.get(&tool_id).copied() {
|
||||
to_add.push(call_idx);
|
||||
found_any = true;
|
||||
}
|
||||
}
|
||||
if !found_any {
|
||||
to_remove.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
for idx in to_add {
|
||||
pinned_indices.insert(idx);
|
||||
}
|
||||
for idx in to_remove {
|
||||
pinned_indices.remove(&idx);
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_tokens_for_message(message: &Message) -> usize {
|
||||
message
|
||||
.content
|
||||
@@ -688,6 +740,7 @@ pub fn merge_system_prompts(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn msg(role: &str, text: &str) -> Message {
|
||||
Message {
|
||||
@@ -869,6 +922,32 @@ mod tests {
|
||||
assert!(plan.pinned_indices.contains(&0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_compaction_pins_tool_calls_for_tool_results() {
|
||||
let messages = vec![
|
||||
msg("user", "noise"),
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::ToolUse {
|
||||
id: "tool-1".to_string(),
|
||||
name: "read_file".to_string(),
|
||||
input: json!({"path": "src/main.rs"}),
|
||||
}],
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::ToolResult {
|
||||
tool_use_id: "tool-1".to_string(),
|
||||
content: "ok src/main.rs".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let plan = plan_compaction(&messages, None, 1, None, None);
|
||||
assert!(plan.pinned_indices.contains(&2));
|
||||
assert!(plan.pinned_indices.contains(&1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_compact_ignores_fully_pinned_context() {
|
||||
let config = CompactionConfig {
|
||||
|
||||
Reference in New Issue
Block a user