test(rlm): cover bridge batch and depth guard
This commit is contained in:
+125
-26
@@ -121,19 +121,8 @@ impl RlmBridge {
|
||||
}
|
||||
|
||||
async fn dispatch_llm_batch(&self, prompts: Vec<String>, model: Option<String>) -> BatchResp {
|
||||
if prompts.is_empty() {
|
||||
return BatchResp { results: vec![] };
|
||||
}
|
||||
if prompts.len() > MAX_BATCH {
|
||||
return BatchResp {
|
||||
results: prompts
|
||||
.iter()
|
||||
.map(|_| SingleResp {
|
||||
text: String::new(),
|
||||
error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())),
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
if let Some(resp) = batch_guard(prompts.len()) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
let model = Arc::new(
|
||||
@@ -201,19 +190,8 @@ impl RlmBridge {
|
||||
}
|
||||
|
||||
async fn dispatch_rlm_batch(&self, prompts: Vec<String>, model: Option<String>) -> BatchResp {
|
||||
if prompts.is_empty() {
|
||||
return BatchResp { results: vec![] };
|
||||
}
|
||||
if prompts.len() > MAX_BATCH {
|
||||
return BatchResp {
|
||||
results: prompts
|
||||
.iter()
|
||||
.map(|_| SingleResp {
|
||||
text: String::new(),
|
||||
error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())),
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
if let Some(resp) = batch_guard(prompts.len()) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
let model = Arc::new(model);
|
||||
@@ -227,6 +205,23 @@ impl RlmBridge {
|
||||
}
|
||||
}
|
||||
|
||||
fn batch_guard(prompt_count: usize) -> Option<BatchResp> {
|
||||
if prompt_count == 0 {
|
||||
return Some(BatchResp { results: vec![] });
|
||||
}
|
||||
if prompt_count > MAX_BATCH {
|
||||
return Some(BatchResp {
|
||||
results: (0..prompt_count)
|
||||
.map(|_| SingleResp {
|
||||
text: String::new(),
|
||||
error: Some(format!("batch too large: {prompt_count} > {MAX_BATCH}")),
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
impl RpcDispatcher for RlmBridge {
|
||||
fn dispatch<'a>(
|
||||
&'a self,
|
||||
@@ -255,3 +250,107 @@ impl RpcDispatcher for RlmBridge {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{Config, ProviderConfig, ProvidersConfig};
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
fn client_for(server: &MockServer) -> DeepSeekClient {
|
||||
let config = Config {
|
||||
provider: Some("sglang".to_string()),
|
||||
providers: Some(ProvidersConfig {
|
||||
sglang: ProviderConfig {
|
||||
base_url: Some(server.uri()),
|
||||
..ProviderConfig::default()
|
||||
},
|
||||
..ProvidersConfig::default()
|
||||
}),
|
||||
..Config::default()
|
||||
};
|
||||
DeepSeekClient::new(&config).expect("test client")
|
||||
}
|
||||
|
||||
fn chat_response(text: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"id": "chatcmpl-test",
|
||||
"model": "test-model",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 5
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_guard_allows_non_empty_batches_at_the_cap() {
|
||||
assert!(batch_guard(MAX_BATCH).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_guard_returns_empty_response_for_empty_batches() {
|
||||
let response = batch_guard(0).expect("empty batch should be handled");
|
||||
assert!(response.results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_guard_returns_one_error_per_oversized_prompt() {
|
||||
let response = batch_guard(MAX_BATCH + 2).expect("oversized batch should be handled");
|
||||
assert_eq!(response.results.len(), MAX_BATCH + 2);
|
||||
assert!(response.results.iter().all(|result| {
|
||||
result.text.is_empty()
|
||||
&& result
|
||||
.error
|
||||
.as_deref()
|
||||
.is_some_and(|err| err.contains("batch too large"))
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable"))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0);
|
||||
let response = bridge
|
||||
.dispatch(RpcRequest::Rlm {
|
||||
prompt: "nested prompt".to_string(),
|
||||
model: Some("override-model".to_string()),
|
||||
})
|
||||
.await;
|
||||
|
||||
match response {
|
||||
RpcResponse::Single(single) => {
|
||||
assert_eq!(single.text, "fallback answer");
|
||||
assert!(single.error.is_none());
|
||||
}
|
||||
other => panic!("expected single response, got {other:?}"),
|
||||
}
|
||||
|
||||
let usage = bridge.usage.lock().await;
|
||||
assert_eq!(usage.input_tokens, 3);
|
||||
assert_eq!(usage.output_tokens, 5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -835,6 +835,22 @@ mod tests {
|
||||
assert!(text.contains("FINAL"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_metadata_truncates_long_context_without_leaking_tail() {
|
||||
let secret_tail = "DO_NOT_LEAK_CONTEXT_TAIL";
|
||||
let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
|
||||
let msg = build_metadata_message(&prompt, None, 0, None, None);
|
||||
let text = extract_text_blocks(&msg.content);
|
||||
|
||||
assert!(text.contains(&format!("- Length: {} chars", prompt.chars().count())));
|
||||
assert!(text.contains("- Preview: \""));
|
||||
assert!(text.contains("..."));
|
||||
assert!(
|
||||
!text.contains(secret_tail),
|
||||
"metadata leaked the non-preview tail of context"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_metadata_with_iteration_shows_previous_code() {
|
||||
let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));
|
||||
|
||||
Reference in New Issue
Block a user