test(rlm): cover bridge batch and depth guard

This commit is contained in:
Hunter Bown
2026-05-01 03:09:05 -05:00
parent df53a22113
commit 84da3b7fc6
2 changed files with 141 additions and 26 deletions
+125 -26
View File
@@ -121,19 +121,8 @@ impl RlmBridge {
}
async fn dispatch_llm_batch(&self, prompts: Vec<String>, model: Option<String>) -> BatchResp {
if prompts.is_empty() {
return BatchResp { results: vec![] };
}
if prompts.len() > MAX_BATCH {
return BatchResp {
results: prompts
.iter()
.map(|_| SingleResp {
text: String::new(),
error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())),
})
.collect(),
};
if let Some(resp) = batch_guard(prompts.len()) {
return resp;
}
let model = Arc::new(
@@ -201,19 +190,8 @@ impl RlmBridge {
}
async fn dispatch_rlm_batch(&self, prompts: Vec<String>, model: Option<String>) -> BatchResp {
if prompts.is_empty() {
return BatchResp { results: vec![] };
}
if prompts.len() > MAX_BATCH {
return BatchResp {
results: prompts
.iter()
.map(|_| SingleResp {
text: String::new(),
error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())),
})
.collect(),
};
if let Some(resp) = batch_guard(prompts.len()) {
return resp;
}
let model = Arc::new(model);
@@ -227,6 +205,23 @@ impl RlmBridge {
}
}
fn batch_guard(prompt_count: usize) -> Option<BatchResp> {
if prompt_count == 0 {
return Some(BatchResp { results: vec![] });
}
if prompt_count > MAX_BATCH {
return Some(BatchResp {
results: (0..prompt_count)
.map(|_| SingleResp {
text: String::new(),
error: Some(format!("batch too large: {prompt_count} > {MAX_BATCH}")),
})
.collect(),
});
}
None
}
impl RpcDispatcher for RlmBridge {
fn dispatch<'a>(
&'a self,
@@ -255,3 +250,107 @@ impl RpcDispatcher for RlmBridge {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Config, ProviderConfig, ProvidersConfig};
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn client_for(server: &MockServer) -> DeepSeekClient {
let config = Config {
provider: Some("sglang".to_string()),
providers: Some(ProvidersConfig {
sglang: ProviderConfig {
base_url: Some(server.uri()),
..ProviderConfig::default()
},
..ProvidersConfig::default()
}),
..Config::default()
};
DeepSeekClient::new(&config).expect("test client")
}
fn chat_response(text: &str) -> serde_json::Value {
json!({
"id": "chatcmpl-test",
"model": "test-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text,
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 3,
"completion_tokens": 5
}
})
}
#[test]
fn batch_guard_allows_non_empty_batches_at_the_cap() {
assert!(batch_guard(MAX_BATCH).is_none());
}
#[test]
fn batch_guard_returns_empty_response_for_empty_batches() {
let response = batch_guard(0).expect("empty batch should be handled");
assert!(response.results.is_empty());
}
#[test]
fn batch_guard_returns_one_error_per_oversized_prompt() {
let response = batch_guard(MAX_BATCH + 2).expect("oversized batch should be handled");
assert_eq!(response.results.len(), MAX_BATCH + 2);
assert!(response.results.iter().all(|result| {
result.text.is_empty()
&& result
.error
.as_deref()
.is_some_and(|err| err.contains("batch too large"))
}));
}
#[tokio::test]
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable"))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")),
)
.mount(&server)
.await;
let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0);
let response = bridge
.dispatch(RpcRequest::Rlm {
prompt: "nested prompt".to_string(),
model: Some("override-model".to_string()),
})
.await;
match response {
RpcResponse::Single(single) => {
assert_eq!(single.text, "fallback answer");
assert!(single.error.is_none());
}
other => panic!("expected single response, got {other:?}"),
}
let usage = bridge.usage.lock().await;
assert_eq!(usage.input_tokens, 3);
assert_eq!(usage.output_tokens, 5);
}
}
+16
View File
@@ -835,6 +835,22 @@ mod tests {
assert!(text.contains("FINAL"));
}
#[test]
fn build_metadata_truncates_long_context_without_leaking_tail() {
let secret_tail = "DO_NOT_LEAK_CONTEXT_TAIL";
let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
let msg = build_metadata_message(&prompt, None, 0, None, None);
let text = extract_text_blocks(&msg.content);
assert!(text.contains(&format!("- Length: {} chars", prompt.chars().count())));
assert!(text.contains("- Preview: \""));
assert!(text.contains("..."));
assert!(
!text.contains(secret_tail),
"metadata leaked the non-preview tail of context"
);
}
#[test]
fn build_metadata_with_iteration_shows_previous_code() {
let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));