hello world
This commit is contained in:
commit
c99507ca1e
84 changed files with 54252 additions and 0 deletions
20
src-rust/crates/api/Cargo.toml
Normal file
20
src-rust/crates/api/Cargo.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "cc-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
995
src-rust/crates/api/src/lib.rs
Normal file
995
src-rust/crates/api/src/lib.rs
Normal file
|
|
@ -0,0 +1,995 @@
|
|||
// cc-api: Anthropic API client with streaming SSE support for the Claude Code
|
||||
// Rust port.
|
||||
//
|
||||
// Handles:
|
||||
// - POST /v1/messages with streaming
|
||||
// - SSE event parsing (message_start, content_block_start, content_block_delta,
|
||||
// content_block_stop, message_delta, message_stop, error)
|
||||
// - Delta types: text_delta, input_json_delta, thinking_delta, signature_delta
|
||||
// - Rate-limit (429) and overloaded (529) retry with exponential back-off
|
||||
// - Authentication via API key from env or config
|
||||
|
||||
use cc_core::constants::{ANTHROPIC_API_VERSION, ANTHROPIC_BETA_HEADER};
|
||||
use cc_core::error::ClaudeError;
|
||||
use cc_core::types::{ContentBlock, Message, MessageContent, Role, ToolDefinition, UsageInfo};
|
||||
use futures::StreamExt;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public re-exports
|
||||
// ---------------------------------------------------------------------------
|
||||
pub use client::AnthropicClient;
|
||||
pub use streaming::{StreamEvent, StreamHandler};
|
||||
pub use types::*;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// request / response types
|
||||
// ---------------------------------------------------------------------------
|
||||
pub mod types {
|
||||
use super::*;
|
||||
|
||||
/// The request body sent to `POST /v1/messages`.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CreateMessageRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<ApiMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<SystemPrompt>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ApiToolDefinition>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ThinkingConfig {
|
||||
#[serde(rename = "type")]
|
||||
pub thinking_type: String,
|
||||
pub budget_tokens: u32,
|
||||
}
|
||||
|
||||
impl ThinkingConfig {
|
||||
pub fn enabled(budget: u32) -> Self {
|
||||
Self {
|
||||
thinking_type: "enabled".to_string(),
|
||||
budget_tokens: budget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// System prompt - either a single string or structured blocks with cache.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SystemPrompt {
|
||||
Text(String),
|
||||
Blocks(Vec<SystemBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SystemBlock {
|
||||
#[serde(rename = "type")]
|
||||
pub block_type: String,
|
||||
pub text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheControl {
|
||||
#[serde(rename = "type")]
|
||||
pub control_type: String,
|
||||
}
|
||||
|
||||
impl CacheControl {
|
||||
pub fn ephemeral() -> Self {
|
||||
Self {
|
||||
control_type: "ephemeral".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simplified message type for the API wire format.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiMessage {
|
||||
pub role: String,
|
||||
pub content: Value,
|
||||
}
|
||||
|
||||
impl From<&Message> for ApiMessage {
|
||||
fn from(msg: &Message) -> Self {
|
||||
let role = match msg.role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
};
|
||||
let content = match &msg.content {
|
||||
MessageContent::Text(t) => Value::String(t.clone()),
|
||||
MessageContent::Blocks(blocks) => {
|
||||
serde_json::to_value(blocks).unwrap_or(Value::Null)
|
||||
}
|
||||
};
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool definition in the API wire format.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
impl From<&ToolDefinition> for ApiToolDefinition {
|
||||
fn from(td: &ToolDefinition) -> Self {
|
||||
Self {
|
||||
name: td.name.clone(),
|
||||
description: td.description.clone(),
|
||||
input_schema: td.input_schema.clone(),
|
||||
cache_control: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Non-streaming response from `POST /v1/messages`.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct CreateMessageResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<Value>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: UsageInfo,
|
||||
}
|
||||
|
||||
/// Error body returned by the API.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub error: ApiErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiErrorDetail {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SSE streaming types
|
||||
// ---------------------------------------------------------------------------
|
||||
pub mod streaming {
|
||||
use super::*;
|
||||
|
||||
/// Events emitted by the streaming SSE parser.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// The overall message has started; carries the message id and model.
|
||||
MessageStart {
|
||||
id: String,
|
||||
model: String,
|
||||
usage: UsageInfo,
|
||||
},
|
||||
/// A new content block has begun.
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: ContentBlock,
|
||||
},
|
||||
/// Incremental delta for an existing content block.
|
||||
ContentBlockDelta {
|
||||
index: usize,
|
||||
delta: ContentDelta,
|
||||
},
|
||||
/// A content block is finished.
|
||||
ContentBlockStop {
|
||||
index: usize,
|
||||
},
|
||||
/// Final message-level delta (stop_reason, usage).
|
||||
MessageDelta {
|
||||
stop_reason: Option<String>,
|
||||
usage: Option<UsageInfo>,
|
||||
},
|
||||
/// The message is complete.
|
||||
MessageStop,
|
||||
/// An error occurred during streaming.
|
||||
Error {
|
||||
error_type: String,
|
||||
message: String,
|
||||
},
|
||||
/// A ping/keep-alive event.
|
||||
Ping,
|
||||
}
|
||||
|
||||
/// The delta payload inside a `content_block_delta` event.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentDelta {
|
||||
TextDelta { text: String },
|
||||
InputJsonDelta { partial_json: String },
|
||||
ThinkingDelta { thinking: String },
|
||||
SignatureDelta { signature: String },
|
||||
}
|
||||
|
||||
/// Trait for anything that wants to consume streaming events in real time.
|
||||
pub trait StreamHandler: Send + Sync {
|
||||
fn on_event(&self, event: &StreamEvent);
|
||||
}
|
||||
|
||||
/// A no-op handler useful for non-interactive / batch mode.
|
||||
pub struct NullStreamHandler;
|
||||
impl StreamHandler for NullStreamHandler {
|
||||
fn on_event(&self, _event: &StreamEvent) {}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SSE line parser
|
||||
// ---------------------------------------------------------------------------
|
||||
mod sse_parser {
|
||||
/// Parsed SSE frame.
|
||||
#[derive(Debug)]
|
||||
pub struct SseFrame {
|
||||
pub event: Option<String>,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
/// Incrementally accumulates raw bytes/lines and yields complete frames.
|
||||
pub struct SseLineParser {
|
||||
event_type: Option<String>,
|
||||
data_buf: String,
|
||||
}
|
||||
|
||||
impl SseLineParser {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
event_type: None,
|
||||
data_buf: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed one line (without the trailing newline). Returns `Some(frame)`
|
||||
/// when a blank line signals the end of an event.
|
||||
pub fn feed_line(&mut self, line: &str) -> Option<SseFrame> {
|
||||
if line.is_empty() {
|
||||
// Blank line = end of event
|
||||
if self.data_buf.is_empty() && self.event_type.is_none() {
|
||||
return None; // spurious blank line
|
||||
}
|
||||
let frame = SseFrame {
|
||||
event: self.event_type.take(),
|
||||
data: std::mem::take(&mut self.data_buf),
|
||||
};
|
||||
return Some(frame);
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
self.event_type = Some(rest.trim().to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("data:") {
|
||||
if !self.data_buf.is_empty() {
|
||||
self.data_buf.push('\n');
|
||||
}
|
||||
self.data_buf.push_str(rest.trim());
|
||||
} else if line.starts_with(':') {
|
||||
// SSE comment / keep-alive – ignore
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic client
|
||||
// ---------------------------------------------------------------------------
|
||||
pub mod client {
|
||||
use super::*;
|
||||
|
||||
/// Configuration for the HTTP client.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub api_key: String,
|
||||
pub api_base: String,
|
||||
pub api_version: String,
|
||||
pub beta_features: String,
|
||||
pub max_retries: u32,
|
||||
pub initial_retry_delay: Duration,
|
||||
pub max_retry_delay: Duration,
|
||||
pub request_timeout: Duration,
|
||||
/// When true, send `Authorization: Bearer <api_key>` instead of `x-api-key`.
|
||||
/// Used for Claude.ai subscription (OAuth user:inference scope) tokens.
|
||||
pub use_bearer_auth: bool,
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
api_key: String::new(),
|
||||
api_base: cc_core::constants::ANTHROPIC_API_BASE.to_string(),
|
||||
api_version: ANTHROPIC_API_VERSION.to_string(),
|
||||
beta_features: ANTHROPIC_BETA_HEADER.to_string(),
|
||||
max_retries: 5,
|
||||
initial_retry_delay: Duration::from_secs(1),
|
||||
max_retry_delay: Duration::from_secs(60),
|
||||
request_timeout: Duration::from_secs(600),
|
||||
use_bearer_auth: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The main Anthropic API client.
|
||||
pub struct AnthropicClient {
|
||||
http: reqwest::Client,
|
||||
config: ClientConfig,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Build a new client. Panics if `config.api_key` is empty.
|
||||
pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
|
||||
if config.api_key.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY or pass --api-key."
|
||||
));
|
||||
}
|
||||
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(config.request_timeout)
|
||||
.build()?;
|
||||
|
||||
Ok(Self { http, config })
|
||||
}
|
||||
|
||||
/// Convenience constructor that resolves the key from config/env.
|
||||
pub fn from_config(cfg: &cc_core::config::Config) -> anyhow::Result<Self> {
|
||||
let api_key = cfg
|
||||
.resolve_api_key()
|
||||
.ok_or_else(|| anyhow::anyhow!("No API key found"))?;
|
||||
let api_base = cfg.resolve_api_base();
|
||||
|
||||
Self::new(ClientConfig {
|
||||
api_key,
|
||||
api_base,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Non-streaming create message --------------------------------
|
||||
|
||||
/// Send a non-streaming `POST /v1/messages` and return the full response.
|
||||
pub async fn create_message(
|
||||
&self,
|
||||
mut request: CreateMessageRequest,
|
||||
) -> Result<CreateMessageResponse, ClaudeError> {
|
||||
request.stream = false;
|
||||
let body = serde_json::to_value(&request).map_err(ClaudeError::Json)?;
|
||||
|
||||
let resp = self.send_with_retry(&body).await?;
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.map_err(ClaudeError::Http)?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(self.parse_api_error(status.as_u16(), &text));
|
||||
}
|
||||
|
||||
serde_json::from_str(&text).map_err(ClaudeError::Json)
|
||||
}
|
||||
|
||||
// ---- Streaming create message ------------------------------------
|
||||
|
||||
/// Send a streaming `POST /v1/messages`. Events are dispatched to the
|
||||
/// provided `handler` in real time, and also forwarded into the returned
|
||||
/// channel so the caller can drive a select loop.
|
||||
pub async fn create_message_stream(
|
||||
&self,
|
||||
mut request: CreateMessageRequest,
|
||||
handler: Arc<dyn StreamHandler>,
|
||||
) -> Result<mpsc::Receiver<StreamEvent>, ClaudeError> {
|
||||
request.stream = true;
|
||||
let body = serde_json::to_value(&request).map_err(ClaudeError::Json)?;
|
||||
|
||||
let resp = self.send_with_retry(&body).await?;
|
||||
let status = resp.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.map_err(ClaudeError::Http)?;
|
||||
return Err(self.parse_api_error(status.as_u16(), &text));
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel(256);
|
||||
|
||||
// Spawn a task that reads the SSE byte stream and emits events.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::process_sse_stream(resp, handler, tx.clone()).await {
|
||||
let _ = tx
|
||||
.send(StreamEvent::Error {
|
||||
error_type: "stream_error".into(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
// ---- Internal helpers --------------------------------------------
|
||||
|
||||
/// Build the common request and execute with retry logic.
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
body: &Value,
|
||||
) -> Result<reqwest::Response, ClaudeError> {
|
||||
let url = format!("{}/v1/messages", self.config.api_base);
|
||||
let mut attempts = 0u32;
|
||||
let mut delay = self.config.initial_retry_delay;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
|
||||
// Use Bearer auth for Claude.ai OAuth tokens; x-api-key for regular keys.
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("anthropic-version", &self.config.api_version)
|
||||
.header("anthropic-beta", &self.config.beta_features)
|
||||
.header("content-type", "application/json")
|
||||
.header("accept", "text/event-stream");
|
||||
req = if self.config.use_bearer_auth {
|
||||
req.header("Authorization", format!("Bearer {}", &self.config.api_key))
|
||||
} else {
|
||||
req.header("x-api-key", &self.config.api_key)
|
||||
};
|
||||
let req = req.json(body);
|
||||
|
||||
let resp = req.send().await.map_err(ClaudeError::Http)?;
|
||||
let status = resp.status().as_u16();
|
||||
|
||||
// 200-299: success
|
||||
if resp.status().is_success() {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
// 429 (rate limit) or 529 (overloaded): retry
|
||||
if (status == 429 || status == 529) && attempts <= self.config.max_retries {
|
||||
// Honour Retry-After header if present
|
||||
let retry_after = resp
|
||||
.headers()
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
|
||||
let wait = retry_after.unwrap_or(delay);
|
||||
warn!(
|
||||
status,
|
||||
attempt = attempts,
|
||||
wait_secs = wait.as_secs(),
|
||||
"Retryable API error, backing off"
|
||||
);
|
||||
tokio::time::sleep(wait).await;
|
||||
delay = (delay * 2).min(self.config.max_retry_delay);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Non-retryable error – return immediately
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(self.parse_api_error(status, &text));
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an API error body into a typed `ClaudeError`.
|
||||
fn parse_api_error(&self, status: u16, body: &str) -> ClaudeError {
|
||||
if let Ok(err) = serde_json::from_str::<ApiErrorResponse>(body) {
|
||||
match status {
|
||||
401 => ClaudeError::Auth(err.error.message),
|
||||
429 => ClaudeError::RateLimit,
|
||||
529 => ClaudeError::ApiStatus {
|
||||
status,
|
||||
message: format!("Overloaded: {}", err.error.message),
|
||||
},
|
||||
_ => ClaudeError::ApiStatus {
|
||||
status,
|
||||
message: err.error.message,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ClaudeError::ApiStatus {
|
||||
status,
|
||||
message: body.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read an SSE byte stream, parse frames, and emit `StreamEvent`s.
|
||||
async fn process_sse_stream(
|
||||
resp: reqwest::Response,
|
||||
handler: Arc<dyn StreamHandler>,
|
||||
tx: mpsc::Sender<StreamEvent>,
|
||||
) -> Result<(), ClaudeError> {
|
||||
use sse_parser::SseLineParser;
|
||||
|
||||
let mut parser = SseLineParser::new();
|
||||
let mut byte_stream = resp.bytes_stream();
|
||||
let mut leftover = String::new();
|
||||
|
||||
while let Some(chunk_result) = byte_stream.next().await {
|
||||
let chunk = chunk_result.map_err(ClaudeError::Http)?;
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
|
||||
// Prepend any leftover from the previous chunk
|
||||
let combined = if leftover.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let mut s = std::mem::take(&mut leftover);
|
||||
s.push_str(&text);
|
||||
s
|
||||
};
|
||||
|
||||
// Split into lines. If the chunk doesn't end with a newline
|
||||
// the last piece is an incomplete line – stash it.
|
||||
let mut lines: Vec<&str> = combined.split('\n').collect();
|
||||
if !combined.ends_with('\n') {
|
||||
leftover = lines.pop().unwrap_or("").to_string();
|
||||
}
|
||||
|
||||
for line in lines {
|
||||
let line = line.trim_end_matches('\r');
|
||||
if let Some(frame) = parser.feed_line(line) {
|
||||
if let Some(event) =
|
||||
Self::frame_to_event(&frame.event, &frame.data)
|
||||
{
|
||||
handler.on_event(&event);
|
||||
if tx.send(event).await.is_err() {
|
||||
// Receiver dropped – stop reading.
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a parsed SSE frame into a typed `StreamEvent`.
|
||||
fn frame_to_event(
|
||||
event_type: &Option<String>,
|
||||
data: &str,
|
||||
) -> Option<StreamEvent> {
|
||||
let event_name = event_type.as_deref().unwrap_or("");
|
||||
|
||||
match event_name {
|
||||
"ping" => Some(StreamEvent::Ping),
|
||||
|
||||
"message_start" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let msg = v.get("message")?;
|
||||
let id = msg.get("id")?.as_str()?.to_string();
|
||||
let model = msg.get("model")?.as_str()?.to_string();
|
||||
let usage = msg
|
||||
.get("usage")
|
||||
.and_then(|u| serde_json::from_value::<UsageInfo>(u.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
|
||||
Some(StreamEvent::MessageStart { id, model, usage })
|
||||
}
|
||||
|
||||
"content_block_start" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let index = v.get("index")?.as_u64()? as usize;
|
||||
let block_value = v.get("content_block")?;
|
||||
let content_block: ContentBlock =
|
||||
serde_json::from_value(block_value.clone()).ok()?;
|
||||
Some(StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
})
|
||||
}
|
||||
|
||||
"content_block_delta" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let index = v.get("index")?.as_u64()? as usize;
|
||||
let delta_value = v.get("delta")?;
|
||||
let delta: streaming::ContentDelta =
|
||||
serde_json::from_value(delta_value.clone()).ok()?;
|
||||
Some(StreamEvent::ContentBlockDelta { index, delta })
|
||||
}
|
||||
|
||||
"content_block_stop" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let index = v.get("index")?.as_u64()? as usize;
|
||||
Some(StreamEvent::ContentBlockStop { index })
|
||||
}
|
||||
|
||||
"message_delta" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let delta = v.get("delta")?;
|
||||
let stop_reason = delta
|
||||
.get("stop_reason")
|
||||
.and_then(|s| s.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let usage = v
|
||||
.get("usage")
|
||||
.and_then(|u| serde_json::from_value::<UsageInfo>(u.clone()).ok());
|
||||
Some(StreamEvent::MessageDelta { stop_reason, usage })
|
||||
}
|
||||
|
||||
"message_stop" => Some(StreamEvent::MessageStop),
|
||||
|
||||
"error" => {
|
||||
let v: Value = serde_json::from_str(data).ok()?;
|
||||
let error = v.get("error")?;
|
||||
let error_type = error
|
||||
.get("type")
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let message = error
|
||||
.get("message")
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string();
|
||||
Some(StreamEvent::Error {
|
||||
error_type,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!(event = event_name, "Unhandled SSE event type");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Convenience builder for CreateMessageRequest
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl CreateMessageRequest {
|
||||
/// Create a minimal request builder.
|
||||
pub fn builder(model: impl Into<String>, max_tokens: u32) -> CreateMessageRequestBuilder {
|
||||
CreateMessageRequestBuilder {
|
||||
model: model.into(),
|
||||
max_tokens,
|
||||
messages: vec![],
|
||||
system: None,
|
||||
tools: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stop_sequences: None,
|
||||
thinking: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CreateMessageRequestBuilder {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
messages: Vec<ApiMessage>,
|
||||
system: Option<SystemPrompt>,
|
||||
tools: Option<Vec<ApiToolDefinition>>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
top_k: Option<u32>,
|
||||
stop_sequences: Option<Vec<String>>,
|
||||
thinking: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
impl CreateMessageRequestBuilder {
|
||||
pub fn messages(mut self, msgs: Vec<ApiMessage>) -> Self {
|
||||
self.messages = msgs;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_message(mut self, msg: ApiMessage) -> Self {
|
||||
self.messages.push(msg);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn system(mut self, s: SystemPrompt) -> Self {
|
||||
self.system = Some(s);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn system_text(mut self, text: impl Into<String>) -> Self {
|
||||
self.system = Some(SystemPrompt::Text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<ApiToolDefinition>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, t: f32) -> Self {
|
||||
self.temperature = Some(t);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, p: f32) -> Self {
|
||||
self.top_p = Some(p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_k(mut self, k: u32) -> Self {
|
||||
self.top_k = Some(k);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop_sequences(mut self, seqs: Vec<String>) -> Self {
|
||||
self.stop_sequences = Some(seqs);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn thinking(mut self, config: ThinkingConfig) -> Self {
|
||||
self.thinking = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> CreateMessageRequest {
|
||||
CreateMessageRequest {
|
||||
model: self.model,
|
||||
max_tokens: self.max_tokens,
|
||||
messages: self.messages,
|
||||
system: self.system,
|
||||
tools: self.tools,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
top_k: self.top_k,
|
||||
stop_sequences: self.stop_sequences,
|
||||
stream: true,
|
||||
thinking: self.thinking,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Accumulated message builder – reconstructs a full Message from stream events
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Collects streaming events and produces a finished `Message` plus usage info.
|
||||
pub struct StreamAccumulator {
|
||||
id: Option<String>,
|
||||
model: Option<String>,
|
||||
content_blocks: Vec<ContentBlock>,
|
||||
/// Partial accumulators keyed by block index.
|
||||
partials: std::collections::HashMap<usize, PartialBlock>,
|
||||
stop_reason: Option<String>,
|
||||
usage: UsageInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PartialBlock {
|
||||
Text(String),
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
json_buf: String,
|
||||
},
|
||||
Thinking {
|
||||
thinking_buf: String,
|
||||
signature_buf: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl StreamAccumulator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
model: None,
|
||||
content_blocks: vec![],
|
||||
partials: Default::default(),
|
||||
stop_reason: None,
|
||||
usage: UsageInfo::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed a stream event. Call this for every event received from the stream.
|
||||
pub fn on_event(&mut self, event: &StreamEvent) {
|
||||
match event {
|
||||
StreamEvent::MessageStart { id, model, usage } => {
|
||||
self.id = Some(id.clone());
|
||||
self.model = Some(model.clone());
|
||||
self.usage = usage.clone();
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
let partial = match content_block {
|
||||
ContentBlock::Text { text } => PartialBlock::Text(text.clone()),
|
||||
ContentBlock::ToolUse { id, name, .. } => PartialBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
json_buf: String::new(),
|
||||
},
|
||||
ContentBlock::Thinking { thinking, signature } => PartialBlock::Thinking {
|
||||
thinking_buf: thinking.clone(),
|
||||
signature_buf: signature.clone(),
|
||||
},
|
||||
_ => return,
|
||||
};
|
||||
self.partials.insert(*index, partial);
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockDelta { index, delta } => {
|
||||
if let Some(partial) = self.partials.get_mut(index) {
|
||||
match (partial, delta) {
|
||||
(PartialBlock::Text(buf), streaming::ContentDelta::TextDelta { text }) => {
|
||||
buf.push_str(text);
|
||||
}
|
||||
(
|
||||
PartialBlock::ToolUse { json_buf, .. },
|
||||
streaming::ContentDelta::InputJsonDelta { partial_json },
|
||||
) => {
|
||||
json_buf.push_str(partial_json);
|
||||
}
|
||||
(
|
||||
PartialBlock::Thinking { thinking_buf, .. },
|
||||
streaming::ContentDelta::ThinkingDelta { thinking },
|
||||
) => {
|
||||
thinking_buf.push_str(thinking);
|
||||
}
|
||||
(
|
||||
PartialBlock::Thinking { signature_buf, .. },
|
||||
streaming::ContentDelta::SignatureDelta { signature },
|
||||
) => {
|
||||
signature_buf.push_str(signature);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockStop { index } => {
|
||||
if let Some(partial) = self.partials.remove(index) {
|
||||
let block = match partial {
|
||||
PartialBlock::Text(text) => ContentBlock::Text { text },
|
||||
PartialBlock::ToolUse { id, name, json_buf } => {
|
||||
let input = serde_json::from_str(&json_buf)
|
||||
.unwrap_or(Value::Object(Default::default()));
|
||||
ContentBlock::ToolUse { id, name, input }
|
||||
}
|
||||
PartialBlock::Thinking {
|
||||
thinking_buf,
|
||||
signature_buf,
|
||||
} => ContentBlock::Thinking {
|
||||
thinking: thinking_buf,
|
||||
signature: signature_buf,
|
||||
},
|
||||
};
|
||||
self.content_blocks.push(block);
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::MessageDelta { stop_reason, usage } => {
|
||||
if let Some(sr) = stop_reason {
|
||||
self.stop_reason = Some(sr.clone());
|
||||
}
|
||||
if let Some(u) = usage {
|
||||
// The delta usage usually only has output_tokens;
|
||||
// add them to the running total.
|
||||
self.usage.output_tokens += u.output_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::MessageStop => {}
|
||||
StreamEvent::Ping => {}
|
||||
StreamEvent::Error { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize and produce the accumulated `Message`.
|
||||
pub fn finish(self) -> (Message, UsageInfo, Option<String>) {
|
||||
let msg = Message::assistant_blocks(self.content_blocks);
|
||||
(msg, self.usage, self.stop_reason)
|
||||
}
|
||||
|
||||
pub fn stop_reason(&self) -> Option<&str> {
|
||||
self.stop_reason.as_deref()
|
||||
}
|
||||
|
||||
pub fn usage(&self) -> &UsageInfo {
|
||||
&self.usage
|
||||
}
|
||||
|
||||
pub fn model(&self) -> Option<&str> {
|
||||
self.model.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sse_parser_basic() {
|
||||
let mut parser = sse_parser::SseLineParser::new();
|
||||
assert!(parser.feed_line("event: message_start").is_none());
|
||||
assert!(parser
|
||||
.feed_line(r#"data: {"message":{"id":"m1","model":"claude","usage":{"input_tokens":0,"output_tokens":0}}}"#)
|
||||
.is_none());
|
||||
let frame = parser.feed_line("").expect("should produce frame");
|
||||
assert_eq!(frame.event.as_deref(), Some("message_start"));
|
||||
assert!(frame.data.contains("m1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_message_request_builder() {
|
||||
let req = CreateMessageRequest::builder("claude-opus-4-6", 4096)
|
||||
.system_text("You are helpful.")
|
||||
.temperature(0.7)
|
||||
.build();
|
||||
assert_eq!(req.model, "claude-opus-4-6");
|
||||
assert_eq!(req.max_tokens, 4096);
|
||||
assert!(req.stream);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_accumulator_text() {
|
||||
let mut acc = StreamAccumulator::new();
|
||||
acc.on_event(&StreamEvent::MessageStart {
|
||||
id: "m1".into(),
|
||||
model: "claude".into(),
|
||||
usage: UsageInfo::default(),
|
||||
});
|
||||
acc.on_event(&StreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: ContentBlock::Text {
|
||||
text: String::new(),
|
||||
},
|
||||
});
|
||||
acc.on_event(&StreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: streaming::ContentDelta::TextDelta {
|
||||
text: "Hello ".into(),
|
||||
},
|
||||
});
|
||||
acc.on_event(&StreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: streaming::ContentDelta::TextDelta {
|
||||
text: "world!".into(),
|
||||
},
|
||||
});
|
||||
acc.on_event(&StreamEvent::ContentBlockStop { index: 0 });
|
||||
acc.on_event(&StreamEvent::MessageDelta {
|
||||
stop_reason: Some("end_turn".into()),
|
||||
usage: None,
|
||||
});
|
||||
acc.on_event(&StreamEvent::MessageStop);
|
||||
|
||||
let (msg, _usage, stop) = acc.finish();
|
||||
assert_eq!(msg.get_text(), Some("Hello world!"));
|
||||
assert_eq!(stop.as_deref(), Some("end_turn"));
|
||||
}
|
||||
}
|
||||
29
src-rust/crates/bridge/Cargo.toml
Normal file
29
src-rust/crates/bridge/Cargo.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "cc-bridge"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-query = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
hostname = "0.4"
|
||||
998
src-rust/crates/bridge/src/lib.rs
Normal file
998
src-rust/crates/bridge/src/lib.rs
Normal file
|
|
@ -0,0 +1,998 @@
|
|||
// cc-bridge: Remote control bridge implementation.
|
||||
//
|
||||
// The bridge connects the local Claude Code CLI to the claude.ai web UI,
|
||||
// enabling mobile/web-initiated sessions. This module implements:
|
||||
//
|
||||
// - Bridge configuration management (env-var and defaults)
|
||||
// - Device fingerprinting for trusted-device identification
|
||||
// - JWT decode/expiry utilities (client-side, no signature verification)
|
||||
// - Session lifecycle (register, poll, upload events, deregister)
|
||||
// - Message and event protocol types for bidirectional communication
|
||||
// - Long-polling loop with exponential backoff and cancellation
|
||||
// - Public `start_bridge` API that spawns background task and returns channels
|
||||
//
|
||||
// Architecture mirrors the TypeScript bridge (bridgeMain.ts / bridgeApi.ts),
|
||||
// adapted to idiomatic Rust async with tokio channels and reqwest.
|
||||
|
||||
#![warn(clippy::all)]
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JWT utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decoded claims from a session-ingress JWT.
|
||||
///
|
||||
/// Parsed client-side without signature verification — used only for
|
||||
/// expiry checks and display, never for authorization decisions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JwtClaims {
|
||||
/// Subject (usually user / device identifier).
|
||||
pub sub: Option<String>,
|
||||
/// Expiry Unix timestamp (seconds).
|
||||
pub exp: Option<i64>,
|
||||
/// Issued-at Unix timestamp (seconds).
|
||||
pub iat: Option<i64>,
|
||||
/// Trusted-device identifier embedded by the server.
|
||||
pub device_id: Option<String>,
|
||||
/// Session identifier embedded by the server.
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
impl JwtClaims {
|
||||
/// Decode a JWT payload segment without verifying the signature.
|
||||
///
|
||||
/// Strips the `sk-ant-si-` session-ingress prefix if present, then
|
||||
/// base64url-decodes the second `.`-separated segment and JSON-parses it.
|
||||
/// Returns an error if the token is malformed or the JSON is invalid.
|
||||
pub fn decode(token: &str) -> anyhow::Result<Self> {
|
||||
// Strip session-ingress prefix used by Anthropic's ingress tokens.
|
||||
let jwt = if token.starts_with("sk-ant-si-") {
|
||||
&token["sk-ant-si-".len()..]
|
||||
} else {
|
||||
token
|
||||
};
|
||||
|
||||
let parts: Vec<&str> = jwt.split('.').collect();
|
||||
if parts.len() < 2 {
|
||||
anyhow::bail!("Invalid JWT: expected at least 2 dot-separated segments");
|
||||
}
|
||||
|
||||
let raw = URL_SAFE_NO_PAD
|
||||
.decode(parts[1])
|
||||
.context("JWT payload is not valid base64url")?;
|
||||
|
||||
serde_json::from_slice::<Self>(&raw)
|
||||
.context("JWT payload is not valid JSON matching JwtClaims")
|
||||
}
|
||||
|
||||
/// Returns `true` if the `exp` claim is in the past.
|
||||
///
|
||||
/// When `exp` is absent the token is treated as non-expired (permissive
|
||||
/// default), matching the TypeScript behaviour in `jwtUtils.ts`.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(exp) = self.exp {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
exp < now
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remaining lifetime in seconds, or `None` if no `exp` claim or already
|
||||
/// expired.
|
||||
pub fn remaining_secs(&self) -> Option<i64> {
|
||||
let exp = self.exp?;
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let diff = exp - now;
|
||||
if diff > 0 { Some(diff) } else { None }
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode just the expiry timestamp from a raw JWT string.
|
||||
/// Returns `None` if the token is malformed or has no `exp` claim.
|
||||
pub fn decode_jwt_expiry(token: &str) -> Option<i64> {
|
||||
JwtClaims::decode(token).ok()?.exp
|
||||
}
|
||||
|
||||
/// Returns `true` if the token is expired (or unparseable).
|
||||
pub fn jwt_is_expired(token: &str) -> bool {
|
||||
JwtClaims::decode(token)
|
||||
.map(|c| c.is_expired())
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Device fingerprint
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute a stable device fingerprint from machine-local information.
|
||||
///
|
||||
/// Combines hostname, login user name, and home directory path, then SHA-256
|
||||
/// hashes them and returns the full hex digest. Matching the TypeScript
|
||||
/// `trustedDevice.ts` algorithm so fingerprints are consistent across the
|
||||
/// two implementations.
|
||||
pub fn device_fingerprint() -> String {
|
||||
let mut input = String::with_capacity(128);
|
||||
|
||||
if let Ok(host) = hostname::get() {
|
||||
input.push_str(&host.to_string_lossy());
|
||||
}
|
||||
input.push(':');
|
||||
|
||||
if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
|
||||
input.push_str(&user);
|
||||
}
|
||||
input.push(':');
|
||||
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
input.push_str(&home.display().to_string());
|
||||
}
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(input.as_bytes());
|
||||
hex::encode(hasher.finalize())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Runtime configuration for the bridge subsystem.
|
||||
///
|
||||
/// Built either from env vars via [`BridgeConfig::from_env`] or manually
|
||||
/// by the caller. The bridge is only active when both `enabled` is `true`
|
||||
/// **and** a `session_token` is present (see [`BridgeConfig::is_active`]).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BridgeConfig {
|
||||
/// Whether the bridge feature is turned on.
|
||||
pub enabled: bool,
|
||||
/// Base URL for bridge API calls (e.g. `https://claude.ai`).
|
||||
pub server_url: String,
|
||||
/// Stable device identifier (SHA-256 fingerprint or custom value).
|
||||
pub device_id: String,
|
||||
/// Bearer token (OAuth access token or session-ingress JWT).
|
||||
pub session_token: Option<String>,
|
||||
/// How long to wait between poll cycles (milliseconds).
|
||||
pub polling_interval_ms: u64,
|
||||
/// Maximum successive failed polls before the loop gives up.
|
||||
pub max_reconnect_attempts: u32,
|
||||
/// Per-session inactivity timeout in milliseconds (default 24 h).
|
||||
pub session_timeout_ms: u64,
|
||||
/// Runner version string sent on API calls for server-side diagnostics.
|
||||
pub runner_version: String,
|
||||
}
|
||||
|
||||
impl Default for BridgeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
server_url: "https://claude.ai".to_string(),
|
||||
device_id: device_fingerprint(),
|
||||
session_token: None,
|
||||
polling_interval_ms: 1_000,
|
||||
max_reconnect_attempts: 10,
|
||||
session_timeout_ms: 24 * 60 * 60 * 1_000,
|
||||
runner_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BridgeConfig {
|
||||
/// Build config from environment variables.
|
||||
///
|
||||
/// Recognised variables:
|
||||
/// - `CLAUDE_CODE_BRIDGE_URL` — overrides `server_url` and sets `enabled = true`
|
||||
/// - `CLAUDE_CODE_BRIDGE_TOKEN` / `CLAUDE_BRIDGE_OAUTH_TOKEN` — sets `session_token`
|
||||
/// - `CLAUDE_BRIDGE_BASE_URL` — alternative URL override (ant-only dev override)
|
||||
pub fn from_env() -> Self {
|
||||
let mut config = Self::default();
|
||||
|
||||
// URL override (sets enabled implicitly)
|
||||
if let Ok(url) = std::env::var("CLAUDE_CODE_BRIDGE_URL")
|
||||
.or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL"))
|
||||
{
|
||||
if !url.is_empty() {
|
||||
config.server_url = url;
|
||||
config.enabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Token override
|
||||
if let Ok(token) = std::env::var("CLAUDE_CODE_BRIDGE_TOKEN")
|
||||
.or_else(|_| std::env::var("CLAUDE_BRIDGE_OAUTH_TOKEN"))
|
||||
{
|
||||
if !token.is_empty() {
|
||||
config.session_token = Some(token);
|
||||
}
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
/// Returns `true` only when the bridge is both enabled and has a token.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.enabled && self.session_token.is_some()
|
||||
}
|
||||
|
||||
/// Validate that a server-provided ID is safe to interpolate into a URL
|
||||
/// path segment. Prevents path traversal (e.g. `../../admin`).
|
||||
///
|
||||
/// Mirrors `validateBridgeId()` in `bridgeApi.ts`.
|
||||
pub fn validate_id<'a>(id: &'a str, label: &str) -> anyhow::Result<&'a str> {
|
||||
static RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
||||
let re = RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
|
||||
if id.is_empty() || !re.is_match(id) {
|
||||
anyhow::bail!("Invalid {}: contains unsafe characters", label);
|
||||
}
|
||||
Ok(id)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission decision
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A tool-use permission decision sent by the web UI back to the CLI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PermissionDecision {
|
||||
Allow,
|
||||
AllowPermanently,
|
||||
Deny,
|
||||
DenyPermanently,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge message types (web UI → CLI)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A file attachment bundled with an inbound user message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BridgeAttachment {
|
||||
/// Display name (filename or label).
|
||||
pub name: String,
|
||||
/// Raw text or base64-encoded content.
|
||||
pub content: String,
|
||||
/// MIME type, e.g. `"text/plain"`.
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Messages flowing from the web UI into the CLI.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum BridgeMessage {
|
||||
/// A new user prompt from the web UI.
|
||||
UserMessage {
|
||||
content: String,
|
||||
session_id: String,
|
||||
message_id: String,
|
||||
#[serde(default)]
|
||||
attachments: Vec<BridgeAttachment>,
|
||||
},
|
||||
/// The web UI has responded to a permission request.
|
||||
PermissionResponse {
|
||||
request_id: String,
|
||||
tool_use_id: Option<String>,
|
||||
decision: PermissionDecision,
|
||||
},
|
||||
/// Cancel the in-progress operation for a session.
|
||||
Cancel {
|
||||
session_id: String,
|
||||
reason: Option<String>,
|
||||
},
|
||||
/// Keepalive — the CLI should respond with a `Pong` event.
|
||||
Ping,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge event types (CLI → web UI)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Token-budget / cost summary attached to `TurnComplete`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BridgeUsage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cost_usd: Option<f64>,
|
||||
}
|
||||
|
||||
/// Session connection state broadcast to the web UI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum BridgeSessionState {
|
||||
Connecting,
|
||||
Connected,
|
||||
Idle,
|
||||
Processing,
|
||||
Disconnected,
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Events flowing from the CLI up to the web UI.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum BridgeEvent {
|
||||
/// Streaming text delta for the current assistant turn.
|
||||
TextDelta {
|
||||
text: String,
|
||||
message_id: String,
|
||||
index: Option<usize>,
|
||||
},
|
||||
/// A tool call has started executing.
|
||||
ToolStart {
|
||||
tool_name: String,
|
||||
tool_id: String,
|
||||
input_preview: Option<String>,
|
||||
},
|
||||
/// A tool call has finished.
|
||||
ToolEnd {
|
||||
tool_name: String,
|
||||
tool_id: String,
|
||||
result: String,
|
||||
is_error: bool,
|
||||
},
|
||||
/// The CLI needs the web UI to approve a tool use.
|
||||
PermissionRequest {
|
||||
request_id: String,
|
||||
tool_use_id: String,
|
||||
tool_name: String,
|
||||
description: String,
|
||||
options: Vec<String>,
|
||||
},
|
||||
/// The current turn has completed.
|
||||
TurnComplete {
|
||||
message_id: String,
|
||||
stop_reason: String,
|
||||
usage: Option<BridgeUsage>,
|
||||
},
|
||||
/// A non-fatal diagnostic or user-visible error message.
|
||||
Error {
|
||||
message: String,
|
||||
code: Option<String>,
|
||||
},
|
||||
/// Response to a `Ping` message.
|
||||
Pong {
|
||||
server_time: Option<u64>,
|
||||
},
|
||||
/// Session lifecycle state change.
|
||||
SessionState {
|
||||
session_id: String,
|
||||
state: BridgeSessionState,
|
||||
},
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge session state (internal)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Internal connection state of a [`BridgeSession`].
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BridgeState {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
Running,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge session
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Active bridge session: owns the HTTP client, session credentials, and
|
||||
/// state. Runs the poll loop in a background tokio task.
|
||||
pub struct BridgeSession {
|
||||
config: BridgeConfig,
|
||||
session_id: String,
|
||||
state: Arc<RwLock<BridgeState>>,
|
||||
http: reqwest::Client,
|
||||
reconnect_count: u32,
|
||||
last_ping: Option<std::time::Instant>,
|
||||
}
|
||||
|
||||
impl BridgeSession {
|
||||
/// Create a new bridge session; generates a fresh UUID for `session_id`.
|
||||
pub fn new(config: BridgeConfig) -> Self {
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.user_agent(format!(
|
||||
"claude-code-rust/{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
))
|
||||
.build()
|
||||
.expect("Failed to build reqwest client");
|
||||
|
||||
Self {
|
||||
config,
|
||||
session_id,
|
||||
state: Arc::new(RwLock::new(BridgeState::Connecting)),
|
||||
http,
|
||||
reconnect_count: 0,
|
||||
last_ping: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn current_state(&self) -> BridgeState {
|
||||
self.state.read().clone()
|
||||
}
|
||||
|
||||
fn set_state(&self, s: BridgeState) {
|
||||
*self.state.write() = s;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Session registration / deregistration
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Register this bridge session with the CCR server.
|
||||
///
|
||||
/// POST `/api/claude_code/sessions` — mirrors the TypeScript
|
||||
/// `registerBridgeEnvironment` call in `bridgeApi.ts`.
|
||||
pub async fn register(&mut self) -> anyhow::Result<()> {
|
||||
let token = self
|
||||
.config
|
||||
.session_token
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Bridge register: no session token"))?;
|
||||
|
||||
let url = format!(
|
||||
"{}/api/claude_code/sessions",
|
||||
self.config.server_url
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"session_id": self.session_id,
|
||||
"device_id": self.config.device_id,
|
||||
"client_version": self.config.runner_version,
|
||||
});
|
||||
|
||||
debug!(session_id = %self.session_id, url = %url, "Registering bridge session");
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("x-environment-runner-version", &self.config.runner_version)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Bridge register: HTTP send failed")?;
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
match status {
|
||||
200 | 201 => {
|
||||
self.set_state(BridgeState::Connected);
|
||||
info!(session_id = %self.session_id, "Bridge session registered");
|
||||
Ok(())
|
||||
}
|
||||
401 | 403 => {
|
||||
self.set_state(BridgeState::Error(format!("Auth error: {status}")));
|
||||
anyhow::bail!("Bridge register: auth error ({})", status)
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("Bridge register: server returned {}", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deregister the session on clean shutdown.
|
||||
///
|
||||
/// DELETE `/api/claude_code/sessions/{id}` — best-effort; errors are
|
||||
/// logged and swallowed so they don't block process exit.
|
||||
pub async fn deregister(&self) {
|
||||
let Some(token) = self.config.session_token.as_deref() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let url = format!(
|
||||
"{}/api/claude_code/sessions/{}",
|
||||
self.config.server_url, self.session_id
|
||||
);
|
||||
|
||||
debug!(session_id = %self.session_id, "Deregistering bridge session");
|
||||
|
||||
match self
|
||||
.http
|
||||
.delete(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => {
|
||||
info!(session_id = %self.session_id, "Bridge session deregistered");
|
||||
}
|
||||
Ok(r) => {
|
||||
warn!(
|
||||
session_id = %self.session_id,
|
||||
status = %r.status(),
|
||||
"Bridge deregister returned non-success (ignored)"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
session_id = %self.session_id,
|
||||
error = %e,
|
||||
"Bridge deregister HTTP error (ignored)"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Polling
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Long-poll for incoming messages from the web UI.
|
||||
///
|
||||
/// GET `/api/claude_code/sessions/{id}/poll`
|
||||
///
|
||||
/// - `200` → JSON array of [`BridgeMessage`]; may be empty.
|
||||
/// - `204` → No messages; returns empty vec.
|
||||
/// - `401`/`403` → Auth failure; sets state to `Disconnected` and errors.
|
||||
async fn poll_messages(&self) -> anyhow::Result<Vec<BridgeMessage>> {
|
||||
let token = self
|
||||
.config
|
||||
.session_token
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Poll: no token"))?;
|
||||
|
||||
let url = format!(
|
||||
"{}/api/claude_code/sessions/{}/poll",
|
||||
self.config.server_url, self.session_id
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.timeout(std::time::Duration::from_secs(35))
|
||||
.send()
|
||||
.await
|
||||
.context("Bridge poll: HTTP send failed")?;
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
match status {
|
||||
200 => {
|
||||
let text = resp.text().await.context("Bridge poll: reading body")?;
|
||||
if text.trim().is_empty() || text.trim() == "[]" {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let msgs: Vec<BridgeMessage> =
|
||||
serde_json::from_str(&text).context("Bridge poll: JSON parse")?;
|
||||
Ok(msgs)
|
||||
}
|
||||
204 => Ok(vec![]),
|
||||
401 | 403 => {
|
||||
self.set_state(BridgeState::Error(format!("Auth error: {status}")));
|
||||
anyhow::bail!("Bridge poll: auth error ({})", status)
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("Bridge poll: server returned {}", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Event upload
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Batch-upload outgoing events to the web UI.
|
||||
///
|
||||
/// POST `/api/claude_code/sessions/{id}/events`
|
||||
async fn upload_events(&self, events: Vec<BridgeEvent>) -> anyhow::Result<()> {
|
||||
if events.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let token = self
|
||||
.config
|
||||
.session_token
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Upload: no token"))?;
|
||||
|
||||
let url = format!(
|
||||
"{}/api/claude_code/sessions/{}/events",
|
||||
self.config.server_url, self.session_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({ "events": events });
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Bridge upload: HTTP send failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status().as_u16();
|
||||
warn!(
|
||||
session_id = %self.session_id,
|
||||
status,
|
||||
count = events.len(),
|
||||
"Bridge event upload failed"
|
||||
);
|
||||
anyhow::bail!("Bridge upload: server returned {}", status);
|
||||
}
|
||||
|
||||
debug!(
|
||||
session_id = %self.session_id,
|
||||
count = events.len(),
|
||||
"Bridge events uploaded"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Main poll loop
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Run the bridge poll loop until `cancel` is triggered or a fatal error
|
||||
/// occurs.
|
||||
///
|
||||
/// On each iteration:
|
||||
/// 1. Drain any pending outgoing events and upload them in a batch.
|
||||
/// 2. Long-poll for incoming messages and forward them to `msg_tx`.
|
||||
/// 3. Back off exponentially on consecutive errors; give up after
|
||||
/// `config.max_reconnect_attempts`.
|
||||
/// 4. Sleep `polling_interval_ms` between successful cycles.
|
||||
pub async fn run_poll_loop(
|
||||
mut self,
|
||||
msg_tx: mpsc::Sender<BridgeMessage>,
|
||||
mut event_rx: mpsc::Receiver<BridgeEvent>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
info!(session_id = %self.session_id, "Bridge poll loop started");
|
||||
|
||||
let base_interval = std::time::Duration::from_millis(
|
||||
self.config.polling_interval_ms.max(500),
|
||||
);
|
||||
let max_backoff = std::time::Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
// Respect cancellation at the top of every iteration.
|
||||
if cancel.is_cancelled() {
|
||||
info!(session_id = %self.session_id, "Bridge poll loop cancelled");
|
||||
break;
|
||||
}
|
||||
|
||||
// --- Drain and upload pending events ---
|
||||
let mut events: Vec<BridgeEvent> = Vec::new();
|
||||
while let Ok(ev) = event_rx.try_recv() {
|
||||
events.push(ev);
|
||||
}
|
||||
if !events.is_empty() {
|
||||
if let Err(e) = self.upload_events(events).await {
|
||||
warn!(session_id = %self.session_id, error = %e, "Event upload error");
|
||||
}
|
||||
}
|
||||
|
||||
// --- Poll for incoming messages ---
|
||||
match self.poll_messages().await {
|
||||
Ok(messages) => {
|
||||
// Successful poll — reset reconnect counter.
|
||||
self.reconnect_count = 0;
|
||||
|
||||
for msg in messages {
|
||||
if msg_tx.send(msg).await.is_err() {
|
||||
debug!(
|
||||
session_id = %self.session_id,
|
||||
"Incoming message channel closed; stopping poll loop"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
session_id = %self.session_id,
|
||||
error = %e,
|
||||
reconnect_count = self.reconnect_count,
|
||||
"Bridge poll error"
|
||||
);
|
||||
|
||||
self.reconnect_count += 1;
|
||||
|
||||
if self.config.max_reconnect_attempts > 0
|
||||
&& self.reconnect_count >= self.config.max_reconnect_attempts
|
||||
{
|
||||
error!(
|
||||
session_id = %self.session_id,
|
||||
"Max bridge reconnect attempts ({}) reached; stopping",
|
||||
self.config.max_reconnect_attempts
|
||||
);
|
||||
self.set_state(BridgeState::Error("max reconnects exceeded".into()));
|
||||
break;
|
||||
}
|
||||
|
||||
// Exponential backoff capped at `max_backoff`.
|
||||
let backoff = (base_interval
|
||||
* 2u32.pow(self.reconnect_count.saturating_sub(1).min(5)))
|
||||
.min(max_backoff);
|
||||
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(backoff) => {}
|
||||
_ = cancel.cancelled() => {
|
||||
info!(
|
||||
session_id = %self.session_id,
|
||||
"Bridge cancelled during backoff sleep"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Wait for the next poll cycle ---
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(base_interval) => {}
|
||||
_ = cancel.cancelled() => {
|
||||
info!(
|
||||
session_id = %self.session_id,
|
||||
"Bridge cancelled during idle sleep"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Best-effort deregister on shutdown.
|
||||
self.deregister().await;
|
||||
info!(session_id = %self.session_id, "Bridge poll loop terminated");
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bridge manager
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// High-level manager wrapping configuration and a shared HTTP client.
|
||||
///
|
||||
/// Prefer [`start_bridge`] for the simple one-shot API.
|
||||
pub struct BridgeManager {
|
||||
config: BridgeConfig,
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl BridgeManager {
|
||||
pub fn new(config: BridgeConfig) -> anyhow::Result<Self> {
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION")))
|
||||
.build()
|
||||
.context("BridgeManager: failed to build HTTP client")?;
|
||||
Ok(Self { config, http })
|
||||
}
|
||||
|
||||
/// Start the bridge polling loop, returning channel endpoints and the
|
||||
/// session ID.
|
||||
///
|
||||
/// The background task runs until `cancel` is triggered.
|
||||
pub async fn start(
|
||||
&self,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<(
|
||||
mpsc::Receiver<BridgeMessage>,
|
||||
mpsc::Sender<BridgeEvent>,
|
||||
String,
|
||||
)> {
|
||||
start_bridge_with_client(self.config.clone(), self.http.clone(), cancel).await
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Start the bridge subsystem in a background task.
|
||||
///
|
||||
/// Registers a new session with the CCR server, then spawns a tokio task
|
||||
/// running the poll loop. Returns:
|
||||
/// - `msg_rx` — incoming messages from the web UI (e.g. user prompts).
|
||||
/// - `event_tx` — sender for outgoing events (e.g. text deltas, tool calls).
|
||||
/// - `session_id` — the UUID assigned to this session.
|
||||
///
|
||||
/// The background task runs until `cancel` is triggered or too many
|
||||
/// consecutive errors occur. On shutdown the session is deregistered.
|
||||
pub async fn start_bridge(
|
||||
config: BridgeConfig,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<(
|
||||
mpsc::Receiver<BridgeMessage>,
|
||||
mpsc::Sender<BridgeEvent>,
|
||||
String,
|
||||
)> {
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION")))
|
||||
.build()
|
||||
.context("start_bridge: failed to build HTTP client")?;
|
||||
|
||||
start_bridge_with_client(config, http, cancel).await
|
||||
}
|
||||
|
||||
async fn start_bridge_with_client(
|
||||
config: BridgeConfig,
|
||||
_http: reqwest::Client,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<(
|
||||
mpsc::Receiver<BridgeMessage>,
|
||||
mpsc::Sender<BridgeEvent>,
|
||||
String,
|
||||
)> {
|
||||
if !config.is_active() {
|
||||
anyhow::bail!("start_bridge: bridge is not active (enabled={}, token={})",
|
||||
config.enabled,
|
||||
config.session_token.is_some()
|
||||
);
|
||||
}
|
||||
|
||||
let mut session = BridgeSession::new(config);
|
||||
session
|
||||
.register()
|
||||
.await
|
||||
.context("start_bridge: session registration failed")?;
|
||||
|
||||
let session_id = session.session_id().to_string();
|
||||
|
||||
// Bounded channels — back-pressure prevents unbounded memory growth on a
|
||||
// slow consumer.
|
||||
let (msg_tx, msg_rx) = mpsc::channel::<BridgeMessage>(64);
|
||||
let (event_tx, event_rx) = mpsc::channel::<BridgeEvent>(256);
|
||||
|
||||
tokio::spawn(async move {
|
||||
session.run_poll_loop(msg_tx, event_rx, cancel).await;
|
||||
});
|
||||
|
||||
info!(session_id = %session_id, "Bridge started");
|
||||
Ok((msg_rx, event_tx, session_id))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Trusted device module (re-exported for external callers)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod trusted_device {
|
||||
/// Re-export the crate-level device fingerprint function.
|
||||
pub use super::device_fingerprint;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JWT module (re-exported for external callers)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod jwt {
|
||||
pub use super::{decode_jwt_expiry, jwt_is_expired, JwtClaims};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Re-exports
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Allow downstream crates to use reqwest types without a direct dep.
|
||||
pub use reqwest;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_device_fingerprint_is_non_empty() {
|
||||
let fp = device_fingerprint();
|
||||
assert!(!fp.is_empty(), "fingerprint should not be empty");
|
||||
// SHA-256 hex is always 64 chars
|
||||
assert_eq!(fp.len(), 64, "SHA-256 hex digest should be 64 chars");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_fingerprint_is_stable() {
|
||||
let a = device_fingerprint();
|
||||
let b = device_fingerprint();
|
||||
assert_eq!(a, b, "fingerprint must be deterministic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jwt_decode_invalid() {
|
||||
assert!(JwtClaims::decode("notajwt").is_err());
|
||||
assert!(JwtClaims::decode("only.two").is_ok() == false || true); // either way, must not panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jwt_expired_unparseable() {
|
||||
// Unparseable token defaults to expired=true
|
||||
assert!(jwt_is_expired("bad.token.here"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_config_default_not_active() {
|
||||
let cfg = BridgeConfig::default();
|
||||
assert!(!cfg.is_active(), "default config must not be active");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_config_with_token_still_needs_enabled() {
|
||||
let mut cfg = BridgeConfig::default();
|
||||
cfg.session_token = Some("tok".into());
|
||||
assert!(!cfg.is_active(), "needs enabled=true too");
|
||||
cfg.enabled = true;
|
||||
assert!(cfg.is_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_id_rejects_traversal() {
|
||||
assert!(BridgeConfig::validate_id("../../etc/passwd", "id").is_err());
|
||||
assert!(BridgeConfig::validate_id("abc123", "id").is_ok());
|
||||
assert!(BridgeConfig::validate_id("env_abc-123", "id").is_ok());
|
||||
assert!(BridgeConfig::validate_id("", "id").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_decision_serde() {
|
||||
let d = PermissionDecision::AllowPermanently;
|
||||
let s = serde_json::to_string(&d).unwrap();
|
||||
assert_eq!(s, r#""allow_permanently""#);
|
||||
let back: PermissionDecision = serde_json::from_str(&s).unwrap();
|
||||
assert_eq!(back, d);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_session_state_serde() {
|
||||
let s = BridgeSessionState::Processing;
|
||||
let j = serde_json::to_string(&s).unwrap();
|
||||
assert_eq!(j, r#""processing""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_message_serde_user_message() {
|
||||
let msg = BridgeMessage::UserMessage {
|
||||
content: "hello".into(),
|
||||
session_id: "s1".into(),
|
||||
message_id: "m1".into(),
|
||||
attachments: vec![],
|
||||
};
|
||||
let j = serde_json::to_string(&msg).unwrap();
|
||||
assert!(j.contains(r#""type":"user_message""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_event_text_delta_serde() {
|
||||
let ev = BridgeEvent::TextDelta {
|
||||
text: "hello world".into(),
|
||||
message_id: "m1".into(),
|
||||
index: Some(0),
|
||||
};
|
||||
let j = serde_json::to_string(&ev).unwrap();
|
||||
assert!(j.contains(r#""type":"text_delta""#));
|
||||
assert!(j.contains("hello world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_event_pong_serde() {
|
||||
let ev = BridgeEvent::Pong { server_time: Some(1_700_000_000) };
|
||||
let j = serde_json::to_string(&ev).unwrap();
|
||||
assert!(j.contains(r#""type":"pong""#));
|
||||
}
|
||||
}
|
||||
16
src-rust/crates/buddy/Cargo.toml
Normal file
16
src-rust/crates/buddy/Cargo.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "cc-buddy"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { path = "../core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
1116
src-rust/crates/buddy/src/lib.rs
Normal file
1116
src-rust/crates/buddy/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
31
src-rust/crates/cli/Cargo.toml
Normal file
31
src-rust/crates/cli/Cargo.toml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
[package]
|
||||
name = "claude-code"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "claude"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-tools = { workspace = true }
|
||||
cc-query = { workspace = true }
|
||||
cc-tui = { workspace = true }
|
||||
cc-commands = { workspace = true }
|
||||
cc-mcp = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
url = { workspace = true }
|
||||
crossterm = { workspace = true }
|
||||
1160
src-rust/crates/cli/src/main.rs
Normal file
1160
src-rust/crates/cli/src/main.rs
Normal file
File diff suppressed because it is too large
Load diff
447
src-rust/crates/cli/src/oauth_flow.rs
Normal file
447
src-rust/crates/cli/src/oauth_flow.rs
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
// OAuth 2.0 PKCE login flow for the Claude Code CLI.
|
||||
//
|
||||
// Implements the same flow as the TypeScript OAuthService + authLogin():
|
||||
// 1. Generate PKCE code_verifier / code_challenge / state
|
||||
// 2. Start a temporary localhost HTTP server on a random port
|
||||
// 3. Build auth URL; print for the user and attempt to open in browser
|
||||
// 4. Wait (with 60-second timeout) for:
|
||||
// a. Automatic redirect to localhost/callback, OR
|
||||
// b. User manually pastes the authorization code at the terminal
|
||||
// 5. Exchange the authorization code for tokens via POST to TOKEN_URL
|
||||
// 6. For Console flow: call create_api_key endpoint to get an API key
|
||||
// 7. Save OAuthTokens to ~/.claude/oauth_tokens.json
|
||||
// 8. Return the credential (API key or Bearer token)
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use cc_core::oauth::{self, OAuthTokens};
|
||||
use serde::Deserialize;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{debug, info, warn};
|
||||
#[allow(unused_imports)]
|
||||
use url::Url;
|
||||
|
||||
// ---- Token exchange response ------------------------------------------------
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenExchangeResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
expires_in: u64,
|
||||
#[serde(default)]
|
||||
scope: Option<String>,
|
||||
#[serde(default)]
|
||||
account: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
organization: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ---- API key creation response ----------------------------------------------
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreateApiKeyResponse {
|
||||
raw_key: Option<String>,
|
||||
}
|
||||
|
||||
// ---- Public entry point -----------------------------------------------------
|
||||
|
||||
/// Outcome of a completed login flow.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoginResult {
|
||||
/// The credential to use: either an API key (Console flow) or Bearer token (Claude.ai).
|
||||
pub credential: String,
|
||||
/// When true, present as `Authorization: Bearer <credential>`.
|
||||
pub use_bearer_auth: bool,
|
||||
/// Cached tokens saved to disk.
|
||||
pub tokens: OAuthTokens,
|
||||
}
|
||||
|
||||
/// Run the interactive OAuth PKCE login flow.
|
||||
///
|
||||
/// `login_with_claude_ai` selects the authorization endpoint:
|
||||
/// - `false` → Console endpoint (creates an API key)
|
||||
/// - `true` → Claude.ai endpoint (user:inference scope, Bearer auth)
|
||||
pub async fn run_oauth_login_flow(login_with_claude_ai: bool) -> anyhow::Result<LoginResult> {
|
||||
// 1. PKCE
|
||||
let code_verifier = oauth::generate_code_verifier();
|
||||
let code_challenge = oauth::generate_code_challenge(&code_verifier);
|
||||
let state = oauth::generate_state();
|
||||
|
||||
// 2. Bind random localhost port for the callback server
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.context("Failed to bind OAuth callback server")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
|
||||
// 3. Build auth URLs
|
||||
let authorize_base = if login_with_claude_ai {
|
||||
oauth::CLAUDE_AI_AUTHORIZE_URL
|
||||
} else {
|
||||
oauth::CONSOLE_AUTHORIZE_URL
|
||||
};
|
||||
let manual_url = oauth::build_auth_url(&authorize_base, &code_challenge, &state, port, true);
|
||||
let automatic_url = oauth::build_auth_url(&authorize_base, &code_challenge, &state, port, false);
|
||||
|
||||
// 4. Print URL and try to open browser
|
||||
println!("\nOpening browser for authentication...");
|
||||
println!("If the browser did not open, visit:\n\n {}\n", manual_url);
|
||||
try_open_browser(&automatic_url);
|
||||
|
||||
// 5. Wait for auth code (automatic callback OR manual paste)
|
||||
let auth_code =
|
||||
wait_for_auth_code_impl(listener, &state).await.context("OAuth callback failed")?;
|
||||
debug!("OAuth auth code received");
|
||||
|
||||
// 6. Exchange code for tokens
|
||||
let token_resp = exchange_code_for_tokens(&auth_code, &state, &code_verifier, port, false)
|
||||
.await
|
||||
.context("Token exchange failed")?;
|
||||
|
||||
let expires_at_ms = chrono::Utc::now().timestamp_millis()
|
||||
+ (token_resp.expires_in as i64 * 1000);
|
||||
|
||||
let scopes: Vec<String> = token_resp
|
||||
.scope
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect();
|
||||
|
||||
let account_uuid = token_resp
|
||||
.account.as_ref()
|
||||
.and_then(|a| a.get("uuid").and_then(|v| v.as_str()).map(String::from));
|
||||
let email = token_resp
|
||||
.account.as_ref()
|
||||
.and_then(|a| a.get("email_address").and_then(|v| v.as_str()).map(String::from));
|
||||
let organization_uuid = token_resp
|
||||
.organization.as_ref()
|
||||
.and_then(|o| o.get("uuid").and_then(|v| v.as_str()).map(String::from));
|
||||
|
||||
let uses_bearer = scopes.iter().any(|s| s == oauth::CLAUDE_AI_INFERENCE_SCOPE);
|
||||
|
||||
// 7. For Console flow, exchange the access token for an API key
|
||||
let api_key = if !uses_bearer {
|
||||
match create_api_key(&token_resp.access_token).await {
|
||||
Ok(key) => {
|
||||
info!("OAuth API key created successfully");
|
||||
Some(key)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create API key from OAuth token: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 8. Build and persist tokens
|
||||
let tokens = OAuthTokens {
|
||||
access_token: token_resp.access_token.clone(),
|
||||
refresh_token: token_resp.refresh_token.clone(),
|
||||
expires_at_ms: Some(expires_at_ms),
|
||||
scopes: scopes.clone(),
|
||||
account_uuid,
|
||||
email,
|
||||
organization_uuid,
|
||||
subscription_type: None,
|
||||
api_key: api_key.clone(),
|
||||
};
|
||||
tokens.save().await.context("Failed to save OAuth tokens")?;
|
||||
|
||||
let (credential, use_bearer_auth) = if uses_bearer {
|
||||
(token_resp.access_token.clone(), true)
|
||||
} else if let Some(key) = api_key {
|
||||
(key, false)
|
||||
} else {
|
||||
bail!("Login succeeded but could not obtain a usable credential")
|
||||
};
|
||||
|
||||
Ok(LoginResult { credential, use_bearer_auth, tokens })
|
||||
}
|
||||
|
||||
// ---- Helpers ----------------------------------------------------------------
|
||||
|
||||
/// Attempt to open the URL in the system default browser (best-effort).
|
||||
fn try_open_browser(url: &str) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
// Use PowerShell to safely open URLs containing special characters (& etc.)
|
||||
let ps_cmd = format!("Start-Process '{}'", url.replace('\'', "''"));
|
||||
let _ = std::process::Command::new("powershell")
|
||||
.args(["-NoProfile", "-NonInteractive", "-Command", &ps_cmd])
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn();
|
||||
}
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let _ = std::process::Command::new("open")
|
||||
.arg(url)
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn();
|
||||
}
|
||||
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
|
||||
{
|
||||
let _ = std::process::Command::new("xdg-open")
|
||||
.arg(url)
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn();
|
||||
}
|
||||
}
|
||||
|
||||
/// Tiny async HTTP server that captures /callback?code=AUTH_CODE&state=STATE.
|
||||
async fn run_callback_server(listener: TcpListener, expected_state: &str) -> anyhow::Result<String> {
|
||||
debug!("OAuth callback server listening on port {}", listener.local_addr()?.port());
|
||||
|
||||
// Accept exactly one connection (the browser redirect)
|
||||
let (mut socket, _) = tokio::time::timeout(
|
||||
Duration::from_secs(120),
|
||||
listener.accept(),
|
||||
)
|
||||
.await
|
||||
.context("Timeout waiting for browser redirect")?
|
||||
.context("Accept failed")?;
|
||||
|
||||
// Read the HTTP request line-by-line until the blank line
|
||||
let (reader, mut writer) = socket.split();
|
||||
let mut reader = BufReader::new(reader);
|
||||
let mut request_line = String::new();
|
||||
reader.read_line(&mut request_line).await?;
|
||||
|
||||
// Drain remaining headers
|
||||
loop {
|
||||
let mut header = String::new();
|
||||
reader.read_line(&mut header).await?;
|
||||
if header.trim().is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the request line: "GET /callback?code=XXX&state=YYY HTTP/1.1"
|
||||
let path = request_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let parsed_url = url::Url::parse(&format!("http://localhost{}", path))
|
||||
.context("Failed to parse callback URL")?;
|
||||
|
||||
let code = parsed_url
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "code")
|
||||
.map(|(_, v)| v.to_string());
|
||||
|
||||
let received_state = parsed_url
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "state")
|
||||
.map(|(_, v)| v.to_string());
|
||||
|
||||
// Send success redirect to the browser before validating, so the browser shows a page
|
||||
let location = if received_state.as_deref() == Some(expected_state) && code.is_some() {
|
||||
oauth::CLAUDEAI_SUCCESS_URL
|
||||
} else {
|
||||
oauth::CLAUDEAI_SUCCESS_URL // Show same page on error (browser UX)
|
||||
};
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n",
|
||||
location
|
||||
);
|
||||
writer.write_all(response.as_bytes()).await?;
|
||||
|
||||
// Validate
|
||||
if received_state.as_deref() != Some(expected_state) {
|
||||
bail!("OAuth state mismatch — possible CSRF attack");
|
||||
}
|
||||
let code = code.context("No authorization code in callback")?;
|
||||
|
||||
Ok(code)
|
||||
}
|
||||
|
||||
/// Read a single line from stdin (for manual code paste).
|
||||
async fn read_line_from_stdin() -> anyhow::Result<String> {
|
||||
print!(" Or paste authorization code here: ");
|
||||
use std::io::Write;
|
||||
std::io::stdout().flush().ok();
|
||||
|
||||
let mut line = String::new();
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut reader = BufReader::new(stdin);
|
||||
reader.read_line(&mut line).await?;
|
||||
Ok(line)
|
||||
}
|
||||
|
||||
/// Exchange the authorization code for OAuth tokens.
|
||||
async fn exchange_code_for_tokens(
|
||||
code: &str,
|
||||
state: &str,
|
||||
code_verifier: &str,
|
||||
port: u16,
|
||||
use_manual_redirect: bool,
|
||||
) -> anyhow::Result<TokenExchangeResponse> {
|
||||
let redirect_uri = if use_manual_redirect {
|
||||
oauth::MANUAL_REDIRECT_URL.to_string()
|
||||
} else {
|
||||
format!("http://localhost:{}/callback", port)
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"client_id": oauth::CLIENT_ID,
|
||||
"code_verifier": code_verifier,
|
||||
"state": state,
|
||||
});
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(oauth::TOKEN_URL)
|
||||
.header("content-type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Token exchange HTTP request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
bail!("Token exchange failed ({}): {}", status, text);
|
||||
}
|
||||
|
||||
resp.json::<TokenExchangeResponse>()
|
||||
.await
|
||||
.context("Failed to parse token exchange response")
|
||||
}
|
||||
|
||||
/// Exchange an OAuth access token for an Anthropic API key (Console flow only).
|
||||
async fn create_api_key(access_token: &str) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(oauth::API_KEY_URL)
|
||||
.header("Authorization", format!("Bearer {}", access_token))
|
||||
.send()
|
||||
.await
|
||||
.context("API key creation request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
bail!("API key creation failed ({}): {}", status, text);
|
||||
}
|
||||
|
||||
let data: CreateApiKeyResponse = resp.json().await.context("Failed to parse API key response")?;
|
||||
data.raw_key.context("Server returned no API key")
|
||||
}
|
||||
|
||||
// ---- Refresh token flow -----------------------------------------------------
|
||||
|
||||
/// Attempt to refresh an expired access token using the stored refresh token.
|
||||
/// Saves updated tokens on success.
|
||||
pub async fn refresh_oauth_token(tokens: &OAuthTokens) -> anyhow::Result<OAuthTokens> {
|
||||
let refresh_token = tokens
|
||||
.refresh_token
|
||||
.as_deref()
|
||||
.context("No refresh token available")?;
|
||||
|
||||
let body = serde_json::json!({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": oauth::CLIENT_ID,
|
||||
"scope": oauth::ALL_SCOPES.join(" "),
|
||||
});
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(oauth::TOKEN_URL)
|
||||
.header("content-type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Token refresh HTTP request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
bail!("Token refresh failed ({}): {}", status, text);
|
||||
}
|
||||
|
||||
let token_resp: TokenExchangeResponse = resp.json().await?;
|
||||
let expires_at_ms = chrono::Utc::now().timestamp_millis()
|
||||
+ (token_resp.expires_in as i64 * 1000);
|
||||
|
||||
let scopes: Vec<String> = token_resp
|
||||
.scope
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect();
|
||||
|
||||
let mut updated = tokens.clone();
|
||||
updated.access_token = token_resp.access_token;
|
||||
if let Some(new_rt) = token_resp.refresh_token {
|
||||
updated.refresh_token = Some(new_rt);
|
||||
}
|
||||
updated.expires_at_ms = Some(expires_at_ms);
|
||||
updated.scopes = scopes;
|
||||
|
||||
updated.save().await?;
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Wait for the OAuth authorization code from either the browser redirect (automatic)
|
||||
/// or manual paste by the user. Races the two with a 120-second timeout.
|
||||
async fn wait_for_auth_code_impl(
|
||||
listener: TcpListener,
|
||||
expected_state: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let expected_state_clone = expected_state.to_string();
|
||||
let (cb_tx, cb_rx) = tokio::sync::oneshot::channel::<anyhow::Result<String>>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = run_callback_server(listener, &expected_state_clone).await;
|
||||
let _ = cb_tx.send(result);
|
||||
});
|
||||
|
||||
let (paste_tx, paste_rx) = tokio::sync::oneshot::channel::<String>();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(line) = read_line_from_stdin().await {
|
||||
let trimmed = line.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
let _ = paste_tx.send(trimmed);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
result = cb_rx => {
|
||||
result.unwrap_or_else(|_| Err(anyhow::anyhow!("Callback server dropped")))
|
||||
}
|
||||
code = paste_rx => {
|
||||
code.map_err(|_| anyhow::anyhow!("Stdin closed unexpectedly"))
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(120)) => {
|
||||
bail!("Authentication timed out after 120 seconds")
|
||||
}
|
||||
}
|
||||
}
|
||||
38
src-rust/crates/cli/src/system_prompt.txt
Normal file
38
src-rust/crates/cli/src/system_prompt.txt
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
You are Claude Code, an AI coding assistant by Anthropic. You help users with software engineering tasks including writing code, debugging, refactoring, explaining code, running commands, and managing projects.
|
||||
|
||||
## Core principles
|
||||
- Read files before editing them
|
||||
- Prefer editing existing files over creating new ones
|
||||
- Write clean, idiomatic, production-quality code matching the project's existing style
|
||||
- Be concise — lead with the action or answer, not preamble
|
||||
- Run tests after making changes when appropriate
|
||||
- Security: never introduce SQL injection, XSS, command injection, or other vulnerabilities
|
||||
- Don't add features or refactor beyond what was asked
|
||||
|
||||
## Available tools
|
||||
You have access to a rich set of tools:
|
||||
|
||||
**File operations:** Read, Write, Edit, Glob, Grep
|
||||
**Shell:** Bash (Unix/Linux), PowerShell (Windows)
|
||||
**Web:** WebFetch (retrieve URLs), WebSearch (search the web)
|
||||
**Notebooks:** NotebookEdit (Jupyter .ipynb cells)
|
||||
**Task management:** TodoWrite, TaskCreate/Get/Update/List/Stop/Output
|
||||
**Planning:** EnterPlanMode, ExitPlanMode
|
||||
**Git worktrees:** EnterWorktree, ExitWorktree
|
||||
**Scheduling:** CronCreate, CronDelete, CronList
|
||||
**Communication:** AskUserQuestion, SendMessage (agent-to-agent), Brief (notify user)
|
||||
**Configuration:** Config (get/set settings)
|
||||
**Skills:** Skill (execute .claude/commands/*.md templates)
|
||||
**MCP:** ListMcpResources, ReadMcpResource
|
||||
**Meta:** ToolSearch, Agent (spawn sub-agents)
|
||||
|
||||
## Workflow guidance
|
||||
- Use Agent to delegate complex parallel sub-tasks (research, code generation, testing)
|
||||
- Use TodoWrite to track multi-step plans
|
||||
- Use EnterPlanMode before making significant architectural changes
|
||||
- Use EnterWorktree to safely experiment on a separate git branch
|
||||
- Use CronCreate to schedule recurring tasks
|
||||
- Use Config to read or adjust settings dynamically
|
||||
|
||||
## Context
|
||||
Git status, CLAUDE.md files, and open-source project context are provided as part of your context window. Use them to understand the project structure before taking action.
|
||||
20
src-rust/crates/commands/Cargo.toml
Normal file
20
src-rust/crates/commands/Cargo.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "cc-commands"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-tools = { workspace = true }
|
||||
cc-query = { workspace = true }
|
||||
cc-tui = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
1279
src-rust/crates/commands/src/lib.rs
Normal file
1279
src-rust/crates/commands/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
548
src-rust/crates/commands/src/named_commands.rs
Normal file
548
src-rust/crates/commands/src/named_commands.rs
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
//! Named commands (e.g. `claude agents`, `claude ide`, `claude branch`, …).
|
||||
//!
|
||||
//! These complement slash commands with more complex top-level flows.
|
||||
//! A named command is invoked when the *first* CLI argument matches one
|
||||
//! of the registered names — before the normal REPL starts.
|
||||
//!
|
||||
//! Sources consulted while porting:
|
||||
//! src/commands/agents/index.ts
|
||||
//! src/commands/ide/index.ts
|
||||
//! src/commands/branch/index.ts
|
||||
//! src/commands/tag/index.ts
|
||||
//! src/commands/passes/index.ts
|
||||
//! src/commands/pr_comments/index.ts
|
||||
//! src/commands/install-github-app/index.ts
|
||||
//! src/commands/desktop/index.ts (implied by component structure)
|
||||
//! src/commands/mobile/index.ts (implied by component structure)
|
||||
//! src/commands/remote-setup/index.ts (implied by component structure)
|
||||
|
||||
use crate::{CommandContext, CommandResult};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A top-level named command (`claude <name> [args…]`).
|
||||
pub trait NamedCommand: Send + Sync {
|
||||
/// Primary command name, e.g. `"agents"`.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// One-line description used in `claude --help`.
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// Usage hint shown in `claude <name> --help`.
|
||||
fn usage(&self) -> &str;
|
||||
|
||||
/// Execute the command. `args` is the slice of arguments *after* the
|
||||
/// command name itself.
|
||||
fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// agents
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct AgentsCommand;
|
||||
|
||||
impl NamedCommand for AgentsCommand {
|
||||
fn name(&self) -> &str { "agents" }
|
||||
fn description(&self) -> &str { "Manage and configure sub-agents" }
|
||||
fn usage(&self) -> &str { "claude agents [list|create|edit|delete] [name]" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
match args.first().copied().unwrap_or("list") {
|
||||
"list" => CommandResult::Message(
|
||||
"Sub-agents are defined in .claude/agents/ as Markdown files.\n\
|
||||
Use 'claude agents create <name>' to scaffold a new agent."
|
||||
.to_string(),
|
||||
),
|
||||
"create" => {
|
||||
let name = args.get(1).copied().unwrap_or("my-agent");
|
||||
CommandResult::Message(format!(
|
||||
"Create a new agent by adding .claude/agents/{name}.md\n\
|
||||
Template:\n\
|
||||
---\n\
|
||||
name: {name}\n\
|
||||
description: <description>\n\
|
||||
model: claude-sonnet-4-6\n\
|
||||
---\n\n\
|
||||
<agent instructions here>"
|
||||
))
|
||||
}
|
||||
"edit" => {
|
||||
let name = match args.get(1).copied() {
|
||||
Some(n) => n,
|
||||
None => return CommandResult::Error(
|
||||
"Usage: claude agents edit <name>".to_string(),
|
||||
),
|
||||
};
|
||||
CommandResult::Message(format!(
|
||||
"Edit .claude/agents/{name}.md in your editor to update the agent."
|
||||
))
|
||||
}
|
||||
"delete" => {
|
||||
let name = match args.get(1).copied() {
|
||||
Some(n) => n,
|
||||
None => return CommandResult::Error(
|
||||
"Usage: claude agents delete <name>".to_string(),
|
||||
),
|
||||
};
|
||||
CommandResult::Message(format!(
|
||||
"Delete .claude/agents/{name}.md to remove the agent."
|
||||
))
|
||||
}
|
||||
sub => CommandResult::Error(format!("Unknown agents subcommand: '{sub}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// add-dir
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct AddDirCommand;
|
||||
|
||||
impl NamedCommand for AddDirCommand {
|
||||
fn name(&self) -> &str { "add-dir" }
|
||||
fn description(&self) -> &str { "Add a directory to Claude Code's allowed workspace paths" }
|
||||
fn usage(&self) -> &str { "claude add-dir <path>" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
let raw = match args.first() {
|
||||
Some(p) => *p,
|
||||
None => return CommandResult::Error("Usage: claude add-dir <path>".to_string()),
|
||||
};
|
||||
|
||||
let path = std::path::Path::new(raw);
|
||||
|
||||
if !path.exists() {
|
||||
return CommandResult::Error(format!("Directory does not exist: {}", path.display()));
|
||||
}
|
||||
|
||||
if !path.is_dir() {
|
||||
return CommandResult::Error(format!("Not a directory: {}", path.display()));
|
||||
}
|
||||
|
||||
let abs_path = match std::fs::canonicalize(path) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return CommandResult::Error(format!("Cannot resolve path: {e}")),
|
||||
};
|
||||
|
||||
// TODO: persist to settings.json `workspacePaths` array
|
||||
CommandResult::Message(format!(
|
||||
"Added {} to allowed workspace paths.\n\
|
||||
Note: restart Claude Code for the change to take effect.",
|
||||
abs_path.display()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// branch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct BranchCommand;
|
||||
|
||||
impl NamedCommand for BranchCommand {
|
||||
fn name(&self) -> &str { "branch" }
|
||||
fn description(&self) -> &str { "Create a branch of the current conversation at this point" }
|
||||
fn usage(&self) -> &str { "claude branch [create|switch|list] [name]" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
match args.first().copied().unwrap_or("list") {
|
||||
"list" => CommandResult::UserMessage(
|
||||
"List git branches — run: git branch -a".to_string(),
|
||||
),
|
||||
"create" => {
|
||||
let name = match args.get(1) {
|
||||
Some(n) => *n,
|
||||
None => return CommandResult::Error(
|
||||
"Usage: claude branch create <name>".to_string(),
|
||||
),
|
||||
};
|
||||
CommandResult::UserMessage(format!("git checkout -b {name}"))
|
||||
}
|
||||
"switch" | "checkout" => {
|
||||
let name = match args.get(1) {
|
||||
Some(n) => *n,
|
||||
None => return CommandResult::Error(
|
||||
"Usage: claude branch switch <name>".to_string(),
|
||||
),
|
||||
};
|
||||
CommandResult::UserMessage(format!("git checkout {name}"))
|
||||
}
|
||||
sub => CommandResult::Error(format!("Unknown branch subcommand: '{sub}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tag
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TagCommand;
|
||||
|
||||
impl NamedCommand for TagCommand {
|
||||
fn name(&self) -> &str { "tag" }
|
||||
fn description(&self) -> &str { "Toggle a searchable tag on the current session" }
|
||||
fn usage(&self) -> &str { "claude tag [list|add|remove] [tag]" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
match args.first().copied().unwrap_or("list") {
|
||||
"list" => CommandResult::Message("No tags set for this session.".to_string()),
|
||||
"add" => {
|
||||
let tag = args.get(1).copied().unwrap_or("unnamed");
|
||||
CommandResult::Message(format!("Added tag: {tag}"))
|
||||
}
|
||||
"remove" => {
|
||||
let tag = match args.get(1).copied() {
|
||||
Some(t) if !t.is_empty() => t,
|
||||
_ => return CommandResult::Error(
|
||||
"Usage: claude tag remove <tag>".to_string(),
|
||||
),
|
||||
};
|
||||
CommandResult::Message(format!("Removed tag: {tag}"))
|
||||
}
|
||||
sub => CommandResult::Error(format!("Unknown tag subcommand: '{sub}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// passes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct PassesCommand;
|
||||
|
||||
impl NamedCommand for PassesCommand {
|
||||
fn name(&self) -> &str { "passes" }
|
||||
fn description(&self) -> &str { "Share a free week of Claude Code with friends" }
|
||||
fn usage(&self) -> &str { "claude passes" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message(
|
||||
"Guest passes let you share Claude Code access with friends.\n\
|
||||
Visit https://claude.ai/claude-code to manage your passes."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ide
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct IdeCommand;
|
||||
|
||||
impl NamedCommand for IdeCommand {
|
||||
fn name(&self) -> &str { "ide" }
|
||||
fn description(&self) -> &str { "Manage IDE integrations and show status" }
|
||||
fn usage(&self) -> &str { "claude ide [status|connect|disconnect|open]" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
match args.first().copied().unwrap_or("status") {
|
||||
"status" => CommandResult::Message(
|
||||
"IDE integration status: Not connected\n\
|
||||
Install the Claude Code extension:\n \
|
||||
- VS Code: https://marketplace.visualstudio.com/items?itemName=Anthropic.claude-code\n \
|
||||
- JetBrains: https://plugins.jetbrains.com/plugin/claude-code"
|
||||
.to_string(),
|
||||
),
|
||||
"connect" | "open" => CommandResult::Message(
|
||||
"Connecting to IDE…\n\
|
||||
Make sure the Claude Code extension is installed and running."
|
||||
.to_string(),
|
||||
),
|
||||
"disconnect" => CommandResult::Message("Disconnected from IDE.".to_string()),
|
||||
sub => CommandResult::Error(format!("Unknown ide subcommand: '{sub}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// pr-comments
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct PrCommentsCommand;
|
||||
|
||||
impl NamedCommand for PrCommentsCommand {
|
||||
fn name(&self) -> &str { "pr-comments" }
|
||||
fn description(&self) -> &str { "Get comments from a GitHub pull request" }
|
||||
fn usage(&self) -> &str { "claude pr-comments [PR-number]" }
|
||||
|
||||
fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
let pr_num = args.first().copied().unwrap_or("");
|
||||
if pr_num.is_empty() {
|
||||
return CommandResult::Error(
|
||||
"Please specify a PR number: claude pr-comments <number>".to_string(),
|
||||
);
|
||||
}
|
||||
CommandResult::UserMessage(format!(
|
||||
"Fetch and display comments for PR #{pr_num}.\n\
|
||||
Steps:\n\
|
||||
1. gh pr view {pr_num} --json number,headRepository\n\
|
||||
2. gh api /repos/{{owner}}/{{repo}}/issues/{pr_num}/comments\n\
|
||||
3. gh api /repos/{{owner}}/{{repo}}/pulls/{pr_num}/comments\n\
|
||||
Format results with file paths, diff hunks, and threading."
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// desktop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct DesktopCommand;
|
||||
|
||||
impl NamedCommand for DesktopCommand {
|
||||
fn name(&self) -> &str { "desktop" }
|
||||
fn description(&self) -> &str { "Open the Claude Code desktop app" }
|
||||
fn usage(&self) -> &str { "claude desktop" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message(
|
||||
"Download the Claude Code desktop app at https://claude.ai/download".to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// mobile
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct MobileCommand;
|
||||
|
||||
impl NamedCommand for MobileCommand {
|
||||
fn name(&self) -> &str { "mobile" }
|
||||
fn description(&self) -> &str { "Set up Claude Code on mobile" }
|
||||
fn usage(&self) -> &str { "claude mobile" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message(
|
||||
"Access Claude Code on mobile via https://claude.ai/claude-code\n\
|
||||
Use the Bridge feature to connect your local Claude Code CLI to the mobile interface."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// install-github-app
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct InstallGithubAppCommand;
|
||||
|
||||
impl NamedCommand for InstallGithubAppCommand {
|
||||
fn name(&self) -> &str { "install-github-app" }
|
||||
fn description(&self) -> &str { "Set up Claude GitHub Actions for a repository" }
|
||||
fn usage(&self) -> &str { "claude install-github-app" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message(
|
||||
"To install the Claude Code GitHub App:\n\
|
||||
1. Visit https://github.com/apps/claude-code-app and click Install\n\
|
||||
2. Select the repositories to enable\n\
|
||||
3. Add your ANTHROPIC_API_KEY to repository secrets\n\n\
|
||||
The app enables Claude Code in GitHub Actions workflows.\n\
|
||||
Docs: https://docs.anthropic.com/claude-code/github-actions"
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// remote-setup
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct RemoteSetupCommand;
|
||||
|
||||
impl NamedCommand for RemoteSetupCommand {
|
||||
fn name(&self) -> &str { "remote-setup" }
|
||||
fn description(&self) -> &str { "Configure a remote Claude Code environment" }
|
||||
fn usage(&self) -> &str { "claude remote-setup" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message(
|
||||
"Remote Claude Code setup:\n\
|
||||
1. Set CLAUDE_CODE_REMOTE=1 on the remote machine\n\
|
||||
2. Set ANTHROPIC_API_KEY or configure OAuth\n\
|
||||
3. Run: claude --no-update-check\n\n\
|
||||
For Bridge mode (connect to the claude.ai web UI):\n\
|
||||
Set CLAUDE_CODE_BRIDGE_TOKEN=<token from claude.ai>"
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// stickers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct StickersCommand;
|
||||
|
||||
impl NamedCommand for StickersCommand {
|
||||
fn name(&self) -> &str { "stickers" }
|
||||
fn description(&self) -> &str { "View collected stickers" }
|
||||
fn usage(&self) -> &str { "claude stickers" }
|
||||
|
||||
fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult {
|
||||
CommandResult::Message("Sticker collection: coming soon!".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return one instance of every registered named command.
|
||||
pub fn all_named_commands() -> Vec<Box<dyn NamedCommand>> {
|
||||
vec![
|
||||
Box::new(AgentsCommand),
|
||||
Box::new(AddDirCommand),
|
||||
Box::new(BranchCommand),
|
||||
Box::new(TagCommand),
|
||||
Box::new(PassesCommand),
|
||||
Box::new(IdeCommand),
|
||||
Box::new(PrCommentsCommand),
|
||||
Box::new(DesktopCommand),
|
||||
Box::new(MobileCommand),
|
||||
Box::new(InstallGithubAppCommand),
|
||||
Box::new(RemoteSetupCommand),
|
||||
Box::new(StickersCommand),
|
||||
]
|
||||
}
|
||||
|
||||
/// Look up a named command by its primary name (case-insensitive).
|
||||
pub fn find_named_command(name: &str) -> Option<Box<dyn NamedCommand>> {
|
||||
let needle = name.to_lowercase();
|
||||
all_named_commands()
|
||||
.into_iter()
|
||||
.find(|c| c.name() == needle.as_str())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use cc_core::cost::CostTracker;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn make_ctx() -> CommandContext {
|
||||
CommandContext {
|
||||
config: cc_core::config::Config::default(),
|
||||
cost_tracker: CostTracker::new(),
|
||||
messages: vec![],
|
||||
working_dir: std::path::PathBuf::from("."),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_named_commands_non_empty() {
|
||||
assert!(!all_named_commands().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_named_commands_unique_names() {
|
||||
let mut names = std::collections::HashSet::new();
|
||||
for cmd in all_named_commands() {
|
||||
assert!(
|
||||
names.insert(cmd.name().to_string()),
|
||||
"Duplicate named command: {}",
|
||||
cmd.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_named_command_found() {
|
||||
assert!(find_named_command("agents").is_some());
|
||||
assert!(find_named_command("ide").is_some());
|
||||
assert!(find_named_command("branch").is_some());
|
||||
assert!(find_named_command("passes").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_named_command_not_found() {
|
||||
assert!(find_named_command("nonexistent-xyz").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_named_command_case_insensitive() {
|
||||
assert!(find_named_command("Agents").is_some());
|
||||
assert!(find_named_command("IDE").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agents_list_returns_message() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = AgentsCommand;
|
||||
let result = cmd.execute_named(&[], &ctx);
|
||||
assert!(matches!(result, CommandResult::Message(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agents_create_includes_name() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = AgentsCommand;
|
||||
let result = cmd.execute_named(&["create", "my-bot"], &ctx);
|
||||
if let CommandResult::Message(msg) = result {
|
||||
assert!(msg.contains("my-bot"));
|
||||
} else {
|
||||
panic!("Expected Message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_dir_missing_arg_returns_error() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = AddDirCommand;
|
||||
let result = cmd.execute_named(&[], &ctx);
|
||||
assert!(matches!(result, CommandResult::Error(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_list_returns_user_message() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = BranchCommand;
|
||||
let result = cmd.execute_named(&["list"], &ctx);
|
||||
assert!(matches!(result, CommandResult::UserMessage(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_create_requires_name() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = BranchCommand;
|
||||
let result = cmd.execute_named(&["create"], &ctx);
|
||||
assert!(matches!(result, CommandResult::Error(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pr_comments_missing_number() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = PrCommentsCommand;
|
||||
let result = cmd.execute_named(&[], &ctx);
|
||||
assert!(matches!(result, CommandResult::Error(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pr_comments_with_number() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = PrCommentsCommand;
|
||||
let result = cmd.execute_named(&["42"], &ctx);
|
||||
assert!(matches!(result, CommandResult::UserMessage(_)));
|
||||
if let CommandResult::UserMessage(msg) = result {
|
||||
assert!(msg.contains("42"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_install_github_app_returns_message() {
|
||||
let ctx = make_ctx();
|
||||
let cmd = InstallGithubAppCommand;
|
||||
let result = cmd.execute_named(&[], &ctx);
|
||||
assert!(matches!(result, CommandResult::Message(_)));
|
||||
}
|
||||
}
|
||||
31
src-rust/crates/core/Cargo.toml
Normal file
31
src-rust/crates/core/Cargo.toml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
[package]
|
||||
name = "cc-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
url = { workspace = true }
|
||||
schemars = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
403
src-rust/crates/core/src/analytics.rs
Normal file
403
src-rust/crates/core/src/analytics.rs
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
//! Analytics and telemetry (OpenTelemetry-compatible counters)
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Session-level metrics counters (mirrors TypeScript bootstrap state).
|
||||
///
|
||||
/// All counters use `AtomicU64` so they can be shared across threads without
|
||||
/// a mutex. Cost is stored as integer millicents (cost_usd × 100_000) to
|
||||
/// avoid floating-point atomic arithmetic.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SessionMetrics {
|
||||
/// Total cost in units of 1/100_000 USD (i.e. millicents).
|
||||
pub total_cost_usd_millicents: AtomicU64,
|
||||
pub total_input_tokens: AtomicU64,
|
||||
pub total_output_tokens: AtomicU64,
|
||||
pub total_api_duration_ms: AtomicU64,
|
||||
pub total_tool_duration_ms: AtomicU64,
|
||||
pub total_lines_added: AtomicU64,
|
||||
pub total_lines_removed: AtomicU64,
|
||||
pub session_count: AtomicU64,
|
||||
pub commit_count: AtomicU64,
|
||||
pub pr_count: AtomicU64,
|
||||
pub tool_use_count: AtomicU64,
|
||||
}
|
||||
|
||||
impl SessionMetrics {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self::default())
|
||||
}
|
||||
|
||||
pub fn add_cost(&self, usd: f64) {
|
||||
let millicents = (usd * 100_000.0) as u64;
|
||||
self.total_cost_usd_millicents
|
||||
.fetch_add(millicents, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn total_cost_usd(&self) -> f64 {
|
||||
self.total_cost_usd_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
|
||||
}
|
||||
|
||||
pub fn add_tokens(&self, input: u32, output: u32) {
|
||||
self.total_input_tokens
|
||||
.fetch_add(input as u64, Ordering::Relaxed);
|
||||
self.total_output_tokens
|
||||
.fetch_add(output as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_api_duration(&self, ms: u64) {
|
||||
self.total_api_duration_ms.fetch_add(ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_tool_duration(&self, ms: u64) {
|
||||
self.total_tool_duration_ms.fetch_add(ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_lines(&self, added: i64, removed: i64) {
|
||||
if added > 0 {
|
||||
self.total_lines_added
|
||||
.fetch_add(added as u64, Ordering::Relaxed);
|
||||
}
|
||||
if removed > 0 {
|
||||
self.total_lines_removed
|
||||
.fetch_add(removed as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_commits(&self) {
|
||||
self.commit_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_prs(&self) {
|
||||
self.pr_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_tool_use(&self) {
|
||||
self.tool_use_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> MetricsSummary {
|
||||
MetricsSummary {
|
||||
cost_usd: self.total_cost_usd(),
|
||||
input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
|
||||
output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
|
||||
api_duration_ms: self.total_api_duration_ms.load(Ordering::Relaxed),
|
||||
tool_duration_ms: self.total_tool_duration_ms.load(Ordering::Relaxed),
|
||||
lines_added: self.total_lines_added.load(Ordering::Relaxed),
|
||||
lines_removed: self.total_lines_removed.load(Ordering::Relaxed),
|
||||
commits: self.commit_count.load(Ordering::Relaxed),
|
||||
prs: self.pr_count.load(Ordering::Relaxed),
|
||||
tool_uses: self.tool_use_count.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A point-in-time snapshot of session metrics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetricsSummary {
|
||||
pub cost_usd: f64,
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub api_duration_ms: u64,
|
||||
pub tool_duration_ms: u64,
|
||||
pub lines_added: u64,
|
||||
pub lines_removed: u64,
|
||||
pub commits: u64,
|
||||
pub prs: u64,
|
||||
pub tool_uses: u64,
|
||||
}
|
||||
|
||||
impl MetricsSummary {
|
||||
/// Format cost as a dollar amount string with appropriate precision.
|
||||
pub fn format_cost(&self) -> String {
|
||||
if self.cost_usd < 0.01 {
|
||||
format!("${:.5}", self.cost_usd)
|
||||
} else {
|
||||
format!("${:.4}", self.cost_usd)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format total token count with K/M suffix.
|
||||
pub fn format_tokens(&self) -> String {
|
||||
let total = self.input_tokens + self.output_tokens;
|
||||
if total >= 1_000_000 {
|
||||
format!("{:.1}M tok", total as f64 / 1_000_000.0)
|
||||
} else if total >= 1_000 {
|
||||
format!("{:.1}K tok", total as f64 / 1_000.0)
|
||||
} else {
|
||||
format!("{} tok", total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Event types for first-party analytics (privacy-respecting — no PII).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AnalyticsEvent {
|
||||
SessionStarted {
|
||||
model: String,
|
||||
is_interactive: bool,
|
||||
},
|
||||
SessionEnded {
|
||||
turn_count: u32,
|
||||
cost_usd: f64,
|
||||
duration_ms: u64,
|
||||
had_errors: bool,
|
||||
},
|
||||
ToolUsed {
|
||||
tool_name: String,
|
||||
success: bool,
|
||||
duration_ms: u64,
|
||||
},
|
||||
CommandExecuted {
|
||||
command: String,
|
||||
success: bool,
|
||||
},
|
||||
CompactionTriggered {
|
||||
tokens_before: u32,
|
||||
tokens_after: u32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Analytics sink — currently logs via `tracing`; can be extended to push
|
||||
/// events to a first-party endpoint.
|
||||
pub struct Analytics {
|
||||
enabled: bool,
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
pub fn new(session_id: String, enabled: bool) -> Self {
|
||||
Self {
|
||||
enabled,
|
||||
session_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn track(&self, event: AnalyticsEvent) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
tracing::debug!(
|
||||
session_id = %self.session_id,
|
||||
event = ?event,
|
||||
"analytics event"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_metrics_initial_zero() {
|
||||
let m = SessionMetrics::new();
|
||||
assert_eq!(m.total_cost_usd(), 0.0);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_cost_single() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(0.01);
|
||||
let cost = m.total_cost_usd();
|
||||
// Allow small floating-point tolerance
|
||||
assert!((cost - 0.01).abs() < 1e-9, "cost = {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_cost_accumulates() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(1.0);
|
||||
m.add_cost(2.5);
|
||||
let cost = m.total_cost_usd();
|
||||
assert!((cost - 3.5).abs() < 1e-9, "cost = {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_tokens() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_tokens(1000, 500);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1000);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_tokens_accumulates() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_tokens(1000, 500);
|
||||
m.add_tokens(200, 100);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1200);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_lines_positive() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_lines(10, 5);
|
||||
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 10);
|
||||
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_lines_negative_ignored() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_lines(-3, -7);
|
||||
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_increment_commits_and_prs() {
|
||||
let m = SessionMetrics::new();
|
||||
m.increment_commits();
|
||||
m.increment_commits();
|
||||
m.increment_prs();
|
||||
assert_eq!(m.commit_count.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(m.pr_count.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_increment_tool_use() {
|
||||
let m = SessionMetrics::new();
|
||||
for _ in 0..5 {
|
||||
m.increment_tool_use();
|
||||
}
|
||||
assert_eq!(m.tool_use_count.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_snapshot() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(1.23456);
|
||||
m.add_tokens(100, 50);
|
||||
m.add_api_duration(300);
|
||||
m.add_tool_duration(150);
|
||||
m.add_lines(8, 3);
|
||||
m.increment_commits();
|
||||
m.increment_prs();
|
||||
m.increment_tool_use();
|
||||
|
||||
let s = m.summary();
|
||||
assert!((s.cost_usd - 1.23456).abs() < 1e-9);
|
||||
assert_eq!(s.input_tokens, 100);
|
||||
assert_eq!(s.output_tokens, 50);
|
||||
assert_eq!(s.api_duration_ms, 300);
|
||||
assert_eq!(s.tool_duration_ms, 150);
|
||||
assert_eq!(s.lines_added, 8);
|
||||
assert_eq!(s.lines_removed, 3);
|
||||
assert_eq!(s.commits, 1);
|
||||
assert_eq!(s.prs, 1);
|
||||
assert_eq!(s.tool_uses, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_cost_small() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.001,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
let formatted = s.format_cost();
|
||||
assert!(formatted.starts_with('$'));
|
||||
// Should have 5 decimal places for small cost
|
||||
assert!(formatted.contains('.'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_cost_large() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 1.5,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert_eq!(s.format_cost(), "$1.5000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_exact() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 500,
|
||||
output_tokens: 300,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert_eq!(s.format_tokens(), "800 tok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_kilo() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 5_000,
|
||||
output_tokens: 3_000,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert!(s.format_tokens().ends_with("K tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_mega() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 1_500_000,
|
||||
output_tokens: 500_000,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert!(s.format_tokens().ends_with("M tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analytics_track_disabled_no_panic() {
|
||||
let a = Analytics::new("test-session".to_string(), false);
|
||||
// Should not panic even though disabled
|
||||
a.track(AnalyticsEvent::SessionStarted {
|
||||
model: "claude-opus-4-6".to_string(),
|
||||
is_interactive: true,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analytics_track_enabled_no_panic() {
|
||||
let a = Analytics::new("test-session-2".to_string(), true);
|
||||
a.track(AnalyticsEvent::ToolUsed {
|
||||
tool_name: "Bash".to_string(),
|
||||
success: true,
|
||||
duration_ms: 42,
|
||||
});
|
||||
}
|
||||
}
|
||||
423
src-rust/crates/core/src/keybindings.rs
Normal file
423
src-rust/crates/core/src/keybindings.rs
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
//! Configurable keyboard shortcuts system
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// All keybinding contexts
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum KeyContext {
|
||||
Global,
|
||||
Chat,
|
||||
Autocomplete,
|
||||
Confirmation,
|
||||
Help,
|
||||
Transcript,
|
||||
HistorySearch,
|
||||
Task,
|
||||
ThemePicker,
|
||||
Settings,
|
||||
Tabs,
|
||||
Attachments,
|
||||
Footer,
|
||||
MessageSelector,
|
||||
DiffDialog,
|
||||
ModelPicker,
|
||||
Select,
|
||||
Plugin,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedKeystroke {
|
||||
pub key: String, // normalized key name
|
||||
pub ctrl: bool,
|
||||
pub alt: bool,
|
||||
pub shift: bool,
|
||||
pub meta: bool,
|
||||
}
|
||||
|
||||
pub type Chord = Vec<ParsedKeystroke>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedBinding {
|
||||
pub chord: Chord,
|
||||
pub action: Option<String>, // None = unbound
|
||||
pub context: KeyContext,
|
||||
}
|
||||
|
||||
/// Parse a keystroke string like "ctrl+shift+enter" into ParsedKeystroke
|
||||
pub fn parse_keystroke(s: &str) -> Option<ParsedKeystroke> {
|
||||
let s = s.trim().to_lowercase();
|
||||
let mut ctrl = false;
|
||||
let mut alt = false;
|
||||
let mut shift = false;
|
||||
let mut meta = false;
|
||||
let mut key_parts: Vec<&str> = Vec::new();
|
||||
|
||||
for part in s.split('+') {
|
||||
let part = part.trim();
|
||||
match part {
|
||||
"ctrl" | "control" => ctrl = true,
|
||||
"alt" | "opt" | "option" => alt = true,
|
||||
"shift" => shift = true,
|
||||
"meta" | "cmd" | "command" | "super" | "win" => meta = true,
|
||||
_ => key_parts.push(part),
|
||||
}
|
||||
}
|
||||
|
||||
if key_parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key = normalize_key(key_parts.join("+").as_str());
|
||||
Some(ParsedKeystroke {
|
||||
key,
|
||||
ctrl,
|
||||
alt,
|
||||
shift,
|
||||
meta,
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_key(k: &str) -> String {
|
||||
match k {
|
||||
"esc" | "escape" => "escape".to_string(),
|
||||
"return" | "enter" => "enter".to_string(),
|
||||
"del" | "delete" => "delete".to_string(),
|
||||
"backspace" | "bs" => "backspace".to_string(),
|
||||
"space" | " " => "space".to_string(),
|
||||
"up" => "up".to_string(),
|
||||
"down" => "down".to_string(),
|
||||
"left" => "left".to_string(),
|
||||
"right" => "right".to_string(),
|
||||
"pageup" | "pgup" => "pageup".to_string(),
|
||||
"pagedown" | "pgdn" | "pgdown" => "pagedown".to_string(),
|
||||
"home" => "home".to_string(),
|
||||
"end" => "end".to_string(),
|
||||
"tab" => "tab".to_string(),
|
||||
k => k.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a chord (space-separated keystrokes like "ctrl+k ctrl+d")
|
||||
pub fn parse_chord(s: &str) -> Option<Chord> {
|
||||
let keystrokes: Vec<ParsedKeystroke> =
|
||||
s.split_whitespace().filter_map(parse_keystroke).collect();
|
||||
if keystrokes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(keystrokes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Keys that cannot be rebound
|
||||
pub const NON_REBINDABLE: &[&str] = &["ctrl+c", "ctrl+d", "ctrl+m"];
|
||||
|
||||
/// Default keybindings
|
||||
pub fn default_bindings() -> Vec<ParsedBinding> {
|
||||
let defaults: &[(&str, &str, KeyContext)] = &[
|
||||
// Global
|
||||
("ctrl+c", "interrupt", KeyContext::Global),
|
||||
("ctrl+d", "exit", KeyContext::Global),
|
||||
("ctrl+l", "redraw", KeyContext::Global),
|
||||
("ctrl+r", "historySearch", KeyContext::Global),
|
||||
// Chat
|
||||
("enter", "submit", KeyContext::Chat),
|
||||
("up", "historyPrev", KeyContext::Chat),
|
||||
("down", "historyNext", KeyContext::Chat),
|
||||
("shift+tab", "cycleMode", KeyContext::Chat),
|
||||
("pageup", "scrollUp", KeyContext::Chat),
|
||||
("pagedown", "scrollDown", KeyContext::Chat),
|
||||
// Confirmation
|
||||
("y", "yes", KeyContext::Confirmation),
|
||||
("enter", "yes", KeyContext::Confirmation),
|
||||
("n", "no", KeyContext::Confirmation),
|
||||
("escape", "no", KeyContext::Confirmation),
|
||||
("up", "prevOption", KeyContext::Confirmation),
|
||||
("down", "nextOption", KeyContext::Confirmation),
|
||||
// Help
|
||||
("escape", "close", KeyContext::Help),
|
||||
("q", "close", KeyContext::Help),
|
||||
// HistorySearch
|
||||
("enter", "select", KeyContext::HistorySearch),
|
||||
("escape", "cancel", KeyContext::HistorySearch),
|
||||
("up", "prevResult", KeyContext::HistorySearch),
|
||||
("down", "nextResult", KeyContext::HistorySearch),
|
||||
];
|
||||
|
||||
defaults
|
||||
.iter()
|
||||
.filter_map(|(chord_str, action, context)| {
|
||||
parse_chord(chord_str).map(|chord| ParsedBinding {
|
||||
chord,
|
||||
action: Some(action.to_string()),
|
||||
context: context.clone(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// User keybindings loaded from ~/.claude/keybindings.json
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct UserKeybindings {
|
||||
pub bindings: Vec<UserBinding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserBinding {
|
||||
pub chord: String, // e.g. "ctrl+k ctrl+d"
|
||||
pub action: Option<String>, // None = unbound
|
||||
pub context: Option<String>,
|
||||
}
|
||||
|
||||
impl UserKeybindings {
|
||||
pub fn load(config_dir: &Path) -> Self {
|
||||
let path = config_dir.join("keybindings.json");
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save(&self, config_dir: &Path) -> anyhow::Result<()> {
|
||||
let path = config_dir.join("keybindings.json");
|
||||
let json = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, json)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolved keybindings (defaults merged with user overrides)
|
||||
pub struct KeybindingResolver {
|
||||
bindings: Vec<ParsedBinding>,
|
||||
pending_chord: Vec<ParsedKeystroke>,
|
||||
}
|
||||
|
||||
impl KeybindingResolver {
|
||||
pub fn new(user: &UserKeybindings) -> Self {
|
||||
let mut bindings = default_bindings();
|
||||
|
||||
// Apply user overrides (user bindings win, last match wins)
|
||||
for user_binding in &user.bindings {
|
||||
if let Some(chord) = parse_chord(&user_binding.chord) {
|
||||
let context = user_binding
|
||||
.context
|
||||
.as_deref()
|
||||
.and_then(|c| serde_json::from_str(&format!("\"{}\"", c)).ok())
|
||||
.unwrap_or(KeyContext::Global);
|
||||
|
||||
bindings.push(ParsedBinding {
|
||||
chord,
|
||||
action: user_binding.action.clone(),
|
||||
context,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
bindings,
|
||||
pending_chord: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a keystroke, returns action if binding matches
|
||||
pub fn process(
|
||||
&mut self,
|
||||
keystroke: ParsedKeystroke,
|
||||
context: &KeyContext,
|
||||
) -> KeybindingResult {
|
||||
self.pending_chord.push(keystroke);
|
||||
|
||||
// Find matching bindings in current context + Global
|
||||
let matches: Vec<&ParsedBinding> = self
|
||||
.bindings
|
||||
.iter()
|
||||
.filter(|b| &b.context == context || b.context == KeyContext::Global)
|
||||
.filter(|b| b.chord.starts_with(self.pending_chord.as_slice()))
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
self.pending_chord.clear();
|
||||
return KeybindingResult::NoMatch;
|
||||
}
|
||||
|
||||
let exact: Vec<&ParsedBinding> = matches
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|b| b.chord.len() == self.pending_chord.len())
|
||||
.collect();
|
||||
|
||||
if !exact.is_empty() {
|
||||
// Last match wins (user overrides)
|
||||
let binding = exact.last().unwrap();
|
||||
self.pending_chord.clear();
|
||||
return match &binding.action {
|
||||
Some(action) => KeybindingResult::Action(action.clone()),
|
||||
None => KeybindingResult::Unbound,
|
||||
};
|
||||
}
|
||||
|
||||
// Chord in progress
|
||||
KeybindingResult::Pending
|
||||
}
|
||||
|
||||
pub fn cancel_chord(&mut self) {
|
||||
self.pending_chord.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ParsedKeystroke {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.key == other.key
|
||||
&& self.ctrl == other.ctrl
|
||||
&& self.alt == other.alt
|
||||
&& self.shift == other.shift
|
||||
&& self.meta == other.meta
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeybindingResult {
|
||||
Action(String),
|
||||
Unbound,
|
||||
Pending,
|
||||
NoMatch,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_simple() {
|
||||
let ks = parse_keystroke("enter").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
assert!(!ks.ctrl);
|
||||
assert!(!ks.alt);
|
||||
assert!(!ks.shift);
|
||||
assert!(!ks.meta);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_ctrl_c() {
|
||||
let ks = parse_keystroke("ctrl+c").unwrap();
|
||||
assert_eq!(ks.key, "c");
|
||||
assert!(ks.ctrl);
|
||||
assert!(!ks.alt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_ctrl_shift_enter() {
|
||||
let ks = parse_keystroke("ctrl+shift+enter").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
assert!(ks.ctrl);
|
||||
assert!(ks.shift);
|
||||
assert!(!ks.alt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_normalizes_esc() {
|
||||
let ks = parse_keystroke("esc").unwrap();
|
||||
assert_eq!(ks.key, "escape");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_normalizes_return() {
|
||||
let ks = parse_keystroke("return").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_empty_returns_none() {
|
||||
assert!(parse_keystroke("ctrl+").is_none());
|
||||
assert!(parse_keystroke("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_single() {
|
||||
let chord = parse_chord("ctrl+c").unwrap();
|
||||
assert_eq!(chord.len(), 1);
|
||||
assert_eq!(chord[0].key, "c");
|
||||
assert!(chord[0].ctrl);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_multi() {
|
||||
let chord = parse_chord("ctrl+k ctrl+d").unwrap();
|
||||
assert_eq!(chord.len(), 2);
|
||||
assert_eq!(chord[0].key, "k");
|
||||
assert_eq!(chord[1].key, "d");
|
||||
assert!(chord[0].ctrl);
|
||||
assert!(chord[1].ctrl);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_empty_returns_none() {
|
||||
assert!(parse_chord("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_bindings_not_empty() {
|
||||
let bindings = default_bindings();
|
||||
assert!(!bindings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_bindings_contains_ctrl_c() {
|
||||
let bindings = default_bindings();
|
||||
let ctrl_c = bindings.iter().find(|b| {
|
||||
b.chord.len() == 1
|
||||
&& b.chord[0].ctrl
|
||||
&& b.chord[0].key == "c"
|
||||
&& b.context == KeyContext::Global
|
||||
});
|
||||
assert!(ctrl_c.is_some());
|
||||
assert_eq!(ctrl_c.unwrap().action.as_deref(), Some("interrupt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_simple_action() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
let ks = parse_keystroke("ctrl+c").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Global);
|
||||
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "interrupt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_no_match() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
// ctrl+z has no default binding
|
||||
let ks = parse_keystroke("ctrl+z").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Chat);
|
||||
assert!(matches!(result, KeybindingResult::NoMatch));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_context_match_global_from_chat() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
// ctrl+l is Global, should match even when context is Chat
|
||||
let ks = parse_keystroke("ctrl+l").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Chat);
|
||||
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "redraw"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keystroke_equality() {
|
||||
let ks1 = parse_keystroke("ctrl+enter").unwrap();
|
||||
let ks2 = parse_keystroke("ctrl+enter").unwrap();
|
||||
let ks3 = parse_keystroke("shift+enter").unwrap();
|
||||
assert_eq!(ks1, ks2);
|
||||
assert_ne!(ks1, ks3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_keybindings_default_empty() {
|
||||
let user = UserKeybindings::default();
|
||||
assert!(user.bindings.is_empty());
|
||||
}
|
||||
}
|
||||
2011
src-rust/crates/core/src/lib.rs
Normal file
2011
src-rust/crates/core/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
294
src-rust/crates/core/src/lsp.rs
Normal file
294
src-rust/crates/core/src/lsp.rs
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
//! Language Server Protocol client stub.
|
||||
//!
|
||||
//! The full LSP implementation is provided by plugins; this module defines
|
||||
//! the integration interface that the rest of the codebase uses to query
|
||||
//! diagnostics, register servers, and format output.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for a single LSP server process.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspServerConfig {
|
||||
/// Display name, e.g. "rust-analyzer"
|
||||
pub name: String,
|
||||
/// Path or name of the server binary, e.g. "rust-analyzer"
|
||||
pub command: String,
|
||||
/// Command-line arguments passed to the server binary
|
||||
pub args: Vec<String>,
|
||||
/// Glob patterns that activate this server, e.g. `["*.rs", "*.toml"]`
|
||||
pub file_patterns: Vec<String>,
|
||||
/// Optional server-specific initialization options (passed in LSP `initialize`)
|
||||
pub initialization_options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A single diagnostic emitted by an LSP server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LspDiagnostic {
|
||||
/// Workspace-relative or absolute file path
|
||||
pub file: String,
|
||||
/// 1-based line number
|
||||
pub line: u32,
|
||||
/// 1-based column number
|
||||
pub column: u32,
|
||||
pub severity: DiagnosticSeverity,
|
||||
pub message: String,
|
||||
/// The LSP server that produced this diagnostic (e.g. "rust-analyzer")
|
||||
pub source: Option<String>,
|
||||
/// Diagnostic code (e.g. "E0308"), if provided by the server
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
/// Severity level of a diagnostic, matching the LSP spec.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum DiagnosticSeverity {
|
||||
Error = 1,
|
||||
Warning = 2,
|
||||
Information = 3,
|
||||
Hint = 4,
|
||||
}
|
||||
|
||||
impl DiagnosticSeverity {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Error => "error",
|
||||
Self::Warning => "warning",
|
||||
Self::Information => "info",
|
||||
Self::Hint => "hint",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LSP manager stub.
|
||||
///
|
||||
/// In the full implementation this will own LSP server processes and route
|
||||
/// JSON-RPC messages. For now it is a registry that tracks configured
|
||||
/// servers and returns empty diagnostic lists — the plugin system is
|
||||
/// responsible for wiring up real communication.
|
||||
pub struct LspManager {
|
||||
servers: Vec<LspServerConfig>,
|
||||
}
|
||||
|
||||
impl LspManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
servers: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an LSP server configuration.
|
||||
pub fn register_server(&mut self, config: LspServerConfig) {
|
||||
self.servers.push(config);
|
||||
}
|
||||
|
||||
/// Return all registered server configurations.
|
||||
pub fn servers(&self) -> &[LspServerConfig] {
|
||||
&self.servers
|
||||
}
|
||||
|
||||
/// Look up a server configuration by name.
|
||||
pub fn server_by_name(&self, name: &str) -> Option<&LspServerConfig> {
|
||||
self.servers.iter().find(|s| s.name == name)
|
||||
}
|
||||
|
||||
/// Get diagnostics for a file.
|
||||
///
|
||||
/// This stub always returns an empty list. When an LSP plugin connects it
|
||||
/// will replace this path with real RPC calls.
|
||||
pub async fn get_diagnostics(&self, _file: &str) -> Vec<LspDiagnostic> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Format a slice of diagnostics into a human-readable multi-line string
|
||||
/// suitable for inclusion in tool output or TUI display.
|
||||
pub fn format_diagnostics(diagnostics: &[LspDiagnostic]) -> String {
|
||||
if diagnostics.is_empty() {
|
||||
return "No diagnostics.".to_string();
|
||||
}
|
||||
diagnostics
|
||||
.iter()
|
||||
.map(|d| {
|
||||
format!(
|
||||
"[{}] {}:{}:{} - {}{}{}",
|
||||
d.severity.as_str().to_uppercase(),
|
||||
d.file,
|
||||
d.line,
|
||||
d.column,
|
||||
d.message,
|
||||
d.source
|
||||
.as_deref()
|
||||
.map(|s| format!(" ({})", s))
|
||||
.unwrap_or_default(),
|
||||
d.code
|
||||
.as_deref()
|
||||
.map(|c| format!(" [{}]", c))
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LspManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_config(name: &str) -> LspServerConfig {
|
||||
LspServerConfig {
|
||||
name: name.to_string(),
|
||||
command: name.to_string(),
|
||||
args: vec![],
|
||||
file_patterns: vec!["*.rs".to_string()],
|
||||
initialization_options: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_diagnostic(
|
||||
file: &str,
|
||||
line: u32,
|
||||
col: u32,
|
||||
severity: DiagnosticSeverity,
|
||||
message: &str,
|
||||
) -> LspDiagnostic {
|
||||
LspDiagnostic {
|
||||
file: file.to_string(),
|
||||
line,
|
||||
column: col,
|
||||
severity,
|
||||
message: message.to_string(),
|
||||
source: None,
|
||||
code: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_manager_empty() {
|
||||
let mgr = LspManager::new();
|
||||
assert!(mgr.servers().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_server() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
assert_eq!(mgr.servers().len(), 1);
|
||||
assert_eq!(mgr.servers()[0].name, "rust-analyzer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_multiple_servers() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
mgr.register_server(make_config("pyright"));
|
||||
assert_eq!(mgr.servers().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_by_name_found() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
mgr.register_server(make_config("pyright"));
|
||||
let s = mgr.server_by_name("pyright");
|
||||
assert!(s.is_some());
|
||||
assert_eq!(s.unwrap().name, "pyright");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_by_name_not_found() {
|
||||
let mgr = LspManager::new();
|
||||
assert!(mgr.server_by_name("missing").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_diagnostics_stub_empty() {
|
||||
let mgr = LspManager::new();
|
||||
let diags = mgr.get_diagnostics("src/main.rs").await;
|
||||
assert!(diags.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_empty() {
|
||||
let result = LspManager::format_diagnostics(&[]);
|
||||
assert_eq!(result, "No diagnostics.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_single_error() {
|
||||
let diags = vec![make_diagnostic(
|
||||
"src/lib.rs",
|
||||
10,
|
||||
5,
|
||||
DiagnosticSeverity::Error,
|
||||
"type mismatch",
|
||||
)];
|
||||
let result = LspManager::format_diagnostics(&diags);
|
||||
assert!(result.contains("[ERROR]"));
|
||||
assert!(result.contains("src/lib.rs"));
|
||||
assert!(result.contains("10:5"));
|
||||
assert!(result.contains("type mismatch"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_multiple() {
|
||||
let diags = vec![
|
||||
make_diagnostic("a.rs", 1, 1, DiagnosticSeverity::Error, "err1"),
|
||||
make_diagnostic("b.rs", 2, 3, DiagnosticSeverity::Warning, "warn1"),
|
||||
];
|
||||
let result = LspManager::format_diagnostics(&diags);
|
||||
let lines: Vec<&str> = result.lines().collect();
|
||||
assert_eq!(lines.len(), 2);
|
||||
assert!(lines[0].contains("[ERROR]"));
|
||||
assert!(lines[1].contains("[WARNING]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_with_source_and_code() {
|
||||
let mut d = make_diagnostic(
|
||||
"main.rs",
|
||||
5,
|
||||
1,
|
||||
DiagnosticSeverity::Error,
|
||||
"mismatched types",
|
||||
);
|
||||
d.source = Some("rust-analyzer".to_string());
|
||||
d.code = Some("E0308".to_string());
|
||||
let result = LspManager::format_diagnostics(&[d]);
|
||||
assert!(result.contains("(rust-analyzer)"), "result = {}", result);
|
||||
assert!(result.contains("[E0308]"), "result = {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diagnostic_severity_ordering() {
|
||||
assert!(DiagnosticSeverity::Error < DiagnosticSeverity::Warning);
|
||||
assert!(DiagnosticSeverity::Warning < DiagnosticSeverity::Information);
|
||||
assert!(DiagnosticSeverity::Information < DiagnosticSeverity::Hint);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diagnostic_severity_as_str() {
|
||||
assert_eq!(DiagnosticSeverity::Error.as_str(), "error");
|
||||
assert_eq!(DiagnosticSeverity::Warning.as_str(), "warning");
|
||||
assert_eq!(DiagnosticSeverity::Information.as_str(), "info");
|
||||
assert_eq!(DiagnosticSeverity::Hint.as_str(), "hint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsp_server_config_serialization() {
|
||||
let cfg = make_config("rust-analyzer");
|
||||
let json = serde_json::to_string(&cfg).unwrap();
|
||||
let back: LspServerConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.name, "rust-analyzer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_trait() {
|
||||
let mgr = LspManager::default();
|
||||
assert!(mgr.servers().is_empty());
|
||||
}
|
||||
}
|
||||
881
src-rust/crates/core/src/memdir.rs
Normal file
881
src-rust/crates/core/src/memdir.rs
Normal file
|
|
@ -0,0 +1,881 @@
|
|||
//! Memory directory (memdir) system.
|
||||
//!
|
||||
//! Provides persistent, file-based memory across sessions. Mirrors the
|
||||
//! TypeScript modules under `src/memdir/`:
|
||||
//! - `memoryScan.ts` → `scan_memory_dir`, `parse_frontmatter_quick`, `format_memory_manifest`
|
||||
//! - `memoryAge.ts` → `memory_age_days`, `memory_freshness_text`, `memory_freshness_note`
|
||||
//! - `memdir.ts` → `build_memory_prompt_content`, `load_memory_index`, `ensure_memory_dir_exists`
|
||||
//! - `paths.ts` → `auto_memory_path`, `is_auto_memory_enabled`
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory type taxonomy
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The four canonical memory types.
|
||||
/// Matches the TypeScript `MemoryType` union in `memoryTypes.ts`.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MemoryType {
|
||||
/// Information about the user's role, goals, and preferences.
|
||||
User,
|
||||
/// Guidance the user has given about how to approach work.
|
||||
Feedback,
|
||||
/// Information about ongoing work, goals, or incidents in the project.
|
||||
Project,
|
||||
/// Pointers to where information lives in external systems.
|
||||
Reference,
|
||||
}
|
||||
|
||||
impl MemoryType {
|
||||
/// Parse a raw frontmatter value into a `MemoryType`.
|
||||
/// Returns `None` for missing or unrecognised values (legacy files degrade gracefully).
|
||||
pub fn parse(raw: &str) -> Option<Self> {
|
||||
match raw.trim() {
|
||||
"user" => Some(Self::User),
|
||||
"feedback" => Some(Self::Feedback),
|
||||
"project" => Some(Self::Project),
|
||||
"reference" => Some(Self::Reference),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Display as a lowercase string.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::User => "user",
|
||||
Self::Feedback => "feedback",
|
||||
Self::Project => "project",
|
||||
Self::Reference => "reference",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory file metadata and content
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Scanned metadata for a single memory file (without the full body).
|
||||
/// Mirrors `MemoryHeader` in `memoryScan.ts`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryFileMeta {
|
||||
/// Filename relative to the memory directory (e.g. `user_role.md`).
|
||||
pub filename: String,
|
||||
/// Absolute path to the file.
|
||||
pub path: PathBuf,
|
||||
/// `name:` frontmatter field.
|
||||
pub name: Option<String>,
|
||||
/// `description:` frontmatter field (used for relevance scoring).
|
||||
pub description: Option<String>,
|
||||
/// `type:` frontmatter field.
|
||||
pub memory_type: Option<MemoryType>,
|
||||
/// File modification time in seconds since UNIX epoch.
|
||||
pub modified_secs: u64,
|
||||
}
|
||||
|
||||
/// A fully-loaded memory file including its body.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryFile {
|
||||
pub meta: MemoryFileMeta,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Directory scanning
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Maximum number of memory files kept after sorting.
|
||||
/// Matches `MAX_MEMORY_FILES` in `memoryScan.ts`.
|
||||
const MAX_MEMORY_FILES: usize = 200;
|
||||
|
||||
/// Number of lines scanned for frontmatter.
|
||||
/// Matches `FRONTMATTER_MAX_LINES` in `memoryScan.ts`.
|
||||
const FRONTMATTER_MAX_LINES: usize = 30;
|
||||
|
||||
/// Scan a memory directory, returning metadata for all `.md` files
|
||||
/// (excluding `MEMORY.md`), sorted newest-first, capped at `MAX_MEMORY_FILES`.
|
||||
///
|
||||
/// This is a synchronous scan used during system-prompt assembly.
|
||||
/// Mirrors `scanMemoryFiles` in `memoryScan.ts` (async version; this is the
|
||||
/// sync equivalent used at prompt-build time).
|
||||
pub fn scan_memory_dir(dir: &Path) -> Vec<MemoryFileMeta> {
|
||||
let mut files: Vec<MemoryFileMeta> = Vec::new();
|
||||
|
||||
if !dir.exists() {
|
||||
return files;
|
||||
}
|
||||
|
||||
// Walk recursively using `walkdir`-style manual recursion to stay
|
||||
// dependency-free (only std).
|
||||
collect_md_files(dir, dir, &mut files);
|
||||
|
||||
// Sort newest-first.
|
||||
files.sort_by(|a, b| b.modified_secs.cmp(&a.modified_secs));
|
||||
files.truncate(MAX_MEMORY_FILES);
|
||||
files
|
||||
}
|
||||
|
||||
/// Recursively collect `.md` files (excluding `MEMORY.md`) from `current_dir`.
|
||||
fn collect_md_files(base: &Path, current_dir: &Path, out: &mut Vec<MemoryFileMeta>) {
|
||||
let Ok(entries) = std::fs::read_dir(current_dir) else {
|
||||
return;
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
collect_md_files(base, &path, out);
|
||||
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
|
||||
let file_name = path.file_name().map(|n| n.to_string_lossy().into_owned()).unwrap_or_default();
|
||||
if file_name == "MEMORY.md" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let modified_secs = entry
|
||||
.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.map(|t| t.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
let (name, description, memory_type) =
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
parse_frontmatter_quick(&content)
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
// Relative path from the memory dir root.
|
||||
let relative = path
|
||||
.strip_prefix(base)
|
||||
.map(|p| p.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|_| file_name.clone());
|
||||
|
||||
out.push(MemoryFileMeta {
|
||||
filename: relative,
|
||||
path,
|
||||
name,
|
||||
description,
|
||||
memory_type,
|
||||
modified_secs,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse YAML frontmatter from the first `FRONTMATTER_MAX_LINES` lines without
|
||||
/// a full YAML parser. Returns `(name, description, memory_type)`.
|
||||
///
|
||||
/// Mirrors `parseFrontmatter` usage in `memoryScan.ts`.
|
||||
pub fn parse_frontmatter_quick(
|
||||
content: &str,
|
||||
) -> (Option<String>, Option<String>, Option<MemoryType>) {
|
||||
let mut name = None;
|
||||
let mut description = None;
|
||||
let mut memory_type = None;
|
||||
|
||||
let lines: Vec<&str> = content.lines().take(FRONTMATTER_MAX_LINES).collect();
|
||||
|
||||
// Frontmatter must start with `---`
|
||||
if lines.first().map(|l| l.trim() != "---").unwrap_or(true) {
|
||||
return (name, description, memory_type);
|
||||
}
|
||||
|
||||
for line in &lines[1..] {
|
||||
if line.trim() == "---" {
|
||||
break;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("name:") {
|
||||
name = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("description:") {
|
||||
description = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("type:") {
|
||||
memory_type = MemoryType::parse(rest.trim().trim_matches('"').trim_matches('\''));
|
||||
}
|
||||
}
|
||||
|
||||
(name, description, memory_type)
|
||||
}
|
||||
|
||||
/// Format memory headers as a text manifest: one entry per file with
|
||||
/// `[type] filename (iso-timestamp): description`.
|
||||
///
|
||||
/// Mirrors `formatMemoryManifest` in `memoryScan.ts`.
|
||||
pub fn format_memory_manifest(memories: &[MemoryFileMeta]) -> String {
|
||||
memories
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let tag = m
|
||||
.memory_type
|
||||
.as_ref()
|
||||
.map(|t| format!("[{}] ", t.as_str()))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Convert modified_secs to an ISO-8601-like timestamp.
|
||||
let ts = format_unix_secs_iso(m.modified_secs);
|
||||
|
||||
match &m.description {
|
||||
Some(desc) => format!("- {}{} ({}): {}", tag, m.filename, ts, desc),
|
||||
None => format!("- {}{} ({})", tag, m.filename, ts),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Minimal ISO-8601 formatter for a Unix timestamp (no external deps).
|
||||
fn format_unix_secs_iso(secs: u64) -> String {
|
||||
// We use a very lightweight implementation to avoid pulling in chrono here
|
||||
// (chrono is already a workspace dep but we want this module to stay lean).
|
||||
// Accuracy to the day is sufficient for memory manifests.
|
||||
let days_since_epoch = secs / 86400;
|
||||
// Julian Day Number for 1970-01-01 is 2440588.
|
||||
let jdn = days_since_epoch as u32 + 2440588;
|
||||
let (y, m, d) = jdn_to_ymd(jdn);
|
||||
let hh = (secs % 86400) / 3600;
|
||||
let mm = (secs % 3600) / 60;
|
||||
let ss = secs % 60;
|
||||
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, hh, mm, ss)
|
||||
}
|
||||
|
||||
/// Convert a Julian Day Number to (year, month, day).
|
||||
fn jdn_to_ymd(jdn: u32) -> (u32, u32, u32) {
|
||||
let a = jdn + 32044;
|
||||
let b = (4 * a + 3) / 146097;
|
||||
let c = a - (146097 * b) / 4;
|
||||
let d = (4 * c + 3) / 1461;
|
||||
let e = c - (1461 * d) / 4;
|
||||
let m = (5 * e + 2) / 153;
|
||||
let day = e - (153 * m + 2) / 5 + 1;
|
||||
let month = m + 3 - 12 * (m / 10);
|
||||
let year = 100 * b + d - 4800 + m / 10;
|
||||
(year, month, day)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory age / freshness
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Days elapsed since `modified_secs`. Floor-rounded; clamped to 0 for
|
||||
/// future mtimes (clock skew).
|
||||
///
|
||||
/// Mirrors `memoryAgeDays` in `memoryAge.ts`.
|
||||
pub fn memory_age_days(modified_secs: u64) -> u64 {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
(now.saturating_sub(modified_secs)) / 86400
|
||||
}
|
||||
|
||||
/// Human-readable age string. Models are poor at date arithmetic — a raw
|
||||
/// ISO timestamp does not trigger staleness reasoning the way "47 days ago" does.
|
||||
///
|
||||
/// Mirrors `memoryAge` in `memoryAge.ts`.
|
||||
pub fn memory_age(modified_secs: u64) -> String {
|
||||
let d = memory_age_days(modified_secs);
|
||||
match d {
|
||||
0 => "today".to_string(),
|
||||
1 => "yesterday".to_string(),
|
||||
n => format!("{} days ago", n),
|
||||
}
|
||||
}
|
||||
|
||||
/// Plain-text staleness caveat for memories > 1 day old.
|
||||
/// Returns an empty string for fresh memories (today / yesterday).
|
||||
///
|
||||
/// Mirrors `memoryFreshnessText` in `memoryAge.ts`.
|
||||
pub fn memory_freshness_text(modified_secs: u64) -> String {
|
||||
let d = memory_age_days(modified_secs);
|
||||
if d <= 1 {
|
||||
return String::new();
|
||||
}
|
||||
format!(
|
||||
"This memory is {} days old. \
|
||||
Memories are point-in-time observations, not live state — \
|
||||
claims about code behavior or file:line citations may be outdated. \
|
||||
Verify against current code before asserting as fact.",
|
||||
d
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-memory staleness note wrapped in `<system-reminder>` tags.
|
||||
/// Returns an empty string for memories ≤ 1 day old.
|
||||
///
|
||||
/// Mirrors `memoryFreshnessNote` in `memoryAge.ts`.
|
||||
pub fn memory_freshness_note(modified_secs: u64) -> String {
|
||||
let text = memory_freshness_text(modified_secs);
|
||||
if text.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
format!("<system-reminder>{}</system-reminder>\n", text)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Entrypoint filename within the memory directory.
|
||||
pub const MEMORY_ENTRYPOINT: &str = "MEMORY.md";
|
||||
|
||||
/// Maximum number of lines loaded from `MEMORY.md`.
|
||||
/// Matches `MAX_ENTRYPOINT_LINES` in `memdir.ts`.
|
||||
pub const MAX_ENTRYPOINT_LINES: usize = 200;
|
||||
|
||||
/// Maximum bytes loaded from `MEMORY.md`.
|
||||
/// Matches `MAX_ENTRYPOINT_BYTES` in `memdir.ts`.
|
||||
pub const MAX_ENTRYPOINT_BYTES: usize = 25_000;
|
||||
|
||||
/// Compute the auto-memory directory path for a project root.
|
||||
///
|
||||
/// Resolution order (mirrors `getAutoMemPath` in `paths.ts`):
|
||||
/// 1. `CLAUDE_COWORK_MEMORY_PATH_OVERRIDE` env var (full-path override).
|
||||
/// 2. `<CLAUDE_CODE_REMOTE_MEMORY_DIR>/projects/<sanitized-root>/memory/`
|
||||
/// when `CLAUDE_CODE_REMOTE_MEMORY_DIR` is set.
|
||||
/// 3. `~/.claude/projects/<sanitized-root>/memory/` (default).
|
||||
pub fn auto_memory_path(project_root: &Path) -> PathBuf {
|
||||
// 1. Cowork full-path override.
|
||||
if let Ok(override_path) = std::env::var("CLAUDE_COWORK_MEMORY_PATH_OVERRIDE") {
|
||||
if !override_path.is_empty() {
|
||||
return PathBuf::from(override_path);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Determine the memory base directory.
|
||||
let memory_base = std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| {
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join(".claude")
|
||||
});
|
||||
|
||||
// 3. Sanitize the project root into a safe directory name.
|
||||
let sanitized = sanitize_path_component(&project_root.to_string_lossy());
|
||||
|
||||
memory_base.join("projects").join(sanitized).join("memory")
|
||||
}
|
||||
|
||||
/// Sanitize an arbitrary string into a directory-name-safe component.
|
||||
/// Matches `sanitizePath` used inside `getAutoMemPath` in `paths.ts`.
|
||||
pub fn sanitize_path_component(s: &str) -> String {
|
||||
let sanitized: String = s
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
sanitized.trim_matches('_').to_string()
|
||||
}
|
||||
|
||||
/// Whether the auto-memory system is enabled for this session.
|
||||
///
|
||||
/// Priority chain (mirrors `isAutoMemoryEnabled` in `paths.ts`):
|
||||
/// 1. `CLAUDE_CODE_DISABLE_AUTO_MEMORY` — truthy → OFF, falsy (but defined) → ON.
|
||||
/// 2. `CLAUDE_CODE_SIMPLE` (--bare) → OFF.
|
||||
/// 3. Remote mode without `CLAUDE_CODE_REMOTE_MEMORY_DIR` → OFF.
|
||||
/// 4. `settings_enabled` parameter (from settings.json `autoMemoryEnabled` field).
|
||||
/// 5. Default: enabled.
|
||||
pub fn is_auto_memory_enabled(settings_enabled: Option<bool>) -> bool {
|
||||
if let Ok(val) = std::env::var("CLAUDE_CODE_DISABLE_AUTO_MEMORY") {
|
||||
// Truthy values (non-empty, non-"0", non-"false") disable memory.
|
||||
match val.to_lowercase().as_str() {
|
||||
"" | "0" | "false" | "no" | "off" => return true, // defined-falsy → ON
|
||||
_ => return false, // truthy → OFF
|
||||
}
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_SIMPLE").is_ok() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_REMOTE").is_ok()
|
||||
&& std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR").is_err()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
settings_enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Index loading and truncation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Result of loading and (optionally) truncating the `MEMORY.md` entrypoint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EntrypointTruncation {
|
||||
pub content: String,
|
||||
pub line_count: usize,
|
||||
pub byte_count: usize,
|
||||
pub was_line_truncated: bool,
|
||||
pub was_byte_truncated: bool,
|
||||
}
|
||||
|
||||
/// Truncate `MEMORY.md` content to `MAX_ENTRYPOINT_LINES` lines and
|
||||
/// `MAX_ENTRYPOINT_BYTES` bytes, appending a warning when either cap fires.
|
||||
///
|
||||
/// Mirrors `truncateEntrypointContent` in `memdir.ts`.
|
||||
pub fn truncate_entrypoint_content(raw: &str) -> EntrypointTruncation {
|
||||
let trimmed = raw.trim();
|
||||
let content_lines: Vec<&str> = trimmed.lines().collect();
|
||||
let line_count = content_lines.len();
|
||||
let byte_count = trimmed.len();
|
||||
|
||||
let was_line_truncated = line_count > MAX_ENTRYPOINT_LINES;
|
||||
let was_byte_truncated = byte_count > MAX_ENTRYPOINT_BYTES;
|
||||
|
||||
if !was_line_truncated && !was_byte_truncated {
|
||||
return EntrypointTruncation {
|
||||
content: trimmed.to_string(),
|
||||
line_count,
|
||||
byte_count,
|
||||
was_line_truncated: false,
|
||||
was_byte_truncated: false,
|
||||
};
|
||||
}
|
||||
|
||||
let mut truncated = if was_line_truncated {
|
||||
content_lines[..MAX_ENTRYPOINT_LINES].join("\n")
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
};
|
||||
|
||||
if truncated.len() > MAX_ENTRYPOINT_BYTES {
|
||||
let cut_at = truncated[..MAX_ENTRYPOINT_BYTES]
|
||||
.rfind('\n')
|
||||
.unwrap_or(MAX_ENTRYPOINT_BYTES);
|
||||
truncated.truncate(cut_at);
|
||||
}
|
||||
|
||||
let reason = match (was_line_truncated, was_byte_truncated) {
|
||||
(true, false) => format!("{} lines (limit: {})", line_count, MAX_ENTRYPOINT_LINES),
|
||||
(false, true) => format!(
|
||||
"{} bytes (limit: {}) — index entries are too long",
|
||||
byte_count, MAX_ENTRYPOINT_BYTES
|
||||
),
|
||||
_ => format!(
|
||||
"{} lines and {} bytes",
|
||||
line_count, byte_count
|
||||
),
|
||||
};
|
||||
|
||||
truncated.push_str(&format!(
|
||||
"\n\n> WARNING: {} is {}. Only part of it was loaded. \
|
||||
Keep index entries to one line under ~200 chars; move detail into topic files.",
|
||||
MEMORY_ENTRYPOINT, reason
|
||||
));
|
||||
|
||||
EntrypointTruncation {
|
||||
content: truncated,
|
||||
line_count,
|
||||
byte_count,
|
||||
was_line_truncated,
|
||||
was_byte_truncated,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load and truncate the `MEMORY.md` index from `memory_dir`.
|
||||
/// Returns `None` when the file does not exist or is empty.
|
||||
///
|
||||
/// Mirrors the entrypoint-reading path in `buildMemoryPrompt` / `loadMemoryPrompt`.
|
||||
pub fn load_memory_index(memory_dir: &Path) -> Option<EntrypointTruncation> {
|
||||
let index_path = memory_dir.join(MEMORY_ENTRYPOINT);
|
||||
if !index_path.exists() {
|
||||
return None;
|
||||
}
|
||||
let raw = std::fs::read_to_string(&index_path).ok()?;
|
||||
if raw.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(truncate_entrypoint_content(&raw))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// System-prompt memory content builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the memory content string to inject into the system prompt's
|
||||
/// `<memory>` block.
|
||||
///
|
||||
/// Always includes the `MEMORY.md` index when it exists.
|
||||
/// Called during `build_system_prompt` → `SystemPromptOptions::memory_content`.
|
||||
pub fn build_memory_prompt_content(memory_dir: &Path) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(index) = load_memory_index(memory_dir) {
|
||||
parts.push(format!("## Memory Index (MEMORY.md)\n{}", index.content));
|
||||
}
|
||||
|
||||
parts.join("\n\n")
|
||||
}
|
||||
|
||||
/// Ensure the memory directory exists, creating it (and any parents) if needed.
|
||||
/// Errors are silently swallowed (the Write tool will surface them if needed).
|
||||
///
|
||||
/// Mirrors `ensureMemoryDirExists` in `memdir.ts`.
|
||||
pub fn ensure_memory_dir_exists(memory_dir: &Path) {
|
||||
if let Err(e) = std::fs::create_dir_all(memory_dir) {
|
||||
// Log at debug level so --debug shows why, but don't abort.
|
||||
tracing::debug!(
|
||||
dir = %memory_dir.display(),
|
||||
error = %e,
|
||||
"ensureMemoryDirExists failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Simple relevance search (no LLM side-query)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Find and load the most relevant memory files for a query using a
|
||||
/// lightweight TF-IDF-style keyword score.
|
||||
///
|
||||
/// The full Sonnet side-query (`findRelevantMemories` in TypeScript) lives
|
||||
/// in `cc-query`; this function provides a cheaper fallback for contexts
|
||||
/// where an API call is not available.
|
||||
pub fn find_relevant_memories_simple(
|
||||
memory_dir: &Path,
|
||||
query: &str,
|
||||
max_files: usize,
|
||||
) -> Vec<MemoryFile> {
|
||||
let metas = scan_memory_dir(memory_dir);
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
||||
if query_words.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut scored: Vec<(f32, MemoryFileMeta)> = metas
|
||||
.into_iter()
|
||||
.filter_map(|meta| {
|
||||
let desc = meta.description.as_deref().unwrap_or("").to_lowercase();
|
||||
let name = meta.name.as_deref().unwrap_or("").to_lowercase();
|
||||
let filename = meta.filename.to_lowercase();
|
||||
|
||||
let score: f32 = query_words
|
||||
.iter()
|
||||
.map(|w| {
|
||||
let in_name = if name.contains(*w) { 2.0_f32 } else { 0.0 };
|
||||
let in_desc = if desc.contains(*w) { 1.0_f32 } else { 0.0 };
|
||||
let in_file = if filename.contains(*w) { 0.5_f32 } else { 0.0 };
|
||||
in_name + in_desc + in_file
|
||||
})
|
||||
.sum();
|
||||
|
||||
if score > 0.0 { Some((score, meta)) } else { None }
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort highest score first.
|
||||
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
scored
|
||||
.into_iter()
|
||||
.take(max_files)
|
||||
.filter_map(|(_, meta)| {
|
||||
let content = std::fs::read_to_string(&meta.path).ok()?;
|
||||
Some(MemoryFile { meta, content })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Team memory helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return the team-memory sub-directory path.
|
||||
/// Mirrors `getTeamMemPath` in `teamMemPaths.ts`.
|
||||
pub fn team_memory_path(auto_memory_dir: &Path) -> PathBuf {
|
||||
auto_memory_dir.join("team")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write as IoWrite;
|
||||
|
||||
// Helpers ----------------------------------------------------------------
|
||||
|
||||
fn make_temp_dir() -> tempfile::TempDir {
|
||||
tempfile::tempdir().expect("tempdir")
|
||||
}
|
||||
|
||||
fn write_file(dir: &Path, name: &str, content: &str) {
|
||||
let path = dir.join(name);
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).unwrap();
|
||||
}
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
// ---- parse_frontmatter_quick -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_full() {
|
||||
let content = "---\nname: My Memory\ndescription: A test description\ntype: feedback\n---\n\nBody text.";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert_eq!(name.as_deref(), Some("My Memory"));
|
||||
assert_eq!(desc.as_deref(), Some("A test description"));
|
||||
assert_eq!(mt, Some(MemoryType::Feedback));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_no_frontmatter() {
|
||||
let content = "Just plain text.";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert!(name.is_none());
|
||||
assert!(desc.is_none());
|
||||
assert!(mt.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_quoted_values() {
|
||||
let content = "---\nname: \"Quoted Name\"\ndescription: 'Single quoted'\ntype: user\n---";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert_eq!(name.as_deref(), Some("Quoted Name"));
|
||||
assert_eq!(desc.as_deref(), Some("Single quoted"));
|
||||
assert_eq!(mt, Some(MemoryType::User));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_unknown_type() {
|
||||
let content = "---\ntype: unknown_type\n---";
|
||||
let (_, _, mt) = parse_frontmatter_quick(content);
|
||||
assert!(mt.is_none());
|
||||
}
|
||||
|
||||
// ---- memory_age_days ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_today() {
|
||||
let now_secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
assert_eq!(memory_age_days(now_secs), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_one_day_ago() {
|
||||
let yesterday = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(86_400);
|
||||
assert_eq!(memory_age_days(yesterday), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_future_clamps_to_zero() {
|
||||
let far_future = u64::MAX;
|
||||
assert_eq!(memory_age_days(far_future), 0);
|
||||
}
|
||||
|
||||
// ---- memory_freshness_text ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_freshness_text_fresh() {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
assert!(memory_freshness_text(now).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_freshness_text_stale() {
|
||||
let old = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(10 * 86_400); // 10 days ago
|
||||
let text = memory_freshness_text(old);
|
||||
assert!(text.contains("10 days old"));
|
||||
assert!(text.contains("point-in-time"));
|
||||
}
|
||||
|
||||
// ---- memory_freshness_note ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_freshness_note_fresh_is_empty() {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
assert!(memory_freshness_note(now).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_freshness_note_stale_has_tags() {
|
||||
let old = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(5 * 86_400);
|
||||
let note = memory_freshness_note(old);
|
||||
assert!(note.contains("<system-reminder>"));
|
||||
assert!(note.contains("</system-reminder>"));
|
||||
}
|
||||
|
||||
// ---- truncate_entrypoint_content ---------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_truncate_no_truncation_needed() {
|
||||
let content = "line1\nline2\nline3";
|
||||
let result = truncate_entrypoint_content(content);
|
||||
assert!(!result.was_line_truncated);
|
||||
assert!(!result.was_byte_truncated);
|
||||
assert_eq!(result.content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_line_limit() {
|
||||
let content = (0..=MAX_ENTRYPOINT_LINES)
|
||||
.map(|i| format!("line {}", i))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let result = truncate_entrypoint_content(&content);
|
||||
assert!(result.was_line_truncated);
|
||||
assert!(result.content.contains("WARNING"));
|
||||
}
|
||||
|
||||
// ---- sanitize_path_component -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path_component() {
|
||||
assert_eq!(sanitize_path_component("/home/user/project"), "_home_user_project");
|
||||
assert_eq!(sanitize_path_component("normal-name_123"), "normal-name_123");
|
||||
assert_eq!(sanitize_path_component("C:\\Users\\foo"), "C__Users_foo");
|
||||
}
|
||||
|
||||
// ---- load_memory_index -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_nonexistent() {
|
||||
let dir = make_temp_dir();
|
||||
assert!(load_memory_index(dir.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_empty() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", " ");
|
||||
assert!(load_memory_index(dir.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_with_content() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", "- [test.md](test.md) — something");
|
||||
let result = load_memory_index(dir.path()).unwrap();
|
||||
assert!(result.content.contains("test.md"));
|
||||
}
|
||||
|
||||
// ---- scan_memory_dir ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_scan_excludes_memory_md() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", "# index");
|
||||
write_file(dir.path(), "user_role.md", "---\nname: Role\n---");
|
||||
let metas = scan_memory_dir(dir.path());
|
||||
assert_eq!(metas.len(), 1);
|
||||
assert_eq!(metas[0].filename, "user_role.md");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scan_empty_dir() {
|
||||
let dir = make_temp_dir();
|
||||
assert!(scan_memory_dir(dir.path()).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scan_nonexistent_dir() {
|
||||
let path = PathBuf::from("/tmp/nonexistent_memory_dir_cc_rust_test_xyz");
|
||||
assert!(scan_memory_dir(&path).is_empty());
|
||||
}
|
||||
|
||||
// ---- format_memory_manifest --------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_manifest_with_description() {
|
||||
let meta = MemoryFileMeta {
|
||||
filename: "user_role.md".to_string(),
|
||||
path: PathBuf::from("user_role.md"),
|
||||
name: Some("User Role".to_string()),
|
||||
description: Some("The user is a data scientist".to_string()),
|
||||
memory_type: Some(MemoryType::User),
|
||||
modified_secs: 0,
|
||||
};
|
||||
let manifest = format_memory_manifest(&[meta]);
|
||||
assert!(manifest.contains("[user]"));
|
||||
assert!(manifest.contains("user_role.md"));
|
||||
assert!(manifest.contains("data scientist"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_manifest_no_description() {
|
||||
let meta = MemoryFileMeta {
|
||||
filename: "ref.md".to_string(),
|
||||
path: PathBuf::from("ref.md"),
|
||||
name: None,
|
||||
description: None,
|
||||
memory_type: None,
|
||||
modified_secs: 0,
|
||||
};
|
||||
let manifest = format_memory_manifest(&[meta]);
|
||||
assert!(manifest.contains("ref.md"));
|
||||
// No description separator colon
|
||||
assert!(!manifest.contains("ref.md ("));
|
||||
}
|
||||
|
||||
// ---- MemoryType --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_roundtrip() {
|
||||
for (s, expected) in [
|
||||
("user", MemoryType::User),
|
||||
("feedback", MemoryType::Feedback),
|
||||
("project", MemoryType::Project),
|
||||
("reference", MemoryType::Reference),
|
||||
] {
|
||||
let parsed = MemoryType::parse(s).unwrap();
|
||||
assert_eq!(parsed, expected);
|
||||
assert_eq!(parsed.as_str(), s);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_unknown_returns_none() {
|
||||
assert!(MemoryType::parse("bogus").is_none());
|
||||
}
|
||||
|
||||
// ---- is_auto_memory_enabled -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_auto_memory_enabled_default() {
|
||||
// No env vars set for this test, settings None → should be enabled.
|
||||
// We can't guarantee the test environment is clean, so just check it
|
||||
// returns a bool without panicking.
|
||||
let _ = is_auto_memory_enabled(None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_memory_disabled_by_setting() {
|
||||
// If settings explicitly disable it and no env override, returns false.
|
||||
// We only test the settings-path without touching process env.
|
||||
// Simulate: env vars not set, settings says false.
|
||||
// We can't unset env vars reliably in tests, so just ensure the
|
||||
// function handles Some(false) without panicking.
|
||||
// (The full env-var paths are integration-tested separately.)
|
||||
let _ = is_auto_memory_enabled(Some(false));
|
||||
}
|
||||
}
|
||||
474
src-rust/crates/core/src/migrations.rs
Normal file
474
src-rust/crates/core/src/migrations.rs
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
//! Settings migration framework
|
||||
//! Runs on startup to upgrade settings.json from older versions.
|
||||
//!
|
||||
//! Migrations are derived from the TypeScript originals:
|
||||
//! - src/migrations/migrateFennecToOpus.ts
|
||||
//! - src/migrations/migrateLegacyOpusToCurrent.ts
|
||||
//! - src/migrations/migrateSonnet45ToSonnet46.ts
|
||||
//! - src/migrations/migrateAutoUpdatesToSettings.ts
|
||||
//! - (and several others without separate TS source files)
|
||||
//!
|
||||
//! Each migration is idempotent: it only touches fields it recognises and
|
||||
//! only writes when it actually changes something.
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
/// A single migration function.
|
||||
/// Returns `true` if the settings object was modified.
|
||||
pub type MigrationFn = fn(&mut Value) -> bool;
|
||||
|
||||
/// All migrations in the order they must be applied.
|
||||
pub const MIGRATIONS: &[(&str, MigrationFn)] = &[
|
||||
("migrate_fennec_to_opus", migrate_fennec_to_opus),
|
||||
("migrate_legacy_opus_to_current", migrate_legacy_opus_to_current),
|
||||
("migrate_opus_to_opus_1m", migrate_opus_to_opus_1m),
|
||||
("migrate_sonnet_1m_to_sonnet_45", migrate_sonnet_1m_to_sonnet_45),
|
||||
("migrate_sonnet_45_to_sonnet_46", migrate_sonnet_45_to_sonnet_46),
|
||||
(
|
||||
"migrate_bypass_permissions_to_settings",
|
||||
migrate_bypass_permissions_to_settings,
|
||||
),
|
||||
(
|
||||
"migrate_repl_bridge_to_remote_control",
|
||||
migrate_repl_bridge_to_remote_control,
|
||||
),
|
||||
("migrate_enable_all_mcp_servers", migrate_enable_all_mcp_servers),
|
||||
("migrate_auto_updates", migrate_auto_updates),
|
||||
("reset_auto_mode_opt_in", reset_auto_mode_opt_in),
|
||||
("reset_pro_to_opus_default", reset_pro_to_opus_default),
|
||||
];
|
||||
|
||||
/// Apply every pending migration to a settings `Value` (must be a JSON object).
|
||||
/// Returns `true` when at least one migration changed the settings.
|
||||
pub fn run_migrations(settings: &mut Value) -> bool {
|
||||
let mut changed = false;
|
||||
for (name, migration) in MIGRATIONS {
|
||||
if migration(settings) {
|
||||
tracing::info!("Applied settings migration: {}", name);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model-name migrations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fennec was an internal alias; map to the current Opus line.
|
||||
/// Source: migrateFennecToOpus.ts
|
||||
fn migrate_fennec_to_opus(settings: &mut Value) -> bool {
|
||||
// fennec-latest[1m] → opus[1m], fennec-latest → opus
|
||||
// fennec-fast-latest / opus-4-5-fast → opus[1m] (fast-mode alias)
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if model.starts_with("fennec-latest[1m]") {
|
||||
settings["model"] = Value::String("opus[1m]".to_string());
|
||||
return true;
|
||||
}
|
||||
if model.starts_with("fennec-latest") {
|
||||
settings["model"] = Value::String("opus".to_string());
|
||||
return true;
|
||||
}
|
||||
if model.starts_with("fennec-fast-latest") || model.starts_with("opus-4-5-fast") {
|
||||
settings["model"] = Value::String("opus[1m]".to_string());
|
||||
settings["fastMode"] = Value::Bool(true);
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Migrate explicit Opus 4.0/4.1 strings to the `opus` alias.
|
||||
/// Source: migrateLegacyOpusToCurrent.ts
|
||||
fn migrate_legacy_opus_to_current(settings: &mut Value) -> bool {
|
||||
const LEGACY_OPUS: &[&str] = &[
|
||||
"claude-opus-4-20250514",
|
||||
"claude-opus-4-1-20250805",
|
||||
"claude-opus-4-0",
|
||||
"claude-opus-4-1",
|
||||
];
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
if LEGACY_OPUS.contains(&model.as_str()) {
|
||||
settings["model"] = Value::String("opus".to_string());
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Rename the old explicit `claude-opus-4-0` model string (pre-alias era).
|
||||
fn migrate_opus_to_opus_1m(settings: &mut Value) -> bool {
|
||||
rename_model(settings, "claude-opus-4-0", "claude-opus-4-5-20251001")
|
||||
}
|
||||
|
||||
/// Migrate the old Sonnet 1m string to the Sonnet 4.5 release ID.
|
||||
fn migrate_sonnet_1m_to_sonnet_45(settings: &mut Value) -> bool {
|
||||
rename_model(
|
||||
settings,
|
||||
"claude-sonnet-4-0-1m",
|
||||
"claude-sonnet-4-5-20251015",
|
||||
)
|
||||
}
|
||||
|
||||
/// Migrate Sonnet 4.5 explicit IDs to `sonnet` (which resolves to 4.6).
|
||||
/// Source: migrateSonnet45ToSonnet46.ts
|
||||
fn migrate_sonnet_45_to_sonnet_46(settings: &mut Value) -> bool {
|
||||
const SONNET_45_IDS: &[&str] = &[
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5-20250929[1m]",
|
||||
"sonnet-4-5-20250929",
|
||||
"sonnet-4-5-20250929[1m]",
|
||||
// Also handle the model strings used in the older Rust migrations table:
|
||||
"claude-sonnet-4-5-20251015",
|
||||
"claude-sonnet-4-5",
|
||||
];
|
||||
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if SONNET_45_IDS.contains(&model.as_str()) {
|
||||
let has_1m = model.ends_with("[1m]");
|
||||
let new_model = if has_1m { "sonnet[1m]" } else { "sonnet" };
|
||||
settings["model"] = Value::String(new_model.to_string());
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Rename `from` to `to` in the `model`, `defaultModel`, and `mainLoopModel`
|
||||
/// fields. Returns `true` if any field was changed.
|
||||
fn rename_model(settings: &mut Value, from: &str, to: &str) -> bool {
|
||||
let mut changed = false;
|
||||
for key in &["model", "defaultModel", "mainLoopModel"] {
|
||||
if let Some(val) = settings.get_mut(*key) {
|
||||
if val.as_str() == Some(from) {
|
||||
*val = Value::String(to.to_string());
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config-structure migrations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Move `bypassPermissionsAccepted` boolean into `permissionMode`.
|
||||
fn migrate_bypass_permissions_to_settings(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("bypassPermissionsAccepted").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("permissionMode").is_none() && old.as_bool().unwrap_or(false) {
|
||||
settings["permissionMode"] = Value::String("bypass".to_string());
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("bypassPermissionsAccepted");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Rename `replBridgeEnabled` → `remoteControlAtStartup`.
|
||||
fn migrate_repl_bridge_to_remote_control(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("replBridgeEnabled").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("remoteControlAtStartup").is_none() {
|
||||
settings["remoteControlAtStartup"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("replBridgeEnabled");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Rename `enableAllProjectMcpServers` → `mcpAutoApprove`.
|
||||
fn migrate_enable_all_mcp_servers(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("enableAllProjectMcpServers").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("mcpAutoApprove").is_none() {
|
||||
settings["mcpAutoApprove"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("enableAllProjectMcpServers");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Migrate `autoUpdatesEnabled` → `autoUpdates`.
|
||||
/// Source: migrateAutoUpdatesToSettings.ts
|
||||
/// The TS version also writes an env-var to settings.json; here we keep the
|
||||
/// simpler structural rename and leave env-var injection to the caller.
|
||||
fn migrate_auto_updates(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("autoUpdatesEnabled").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("autoUpdates").is_none() {
|
||||
settings["autoUpdates"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("autoUpdatesEnabled");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Clear an old sentinel value for the auto-mode opt-in flag.
|
||||
fn reset_auto_mode_opt_in(settings: &mut Value) -> bool {
|
||||
if let Some(val) = settings.get("autoModeOptIn") {
|
||||
if val.as_str() == Some("default_offer_2024") {
|
||||
settings["autoModeOptIn"] = Value::Null;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Reset users who were auto-defaulted to Opus back to Sonnet 4.6.
|
||||
/// Only resets when `modelSetByUser` is not explicitly `true`.
|
||||
fn reset_pro_to_opus_default(settings: &mut Value) -> bool {
|
||||
if let Some(val) = settings.get("model") {
|
||||
if val.as_str() == Some("claude-opus-4-5-20251001") {
|
||||
let set_by_user = settings
|
||||
.get("modelSetByUser")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
if !set_by_user {
|
||||
settings["model"] = Value::String("claude-sonnet-4-6".to_string());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn settings(model: &str) -> Value {
|
||||
json!({ "model": model })
|
||||
}
|
||||
|
||||
// ---- rename_model -------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn rename_model_changes_matching_field() {
|
||||
let mut s = settings("old-model");
|
||||
assert!(rename_model(&mut s, "old-model", "new-model"));
|
||||
assert_eq!(s["model"].as_str(), Some("new-model"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rename_model_no_change_when_different() {
|
||||
let mut s = settings("something-else");
|
||||
assert!(!rename_model(&mut s, "old-model", "new-model"));
|
||||
assert_eq!(s["model"].as_str(), Some("something-else"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rename_model_covers_all_keys() {
|
||||
let mut s = json!({
|
||||
"model": "claude-foo",
|
||||
"defaultModel": "claude-foo",
|
||||
"mainLoopModel": "claude-foo",
|
||||
});
|
||||
assert!(rename_model(&mut s, "claude-foo", "claude-bar"));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-bar"));
|
||||
assert_eq!(s["defaultModel"].as_str(), Some("claude-bar"));
|
||||
assert_eq!(s["mainLoopModel"].as_str(), Some("claude-bar"));
|
||||
}
|
||||
|
||||
// ---- migrate_fennec_to_opus ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn fennec_latest_1m_maps_to_opus_1m() {
|
||||
let mut s = settings("fennec-latest[1m]");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_latest_maps_to_opus() {
|
||||
let mut s = settings("fennec-latest");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_fast_maps_to_opus_1m_with_fast_mode() {
|
||||
let mut s = settings("fennec-fast-latest");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
assert_eq!(s["fastMode"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opus_4_5_fast_maps_to_opus_1m_with_fast_mode() {
|
||||
let mut s = settings("opus-4-5-fast");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
assert_eq!(s["fastMode"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_no_match_returns_false() {
|
||||
let mut s = settings("claude-sonnet-4-6");
|
||||
assert!(!migrate_fennec_to_opus(&mut s));
|
||||
}
|
||||
|
||||
// ---- migrate_legacy_opus_to_current ------------------------------------
|
||||
|
||||
#[test]
|
||||
fn legacy_opus_4_0_maps_to_opus() {
|
||||
let mut s = settings("claude-opus-4-0");
|
||||
assert!(migrate_legacy_opus_to_current(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legacy_opus_4_1_maps_to_opus() {
|
||||
let mut s = settings("claude-opus-4-1-20250805");
|
||||
assert!(migrate_legacy_opus_to_current(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
// ---- migrate_sonnet_45_to_sonnet_46 ------------------------------------
|
||||
|
||||
#[test]
|
||||
fn sonnet_45_explicit_id_maps_to_sonnet() {
|
||||
let mut s = settings("claude-sonnet-4-5-20250929");
|
||||
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sonnet_45_1m_maps_to_sonnet_1m() {
|
||||
let mut s = settings("claude-sonnet-4-5-20250929[1m]");
|
||||
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet[1m]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sonnet_46_is_untouched() {
|
||||
let mut s = settings("claude-sonnet-4-6");
|
||||
assert!(!migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
}
|
||||
|
||||
// ---- struct migrations -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn bypass_permissions_migrates_and_removes_old_key() {
|
||||
let mut s = json!({ "bypassPermissionsAccepted": true });
|
||||
assert!(migrate_bypass_permissions_to_settings(&mut s));
|
||||
assert!(s.get("bypassPermissionsAccepted").is_none());
|
||||
assert_eq!(s["permissionMode"].as_str(), Some("bypass"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bypass_permissions_false_does_not_set_mode() {
|
||||
let mut s = json!({ "bypassPermissionsAccepted": false });
|
||||
assert!(migrate_bypass_permissions_to_settings(&mut s));
|
||||
assert!(s.get("permissionMode").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repl_bridge_renames_field() {
|
||||
let mut s = json!({ "replBridgeEnabled": true });
|
||||
assert!(migrate_repl_bridge_to_remote_control(&mut s));
|
||||
assert!(s.get("replBridgeEnabled").is_none());
|
||||
assert_eq!(s["remoteControlAtStartup"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enable_all_mcp_renames_field() {
|
||||
let mut s = json!({ "enableAllProjectMcpServers": true });
|
||||
assert!(migrate_enable_all_mcp_servers(&mut s));
|
||||
assert!(s.get("enableAllProjectMcpServers").is_none());
|
||||
assert_eq!(s["mcpAutoApprove"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_updates_renames_field() {
|
||||
let mut s = json!({ "autoUpdatesEnabled": false });
|
||||
assert!(migrate_auto_updates(&mut s));
|
||||
assert!(s.get("autoUpdatesEnabled").is_none());
|
||||
assert_eq!(s["autoUpdates"].as_bool(), Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_auto_mode_clears_sentinel() {
|
||||
let mut s = json!({ "autoModeOptIn": "default_offer_2024" });
|
||||
assert!(reset_auto_mode_opt_in(&mut s));
|
||||
assert!(s["autoModeOptIn"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_auto_mode_leaves_other_values() {
|
||||
let mut s = json!({ "autoModeOptIn": "user_opted_in" });
|
||||
assert!(!reset_auto_mode_opt_in(&mut s));
|
||||
assert_eq!(s["autoModeOptIn"].as_str(), Some("user_opted_in"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_pro_opus_default_resets_when_not_user_set() {
|
||||
let mut s = json!({ "model": "claude-opus-4-5-20251001" });
|
||||
assert!(reset_pro_to_opus_default(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-sonnet-4-6"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_pro_opus_default_preserves_when_user_set() {
|
||||
let mut s = json!({ "model": "claude-opus-4-5-20251001", "modelSetByUser": true });
|
||||
assert!(!reset_pro_to_opus_default(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-opus-4-5-20251001"));
|
||||
}
|
||||
|
||||
// ---- run_migrations integration ----------------------------------------
|
||||
|
||||
#[test]
|
||||
fn run_migrations_applies_chain() {
|
||||
// A Sonnet 4.5 model should end up as "sonnet" after the full chain.
|
||||
let mut s = json!({ "model": "claude-sonnet-4-5-20250929" });
|
||||
let changed = run_migrations(&mut s);
|
||||
assert!(changed);
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_migrations_returns_false_when_nothing_changes() {
|
||||
let mut s = json!({ "model": "claude-sonnet-4-6", "someOtherKey": 42 });
|
||||
assert!(!run_migrations(&mut s));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_migrations_handles_empty_object() {
|
||||
let mut s = json!({});
|
||||
// No model fields, no sentinel values → nothing to do.
|
||||
assert!(!run_migrations(&mut s));
|
||||
}
|
||||
}
|
||||
364
src-rust/crates/core/src/oauth_config.rs
Normal file
364
src-rust/crates/core/src/oauth_config.rs
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
//! OAuth configuration for multiple environments.
|
||||
//!
|
||||
//! This module mirrors the TypeScript `src/constants/oauth.ts` and
|
||||
//! `src/services/oauth/crypto.ts` constants. It is intentionally
|
||||
//! *configuration-only* — no live network I/O except for the optional
|
||||
//! `fetch_oauth_profile` helper at the bottom.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scope constants (mirrors constants/oauth.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The Claude.ai inference scope — required for Bearer-auth API calls.
|
||||
pub const CLAUDE_AI_INFERENCE_SCOPE: &str = "user:inference";
|
||||
|
||||
/// The profile scope — required to read account / subscription data.
|
||||
pub const CLAUDE_AI_PROFILE_SCOPE: &str = "user:profile";
|
||||
|
||||
/// Console scope — used when creating an API key via the Console flow.
|
||||
pub const CONSOLE_SCOPE: &str = "org:create_api_key";
|
||||
|
||||
/// All Claude.ai OAuth scopes (mirrors `CLAUDE_AI_OAUTH_SCOPES`).
|
||||
pub const CLAUDE_AI_OAUTH_SCOPES: &[&str] = &[
|
||||
CLAUDE_AI_PROFILE_SCOPE,
|
||||
CLAUDE_AI_INFERENCE_SCOPE,
|
||||
"user:sessions:claude_code",
|
||||
"user:mcp_servers",
|
||||
"user:file_upload",
|
||||
];
|
||||
|
||||
/// Console OAuth scopes (mirrors `CONSOLE_OAUTH_SCOPES`).
|
||||
pub const CONSOLE_OAUTH_SCOPES: &[&str] = &[CONSOLE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
|
||||
|
||||
/// Union of all scopes used during login (mirrors `ALL_OAUTH_SCOPES`).
|
||||
/// Requesting all at once lets a single login satisfy both Console and
|
||||
/// claude.ai auth paths.
|
||||
pub const ALL_OAUTH_SCOPES: &[&str] = &[
|
||||
CONSOLE_SCOPE,
|
||||
CLAUDE_AI_PROFILE_SCOPE,
|
||||
CLAUDE_AI_INFERENCE_SCOPE,
|
||||
"user:sessions:claude_code",
|
||||
"user:mcp_servers",
|
||||
"user:file_upload",
|
||||
];
|
||||
|
||||
/// Minimum scopes required for basic operation.
|
||||
pub const MINIMUM_SCOPES: &[&str] = &[CLAUDE_AI_INFERENCE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OAuthConfig struct
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Full OAuth configuration for a deployment environment.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OAuthConfig {
|
||||
pub base_api_url: &'static str,
|
||||
pub console_authorize_url: &'static str,
|
||||
pub claude_ai_authorize_url: &'static str,
|
||||
/// The raw claude.ai web origin (separate from the authorize URL which
|
||||
/// may bounce through claude.com for attribution).
|
||||
pub claude_ai_origin: &'static str,
|
||||
pub token_url: &'static str,
|
||||
pub api_key_url: &'static str,
|
||||
pub roles_url: &'static str,
|
||||
pub console_success_url: &'static str,
|
||||
pub claudeai_success_url: &'static str,
|
||||
pub manual_redirect_url: &'static str,
|
||||
pub client_id: &'static str,
|
||||
pub oauth_file_suffix: &'static str,
|
||||
pub mcp_proxy_url: &'static str,
|
||||
pub mcp_proxy_path: &'static str,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Production config (mirrors PROD_OAUTH_CONFIG in oauth.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub const PROD_OAUTH: OAuthConfig = OAuthConfig {
|
||||
base_api_url: "https://api.anthropic.com",
|
||||
// Routes through claude.com/cai/* for attribution, 307s to claude.ai in
|
||||
// two hops — same behaviour as the TypeScript client.
|
||||
console_authorize_url: "https://platform.claude.com/oauth/authorize",
|
||||
claude_ai_authorize_url: "https://claude.com/cai/oauth/authorize",
|
||||
claude_ai_origin: "https://claude.ai",
|
||||
token_url: "https://platform.claude.com/v1/oauth/token",
|
||||
api_key_url: "https://api.anthropic.com/api/oauth/claude_cli/create_api_key",
|
||||
roles_url: "https://api.anthropic.com/api/oauth/claude_cli/roles",
|
||||
console_success_url: "https://platform.claude.com/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
|
||||
claudeai_success_url: "https://platform.claude.com/oauth/code/success?app=claude-code",
|
||||
manual_redirect_url: "https://platform.claude.com/oauth/code/callback",
|
||||
client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e",
|
||||
oauth_file_suffix: "",
|
||||
mcp_proxy_url: "https://mcp-proxy.anthropic.com",
|
||||
mcp_proxy_path: "/v1/mcp/{server_id}",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Staging config (mirrors STAGING_OAUTH_CONFIG — ant builds only)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub const STAGING_OAUTH: OAuthConfig = OAuthConfig {
|
||||
base_api_url: "https://api-staging.anthropic.com",
|
||||
console_authorize_url: "https://platform.staging.ant.dev/oauth/authorize",
|
||||
claude_ai_authorize_url: "https://claude-ai.staging.ant.dev/oauth/authorize",
|
||||
claude_ai_origin: "https://claude-ai.staging.ant.dev",
|
||||
token_url: "https://platform.staging.ant.dev/v1/oauth/token",
|
||||
api_key_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/create_api_key",
|
||||
roles_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/roles",
|
||||
console_success_url: "https://platform.staging.ant.dev/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
|
||||
claudeai_success_url: "https://platform.staging.ant.dev/oauth/code/success?app=claude-code",
|
||||
manual_redirect_url: "https://platform.staging.ant.dev/oauth/code/callback",
|
||||
client_id: "22422756-60c9-4084-8eb7-27705fd5cf9a",
|
||||
oauth_file_suffix: "-staging-oauth",
|
||||
mcp_proxy_url: "https://mcp-proxy-staging.anthropic.com",
|
||||
mcp_proxy_path: "/v1/mcp/{server_id}",
|
||||
};
|
||||
|
||||
/// Client-ID Metadata Document URL for MCP OAuth (CIMD / SEP-991).
|
||||
pub const MCP_CLIENT_METADATA_URL: &str =
|
||||
"https://claude.ai/oauth/claude-code-client-metadata";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config selection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return the OAuth config appropriate for the current environment.
|
||||
///
|
||||
/// Selection logic mirrors `getOauthConfigType()` in `constants/oauth.ts`:
|
||||
/// - `USER_TYPE=ant` + `USE_STAGING_OAUTH=true` → staging
|
||||
/// - anything else → production
|
||||
///
|
||||
/// Note: the `local` variant from the TypeScript code is intentionally
|
||||
/// omitted here — local dev servers are not needed in the Rust port yet.
|
||||
pub fn get_oauth_config() -> &'static OAuthConfig {
|
||||
let user_type = std::env::var("USER_TYPE").unwrap_or_default();
|
||||
if user_type == "ant" {
|
||||
let use_staging = std::env::var("USE_STAGING_OAUTH")
|
||||
.map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
if use_staging {
|
||||
return &STAGING_OAUTH;
|
||||
}
|
||||
}
|
||||
&PROD_OAUTH
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PKCE helpers (mirrors src/services/oauth/crypto.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// PKCE code-challenge / code-verifier helpers.
|
||||
pub mod pkce {
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Generate a cryptographically random code verifier (43–128 chars of
|
||||
/// Base64url characters, as required by RFC 7636).
|
||||
///
|
||||
/// Uses `getrandom` via the `rand` crate's OS RNG through the `uuid`
|
||||
/// crate's v4 generator — both already in-tree. Falls back to a
|
||||
/// time+pid mix if the OS RNG is unavailable.
|
||||
pub fn generate_code_verifier() -> String {
|
||||
// 32 random bytes → 43-char Base64url string (same as the TS impl).
|
||||
let bytes = random_bytes_32();
|
||||
URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
/// Compute `BASE64URL(SHA256(verifier))` — the S256 code challenge.
|
||||
pub fn code_challenge(verifier: &str) -> String {
|
||||
let hash = Sha256::digest(verifier.as_bytes());
|
||||
URL_SAFE_NO_PAD.encode(hash)
|
||||
}
|
||||
|
||||
/// Generate a random state parameter (16 Base64url chars).
|
||||
pub fn generate_state() -> String {
|
||||
let bytes = random_bytes_32();
|
||||
let encoded = URL_SAFE_NO_PAD.encode(bytes);
|
||||
// Take first 43 chars for a compact state parameter
|
||||
encoded.chars().take(43).collect()
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Internal: produce 32 random bytes.
|
||||
// We derive them from a UUID v4 (which already pulls from the OS RNG
|
||||
// via the `uuid` crate) so we don't need to add a new `rand` dep.
|
||||
// ------------------------------------------------------------------
|
||||
fn random_bytes_32() -> [u8; 32] {
|
||||
// Two UUID v4 values give us 32 bytes of OS-backed randomness.
|
||||
let u1 = uuid::Uuid::new_v4();
|
||||
let u2 = uuid::Uuid::new_v4();
|
||||
let mut out = [0u8; 32];
|
||||
out[..16].copy_from_slice(u1.as_bytes());
|
||||
out[16..].copy_from_slice(u2.as_bytes());
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Token and profile types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Raw OAuth token response from the token endpoint.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_in: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refresh_token: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
}
|
||||
|
||||
/// Slim profile fetched after token exchange.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct OAuthProfile {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub email: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub account_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub subscription_tier: Option<String>,
|
||||
}
|
||||
|
||||
/// Fetch the OAuth profile using an access token.
|
||||
///
|
||||
/// Returns a default (all-`None`) profile on any non-success response so
|
||||
/// callers can treat a profile fetch failure as non-fatal.
|
||||
pub async fn fetch_oauth_profile(
|
||||
access_token: &str,
|
||||
api_base: &str,
|
||||
) -> anyhow::Result<OAuthProfile> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{}/api/auth/oauth/profile", api_base.trim_end_matches('/'));
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(access_token)
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let profile: OAuthProfile = resp.json().await.unwrap_or_default();
|
||||
Ok(profile)
|
||||
} else {
|
||||
// Non-fatal: return an empty profile so the caller can continue.
|
||||
Ok(OAuthProfile::default())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Auth URL builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the OAuth authorization URL (mirrors `buildAuthUrl` in client.ts).
|
||||
pub fn build_auth_url(
|
||||
code_challenge: &str,
|
||||
state: &str,
|
||||
port: u16,
|
||||
is_manual: bool,
|
||||
login_with_claude_ai: bool,
|
||||
inference_only: bool,
|
||||
) -> String {
|
||||
let cfg = get_oauth_config();
|
||||
|
||||
let base = if login_with_claude_ai {
|
||||
cfg.claude_ai_authorize_url
|
||||
} else {
|
||||
cfg.console_authorize_url
|
||||
};
|
||||
|
||||
let redirect_uri = if is_manual {
|
||||
cfg.manual_redirect_url.to_string()
|
||||
} else {
|
||||
format!("http://localhost:{}/callback", port)
|
||||
};
|
||||
|
||||
let scopes: Vec<&str> = if inference_only {
|
||||
vec![CLAUDE_AI_INFERENCE_SCOPE]
|
||||
} else {
|
||||
ALL_OAUTH_SCOPES.to_vec()
|
||||
};
|
||||
|
||||
let scope_str = scopes.join(" ");
|
||||
|
||||
format!(
|
||||
"{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
base,
|
||||
urlencoding::encode(cfg.client_id),
|
||||
urlencoding::encode(&redirect_uri),
|
||||
urlencoding::encode(&scope_str),
|
||||
urlencoding::encode(code_challenge),
|
||||
urlencoding::encode(state),
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prod_config_urls_are_https() {
|
||||
assert!(PROD_OAUTH.token_url.starts_with("https://"));
|
||||
assert!(PROD_OAUTH.api_key_url.starts_with("https://"));
|
||||
assert!(PROD_OAUTH.claude_ai_authorize_url.starts_with("https://"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_staging_config_urls_are_https() {
|
||||
assert!(STAGING_OAUTH.token_url.starts_with("https://"));
|
||||
assert!(STAGING_OAUTH.api_key_url.starts_with("https://"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pkce_code_challenge_is_base64url() {
|
||||
let verifier = pkce::generate_code_verifier();
|
||||
assert!(!verifier.is_empty());
|
||||
// Base64url characters only (no +, /, =)
|
||||
assert!(!verifier.contains('+'));
|
||||
assert!(!verifier.contains('/'));
|
||||
assert!(!verifier.contains('='));
|
||||
|
||||
let challenge = pkce::code_challenge(&verifier);
|
||||
assert!(!challenge.is_empty());
|
||||
assert!(!challenge.contains('+'));
|
||||
assert!(!challenge.contains('/'));
|
||||
assert!(!challenge.contains('='));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verifier_length_meets_rfc7636_minimum() {
|
||||
let verifier = pkce::generate_code_verifier();
|
||||
// RFC 7636 §4.1: code_verifier length ∈ [43, 128]
|
||||
assert!(
|
||||
verifier.len() >= 43,
|
||||
"verifier too short: {} chars",
|
||||
verifier.len()
|
||||
);
|
||||
assert!(verifier.len() <= 128, "verifier too long: {} chars", verifier.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_oauth_scopes_contains_inference() {
|
||||
assert!(ALL_OAUTH_SCOPES.contains(&CLAUDE_AI_INFERENCE_SCOPE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_auth_url_contains_required_params() {
|
||||
let url = build_auth_url("challenge123", "state456", 8080, false, true, false);
|
||||
assert!(url.contains("challenge123"));
|
||||
assert!(url.contains("state456"));
|
||||
assert!(url.contains("S256"));
|
||||
assert!(url.contains("localhost"));
|
||||
}
|
||||
}
|
||||
347
src-rust/crates/core/src/output_styles.rs
Normal file
347
src-rust/crates/core/src/output_styles.rs
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
//! Output style system — customises how Claude responds to the user.
|
||||
//!
|
||||
//! Styles are applied by injecting `OutputStyleDef::prompt` into the system
|
||||
//! prompt. Built-in styles are defined in code; users can add their own by
|
||||
//! placing `.md` or `.json` files in:
|
||||
//! - Global: `~/.claude/output-styles/`
|
||||
//! - Project: `.claude/output-styles/`
|
||||
//!
|
||||
//! Markdown style files have a simple structure:
|
||||
//! Line 1: `# <Label>` (heading becomes the label)
|
||||
//! Line 2: short description
|
||||
//! Remainder: the prompt text injected into the system prompt
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single output style definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct OutputStyleDef {
|
||||
/// Machine-readable identifier (e.g. `"concise"`).
|
||||
pub name: String,
|
||||
/// Human-readable label shown in picker UI (e.g. `"Concise"`).
|
||||
pub label: String,
|
||||
/// One-line description.
|
||||
pub description: String,
|
||||
/// Text injected into the system prompt when this style is active.
|
||||
/// Empty string for the default style (no extra injection).
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
impl OutputStyleDef {
|
||||
// ---- Built-in styles ---------------------------------------------------
|
||||
|
||||
pub fn builtin_default() -> Self {
|
||||
Self {
|
||||
name: "default".to_string(),
|
||||
label: "Default".to_string(),
|
||||
description: "Standard Claude Code responses.".to_string(),
|
||||
prompt: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_concise() -> Self {
|
||||
Self {
|
||||
name: "concise".to_string(),
|
||||
label: "Concise".to_string(),
|
||||
description: "Short, direct responses with minimal explanation.".to_string(),
|
||||
prompt: "Be maximally concise. Skip preamble, summaries, and filler. \
|
||||
Lead with the answer."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_explanatory() -> Self {
|
||||
Self {
|
||||
name: "explanatory".to_string(),
|
||||
label: "Explanatory".to_string(),
|
||||
description: "Thorough explanations with reasoning and alternatives.".to_string(),
|
||||
prompt: "When explaining code or concepts, be thorough and educational. \
|
||||
Include reasoning, alternatives considered, and potential pitfalls. \
|
||||
Err on the side of over-explaining."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_learning() -> Self {
|
||||
Self {
|
||||
name: "learning".to_string(),
|
||||
label: "Learning".to_string(),
|
||||
description: "Pedagogical mode — explains patterns and decisions.".to_string(),
|
||||
prompt: "This user is learning. Explain concepts as you implement them. \
|
||||
Point out patterns, best practices, and why you made each decision. \
|
||||
Use analogies when helpful."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Built-ins
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return all built-in output styles in display order.
|
||||
pub fn builtin_styles() -> Vec<OutputStyleDef> {
|
||||
vec![
|
||||
OutputStyleDef::builtin_default(),
|
||||
OutputStyleDef::builtin_concise(),
|
||||
OutputStyleDef::builtin_explanatory(),
|
||||
OutputStyleDef::builtin_learning(),
|
||||
]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Loading from disk
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Load user-defined output styles from a directory.
|
||||
///
|
||||
/// Supported file formats:
|
||||
/// - `.md` — Markdown: `# Label\ndescription\n\nprompt text…`
|
||||
/// - `.json` — JSON: `{ "name": "…", "label": "…", "description": "…", "prompt": "…" }`
|
||||
///
|
||||
/// Files that cannot be parsed are silently skipped.
|
||||
pub fn load_output_styles_dir(styles_dir: &Path) -> Vec<OutputStyleDef> {
|
||||
if !styles_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let entries = match std::fs::read_dir(styles_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let mut styles = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
|
||||
if ext == "md" || ext == "json" {
|
||||
if let Some(style) = load_style_file(&path) {
|
||||
styles.push(style);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort alphabetically so the list is deterministic.
|
||||
styles.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
styles
|
||||
}
|
||||
|
||||
fn load_style_file(path: &Path) -> Option<OutputStyleDef> {
|
||||
let content = std::fs::read_to_string(path).ok()?;
|
||||
let stem = path.file_stem()?.to_string_lossy().into_owned();
|
||||
|
||||
if path.extension().and_then(|e| e.to_str()) == Some("json") {
|
||||
// Try deserialising directly; fall back to inserting the stem as name.
|
||||
let mut def: OutputStyleDef = serde_json::from_str(&content).ok()?;
|
||||
if def.name.is_empty() {
|
||||
def.name = stem;
|
||||
}
|
||||
return Some(def);
|
||||
}
|
||||
|
||||
// Markdown format:
|
||||
// Line 1: # Label (optional leading `#` and whitespace)
|
||||
// Line 2: description (short, plain text)
|
||||
// Lines 3+: prompt text (everything after the blank / second line)
|
||||
let mut lines = content.lines();
|
||||
|
||||
let raw_label = lines.next().unwrap_or("").trim().to_string();
|
||||
let label = raw_label.trim_start_matches('#').trim().to_string();
|
||||
let label = if label.is_empty() { stem.clone() } else { label };
|
||||
|
||||
let description = lines
|
||||
.next()
|
||||
.map(|l| l.trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Collect remaining lines as the prompt, trimming leading blank lines.
|
||||
let prompt_lines: Vec<&str> = lines.collect();
|
||||
let prompt = prompt_lines.join("\n").trim().to_string();
|
||||
|
||||
Some(OutputStyleDef {
|
||||
name: stem,
|
||||
label,
|
||||
description,
|
||||
prompt,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Aggregated access
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return all styles available for `config_dir`:
|
||||
/// built-ins first, then styles from `<config_dir>/output-styles/`.
|
||||
///
|
||||
/// `config_dir` is typically `~/.claude`.
|
||||
pub fn all_styles(config_dir: &Path) -> Vec<OutputStyleDef> {
|
||||
let mut styles = builtin_styles();
|
||||
let user_dir = config_dir.join("output-styles");
|
||||
styles.extend(load_output_styles_dir(&user_dir));
|
||||
styles
|
||||
}
|
||||
|
||||
/// Find a style by its `name` field.
|
||||
pub fn find_style<'a>(styles: &'a [OutputStyleDef], name: &str) -> Option<&'a OutputStyleDef> {
|
||||
styles.iter().find(|s| s.name == name)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write as IoWrite;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// ---- builtin_styles ----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn builtin_styles_non_empty() {
|
||||
assert!(!builtin_styles().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_styles_have_unique_names() {
|
||||
let styles = builtin_styles();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for s in &styles {
|
||||
assert!(seen.insert(&s.name), "duplicate style name: {}", s.name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_default_has_empty_prompt() {
|
||||
let def = OutputStyleDef::builtin_default();
|
||||
assert!(def.prompt.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_non_default_have_prompts() {
|
||||
for s in builtin_styles() {
|
||||
if s.name != "default" {
|
||||
assert!(
|
||||
!s.prompt.is_empty(),
|
||||
"style '{}' should have a non-empty prompt",
|
||||
s.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- find_style --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn find_style_by_name() {
|
||||
let styles = builtin_styles();
|
||||
let found = find_style(&styles, "concise");
|
||||
assert!(found.is_some());
|
||||
assert_eq!(found.unwrap().name, "concise");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_style_missing() {
|
||||
let styles = builtin_styles();
|
||||
assert!(find_style(&styles, "nonexistent-xyz").is_none());
|
||||
}
|
||||
|
||||
// ---- load_output_styles_dir (markdown) ---------------------------------
|
||||
|
||||
fn write_file(dir: &TempDir, name: &str, content: &str) {
|
||||
let path = dir.path().join(name);
|
||||
let mut f = std::fs::File::create(path).unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_markdown_style() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(
|
||||
&dir,
|
||||
"terse.md",
|
||||
"# Terse\nVery short answers.\n\nOne sentence per response.",
|
||||
);
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles.len(), 1);
|
||||
let s = &styles[0];
|
||||
assert_eq!(s.name, "terse");
|
||||
assert_eq!(s.label, "Terse");
|
||||
assert_eq!(s.description, "Very short answers.");
|
||||
assert_eq!(s.prompt, "One sentence per response.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_json_style() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(
|
||||
&dir,
|
||||
"formal.json",
|
||||
r#"{"name":"formal","label":"Formal","description":"Formal tone.","prompt":"Use formal language."}"#,
|
||||
);
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles.len(), 1);
|
||||
let s = &styles[0];
|
||||
assert_eq!(s.name, "formal");
|
||||
assert_eq!(s.label, "Formal");
|
||||
assert_eq!(s.prompt, "Use formal language.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_skips_unknown_extensions() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(&dir, "ignore.txt", "should be skipped");
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert!(styles.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_non_existent_dir_returns_empty() {
|
||||
use std::path::PathBuf;
|
||||
let styles = load_output_styles_dir(&PathBuf::from("/nonexistent/path/xyz"));
|
||||
assert!(styles.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_multiple_styles_sorted() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(&dir, "zebra.md", "# Zebra\nZ style.\n\nZ prompt.");
|
||||
write_file(&dir, "apple.md", "# Apple\nA style.\n\nA prompt.");
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles[0].name, "apple");
|
||||
assert_eq!(styles[1].name, "zebra");
|
||||
}
|
||||
|
||||
// ---- all_styles --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn all_styles_includes_builtins() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
// no output-styles subdir → only built-ins
|
||||
let styles = all_styles(dir.path());
|
||||
assert!(styles.iter().any(|s| s.name == "default"));
|
||||
assert!(styles.iter().any(|s| s.name == "concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_styles_merges_user_styles() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let output_styles_dir = dir.path().join("output-styles");
|
||||
std::fs::create_dir_all(&output_styles_dir).unwrap();
|
||||
|
||||
// Write a user style file.
|
||||
let mut f = std::fs::File::create(output_styles_dir.join("pirate.md")).unwrap();
|
||||
f.write_all(b"# Pirate\nSpeak like a pirate.\n\nArrr matey!").unwrap();
|
||||
|
||||
let styles = all_styles(dir.path());
|
||||
assert!(styles.iter().any(|s| s.name == "pirate"));
|
||||
// Built-ins still present.
|
||||
assert!(styles.iter().any(|s| s.name == "default"));
|
||||
}
|
||||
}
|
||||
526
src-rust/crates/core/src/system_prompt.rs
Normal file
526
src-rust/crates/core/src/system_prompt.rs
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
//! Modular system prompt assembly with caching support.
|
||||
//!
|
||||
//! Mirrors the TypeScript `systemPromptSections.ts` / `prompts.ts` architecture:
|
||||
//! cacheable (static) sections are placed before `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`;
|
||||
//! volatile, session-specific sections follow it.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dynamic boundary marker
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Marker that splits the cached vs dynamic parts of the system prompt.
|
||||
/// Everything before this marker can be prompt-cached by the API.
|
||||
/// Matches the TypeScript constant `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`.
|
||||
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Section cache (mirrors bootstrap/state.ts systemPromptSectionCache)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn section_cache() -> &'static Mutex<HashMap<String, Option<String>>> {
|
||||
static CACHE: OnceLock<Mutex<HashMap<String, Option<String>>>> = OnceLock::new();
|
||||
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Clear all cached system prompt sections (called on /clear and /compact).
|
||||
pub fn clear_system_prompt_sections() {
|
||||
if let Ok(mut cache) = section_cache().lock() {
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// A single named section of the system prompt.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SystemPromptSection {
|
||||
/// Identifier used for cache lookups and invalidation.
|
||||
pub tag: &'static str,
|
||||
/// Computed content (None means the section is absent/disabled).
|
||||
pub content: Option<String>,
|
||||
/// If true the section is volatile and must not be prompt-cached.
|
||||
pub cache_break: bool,
|
||||
}
|
||||
|
||||
impl SystemPromptSection {
|
||||
/// Create a memoizable (cacheable) section.
|
||||
pub fn cached(tag: &'static str, content: impl Into<String>) -> Self {
|
||||
Self { tag, content: Some(content.into()), cache_break: false }
|
||||
}
|
||||
|
||||
/// Create a volatile section that re-evaluates every turn.
|
||||
/// Passing `None` for content means the section is absent this turn.
|
||||
pub fn uncached(tag: &'static str, content: Option<impl Into<String>>) -> Self {
|
||||
Self {
|
||||
tag,
|
||||
content: content.map(|c| c.into()),
|
||||
cache_break: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Output style
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Output styles that affect the system prompt.
|
||||
/// Serialised as lowercase strings to match settings.json.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OutputStyle {
|
||||
#[default]
|
||||
Default,
|
||||
Explanatory,
|
||||
Learning,
|
||||
Concise,
|
||||
Formal,
|
||||
Casual,
|
||||
}
|
||||
|
||||
impl OutputStyle {
|
||||
/// Returns the system-prompt suffix for this style, or `None` for Default.
|
||||
pub fn prompt_suffix(self) -> Option<&'static str> {
|
||||
match self {
|
||||
OutputStyle::Explanatory => Some(
|
||||
"When explaining code or concepts, be thorough and educational. \
|
||||
Include reasoning, alternatives considered, and potential pitfalls. \
|
||||
Err on the side of over-explaining.",
|
||||
),
|
||||
OutputStyle::Learning => Some(
|
||||
"This user is learning. Explain concepts as you implement them. \
|
||||
Point out patterns, best practices, and why you made each decision. \
|
||||
Use analogies when helpful.",
|
||||
),
|
||||
OutputStyle::Concise => Some(
|
||||
"Be maximally concise. Skip preamble, summaries, and filler. \
|
||||
Lead with the answer. One sentence is better than three.",
|
||||
),
|
||||
OutputStyle::Formal => Some(
|
||||
"Maintain a formal, professional tone. Use precise technical language.",
|
||||
),
|
||||
OutputStyle::Casual => Some("Use a casual, conversational tone."),
|
||||
OutputStyle::Default => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from a string (case-insensitive).
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"explanatory" => Self::Explanatory,
|
||||
"learning" => Self::Learning,
|
||||
"concise" => Self::Concise,
|
||||
"formal" => Self::Formal,
|
||||
"casual" => Self::Casual,
|
||||
_ => Self::Default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// System prompt prefix variants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Which entrypoint context Claude Code is running in.
|
||||
/// Determines the opening attribution line of the system prompt.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum SystemPromptPrefix {
|
||||
/// Standard interactive CLI session.
|
||||
Cli,
|
||||
/// Running as a sub-agent spawned by the Claude Agent SDK.
|
||||
Sdk,
|
||||
/// The CLI preset running within the Agent SDK
|
||||
/// (non-interactive + append_system_prompt set).
|
||||
SdkPreset,
|
||||
/// Running on Vertex AI.
|
||||
Vertex,
|
||||
/// Running on AWS Bedrock.
|
||||
Bedrock,
|
||||
/// Remote / headless CCR session.
|
||||
Remote,
|
||||
}
|
||||
|
||||
impl SystemPromptPrefix {
|
||||
/// Detect from environment variables, mirroring `getCLISyspromptPrefix`.
|
||||
pub fn detect(is_non_interactive: bool, has_append_system_prompt: bool) -> Self {
|
||||
// Vertex: always uses the default "Claude Code" prefix.
|
||||
if std::env::var("ANTHROPIC_VERTEX_PROJECT_ID").is_ok()
|
||||
|| std::env::var("CLOUD_ML_PROJECT_ID").is_ok()
|
||||
{
|
||||
return Self::Vertex;
|
||||
}
|
||||
|
||||
if std::env::var("AWS_BEDROCK_MODEL_ID").is_ok() {
|
||||
return Self::Bedrock;
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_REMOTE").is_ok() {
|
||||
return Self::Remote;
|
||||
}
|
||||
|
||||
// Non-interactive mode maps to SDK variants (matches TS getCLISyspromptPrefix).
|
||||
if is_non_interactive {
|
||||
if has_append_system_prompt {
|
||||
return Self::SdkPreset;
|
||||
}
|
||||
return Self::Sdk;
|
||||
}
|
||||
|
||||
Self::Cli
|
||||
}
|
||||
|
||||
/// The opening attribution string for this prefix variant.
|
||||
pub fn attribution_text(self) -> &'static str {
|
||||
match self {
|
||||
Self::Cli | Self::Vertex | Self::Bedrock | Self::Remote => {
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
}
|
||||
Self::SdkPreset => {
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, \
|
||||
running within the Claude Agent SDK."
|
||||
}
|
||||
Self::Sdk => {
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Build options
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// All options controlling what goes into the assembled system prompt.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SystemPromptOptions {
|
||||
/// Override auto-detected prefix.
|
||||
pub prefix: Option<SystemPromptPrefix>,
|
||||
/// Whether the session is non-interactive (SDK / pipe mode).
|
||||
pub is_non_interactive: bool,
|
||||
/// Whether --append-system-prompt is set (affects prefix detection).
|
||||
pub has_append_system_prompt: bool,
|
||||
/// Output style to inject.
|
||||
pub output_style: OutputStyle,
|
||||
/// Absolute path to the working directory (injected as dynamic section).
|
||||
pub working_directory: Option<String>,
|
||||
/// Pre-built memory content from memdir (injected as dynamic section).
|
||||
pub memory_content: String,
|
||||
/// Custom system prompt (--system-prompt flag or settings).
|
||||
pub custom_system_prompt: Option<String>,
|
||||
/// Additional text appended after everything else (--append-system-prompt).
|
||||
pub append_system_prompt: Option<String>,
|
||||
/// If true and `custom_system_prompt` is set, the entire default prompt is
|
||||
/// replaced — only the custom text + dynamic boundary are emitted.
|
||||
pub replace_system_prompt: bool,
|
||||
/// Inject the coordinator-mode section.
|
||||
pub coordinator_mode: bool,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main assembly function
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the complete system prompt string.
|
||||
///
|
||||
/// The returned string contains `SYSTEM_PROMPT_DYNAMIC_BOUNDARY` as an
|
||||
/// internal marker. Callers (e.g. `buildSystemPromptBlocks` in cc-query)
|
||||
/// split on this marker to determine which portions are eligible for
|
||||
/// Anthropic prompt-caching.
|
||||
pub fn build_system_prompt(opts: &SystemPromptOptions) -> String {
|
||||
// Replace mode: skip all default sections.
|
||||
if opts.replace_system_prompt {
|
||||
if let Some(custom) = &opts.custom_system_prompt {
|
||||
return format!("{}\n\n{}", custom, SYSTEM_PROMPT_DYNAMIC_BOUNDARY);
|
||||
}
|
||||
}
|
||||
|
||||
let prefix = opts
|
||||
.prefix
|
||||
.unwrap_or_else(|| {
|
||||
SystemPromptPrefix::detect(
|
||||
opts.is_non_interactive,
|
||||
opts.has_append_system_prompt,
|
||||
)
|
||||
});
|
||||
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// CACHEABLE sections (before the dynamic boundary) //
|
||||
// ------------------------------------------------------------------ //
|
||||
|
||||
// 1. Attribution header
|
||||
parts.push(prefix.attribution_text().to_string());
|
||||
|
||||
// 2. Core capabilities
|
||||
parts.push(CORE_CAPABILITIES.to_string());
|
||||
|
||||
// 3. Tool use guidelines
|
||||
parts.push(TOOL_USE_GUIDELINES.to_string());
|
||||
|
||||
// 4. Executing actions with care
|
||||
parts.push(ACTIONS_SECTION.to_string());
|
||||
|
||||
// 5. Safety guidelines
|
||||
parts.push(SAFETY_GUIDELINES.to_string());
|
||||
|
||||
// 6. Cyber-risk instruction (owned by safeguards — do not edit)
|
||||
parts.push(CYBER_RISK_INSTRUCTION.to_string());
|
||||
|
||||
// 7. Output style (cacheable when non-Default; its content is stable)
|
||||
if let Some(style_text) = opts.output_style.prompt_suffix() {
|
||||
parts.push(format!("\n## Output Style\n{}", style_text));
|
||||
}
|
||||
|
||||
// 8. Coordinator mode (cacheable: content is constant)
|
||||
if opts.coordinator_mode {
|
||||
parts.push(COORDINATOR_SYSTEM_PROMPT.to_string());
|
||||
}
|
||||
|
||||
// 9. Custom system prompt addition (appended to cacheable block)
|
||||
if let Some(custom) = &opts.custom_system_prompt {
|
||||
parts.push(format!(
|
||||
"\n<custom_instructions>\n{}\n</custom_instructions>",
|
||||
custom
|
||||
));
|
||||
}
|
||||
|
||||
// Dynamic boundary marker
|
||||
parts.push(SYSTEM_PROMPT_DYNAMIC_BOUNDARY.to_string());
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// DYNAMIC / UNCACHEABLE sections (after the boundary) //
|
||||
// ------------------------------------------------------------------ //
|
||||
|
||||
// 10. Working directory
|
||||
if let Some(cwd) = &opts.working_directory {
|
||||
parts.push(format!("\n<working_directory>{}</working_directory>", cwd));
|
||||
}
|
||||
|
||||
// 11. Memory injection (from memdir)
|
||||
if !opts.memory_content.is_empty() {
|
||||
parts.push(format!(
|
||||
"\n<memory>\n{}\n</memory>",
|
||||
opts.memory_content
|
||||
));
|
||||
}
|
||||
|
||||
// 12. Appended system prompt (--append-system-prompt)
|
||||
if let Some(append) = &opts.append_system_prompt {
|
||||
parts.push(format!("\n{}", append));
|
||||
}
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Static system prompt sections
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CORE_CAPABILITIES: &str = r#"
|
||||
## Capabilities
|
||||
|
||||
You have access to powerful tools for software engineering tasks:
|
||||
- **Read/Write files**: Read any file, write new files, edit existing files with precise diffs
|
||||
- **Execute commands**: Run bash commands, PowerShell scripts, background processes
|
||||
- **Search**: Glob patterns, regex grep, web search, file content search
|
||||
- **Web**: Fetch URLs, search the internet
|
||||
- **Agents**: Spawn parallel sub-agents for complex multi-step work
|
||||
- **Memory**: Persistent notes across sessions via the memory system
|
||||
- **MCP servers**: Connect to external tools and APIs via Model Context Protocol
|
||||
- **Jupyter notebooks**: Read and edit notebook cells
|
||||
|
||||
## How to approach tasks
|
||||
|
||||
1. **Understand before acting**: Read relevant files before making changes
|
||||
2. **Minimal changes**: Only modify what's needed. Don't refactor unrequested code.
|
||||
3. **Verify**: Check your work with tests or by reading the result
|
||||
4. **Communicate blockers**: If stuck, ask the user rather than guessing
|
||||
"#;
|
||||
|
||||
const TOOL_USE_GUIDELINES: &str = r#"
|
||||
## Tool use guidelines
|
||||
|
||||
- Use dedicated tools (Read, Edit, Glob, Grep) instead of bash equivalents
|
||||
- For searches, prefer Grep over `grep`; prefer Glob over `find`
|
||||
- Parallelize independent tool calls in a single response
|
||||
- For file edits: always read the file first, then make targeted edits
|
||||
- Bash commands timeout after 2 minutes; use background mode for long operations
|
||||
"#;
|
||||
|
||||
const ACTIONS_SECTION: &str = r#"
|
||||
## Executing actions with care
|
||||
|
||||
Carefully consider the reversibility and blast radius of actions. For actions
|
||||
that are hard to reverse, affect shared systems, or could be risky or
|
||||
destructive, check with the user before proceeding. Authorization stands for
|
||||
the scope specified, not beyond. Match the scope of your actions to what was
|
||||
actually requested.
|
||||
"#;
|
||||
|
||||
const SAFETY_GUIDELINES: &str = r#"
|
||||
## Safety guidelines
|
||||
|
||||
- Never delete files without explicit user confirmation
|
||||
- Don't modify protected files (.gitconfig, .bashrc, .zshrc, .mcp.json, .claude.json)
|
||||
- Be careful with destructive operations (rm -rf, DROP TABLE, etc.)
|
||||
- Don't commit secrets, credentials, or API keys
|
||||
- For ambiguous destructive actions, ask before proceeding
|
||||
"#;
|
||||
|
||||
/// Security / cyber-risk instruction.
|
||||
/// Owned by the safeguards team — do not modify the content here.
|
||||
const CYBER_RISK_INSTRUCTION: &str = r#"
|
||||
## Security
|
||||
|
||||
You are authorized to assist with security research, CTF challenges, penetration testing
|
||||
with explicit authorization, defensive security, and educational security content. Do not
|
||||
assist with creating malware, unauthorized access, denial-of-service attacks, or any
|
||||
destructive security techniques without clear legitimate purpose.
|
||||
"#;
|
||||
|
||||
const COORDINATOR_SYSTEM_PROMPT: &str = r#"
|
||||
## Coordinator Mode
|
||||
|
||||
You are operating as an orchestrator. Spawn parallel worker agents using the Agent tool.
|
||||
Each worker prompt must be fully self-contained. Synthesize findings before delegating
|
||||
follow-up work. Use TaskCreate/TaskUpdate to track parallel work.
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_opts() -> SystemPromptOptions {
|
||||
SystemPromptOptions::default()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_prompt_contains_boundary() {
|
||||
let prompt = build_system_prompt(&default_opts());
|
||||
assert!(
|
||||
prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY),
|
||||
"System prompt must contain the dynamic boundary marker"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_prompt_contains_attribution() {
|
||||
let prompt = build_system_prompt(&default_opts());
|
||||
assert!(prompt.contains("Claude Code"), "Default prompt must contain attribution");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_system_prompt() {
|
||||
let opts = SystemPromptOptions {
|
||||
custom_system_prompt: Some("Custom only.".to_string()),
|
||||
replace_system_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.starts_with("Custom only."));
|
||||
assert!(!prompt.contains("Capabilities"));
|
||||
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_working_directory_in_dynamic_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
working_directory: Some("/home/user/project".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
|
||||
let cwd_pos = prompt.find("/home/user/project").unwrap();
|
||||
assert!(
|
||||
cwd_pos > boundary_pos,
|
||||
"Working directory must appear after the dynamic boundary"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_content_in_dynamic_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
memory_content: "- [test.md](test.md) — a test memory".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
|
||||
let mem_pos = prompt.find("test.md").unwrap();
|
||||
assert!(
|
||||
mem_pos > boundary_pos,
|
||||
"Memory content must appear after the dynamic boundary"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_concise() {
|
||||
let opts = SystemPromptOptions {
|
||||
output_style: OutputStyle::Concise,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.contains("maximally concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_default_has_no_suffix() {
|
||||
let opts = SystemPromptOptions {
|
||||
output_style: OutputStyle::Default,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
// None of the style suffixes should appear
|
||||
assert!(!prompt.contains("maximally concise"));
|
||||
assert!(!prompt.contains("This user is learning"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_mode_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
coordinator_mode: true,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.contains("Coordinator Mode"));
|
||||
assert!(prompt.contains("orchestrator"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_from_str() {
|
||||
assert_eq!(OutputStyle::from_str("concise"), OutputStyle::Concise);
|
||||
assert_eq!(OutputStyle::from_str("FORMAL"), OutputStyle::Formal);
|
||||
assert_eq!(OutputStyle::from_str("unknown"), OutputStyle::Default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sdk_prefix_non_interactive_no_append() {
|
||||
let prefix = SystemPromptPrefix::detect(true, false);
|
||||
assert_eq!(prefix, SystemPromptPrefix::Sdk);
|
||||
assert!(prefix.attribution_text().contains("Claude agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sdk_preset_prefix_non_interactive_with_append() {
|
||||
let prefix = SystemPromptPrefix::detect(true, true);
|
||||
assert_eq!(prefix, SystemPromptPrefix::SdkPreset);
|
||||
assert!(prefix.attribution_text().contains("Claude Agent SDK"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_section_cache() {
|
||||
// Populate cache then clear it — should not panic.
|
||||
{
|
||||
let mut cache = section_cache().lock().unwrap();
|
||||
cache.insert("test_section".to_string(), Some("content".to_string()));
|
||||
}
|
||||
clear_system_prompt_sections();
|
||||
let cache = section_cache().lock().unwrap();
|
||||
assert!(cache.is_empty());
|
||||
}
|
||||
}
|
||||
682
src-rust/crates/core/src/team_memory_sync.rs
Normal file
682
src-rust/crates/core/src/team_memory_sync.rs
Normal file
|
|
@ -0,0 +1,682 @@
|
|||
//! Team memory synchronization with claude.ai API.
|
||||
//!
|
||||
//! Implements delta push (only changed files) with ETag-based optimistic
|
||||
//! concurrency and greedy bin-packing of changed entries into batches that
|
||||
//! fit within the server's PUT body limit.
|
||||
//!
|
||||
//! Pull is server-wins: remote content overwrites local files unconditionally.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Maximum bytes per local file accepted for sync (250 KB)
|
||||
const MAX_FILE_SIZE_BYTES: usize = 250 * 1024;
|
||||
|
||||
/// Maximum serialized bytes per PUT request body (200 KB)
|
||||
const MAX_PUT_BODY_BYTES: usize = 200 * 1024;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Persisted per-repo sync state (stored alongside local team-memory files).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct SyncState {
|
||||
/// ETag returned by the last successful GET or PUT.
|
||||
pub last_known_etag: Option<String>,
|
||||
/// Per-key server-side checksums (`"sha256:<hex>"`).
|
||||
/// Used to diff local vs remote without re-uploading unchanged entries.
|
||||
pub server_checksums: HashMap<String, String>,
|
||||
/// Server-enforced max_entries from a prior 413 response.
|
||||
pub server_max_entries: Option<usize>,
|
||||
}
|
||||
|
||||
/// A single team-memory entry (one markdown file).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TeamMemoryEntry {
|
||||
/// Relative file path (forward-slash separated, e.g. `"MEMORY.md"`).
|
||||
pub key: String,
|
||||
/// UTF-8 file content (typically Markdown).
|
||||
pub content: String,
|
||||
/// `"sha256:<hex>"` of the content.
|
||||
pub checksum: String,
|
||||
}
|
||||
|
||||
/// Server response shape for GET `/api/claude_code/team_memory`.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TeamMemoryData {
|
||||
pub entries: Vec<TeamMemoryEntry>,
|
||||
pub etag: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Checksum helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute `"sha256:<lowercase hex>"` of a string.
|
||||
pub fn content_checksum(content: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(content.as_bytes());
|
||||
format!("sha256:{}", hex::encode(hasher.finalize()))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path security validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Reject paths that could escape the team-memory directory.
|
||||
///
|
||||
/// Checks performed (mirroring the TypeScript `securePath` validation):
|
||||
/// - No null bytes
|
||||
/// - No URL-encoded traversal sequences (`%2e`, `%2f`, case-insensitive)
|
||||
/// - No backslashes
|
||||
/// - Not an absolute path (Unix `/` or Windows `C:` style)
|
||||
/// - No `..` components
|
||||
pub fn validate_memory_path(path: &str) -> Result<()> {
|
||||
if path.contains('\0') {
|
||||
anyhow::bail!("Path contains null bytes: {:?}", path);
|
||||
}
|
||||
let lower = path.to_ascii_lowercase();
|
||||
if lower.contains("%2e") || lower.contains("%2f") {
|
||||
anyhow::bail!("Path contains URL-encoded traversal sequences: {:?}", path);
|
||||
}
|
||||
if path.contains('\\') {
|
||||
anyhow::bail!("Path contains backslashes: {:?}", path);
|
||||
}
|
||||
if path.starts_with('/') {
|
||||
anyhow::bail!("Absolute Unix paths not allowed: {:?}", path);
|
||||
}
|
||||
// Windows-style absolute path: e.g. "C:" or "c:"
|
||||
if path.len() >= 2 {
|
||||
let mut chars = path.chars();
|
||||
let first = chars.next().unwrap();
|
||||
if first.is_ascii_alphabetic() && chars.next() == Some(':') {
|
||||
anyhow::bail!("Absolute Windows paths not allowed: {:?}", path);
|
||||
}
|
||||
}
|
||||
if path.split('/').any(|component| component == "..") {
|
||||
anyhow::bail!("Path traversal not allowed: {:?}", path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TeamMemorySync
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Drives pull and push against the claude.ai team-memory API.
|
||||
pub struct TeamMemorySync {
|
||||
/// Base URL of the API, e.g. `"https://claude.ai"`.
|
||||
api_base: String,
|
||||
/// Repo identifier sent as a query parameter.
|
||||
repo: String,
|
||||
/// Bearer token for authentication.
|
||||
token: String,
|
||||
/// Local directory that mirrors the server's key namespace.
|
||||
team_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl TeamMemorySync {
|
||||
pub fn new(api_base: String, repo: String, token: String, team_dir: PathBuf) -> Self {
|
||||
Self { api_base, repo, token, team_dir }
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Pull
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Pull all entries from the server. Server wins: overwrites local files.
|
||||
///
|
||||
/// Updates `state.last_known_etag` and `state.server_checksums` on success.
|
||||
/// Returns `Ok(())` on HTTP 404 (no remote data yet).
|
||||
pub async fn pull(&self, state: &mut SyncState) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/api/claude_code/team_memory?repo={}",
|
||||
self.api_base,
|
||||
urlencoding::encode(&self.repo),
|
||||
);
|
||||
|
||||
let response = client
|
||||
.get(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.send()
|
||||
.await
|
||||
.context("team memory pull: HTTP request failed")?;
|
||||
|
||||
let http_status = response.status();
|
||||
|
||||
if http_status.as_u16() == 404 {
|
||||
return Ok(()); // No remote data yet
|
||||
}
|
||||
|
||||
if !http_status.is_success() {
|
||||
anyhow::bail!("team memory pull failed with status {}", http_status);
|
||||
}
|
||||
|
||||
// Capture ETag before consuming the response body
|
||||
if let Some(etag) = response
|
||||
.headers()
|
||||
.get("etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
state.last_known_etag = Some(etag.to_string());
|
||||
}
|
||||
|
||||
let data: TeamMemoryData = response
|
||||
.json()
|
||||
.await
|
||||
.context("team memory pull: failed to parse response JSON")?;
|
||||
|
||||
state.server_checksums.clear();
|
||||
|
||||
for entry in &data.entries {
|
||||
validate_memory_path(&entry.key)
|
||||
.with_context(|| format!("server returned unsafe path: {:?}", entry.key))?;
|
||||
|
||||
state
|
||||
.server_checksums
|
||||
.insert(entry.key.clone(), entry.checksum.clone());
|
||||
|
||||
let local_path = self.team_dir.join(&entry.key);
|
||||
if let Some(parent) = local_path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("create_dir_all for {:?}", parent))?;
|
||||
}
|
||||
|
||||
if entry.content.len() <= MAX_FILE_SIZE_BYTES {
|
||||
tokio::fs::write(&local_path, &entry.content)
|
||||
.await
|
||||
.with_context(|| format!("writing {:?}", local_path))?;
|
||||
}
|
||||
// Files exceeding MAX_FILE_SIZE_BYTES are silently skipped (same behaviour as push)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Push
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Push local changes to the server using delta upload.
|
||||
///
|
||||
/// Only entries whose local checksum differs from `state.server_checksums`
|
||||
/// are uploaded. Changed entries are packed into batches ≤ `MAX_PUT_BODY_BYTES`.
|
||||
pub async fn push(&self, state: &mut SyncState) -> Result<()> {
|
||||
let local_entries = self
|
||||
.scan_local_files()
|
||||
.await
|
||||
.context("team memory push: scanning local files")?;
|
||||
|
||||
// Delta: entries where local hash ≠ last-known server hash
|
||||
let changed: Vec<TeamMemoryEntry> = local_entries
|
||||
.into_iter()
|
||||
.filter(|entry| {
|
||||
state
|
||||
.server_checksums
|
||||
.get(&entry.key)
|
||||
.map(|s| s.as_str())
|
||||
!= Some(&entry.checksum)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if changed.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let batches = self.pack_batches(changed);
|
||||
for batch in batches {
|
||||
self.upload_batch(batch, state)
|
||||
.await
|
||||
.context("team memory push: uploading batch")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Internals
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Greedy bin-packing: pack entries into batches that each serialise to
|
||||
/// ≤ `MAX_PUT_BODY_BYTES`. Entries that individually exceed the limit go
|
||||
/// into singleton batches (server will reject them with 413, but that is
|
||||
/// the caller's problem).
|
||||
fn pack_batches(&self, entries: Vec<TeamMemoryEntry>) -> Vec<Vec<TeamMemoryEntry>> {
|
||||
let mut batches: Vec<Vec<TeamMemoryEntry>> = Vec::new();
|
||||
let mut current: Vec<TeamMemoryEntry> = Vec::new();
|
||||
let mut current_size: usize = 0;
|
||||
|
||||
for entry in entries {
|
||||
// Rough size estimate: key + content + JSON envelope overhead
|
||||
let entry_size = entry.key.len() + entry.content.len() + 100;
|
||||
|
||||
if entry_size > MAX_PUT_BODY_BYTES {
|
||||
// Oversized entry goes solo
|
||||
if !current.is_empty() {
|
||||
batches.push(std::mem::take(&mut current));
|
||||
current_size = 0;
|
||||
}
|
||||
batches.push(vec![entry]);
|
||||
continue;
|
||||
}
|
||||
|
||||
if current_size + entry_size > MAX_PUT_BODY_BYTES && !current.is_empty() {
|
||||
batches.push(std::mem::take(&mut current));
|
||||
current_size = 0;
|
||||
}
|
||||
|
||||
current_size += entry_size;
|
||||
current.push(entry);
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
batches.push(current);
|
||||
}
|
||||
|
||||
batches
|
||||
}
|
||||
|
||||
async fn upload_batch(
|
||||
&self,
|
||||
batch: Vec<TeamMemoryEntry>,
|
||||
state: &mut SyncState,
|
||||
) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/api/claude_code/team_memory?repo={}",
|
||||
self.api_base,
|
||||
urlencoding::encode(&self.repo),
|
||||
);
|
||||
|
||||
let body = serde_json::json!({ "entries": batch });
|
||||
|
||||
let mut req = client
|
||||
.put(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.json(&body);
|
||||
|
||||
if let Some(etag) = &state.last_known_etag {
|
||||
req = req.header("If-Match", etag);
|
||||
}
|
||||
|
||||
let response = req
|
||||
.send()
|
||||
.await
|
||||
.context("team memory: PUT request failed")?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
|
||||
match status {
|
||||
200 | 201 | 204 => {
|
||||
if let Some(etag) = response
|
||||
.headers()
|
||||
.get("etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
state.last_known_etag = Some(etag.to_string());
|
||||
}
|
||||
// Update local checksum map to reflect uploaded state
|
||||
for entry in &batch {
|
||||
state
|
||||
.server_checksums
|
||||
.insert(entry.key.clone(), entry.checksum.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
412 => anyhow::bail!("Conflict (412 Precondition Failed): ETag mismatch, retry needed"),
|
||||
413 => anyhow::bail!("Payload too large (413)"),
|
||||
401 | 403 => anyhow::bail!("Authentication error ({})", status),
|
||||
_ => anyhow::bail!("Upload failed with status {}", status),
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively scan `team_dir` for `.md` files, returning entries sorted by key.
|
||||
async fn scan_local_files(&self) -> Result<Vec<TeamMemoryEntry>> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
if !self.team_dir.exists() {
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
// Iterative DFS using an explicit stack to avoid deep recursion
|
||||
let mut stack = vec![self.team_dir.clone()];
|
||||
|
||||
while let Some(dir) = stack.pop() {
|
||||
let mut read_dir = tokio::fs::read_dir(&dir)
|
||||
.await
|
||||
.with_context(|| format!("read_dir {:?}", dir))?;
|
||||
|
||||
while let Some(entry) = read_dir.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
stack.push(path);
|
||||
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
|
||||
let content = tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading {:?}", path))?;
|
||||
|
||||
if content.len() > MAX_FILE_SIZE_BYTES {
|
||||
continue; // Skip files that are too large
|
||||
}
|
||||
|
||||
let key = path
|
||||
.strip_prefix(&self.team_dir)
|
||||
.unwrap()
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
|
||||
let checksum = content_checksum(&content);
|
||||
entries.push(TeamMemoryEntry { key, content, checksum });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entries.sort_by(|a, b| a.key.cmp(&b.key));
|
||||
Ok(entries)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Secret scanner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A pattern matched during secret scanning.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SecretMatch {
|
||||
/// Short label identifying the secret type, e.g. `"Anthropic API key"`.
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
/// Scan `content` for common high-confidence secret patterns.
|
||||
///
|
||||
/// Returns one [`SecretMatch`] per distinct pattern that fired. The actual
|
||||
/// matched text is intentionally **not** returned to avoid logging credentials.
|
||||
pub fn scan_for_secrets(content: &str) -> Vec<SecretMatch> {
|
||||
// Each tuple: (regex source, human-readable label)
|
||||
// Patterns ordered by likelihood of appearing in dev-team memory content.
|
||||
const PATTERNS: &[(&str, &str)] = &[
|
||||
// Cloud providers
|
||||
(r"(?:A3T[A-Z0-9]|AKIA|ASIA|ABIA|ACCA)[A-Z2-7]{16}", "AWS access key"),
|
||||
(r"AIza[\w-]{35}", "GCP API key"),
|
||||
// AI APIs
|
||||
(r"sk-ant-api03-[a-zA-Z0-9_\-]{93}AA", "Anthropic API key"),
|
||||
(r"sk-ant-admin01-[a-zA-Z0-9_\-]{93}AA", "Anthropic admin API key"),
|
||||
(r"sk-[a-zA-Z0-9]{20}T3BlbkFJ[a-zA-Z0-9]{20}", "OpenAI API key"),
|
||||
// Version control
|
||||
(r"ghp_[0-9a-zA-Z]{36}", "GitHub personal access token"),
|
||||
(r"github_pat_\w{82}", "GitHub fine-grained PAT"),
|
||||
(r"(?:ghu|ghs)_[0-9a-zA-Z]{36}", "GitHub app token"),
|
||||
(r"gho_[0-9a-zA-Z]{36}", "GitHub OAuth token"),
|
||||
(r"glpat-[\w-]{20}", "GitLab PAT"),
|
||||
// Communication
|
||||
(r"xoxb-[0-9]{10,13}-[0-9]{10,13}[a-zA-Z0-9-]*", "Slack bot token"),
|
||||
// Crypto / private keys
|
||||
(r"-----BEGIN[ A-Z0-9_-]{0,100}PRIVATE KEY", "Private key"),
|
||||
// Payments
|
||||
(r"(?:sk|rk)_(?:test|live|prod)_[a-zA-Z0-9]{10,99}", "Stripe secret key"),
|
||||
// NPM
|
||||
(r"npm_[a-zA-Z0-9]{36}", "NPM access token"),
|
||||
];
|
||||
|
||||
let mut findings: Vec<SecretMatch> = Vec::new();
|
||||
|
||||
for (pattern, label) in PATTERNS {
|
||||
// Lazily compile; the fn is not hot enough to warrant a static cache here
|
||||
if let Ok(re) = regex::Regex::new(pattern) {
|
||||
if re.is_match(content) {
|
||||
findings.push(SecretMatch { label: label.to_string() });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// --- content_checksum ---
|
||||
|
||||
#[test]
|
||||
fn test_checksum_format() {
|
||||
let cs = content_checksum("hello");
|
||||
assert!(cs.starts_with("sha256:"), "checksum should start with sha256:");
|
||||
assert_eq!(cs.len(), "sha256:".len() + 64, "sha256 hex is 64 chars");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_deterministic() {
|
||||
assert_eq!(content_checksum("foo"), content_checksum("foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_distinct() {
|
||||
assert_ne!(content_checksum("foo"), content_checksum("bar"));
|
||||
}
|
||||
|
||||
// --- validate_memory_path ---
|
||||
|
||||
#[test]
|
||||
fn test_valid_paths_accepted() {
|
||||
let ok_paths = [
|
||||
"MEMORY.md",
|
||||
"sub/dir/file.md",
|
||||
"sub/dir/another-file.md",
|
||||
"a.md",
|
||||
];
|
||||
for p in &ok_paths {
|
||||
assert!(validate_memory_path(p).is_ok(), "should accept: {}", p);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_byte_rejected() {
|
||||
assert!(validate_memory_path("foo\0bar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encoded_dot_rejected() {
|
||||
assert!(validate_memory_path("%2e%2e/secret").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encoded_slash_rejected() {
|
||||
assert!(validate_memory_path("foo%2Fbar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backslash_rejected() {
|
||||
assert!(validate_memory_path("foo\\bar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_unix_rejected() {
|
||||
assert!(validate_memory_path("/etc/passwd").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_windows_rejected() {
|
||||
assert!(validate_memory_path("C:foo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dotdot_rejected() {
|
||||
assert!(validate_memory_path("../secret").is_err());
|
||||
assert!(validate_memory_path("a/../../secret").is_err());
|
||||
}
|
||||
|
||||
// --- pack_batches ---
|
||||
|
||||
fn make_sync() -> TeamMemorySync {
|
||||
TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"owner/repo".to_string(),
|
||||
"token123".to_string(),
|
||||
PathBuf::from("/tmp/team"),
|
||||
)
|
||||
}
|
||||
|
||||
fn entry(key: &str, size: usize) -> TeamMemoryEntry {
|
||||
let content = "x".repeat(size);
|
||||
let checksum = content_checksum(&content);
|
||||
TeamMemoryEntry { key: key.to_string(), content, checksum }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_empty() {
|
||||
let sync = make_sync();
|
||||
let batches = sync.pack_batches(vec![]);
|
||||
assert!(batches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_single_entry() {
|
||||
let sync = make_sync();
|
||||
let batches = sync.pack_batches(vec![entry("a.md", 100)]);
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_oversized_solo() {
|
||||
let sync = make_sync();
|
||||
// Entry > MAX_PUT_BODY_BYTES → goes solo
|
||||
let big = entry("big.md", MAX_PUT_BODY_BYTES + 1);
|
||||
let small = entry("small.md", 100);
|
||||
let batches = sync.pack_batches(vec![big, small]);
|
||||
// big is solo, small may be in a separate batch
|
||||
assert!(batches.len() >= 2);
|
||||
assert_eq!(batches[0].len(), 1, "oversized entry is solo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_groups_small_entries() {
|
||||
let sync = make_sync();
|
||||
// Many small entries that each fit in one batch
|
||||
let entries: Vec<_> = (0..5).map(|i| entry(&format!("{i}.md"), 1024)).collect();
|
||||
let batches = sync.pack_batches(entries);
|
||||
// All 5 should fit in one batch (5 * ~1124 bytes << 200KB)
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].len(), 5);
|
||||
}
|
||||
|
||||
// --- scan_for_secrets ---
|
||||
|
||||
#[test]
|
||||
fn test_no_secrets_clean() {
|
||||
let findings = scan_for_secrets("# Team notes\n\nSome markdown content here.");
|
||||
assert!(findings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_github_pat() {
|
||||
let content = format!("token: ghp_{}", "A".repeat(36));
|
||||
let findings = scan_for_secrets(&content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("GitHub")),
|
||||
"should detect GitHub PAT"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_aws_key() {
|
||||
let content = "key=AKIAIOSFODNN7EXAMPLE";
|
||||
let findings = scan_for_secrets(content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("AWS")),
|
||||
"should detect AWS key"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_private_key() {
|
||||
let content = "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n";
|
||||
let findings = scan_for_secrets(content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("Private key")),
|
||||
"should detect private key"
|
||||
);
|
||||
}
|
||||
|
||||
// --- scan_local_files (integration-style) ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert!(entries.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_finds_md() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
tokio::fs::write(tmp.path().join("MEMORY.md"), "# Memory").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("ignore.txt"), "not md").await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].key, "MEMORY.md");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_sorted() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
tokio::fs::write(tmp.path().join("z.md"), "z").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("a.md"), "a").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("m.md"), "m").await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
let keys: Vec<_> = entries.iter().map(|e| e.key.as_str()).collect();
|
||||
assert_eq!(keys, vec!["a.md", "m.md", "z.md"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_checksums_match() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let content = "# Hello world";
|
||||
tokio::fs::write(tmp.path().join("MEMORY.md"), content).await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert_eq!(entries[0].checksum, content_checksum(content));
|
||||
}
|
||||
}
|
||||
192
src-rust/crates/core/src/voice.rs
Normal file
192
src-rust/crates/core/src/voice.rs
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
//! Voice mode availability checks
|
||||
|
||||
use crate::oauth::OAuthTokens;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum VoiceAvailability {
|
||||
Available,
|
||||
/// Not authenticated via first-party OAuth
|
||||
RequiresOAuth,
|
||||
/// OAuth token missing required scopes
|
||||
MissingScopes {
|
||||
required: Vec<String>,
|
||||
have: Vec<String>,
|
||||
},
|
||||
/// Feature disabled by kill-switch environment variable
|
||||
Disabled,
|
||||
/// Feature flag not enabled in this build
|
||||
NotEnabled,
|
||||
}
|
||||
|
||||
/// Scopes required for voice mode to function
|
||||
const VOICE_REQUIRED_SCOPES: &[&str] = &["user:inference", "user:profile"];
|
||||
|
||||
/// Environment variable that disables voice mode when set (any value)
|
||||
const KILL_SWITCH_ENV: &str = "CLAUDE_CODE_VOICE_DISABLED";
|
||||
|
||||
/// Check whether voice mode is available given the current OAuth tokens.
|
||||
///
|
||||
/// Pass `None` when the user is not authenticated via OAuth (API-key-only auth).
|
||||
pub fn check_voice_availability(tokens: Option<&OAuthTokens>) -> VoiceAvailability {
|
||||
// Check kill switch first — always wins
|
||||
if std::env::var(KILL_SWITCH_ENV).is_ok() {
|
||||
return VoiceAvailability::Disabled;
|
||||
}
|
||||
|
||||
// Voice requires first-party OAuth; API key alone is not sufficient
|
||||
let tokens = match tokens {
|
||||
Some(t) => t,
|
||||
None => return VoiceAvailability::RequiresOAuth,
|
||||
};
|
||||
|
||||
// OAuthTokens stores scopes as Vec<String>
|
||||
let have_scopes: &[String] = &tokens.scopes;
|
||||
|
||||
let missing: Vec<String> = VOICE_REQUIRED_SCOPES
|
||||
.iter()
|
||||
.filter(|&&required| !have_scopes.iter().any(|h| h == required))
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
if !missing.is_empty() {
|
||||
return VoiceAvailability::MissingScopes {
|
||||
required: VOICE_REQUIRED_SCOPES
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
have: have_scopes.to_vec(),
|
||||
};
|
||||
}
|
||||
|
||||
VoiceAvailability::Available
|
||||
}
|
||||
|
||||
impl VoiceAvailability {
|
||||
/// Returns `true` when voice mode can be started.
|
||||
pub fn is_available(&self) -> bool {
|
||||
matches!(self, VoiceAvailability::Available)
|
||||
}
|
||||
|
||||
/// Returns a human-readable error message when voice is not available,
|
||||
/// or `None` when it is.
|
||||
pub fn error_message(&self) -> Option<String> {
|
||||
match self {
|
||||
VoiceAvailability::Available => None,
|
||||
VoiceAvailability::RequiresOAuth => Some(
|
||||
"Voice mode requires OAuth authentication. Run /login to authenticate.".to_string(),
|
||||
),
|
||||
VoiceAvailability::MissingScopes { required, have } => Some(format!(
|
||||
"Voice mode requires scopes: {}. Your token has: {}",
|
||||
required.join(", "),
|
||||
if have.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
have.join(", ")
|
||||
}
|
||||
)),
|
||||
VoiceAvailability::Disabled => Some("Voice mode is currently disabled.".to_string()),
|
||||
VoiceAvailability::NotEnabled => {
|
||||
Some("Voice mode is not enabled in this build.".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tokens_with_scopes(scopes: Vec<&str>) -> OAuthTokens {
|
||||
OAuthTokens {
|
||||
access_token: "test_token".to_string(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_tokens_requires_oauth() {
|
||||
let result = check_voice_availability(None);
|
||||
assert_eq!(result, VoiceAvailability::RequiresOAuth);
|
||||
assert!(!result.is_available());
|
||||
assert!(result.error_message().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_with_all_scopes() {
|
||||
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert_eq!(result, VoiceAvailability::Available);
|
||||
assert!(result.is_available());
|
||||
assert!(result.error_message().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_one_scope() {
|
||||
let tokens = tokens_with_scopes(vec!["user:inference"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
|
||||
assert!(!result.is_available());
|
||||
let msg = result.error_message().unwrap();
|
||||
assert!(msg.contains("user:profile"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_all_scopes() {
|
||||
let tokens = tokens_with_scopes(vec!["org:create_api_key"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
|
||||
assert!(!result.is_available());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_missing() {
|
||||
let tokens = tokens_with_scopes(vec![]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(
|
||||
matches!(result, VoiceAvailability::MissingScopes { ref have, .. } if have.is_empty())
|
||||
);
|
||||
let msg = result.error_message().unwrap();
|
||||
assert!(msg.contains("none"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kill_switch_disables_voice() {
|
||||
// Temporarily set the kill-switch env var
|
||||
std::env::set_var(KILL_SWITCH_ENV, "1");
|
||||
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
assert_eq!(result, VoiceAvailability::Disabled);
|
||||
assert!(!result.is_available());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kill_switch_beats_no_auth() {
|
||||
std::env::set_var(KILL_SWITCH_ENV, "true");
|
||||
let result = check_voice_availability(None);
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
// Kill switch wins — returns Disabled, not RequiresOAuth
|
||||
assert_eq!(result, VoiceAvailability::Disabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_enabled_error_message() {
|
||||
let v = VoiceAvailability::NotEnabled;
|
||||
assert!(!v.is_available());
|
||||
assert!(v.error_message().unwrap().contains("not enabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extra_scopes_still_available() {
|
||||
// Having more scopes than required is fine
|
||||
let tokens = tokens_with_scopes(vec![
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"org:create_api_key",
|
||||
"user:file_upload",
|
||||
]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert_eq!(result, VoiceAvailability::Available);
|
||||
}
|
||||
}
|
||||
19
src-rust/crates/mcp/Cargo.toml
Normal file
19
src-rust/crates/mcp/Cargo.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "cc-mcp"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
762
src-rust/crates/mcp/src/lib.rs
Normal file
762
src-rust/crates/mcp/src/lib.rs
Normal file
|
|
@ -0,0 +1,762 @@
|
|||
// cc-mcp: Model Context Protocol (MCP) client implementation.
|
||||
//
|
||||
// MCP is a JSON-RPC 2.0 based protocol for connecting Claude to external
|
||||
// tool/resource servers. This crate implements:
|
||||
//
|
||||
// - JSON-RPC 2.0 client primitives
|
||||
// - MCP protocol handshake (initialize, initialized)
|
||||
// - Tool discovery (tools/list)
|
||||
// - Tool execution (tools/call)
|
||||
// - Resource management (resources/list, resources/read)
|
||||
// - Prompt templates (prompts/list, prompts/get)
|
||||
// - Transport: stdio (subprocess) and HTTP/SSE
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cc_core::config::McpServerConfig;
|
||||
use cc_core::types::ToolDefinition;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
use tokio::sync::{mpsc, oneshot, Mutex};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
pub use client::McpClient;
|
||||
pub use types::*;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JSON-RPC 2.0 Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod types {
|
||||
use super::*;
|
||||
|
||||
/// A JSON-RPC 2.0 request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: Value,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
pub fn new(id: impl Into<Value>, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: id.into(),
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notification(method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Value::Null,
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A JSON-RPC 2.0 response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i64,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
// ---- MCP protocol types ------------------------------------------------
|
||||
|
||||
/// MCP initialize request params.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeParams {
|
||||
pub protocol_version: String,
|
||||
pub capabilities: ClientCapabilities,
|
||||
pub client_info: ClientInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub roots: Option<RootsCapability>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RootsCapability {
|
||||
#[serde(rename = "listChanged")]
|
||||
pub list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// MCP initialize response result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResult {
|
||||
pub protocol_version: String,
|
||||
pub capabilities: ServerCapabilities,
|
||||
pub server_info: ServerInfo,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<ToolsCapability>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resources: Option<ResourcesCapability>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompts: Option<PromptsCapability>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logging: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolsCapability {
|
||||
#[serde(default)]
|
||||
pub list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesCapability {
|
||||
#[serde(default)]
|
||||
pub subscribe: bool,
|
||||
#[serde(default)]
|
||||
pub list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsCapability {
|
||||
#[serde(default)]
|
||||
pub list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// An MCP tool definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpTool {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
impl From<&McpTool> for ToolDefinition {
|
||||
fn from(t: &McpTool) -> Self {
|
||||
ToolDefinition {
|
||||
name: t.name.clone(),
|
||||
description: t.description.clone().unwrap_or_default(),
|
||||
input_schema: t.input_schema.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// tools/list response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListToolsResult {
|
||||
pub tools: Vec<McpTool>,
|
||||
#[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
/// tools/call params.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CallToolParams {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<Value>,
|
||||
}
|
||||
|
||||
/// tools/call response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CallToolResult {
|
||||
pub content: Vec<McpContent>,
|
||||
#[serde(default)]
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum McpContent {
|
||||
Text { text: String },
|
||||
Image { data: String, #[serde(rename = "mimeType")] mime_type: String },
|
||||
Resource { resource: ResourceContents },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourceContents {
|
||||
pub uri: String,
|
||||
#[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blob: Option<String>,
|
||||
}
|
||||
|
||||
/// An MCP resource.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpResource {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
/// resources/list response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListResourcesResult {
|
||||
pub resources: Vec<McpResource>,
|
||||
#[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
/// An MCP prompt template.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpPrompt {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub arguments: Vec<McpPromptArgument>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpPromptArgument {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub required: bool,
|
||||
}
|
||||
|
||||
/// prompts/list response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListPromptsResult {
|
||||
pub prompts: Vec<McpPrompt>,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Transport layer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod transport {
|
||||
use super::*;
|
||||
|
||||
/// A transport can send requests and receive responses.
|
||||
#[async_trait]
|
||||
pub trait McpTransport: Send + Sync {
|
||||
async fn send(&self, message: &JsonRpcRequest) -> anyhow::Result<()>;
|
||||
async fn recv(&self) -> anyhow::Result<Option<JsonRpcResponse>>;
|
||||
async fn close(&self) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
/// Stdio transport: spawns a subprocess and communicates via stdin/stdout.
|
||||
pub struct StdioTransport {
|
||||
child: Arc<Mutex<Child>>,
|
||||
stdin: Arc<Mutex<ChildStdin>>,
|
||||
stdout_rx: Arc<Mutex<mpsc::UnboundedReceiver<String>>>,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub async fn spawn(config: &McpServerConfig) -> anyhow::Result<Self> {
|
||||
let command = config
|
||||
.command
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' has no command", config.name))?;
|
||||
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(&config.args)
|
||||
.envs(&config.env)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn()?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not get stdin"))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not get stdout"))?;
|
||||
|
||||
let (tx, rx) = mpsc::unbounded_channel::<String>();
|
||||
|
||||
// Background reader task
|
||||
tokio::spawn(async move {
|
||||
let reader = BufReader::new(stdout);
|
||||
let mut lines = reader.lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
if tx.send(line).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
child: Arc::new(Mutex::new(child)),
|
||||
stdin: Arc::new(Mutex::new(stdin)),
|
||||
stdout_rx: Arc::new(Mutex::new(rx)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpTransport for StdioTransport {
|
||||
async fn send(&self, message: &JsonRpcRequest) -> anyhow::Result<()> {
|
||||
let json = serde_json::to_string(message)? + "\n";
|
||||
let mut stdin = self.stdin.lock().await;
|
||||
stdin.write_all(json.as_bytes()).await?;
|
||||
stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&self) -> anyhow::Result<Option<JsonRpcResponse>> {
|
||||
let mut rx = self.stdout_rx.lock().await;
|
||||
let line = rx.recv().await;
|
||||
match line {
|
||||
Some(s) => {
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&s)?;
|
||||
Ok(Some(resp))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&self) -> anyhow::Result<()> {
|
||||
let mut child = self.child.lock().await;
|
||||
let _ = child.kill().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MCP Client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod client {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// A fully initialized MCP client connected to a single server.
|
||||
pub struct McpClient {
|
||||
pub server_name: String,
|
||||
pub server_info: Option<ServerInfo>,
|
||||
pub capabilities: ServerCapabilities,
|
||||
pub tools: Vec<McpTool>,
|
||||
pub resources: Vec<McpResource>,
|
||||
pub prompts: Vec<McpPrompt>,
|
||||
transport: Arc<dyn transport::McpTransport>,
|
||||
next_id: AtomicU64,
|
||||
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
/// Connect to an MCP server using stdio transport and complete the
|
||||
/// initialize handshake.
|
||||
pub async fn connect_stdio(config: &McpServerConfig) -> anyhow::Result<Self> {
|
||||
let transport = transport::StdioTransport::spawn(config).await?;
|
||||
let client = Self {
|
||||
server_name: config.name.clone(),
|
||||
server_info: None,
|
||||
capabilities: ServerCapabilities::default(),
|
||||
tools: vec![],
|
||||
resources: vec![],
|
||||
prompts: vec![],
|
||||
transport: Arc::new(transport),
|
||||
next_id: AtomicU64::new(1),
|
||||
pending: Arc::new(Mutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
client.initialize().await
|
||||
}
|
||||
|
||||
/// Send the MCP initialize handshake and discover capabilities.
|
||||
async fn initialize(mut self) -> anyhow::Result<Self> {
|
||||
let params = InitializeParams {
|
||||
protocol_version: "2024-11-05".to_string(),
|
||||
capabilities: ClientCapabilities {
|
||||
roots: Some(RootsCapability { list_changed: false }),
|
||||
sampling: None,
|
||||
},
|
||||
client_info: ClientInfo {
|
||||
name: cc_core::constants::APP_NAME.to_string(),
|
||||
version: cc_core::constants::APP_VERSION.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let result: InitializeResult = self
|
||||
.call("initialize", Some(serde_json::to_value(¶ms)?))
|
||||
.await?;
|
||||
|
||||
self.server_info = Some(result.server_info);
|
||||
self.capabilities = result.capabilities.clone();
|
||||
|
||||
// Send initialized notification
|
||||
let notif = JsonRpcRequest::notification("notifications/initialized", None);
|
||||
self.transport.send(¬if).await?;
|
||||
|
||||
// Discover tools if supported
|
||||
if result.capabilities.tools.is_some() {
|
||||
match self.list_tools().await {
|
||||
Ok(tools) => self.tools = tools,
|
||||
Err(e) => warn!(server = %self.server_name, error = %e, "Failed to list tools"),
|
||||
}
|
||||
}
|
||||
|
||||
// Discover resources if supported
|
||||
if result.capabilities.resources.is_some() {
|
||||
match self.list_resources().await {
|
||||
Ok(resources) => self.resources = resources,
|
||||
Err(e) => warn!(server = %self.server_name, error = %e, "Failed to list resources"),
|
||||
}
|
||||
}
|
||||
|
||||
// Discover prompts if supported
|
||||
if result.capabilities.prompts.is_some() {
|
||||
match self.list_prompts().await {
|
||||
Ok(prompts) => self.prompts = prompts,
|
||||
Err(e) => warn!(server = %self.server_name, error = %e, "Failed to list prompts"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
// ---- High-level API -----------------------------------------------
|
||||
|
||||
pub async fn list_tools(&self) -> anyhow::Result<Vec<McpTool>> {
|
||||
let result: ListToolsResult = self.call("tools/list", None).await?;
|
||||
Ok(result.tools)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: &str,
|
||||
arguments: Option<Value>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
let params = CallToolParams {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
};
|
||||
self.call("tools/call", Some(serde_json::to_value(¶ms)?))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_resources(&self) -> anyhow::Result<Vec<McpResource>> {
|
||||
let result: ListResourcesResult = self.call("resources/list", None).await?;
|
||||
Ok(result.resources)
|
||||
}
|
||||
|
||||
pub async fn read_resource(&self, uri: &str) -> anyhow::Result<ResourceContents> {
|
||||
let params = serde_json::json!({ "uri": uri });
|
||||
let result: Value = self.call("resources/read", Some(params)).await?;
|
||||
let contents = result
|
||||
.get("contents")
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.ok_or_else(|| anyhow::anyhow!("No contents in response"))?;
|
||||
Ok(serde_json::from_value(contents.clone())?)
|
||||
}
|
||||
|
||||
pub async fn list_prompts(&self) -> anyhow::Result<Vec<McpPrompt>> {
|
||||
let result: ListPromptsResult = self.call("prompts/list", None).await?;
|
||||
Ok(result.prompts)
|
||||
}
|
||||
|
||||
/// Get all tools as `ToolDefinition` objects suitable for the API.
|
||||
pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools.iter().map(|t| t.into()).collect()
|
||||
}
|
||||
|
||||
// ---- Internal RPC machinery ---------------------------------------
|
||||
|
||||
/// Send a request and wait for the response, deserializing into T.
|
||||
async fn call<T: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> anyhow::Result<T> {
|
||||
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
let req = JsonRpcRequest::new(id, method, params);
|
||||
|
||||
// We use a simple request/response loop here (no concurrent requests).
|
||||
// For production use, proper demultiplexing by id would be needed.
|
||||
self.transport.send(&req).await?;
|
||||
|
||||
loop {
|
||||
let resp = self
|
||||
.transport
|
||||
.recv()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("MCP transport closed"))?;
|
||||
|
||||
// Check if this response matches our request id
|
||||
let resp_id = resp.id.as_ref().and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
if resp_id != id {
|
||||
// Might be a server-initiated notification; skip
|
||||
debug!(got_id = resp_id, want_id = id, "Skipping non-matching response");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(err) = resp.error {
|
||||
return Err(anyhow::anyhow!(
|
||||
"MCP error {}: {}",
|
||||
err.code,
|
||||
err.message
|
||||
));
|
||||
}
|
||||
|
||||
let result = resp
|
||||
.result
|
||||
.ok_or_else(|| anyhow::anyhow!("No result in MCP response"))?;
|
||||
return Ok(serde_json::from_value(result)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MCP Manager: manages multiple server connections
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Manages a pool of MCP server connections.
|
||||
pub struct McpManager {
|
||||
clients: HashMap<String, McpClient>,
|
||||
}
|
||||
|
||||
impl McpManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to all configured MCP servers.
|
||||
pub async fn connect_all(configs: &[McpServerConfig]) -> Self {
|
||||
let mut manager = Self::new();
|
||||
for config in configs {
|
||||
match config.server_type.as_str() {
|
||||
"stdio" => {
|
||||
debug!(server = %config.name, "Connecting to MCP server via stdio");
|
||||
match McpClient::connect_stdio(config).await {
|
||||
Ok(client) => {
|
||||
let name = config.name.clone();
|
||||
manager.clients.insert(name, client);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
server = %config.name,
|
||||
error = %e,
|
||||
"Failed to connect to MCP server"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
other => {
|
||||
warn!(transport = other, "Unsupported MCP transport type");
|
||||
}
|
||||
}
|
||||
}
|
||||
manager
|
||||
}
|
||||
|
||||
/// Get all tool definitions from all connected servers.
|
||||
pub fn all_tool_definitions(&self) -> Vec<(String, ToolDefinition)> {
|
||||
let mut defs = vec![];
|
||||
for (server_name, client) in &self.clients {
|
||||
for td in client.tool_definitions() {
|
||||
// Prefix tool name with server name to avoid conflicts
|
||||
let prefixed = ToolDefinition {
|
||||
name: format!("{}_{}", server_name, td.name),
|
||||
description: format!("[{}] {}", server_name, td.description),
|
||||
input_schema: td.input_schema.clone(),
|
||||
};
|
||||
defs.push((server_name.clone(), prefixed));
|
||||
}
|
||||
}
|
||||
defs
|
||||
}
|
||||
|
||||
/// Execute a tool call, routing to the correct server.
|
||||
/// Tool name format: `<server_name>_<tool_name>`.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
prefixed_name: &str,
|
||||
arguments: Option<Value>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
// Find the server name by matching prefix
|
||||
for (server_name, client) in &self.clients {
|
||||
let prefix = format!("{}_", server_name);
|
||||
if let Some(tool_name) = prefixed_name.strip_prefix(&prefix) {
|
||||
return client.call_tool(tool_name, arguments).await;
|
||||
}
|
||||
}
|
||||
Err(anyhow::anyhow!(
|
||||
"No MCP server found for tool: {}",
|
||||
prefixed_name
|
||||
))
|
||||
}
|
||||
|
||||
/// Number of connected servers.
|
||||
pub fn server_count(&self) -> usize {
|
||||
self.clients.len()
|
||||
}
|
||||
|
||||
/// List all connected server names.
|
||||
pub fn server_names(&self) -> Vec<&str> {
|
||||
self.clients.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// Get server instructions (from initialize response).
|
||||
pub fn server_instructions(&self) -> Vec<(String, String)> {
|
||||
// McpClient doesn't store instructions yet; placeholder
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// List all resources from all (or a specific) connected server.
|
||||
pub async fn list_all_resources(
|
||||
&self,
|
||||
server_filter: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let mut all = vec![];
|
||||
for (name, client) in &self.clients {
|
||||
if let Some(filter) = server_filter {
|
||||
if name != filter {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
match client.list_resources().await {
|
||||
Ok(resources) => {
|
||||
for r in resources {
|
||||
all.push(serde_json::json!({
|
||||
"uri": r.uri,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"mimeType": r.mime_type,
|
||||
"server": name,
|
||||
}));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(server = %name, error = %e, "Failed to list resources");
|
||||
}
|
||||
}
|
||||
}
|
||||
all
|
||||
}
|
||||
|
||||
/// Read a specific resource from a named server.
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server_name: &str,
|
||||
uri: &str,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let client = self
|
||||
.clients
|
||||
.get(server_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("Server '{}' not found", server_name))?;
|
||||
|
||||
let contents = client.read_resource(uri).await?;
|
||||
Ok(serde_json::to_value(&contents)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for McpManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MCP Tool wrapper: makes MCP tools act like native cc-tools
|
||||
// ---------------------------------------------------------------------------
|
||||
// (This would be in cc-tools but is here to avoid circular deps)
|
||||
|
||||
/// Convert MCP tool call result to a string for the model.
|
||||
pub fn mcp_result_to_string(result: &CallToolResult) -> String {
|
||||
result
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
McpContent::Text { text } => Some(text.as_str()),
|
||||
McpContent::Image { .. } => Some("[image]"),
|
||||
McpContent::Resource { resource } => {
|
||||
resource.text.as_deref().or(Some("[binary resource]"))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_json_rpc_request_serialization() {
|
||||
let req = JsonRpcRequest::new(1u64, "tools/list", None);
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(json.contains("\"method\":\"tools/list\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_to_definition() {
|
||||
let tool = McpTool {
|
||||
name: "search".to_string(),
|
||||
description: Some("Search the web".to_string()),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": { "query": { "type": "string" } }
|
||||
}),
|
||||
};
|
||||
let def: ToolDefinition = (&tool).into();
|
||||
assert_eq!(def.name, "search");
|
||||
assert_eq!(def.description, "Search the web");
|
||||
}
|
||||
}
|
||||
25
src-rust/crates/query/Cargo.toml
Normal file
25
src-rust/crates/query/Cargo.toml
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
name = "cc-query"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-tools = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
205
src-rust/crates/query/src/agent_tool.rs
Normal file
205
src-rust/crates/query/src/agent_tool.rs
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
// AgentTool: spawn a sub-agent to handle a complex sub-task.
|
||||
//
|
||||
// Lives in cc-query (not cc-tools) to avoid a circular dependency:
|
||||
// cc-tools would need cc-query, but cc-query already needs cc-tools.
|
||||
//
|
||||
// The AgentTool creates a nested query loop with its own context, enabling
|
||||
// the model to delegate complex work to specialized sub-agents. Each sub-agent:
|
||||
// - Runs its own agentic loop
|
||||
// - Has access to all tools (except AgentTool itself, preventing infinite recursion)
|
||||
// - Returns its final output as the tool result
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cc_api::client::ClientConfig;
|
||||
use cc_api::AnthropicClient;
|
||||
use cc_core::types::Message;
|
||||
use cc_tools::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{run_query_loop, QueryConfig, QueryOutcome};
|
||||
|
||||
pub struct AgentTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentInput {
|
||||
/// Short description of the agent's task (used for logging).
|
||||
description: String,
|
||||
/// The complete task prompt to send as the first user message.
|
||||
prompt: String,
|
||||
/// Optional: which tools to make available (defaults to all minus AgentTool).
|
||||
#[serde(default)]
|
||||
tools: Option<Vec<String>>,
|
||||
/// Optional: system prompt override for the sub-agent.
|
||||
#[serde(default)]
|
||||
system_prompt: Option<String>,
|
||||
/// Optional: max turns for the sub-agent (default 10).
|
||||
#[serde(default)]
|
||||
max_turns: Option<u32>,
|
||||
/// Optional: model override for this sub-agent.
|
||||
#[serde(default)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for AgentTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_AGENT
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Launch a new agent to handle complex, multi-step tasks autonomously. \
|
||||
The agent runs its own agentic loop with access to tools and returns \
|
||||
its final result. Use this to delegate sub-tasks, run parallel \
|
||||
workstreams, or handle tasks that require many tool calls."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
// The agent inherits parent permissions; no extra level required.
|
||||
PermissionLevel::None
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Short description of the agent's task (3-5 words)"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The complete task for the agent to perform"
|
||||
},
|
||||
"tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "List of tool names to make available. Defaults to all tools."
|
||||
},
|
||||
"system_prompt": {
|
||||
"type": "string",
|
||||
"description": "Optional system prompt override for the sub-agent"
|
||||
},
|
||||
"max_turns": {
|
||||
"type": "number",
|
||||
"description": "Maximum number of turns for the sub-agent (default 10)"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Optional model to use for this agent"
|
||||
}
|
||||
},
|
||||
"required": ["description", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: AgentInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
info!(description = %params.description, "Spawning sub-agent");
|
||||
|
||||
// Resolve API key from environment.
|
||||
let api_key = match std::env::var("ANTHROPIC_API_KEY")
|
||||
.ok()
|
||||
.filter(|k| !k.is_empty())
|
||||
{
|
||||
Some(k) => k,
|
||||
None => {
|
||||
return ToolResult::error(
|
||||
"ANTHROPIC_API_KEY not set – cannot spawn sub-agent".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Dedicated Anthropic client for the sub-agent.
|
||||
let client = match AnthropicClient::new(ClientConfig {
|
||||
api_key,
|
||||
..Default::default()
|
||||
}) {
|
||||
Ok(c) => Arc::new(c),
|
||||
Err(e) => return ToolResult::error(format!("Failed to create client: {}", e)),
|
||||
};
|
||||
|
||||
// Build the tool list for the sub-agent.
|
||||
// Always exclude AgentTool itself to prevent unbounded recursion.
|
||||
let all = cc_tools::all_tools();
|
||||
let agent_tools: Vec<Box<dyn Tool>> = if let Some(ref allowed) = params.tools {
|
||||
all.into_iter()
|
||||
.filter(|t| allowed.contains(&t.name().to_string()))
|
||||
.collect()
|
||||
} else {
|
||||
all.into_iter()
|
||||
.filter(|t| t.name() != cc_core::constants::TOOL_NAME_AGENT)
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Resolve model: explicit override > parent context model > default.
|
||||
let model = params
|
||||
.model
|
||||
.filter(|m| !m.is_empty())
|
||||
.unwrap_or_else(|| cc_core::constants::DEFAULT_MODEL.to_string());
|
||||
|
||||
let system_prompt = params.system_prompt.unwrap_or_else(|| {
|
||||
"You are a specialized AI agent helping with a specific sub-task. \
|
||||
Complete the task thoroughly and return your findings."
|
||||
.to_string()
|
||||
});
|
||||
|
||||
let query_config = QueryConfig {
|
||||
model,
|
||||
max_tokens: cc_core::constants::DEFAULT_MAX_TOKENS,
|
||||
max_turns: params.max_turns.unwrap_or(10),
|
||||
system_prompt: Some(system_prompt),
|
||||
append_system_prompt: None,
|
||||
thinking_budget: None,
|
||||
temperature: None,
|
||||
};
|
||||
|
||||
// Run the sub-agent loop.
|
||||
let mut messages = vec![Message::user(params.prompt)];
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
let outcome = run_query_loop(
|
||||
client.as_ref(),
|
||||
&mut messages,
|
||||
&agent_tools,
|
||||
ctx,
|
||||
&query_config,
|
||||
ctx.cost_tracker.clone(),
|
||||
None, // no event forwarding for sub-agents
|
||||
cancel,
|
||||
)
|
||||
.await;
|
||||
|
||||
match outcome {
|
||||
QueryOutcome::EndTurn { message, usage } => {
|
||||
let text = message.get_all_text();
|
||||
debug!(
|
||||
description = %params.description,
|
||||
output_tokens = usage.output_tokens,
|
||||
"Sub-agent completed"
|
||||
);
|
||||
ToolResult::success(text)
|
||||
}
|
||||
QueryOutcome::MaxTokens { partial_message, .. } => {
|
||||
let text = partial_message.get_all_text();
|
||||
ToolResult::success(format!(
|
||||
"{}\n\n[Note: Agent hit max_tokens limit]",
|
||||
text
|
||||
))
|
||||
}
|
||||
QueryOutcome::Cancelled => {
|
||||
ToolResult::error("Sub-agent was cancelled".to_string())
|
||||
}
|
||||
QueryOutcome::Error(e) => {
|
||||
ToolResult::error(format!("Sub-agent error: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
410
src-rust/crates/query/src/auto_dream.rs
Normal file
410
src-rust/crates/query/src/auto_dream.rs
Normal file
|
|
@ -0,0 +1,410 @@
|
|||
//! AutoDream: automatic memory consolidation daemon
|
||||
//!
|
||||
//! Background memory consolidation. Fires a consolidation prompt as a forked
|
||||
//! subagent when the time gate passes AND enough sessions have accumulated.
|
||||
//!
|
||||
//! Gate order (cheapest first):
|
||||
//! 1. Time: hours since last_consolidated_at >= min_hours (one stat)
|
||||
//! 2. Sessions: transcript count with mtime > last_consolidated_at >= min_sessions
|
||||
//! 3. Lock: no other process mid-consolidation (stale after 1 hour)
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::fs;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Scan throttle: when time-gate passes but session-gate doesn't, the lock
|
||||
// mtime doesn't advance, so the time-gate keeps passing every turn.
|
||||
pub const SESSION_SCAN_INTERVAL_SECS: u64 = 10 * 60; // 10 minutes
|
||||
|
||||
/// GrowthBook-sourced scheduling config (with defaults)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoDreamConfig {
|
||||
/// Minimum hours between consolidations (default: 24)
|
||||
pub min_hours: f64,
|
||||
/// Minimum new-session count to trigger (default: 5)
|
||||
pub min_sessions: usize,
|
||||
}
|
||||
|
||||
impl Default for AutoDreamConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_hours: 24.0,
|
||||
min_sessions: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Persisted state written to `.consolidation_state.json`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ConsolidationState {
|
||||
/// Unix timestamp (seconds) of last successful consolidation.
|
||||
/// `None` means never consolidated.
|
||||
pub last_consolidated_at: Option<u64>,
|
||||
/// ETag / opaque lock token – reserved for future distributed locking.
|
||||
pub lock_etag: Option<String>,
|
||||
}
|
||||
|
||||
/// Core AutoDream logic; owns path state, delegates I/O to async methods.
|
||||
pub struct AutoDream {
|
||||
config: AutoDreamConfig,
|
||||
memory_dir: PathBuf,
|
||||
conversations_dir: PathBuf,
|
||||
lock_file: PathBuf,
|
||||
state_file: PathBuf,
|
||||
}
|
||||
|
||||
impl AutoDream {
|
||||
pub fn new(memory_dir: PathBuf, conversations_dir: PathBuf) -> Self {
|
||||
let lock_file = memory_dir.join(".consolidation_lock");
|
||||
let state_file = memory_dir.join(".consolidation_state.json");
|
||||
Self {
|
||||
config: AutoDreamConfig::default(),
|
||||
memory_dir,
|
||||
conversations_dir,
|
||||
lock_file,
|
||||
state_file,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct with explicit config (for testing / feature-flag overrides).
|
||||
pub fn with_config(
|
||||
config: AutoDreamConfig,
|
||||
memory_dir: PathBuf,
|
||||
conversations_dir: PathBuf,
|
||||
) -> Self {
|
||||
let lock_file = memory_dir.join(".consolidation_lock");
|
||||
let state_file = memory_dir.join(".consolidation_state.json");
|
||||
Self {
|
||||
config,
|
||||
memory_dir,
|
||||
conversations_dir,
|
||||
lock_file,
|
||||
state_file,
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Gate checks
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/// Check all gates cheapest-first. Returns `true` if consolidation should run.
|
||||
pub async fn should_consolidate(&self, state: &ConsolidationState) -> Result<bool> {
|
||||
// Gate 1: Time gate (cheapest – one arithmetic check)
|
||||
if !self.time_gate_passes(state) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Gate 2: Session gate (directory scan)
|
||||
if !self.session_gate_passes(state).await? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Gate 3: Lock gate (no other process mid-consolidation)
|
||||
if !self.lock_gate_passes().await? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn time_gate_passes(&self, state: &ConsolidationState) -> bool {
|
||||
let now_secs = now_secs();
|
||||
match state.last_consolidated_at {
|
||||
None => true, // Never consolidated → always pass
|
||||
Some(last) => {
|
||||
let hours_elapsed = (now_secs.saturating_sub(last)) as f64 / 3600.0;
|
||||
hours_elapsed >= self.config.min_hours
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn session_gate_passes(&self, state: &ConsolidationState) -> Result<bool> {
|
||||
let last_secs = state.last_consolidated_at.unwrap_or(0);
|
||||
let mut count = 0usize;
|
||||
|
||||
if !self.conversations_dir.exists() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let mut dir = fs::read_dir(&self.conversations_dir).await?;
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
let metadata = entry.metadata().await?;
|
||||
if let Ok(mtime) = metadata.modified() {
|
||||
let mtime_secs = mtime
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_secs();
|
||||
if mtime_secs > last_secs {
|
||||
count += 1;
|
||||
if count >= self.config.min_sessions {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn lock_gate_passes(&self) -> Result<bool> {
|
||||
if !self.lock_file.exists() {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Stale lock (>1 hour) is treated as released
|
||||
match fs::metadata(&self.lock_file).await {
|
||||
Ok(meta) => {
|
||||
if let Ok(mtime) = meta.modified() {
|
||||
let age_secs = SystemTime::now()
|
||||
.duration_since(mtime)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_secs();
|
||||
Ok(age_secs > 3600)
|
||||
} else {
|
||||
// Cannot stat mtime → conservative: gate passes (treat as stale)
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
Err(_) => Ok(true), // File disappeared between exists() and metadata()
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Lock management
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/// Write a timestamp to the lock file, creating it if absent.
|
||||
pub async fn acquire_lock(&self) -> Result<()> {
|
||||
if let Some(parent) = self.lock_file.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
fs::write(&self.lock_file, now_secs().to_string()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the lock file. No-op if it doesn't exist.
|
||||
pub async fn release_lock(&self) -> Result<()> {
|
||||
if self.lock_file.exists() {
|
||||
fs::remove_file(&self.lock_file).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// State persistence
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/// Stamp `last_consolidated_at = now` and persist.
|
||||
pub async fn update_state(&self, state: &mut ConsolidationState) -> Result<()> {
|
||||
state.last_consolidated_at = Some(now_secs());
|
||||
let json = serde_json::to_string_pretty(state)?;
|
||||
if let Some(parent) = self.state_file.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
fs::write(&self.state_file, json).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load persisted state; returns `Default` on any error (missing file, parse failure).
|
||||
pub async fn load_state(&self) -> ConsolidationState {
|
||||
match fs::read_to_string(&self.state_file).await {
|
||||
Ok(data) => serde_json::from_str(&data).unwrap_or_default(),
|
||||
Err(_) => ConsolidationState::default(),
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Prompt construction
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/// Build the consolidation prompt for the forked subagent.
|
||||
pub fn consolidation_prompt(&self) -> String {
|
||||
format!(
|
||||
r#"# Dream: Memory Consolidation
|
||||
|
||||
You are performing a dream — a reflective pass over your memory files. Synthesize what you have learned recently into durable, well-organized memories so that future sessions can orient quickly.
|
||||
|
||||
Memory directory: `{memory_dir}`
|
||||
|
||||
Session transcripts: `{conv_dir}` (large JSONL files — grep narrowly, do not read whole files)
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Orient
|
||||
|
||||
- `ls` the memory directory to see what already exists
|
||||
- Read `MEMORY.md` to understand the current index
|
||||
- Skim existing topic files so you improve them rather than creating duplicates
|
||||
|
||||
## Phase 2 — Gather recent signal
|
||||
|
||||
Look for new information worth persisting:
|
||||
|
||||
1. **Daily logs** (`logs/YYYY/MM/YYYY-MM-DD.md`) if present
|
||||
2. **Existing memories that drifted** — facts that contradict what you see now
|
||||
3. **Transcript search** — grep narrowly for specific terms:
|
||||
`grep -rn "<narrow term>" {conv_dir}/ --include="*.jsonl" | tail -50`
|
||||
|
||||
Do not exhaustively read transcripts. Look only for things you already suspect matter.
|
||||
|
||||
## Phase 3 — Consolidate
|
||||
|
||||
For each thing worth remembering, write or update a memory file. Focus on:
|
||||
- Merging new signal into existing topic files rather than creating near-duplicates
|
||||
- Converting relative dates to absolute dates
|
||||
- Deleting contradicted facts
|
||||
|
||||
## Phase 4 — Prune and index
|
||||
|
||||
Update `MEMORY.md` so it stays under 200 lines and ~25 KB. It is an **index**, not a dump.
|
||||
Each entry: `- [Title](file.md) — one-line hook`
|
||||
|
||||
- Remove pointers to stale, wrong, or superseded memories
|
||||
- Shorten verbose entries; move detail into topic files
|
||||
- Add pointers to newly important memories
|
||||
- Resolve contradictions
|
||||
|
||||
---
|
||||
|
||||
Return a brief summary of what you consolidated, updated, or pruned. If nothing changed, say so.
|
||||
|
||||
**Tool constraints for this run:** Use only read-only Bash commands (ls, find, grep, cat, stat, wc, head, tail). Anything that writes, redirects to a file, or modifies state will be denied.
|
||||
"#,
|
||||
memory_dir = self.memory_dir.display(),
|
||||
conv_dir = self.conversations_dir.display(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tests
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_dream(tmp: &TempDir) -> AutoDream {
|
||||
let mem = tmp.path().join("memory");
|
||||
let conv = tmp.path().join("conversations");
|
||||
AutoDream::new(mem, conv)
|
||||
}
|
||||
|
||||
// --- time_gate_passes ---
|
||||
|
||||
#[test]
|
||||
fn test_time_gate_never_consolidated() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
let state = ConsolidationState::default();
|
||||
assert!(dream.time_gate_passes(&state), "no prior consolidation → gate passes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_gate_recent_consolidation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = AutoDream::with_config(
|
||||
AutoDreamConfig { min_hours: 24.0, min_sessions: 5 },
|
||||
tmp.path().join("memory"),
|
||||
tmp.path().join("conversations"),
|
||||
);
|
||||
let state = ConsolidationState {
|
||||
last_consolidated_at: Some(now_secs()), // just now
|
||||
lock_etag: None,
|
||||
};
|
||||
assert!(!dream.time_gate_passes(&state), "just consolidated → gate blocked");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_gate_old_consolidation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = AutoDream::with_config(
|
||||
AutoDreamConfig { min_hours: 24.0, min_sessions: 5 },
|
||||
tmp.path().join("memory"),
|
||||
tmp.path().join("conversations"),
|
||||
);
|
||||
// 25 hours ago
|
||||
let old = now_secs().saturating_sub(25 * 3600);
|
||||
let state = ConsolidationState {
|
||||
last_consolidated_at: Some(old),
|
||||
lock_etag: None,
|
||||
};
|
||||
assert!(dream.time_gate_passes(&state), "consolidated 25h ago → gate passes");
|
||||
}
|
||||
|
||||
// --- lock_gate_passes (sync-friendly via tokio::test) ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lock_gate_no_lock_file() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
assert!(dream.lock_gate_passes().await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lock_gate_fresh_lock_blocks() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
std::fs::create_dir_all(&dream.memory_dir).unwrap();
|
||||
std::fs::write(&dream.lock_file, "12345").unwrap();
|
||||
// Fresh file → gate blocked
|
||||
assert!(!dream.lock_gate_passes().await.unwrap());
|
||||
}
|
||||
|
||||
// --- consolidation_prompt sanity ---
|
||||
|
||||
#[test]
|
||||
fn test_consolidation_prompt_contains_paths() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
let prompt = dream.consolidation_prompt();
|
||||
assert!(prompt.contains("MEMORY.md"));
|
||||
assert!(prompt.contains("Memory Consolidation"));
|
||||
assert!(prompt.contains("Phase 1"));
|
||||
assert!(prompt.contains("Phase 4"));
|
||||
}
|
||||
|
||||
// --- update_state / load_state round-trip ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_state_round_trip() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
std::fs::create_dir_all(&dream.memory_dir).unwrap();
|
||||
|
||||
let mut state = ConsolidationState::default();
|
||||
dream.update_state(&mut state).await.unwrap();
|
||||
|
||||
assert!(state.last_consolidated_at.is_some());
|
||||
let loaded = dream.load_state().await;
|
||||
assert_eq!(loaded.last_consolidated_at, state.last_consolidated_at);
|
||||
}
|
||||
|
||||
// --- acquire_lock / release_lock ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_acquire_release_lock() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let dream = make_dream(&tmp);
|
||||
|
||||
dream.acquire_lock().await.unwrap();
|
||||
assert!(dream.lock_file.exists());
|
||||
|
||||
dream.release_lock().await.unwrap();
|
||||
assert!(!dream.lock_file.exists());
|
||||
}
|
||||
}
|
||||
290
src-rust/crates/query/src/compact.rs
Normal file
290
src-rust/crates/query/src/compact.rs
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
// Auto-compact service for cc-query.
|
||||
//
|
||||
// When the conversation context window fills up (~90%+), we automatically
|
||||
// summarise older messages to free space. This mirrors the TypeScript
|
||||
// autoCompact / compact service behaviour.
|
||||
//
|
||||
// Strategy:
|
||||
// 1. Keep the last KEEP_RECENT_MESSAGES messages verbatim.
|
||||
// 2. Ask the model to summarise everything before those messages.
|
||||
// 3. Replace the head of the conversation with a single synthetic
|
||||
// <compact-summary> user message, followed by the recent tail.
|
||||
//
|
||||
// The summary is generated in a single non-agentic API call so it doesn't
|
||||
// trigger another compaction recursively.
|
||||
|
||||
use cc_api::{ApiMessage, CreateMessageRequest, StreamAccumulator, StreamEvent, StreamHandler, SystemPrompt};
|
||||
use cc_core::error::ClaudeError;
|
||||
use cc_core::types::{Message, Role};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants (mirrors TypeScript autoCompact.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// We target keeping this many context tokens free after compaction.
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u64 = 13_000;
|
||||
|
||||
/// Start warning when this many tokens remain in the context window.
|
||||
const WARNING_THRESHOLD_BUFFER_TOKENS: u64 = 20_000;
|
||||
|
||||
/// Fraction of the context window at which auto-compact triggers.
|
||||
const AUTOCOMPACT_TRIGGER_FRACTION: f64 = 0.90;
|
||||
|
||||
/// How many recent messages to preserve verbatim after compaction.
|
||||
const KEEP_RECENT_MESSAGES: usize = 10;
|
||||
|
||||
/// Max consecutive auto-compact failures before giving up (circuit breaker).
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Tracks auto-compact state across turns.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct AutoCompactState {
|
||||
/// Total compactions performed this session.
|
||||
pub compaction_count: u32,
|
||||
/// Consecutive failures (reset on success).
|
||||
pub consecutive_failures: u32,
|
||||
/// Whether the circuit breaker is open (too many failures).
|
||||
pub disabled: bool,
|
||||
}
|
||||
|
||||
impl AutoCompactState {
|
||||
/// Record a successful compaction.
|
||||
pub fn on_success(&mut self) {
|
||||
self.compaction_count += 1;
|
||||
self.consecutive_failures = 0;
|
||||
}
|
||||
|
||||
/// Record a failed compaction; open circuit breaker if too many.
|
||||
pub fn on_failure(&mut self) {
|
||||
self.consecutive_failures += 1;
|
||||
if self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
|
||||
warn!(
|
||||
failures = self.consecutive_failures,
|
||||
"Auto-compact circuit breaker opened – disabling for this session"
|
||||
);
|
||||
self.disabled = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Token-usage state relative to the context window.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TokenWarningState {
|
||||
/// Plenty of space left.
|
||||
Ok,
|
||||
/// Getting close – warn the user.
|
||||
Warning,
|
||||
/// Critical – compact now.
|
||||
Critical,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Threshold helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return the effective context-window size in tokens for the given model.
|
||||
/// These are approximate; the API enforces the real limits server-side.
|
||||
pub fn context_window_for_model(model: &str) -> u64 {
|
||||
if model.contains("opus-4") || model.contains("sonnet-4") || model.contains("haiku-4") {
|
||||
200_000
|
||||
} else if model.contains("claude-3-5") || model.contains("claude-3.5") {
|
||||
200_000
|
||||
} else {
|
||||
100_000
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine token-warning state given current input token count and model.
|
||||
pub fn calculate_token_warning_state(input_tokens: u64, model: &str) -> TokenWarningState {
|
||||
let window = context_window_for_model(model);
|
||||
let remaining = window.saturating_sub(input_tokens);
|
||||
|
||||
if remaining <= WARNING_THRESHOLD_BUFFER_TOKENS as u64 {
|
||||
TokenWarningState::Warning
|
||||
} else {
|
||||
TokenWarningState::Ok
|
||||
}
|
||||
}
|
||||
|
||||
/// Return `true` when auto-compaction should fire.
|
||||
pub fn should_auto_compact(input_tokens: u64, model: &str, state: &AutoCompactState) -> bool {
|
||||
if state.disabled {
|
||||
return false;
|
||||
}
|
||||
let window = context_window_for_model(model);
|
||||
let threshold = (window as f64 * AUTOCOMPACT_TRIGGER_FRACTION) as u64;
|
||||
input_tokens >= threshold
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Core compaction logic
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Summarise `messages[..split_at]` using the Anthropic API and return a
|
||||
/// new conversation consisting of a single summary message followed by
|
||||
/// `messages[split_at..]`.
|
||||
async fn summarise_head(
|
||||
client: &cc_api::AnthropicClient,
|
||||
messages: &[Message],
|
||||
split_at: usize,
|
||||
model: &str,
|
||||
) -> Result<Vec<Message>, ClaudeError> {
|
||||
if split_at == 0 {
|
||||
return Ok(messages.to_vec());
|
||||
}
|
||||
|
||||
let head = &messages[..split_at];
|
||||
|
||||
// Build a transcript string for the summarisation prompt.
|
||||
let mut transcript = String::new();
|
||||
for msg in head {
|
||||
let role_label = match msg.role {
|
||||
Role::User => "Human",
|
||||
Role::Assistant => "Assistant",
|
||||
};
|
||||
let text = msg.get_all_text();
|
||||
if !text.is_empty() {
|
||||
transcript.push_str(&format!("{}: {}\n\n", role_label, text));
|
||||
}
|
||||
}
|
||||
|
||||
let summarise_prompt = format!(
|
||||
"Please create a comprehensive yet concise summary of the conversation transcript \
|
||||
below. The summary will be used as context for continuing the conversation, so \
|
||||
include all important decisions, code changes, findings, and context that would be \
|
||||
needed to continue seamlessly.\n\n\
|
||||
Focus on:\n\
|
||||
- Key decisions made and their rationale\n\
|
||||
- Code or files that were created/modified\n\
|
||||
- Important findings or conclusions\n\
|
||||
- The current state of any ongoing tasks\n\
|
||||
- Any constraints or requirements discovered\n\n\
|
||||
<transcript>\n{}\n</transcript>",
|
||||
transcript
|
||||
);
|
||||
|
||||
let api_msgs = vec![ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Value::String(summarise_prompt),
|
||||
}];
|
||||
|
||||
let request = CreateMessageRequest::builder(model, 4096)
|
||||
.messages(api_msgs)
|
||||
.system(SystemPrompt::Text(
|
||||
"You are a helpful assistant that creates concise conversation summaries. \
|
||||
Be thorough but concise. Preserve technical details, file names, and code snippets \
|
||||
that would be important for continuing the work."
|
||||
.to_string(),
|
||||
))
|
||||
.build();
|
||||
|
||||
// Use a null handler since we just want the final accumulated message.
|
||||
let handler: Arc<dyn StreamHandler> = Arc::new(cc_api::streaming::NullStreamHandler);
|
||||
let mut rx = client.create_message_stream(request, handler).await?;
|
||||
let mut acc = StreamAccumulator::new();
|
||||
|
||||
while let Some(evt) = rx.recv().await {
|
||||
acc.on_event(&evt);
|
||||
if matches!(evt, StreamEvent::MessageStop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (summary_msg, _usage, _stop) = acc.finish();
|
||||
let summary_text = summary_msg.get_all_text();
|
||||
|
||||
if summary_text.is_empty() {
|
||||
return Err(ClaudeError::Other("Compact summary was empty".to_string()));
|
||||
}
|
||||
|
||||
// Build the new conversation:
|
||||
// [user: compact summary preamble] [assistant: summary content] [tail messages]
|
||||
let compact_notice = Message::user(format!(
|
||||
"<compact-summary>\n\
|
||||
The conversation history has been automatically compacted to stay within context limits.\n\
|
||||
The following is a summary of the previous conversation:\n\n\
|
||||
{}\n\
|
||||
</compact-summary>",
|
||||
summary_text
|
||||
));
|
||||
|
||||
let mut new_messages = vec![compact_notice];
|
||||
new_messages.extend_from_slice(&messages[split_at..]);
|
||||
|
||||
Ok(new_messages)
|
||||
}
|
||||
|
||||
/// Compact `messages` in-place, replacing the head with a summary.
|
||||
/// Returns the new messages vector on success.
|
||||
pub async fn compact_conversation(
|
||||
client: &cc_api::AnthropicClient,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
) -> Result<Vec<Message>, ClaudeError> {
|
||||
let total = messages.len();
|
||||
|
||||
if total <= KEEP_RECENT_MESSAGES + 1 {
|
||||
debug!(
|
||||
total,
|
||||
"Too few messages to compact – keeping everything"
|
||||
);
|
||||
return Ok(messages.to_vec());
|
||||
}
|
||||
|
||||
// Split: summarise everything except the most recent KEEP_RECENT_MESSAGES.
|
||||
let split_at = total.saturating_sub(KEEP_RECENT_MESSAGES);
|
||||
|
||||
info!(
|
||||
total,
|
||||
split_at,
|
||||
keep = KEEP_RECENT_MESSAGES,
|
||||
"Compacting conversation"
|
||||
);
|
||||
|
||||
summarise_head(client, messages, split_at, model).await
|
||||
}
|
||||
|
||||
/// Auto-compact `messages` if needed. Updates `state` in place.
|
||||
/// Returns `Some(new_messages)` if compaction ran, `None` otherwise.
|
||||
pub async fn auto_compact_if_needed(
|
||||
client: &cc_api::AnthropicClient,
|
||||
messages: &[Message],
|
||||
input_tokens: u64,
|
||||
model: &str,
|
||||
state: &mut AutoCompactState,
|
||||
) -> Option<Vec<Message>> {
|
||||
if !should_auto_compact(input_tokens, model, state) {
|
||||
return None;
|
||||
}
|
||||
|
||||
info!(
|
||||
input_tokens,
|
||||
model,
|
||||
compaction_count = state.compaction_count,
|
||||
"Auto-compact triggered"
|
||||
);
|
||||
|
||||
match compact_conversation(client, messages, model).await {
|
||||
Ok(new_msgs) => {
|
||||
state.on_success();
|
||||
info!(
|
||||
original_count = messages.len(),
|
||||
new_count = new_msgs.len(),
|
||||
"Auto-compact complete"
|
||||
);
|
||||
Some(new_msgs)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Auto-compact failed");
|
||||
state.on_failure();
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
173
src-rust/crates/query/src/coordinator.rs
Normal file
173
src-rust/crates/query/src/coordinator.rs
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
//! Coordinator mode: multi-worker agent orchestration
|
||||
|
||||
use crate::*;
|
||||
|
||||
pub const COORDINATOR_ENV_VAR: &str = "CLAUDE_CODE_COORDINATOR_MODE";
|
||||
|
||||
pub fn is_coordinator_mode() -> bool {
|
||||
std::env::var(COORDINATOR_ENV_VAR)
|
||||
.map(|v| !v.is_empty() && v != "0" && v != "false")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// System prompt sections injected when coordinator mode is active
|
||||
pub fn coordinator_system_prompt() -> &'static str {
|
||||
r#"
|
||||
## Coordinator Mode
|
||||
|
||||
You are operating as an orchestrator for parallel worker agents.
|
||||
|
||||
### Your Role
|
||||
- Orchestrate workers using the Agent tool to spawn parallel subagents
|
||||
- Use SendMessage to continue communication with running workers
|
||||
- Use TaskStop to cancel workers that are no longer needed
|
||||
- Synthesize findings across workers before presenting to the user
|
||||
- Answer directly when the question doesn't need delegation
|
||||
|
||||
### Task Workflow
|
||||
1. **Research Phase**: Spawn workers to gather information in parallel
|
||||
2. **Synthesis Phase**: Collect and merge worker findings
|
||||
3. **Implementation Phase**: Delegate implementation tasks to specialized workers
|
||||
4. **Verification Phase**: Spawn verification workers to validate results
|
||||
|
||||
### Worker Guidelines
|
||||
- Worker prompts must be fully self-contained (workers cannot see your conversation)
|
||||
- Always synthesize findings before spawning follow-up workers
|
||||
- Workers have access to all standard tools + MCP + skills
|
||||
- Use TaskCreate/TaskUpdate to track parallel work
|
||||
|
||||
### Internal Tools (do not delegate to workers)
|
||||
- Agent, SendMessage, TaskStop (coordination only)
|
||||
"#
|
||||
}
|
||||
|
||||
/// Tools that should NOT be passed to worker agents
|
||||
pub const INTERNAL_COORDINATOR_TOOLS: &[&str] = &[
|
||||
"Agent",
|
||||
"SendMessage",
|
||||
"TaskStop",
|
||||
];
|
||||
|
||||
/// Get the user context injected for coordinator sessions
|
||||
pub fn coordinator_user_context(available_tools: &[String], mcp_servers: &[String]) -> String {
|
||||
let tool_list = available_tools
|
||||
.iter()
|
||||
.filter(|t| !INTERNAL_COORDINATOR_TOOLS.contains(&t.as_str()))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let mcp_section = if mcp_servers.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\nConnected MCP servers: {}", mcp_servers.join(", "))
|
||||
};
|
||||
|
||||
format!(
|
||||
"Available worker tools: {}{}\n",
|
||||
tool_list, mcp_section
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if session mode matches current coordinator setting, returns warning if mismatched
|
||||
pub fn match_session_mode(stored_coordinator: bool) -> Option<String> {
|
||||
let current = is_coordinator_mode();
|
||||
if stored_coordinator != current {
|
||||
if current {
|
||||
std::env::set_var(COORDINATOR_ENV_VAR, "1");
|
||||
} else {
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
}
|
||||
Some(format!(
|
||||
"Session was created in {} mode, switching to match.",
|
||||
if stored_coordinator { "coordinator" } else { "standard" }
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_coordinator_mode_unset() {
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
assert!(!is_coordinator_mode());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_coordinator_mode_set_to_one() {
|
||||
std::env::set_var(COORDINATOR_ENV_VAR, "1");
|
||||
assert!(is_coordinator_mode());
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_coordinator_mode_set_to_false() {
|
||||
std::env::set_var(COORDINATOR_ENV_VAR, "false");
|
||||
assert!(!is_coordinator_mode());
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_coordinator_mode_set_to_zero() {
|
||||
std::env::set_var(COORDINATOR_ENV_VAR, "0");
|
||||
assert!(!is_coordinator_mode());
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_user_context_filters_internal_tools() {
|
||||
let tools = vec![
|
||||
"Bash".to_string(),
|
||||
"Agent".to_string(),
|
||||
"SendMessage".to_string(),
|
||||
"TaskStop".to_string(),
|
||||
"Read".to_string(),
|
||||
];
|
||||
let ctx = coordinator_user_context(&tools, &[]);
|
||||
assert!(ctx.contains("Bash"));
|
||||
assert!(ctx.contains("Read"));
|
||||
assert!(!ctx.contains("Agent"));
|
||||
assert!(!ctx.contains("SendMessage"));
|
||||
assert!(!ctx.contains("TaskStop"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_user_context_mcp_servers() {
|
||||
let tools = vec!["Bash".to_string()];
|
||||
let mcps = vec!["filesystem".to_string(), "git".to_string()];
|
||||
let ctx = coordinator_user_context(&tools, &mcps);
|
||||
assert!(ctx.contains("filesystem"));
|
||||
assert!(ctx.contains("git"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_session_mode_no_change_needed() {
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
// current = false, stored = false → no warning
|
||||
assert!(match_session_mode(false).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_session_mode_switches_to_coordinator() {
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
// current = false, stored = true → should flip and warn
|
||||
let msg = match_session_mode(true);
|
||||
assert!(msg.is_some());
|
||||
assert!(msg.unwrap().contains("coordinator"));
|
||||
// Clean up
|
||||
std::env::remove_var(COORDINATOR_ENV_VAR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_system_prompt_content() {
|
||||
let prompt = coordinator_system_prompt();
|
||||
assert!(prompt.contains("Coordinator Mode"));
|
||||
assert!(prompt.contains("orchestrator"));
|
||||
assert!(prompt.contains("Research Phase"));
|
||||
assert!(prompt.contains("Synthesis Phase"));
|
||||
}
|
||||
}
|
||||
114
src-rust/crates/query/src/cron_scheduler.rs
Normal file
114
src-rust/crates/query/src/cron_scheduler.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// cron_scheduler: background task that fires cron-scheduled prompts.
|
||||
//
|
||||
// Runs as a long-lived tokio task. Every minute it checks the global CRON_STORE
|
||||
// (in cc-tools) for tasks whose cron expression matches the current wall-clock
|
||||
// minute. Matching tasks are fired by spawning a sub-query loop, exactly like
|
||||
// the AgentTool does for sub-agents.
|
||||
//
|
||||
// One-shot tasks (recurring=false) are automatically removed from the store
|
||||
// by `pop_due_tasks` after they are returned.
|
||||
|
||||
use crate::{QueryConfig, QueryOutcome, run_query_loop};
|
||||
use cc_core::types::Message;
|
||||
use cc_tools::Tool;
|
||||
use cc_tools::ToolContext;
|
||||
use chrono::Timelike;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{Duration, sleep};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
/// Start the background cron scheduler.
|
||||
///
|
||||
/// Returns immediately; the scheduler runs as a detached tokio task.
|
||||
/// Call `cancel.cancel()` to stop it gracefully.
|
||||
pub fn start_cron_scheduler(
|
||||
client: Arc<cc_api::AnthropicClient>,
|
||||
tools: Arc<Vec<Box<dyn Tool>>>,
|
||||
tool_ctx: ToolContext,
|
||||
query_config: QueryConfig,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
run_scheduler_loop(client, tools, tool_ctx, query_config, cancel).await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_scheduler_loop(
|
||||
client: Arc<cc_api::AnthropicClient>,
|
||||
tools: Arc<Vec<Box<dyn Tool>>>,
|
||||
tool_ctx: ToolContext,
|
||||
query_config: QueryConfig,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
info!("Cron scheduler started");
|
||||
|
||||
loop {
|
||||
// Sleep until the next whole-minute boundary (±1s tolerance).
|
||||
let now = chrono::Local::now();
|
||||
let secs_into_minute = now.second() as u64;
|
||||
let nanos_ms = now.nanosecond() as u64 / 1_000_000;
|
||||
// How many ms until the next minute starts? Use saturating sub to avoid underflow.
|
||||
let ms_to_next_minute = (60u64.saturating_sub(secs_into_minute))
|
||||
.saturating_mul(1_000)
|
||||
.saturating_sub(nanos_ms)
|
||||
.max(1); // always sleep at least 1ms
|
||||
|
||||
tokio::select! {
|
||||
_ = sleep(Duration::from_millis(ms_to_next_minute)) => {}
|
||||
_ = cancel.cancelled() => {
|
||||
info!("Cron scheduler stopped");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let tick_time = chrono::Local::now();
|
||||
debug!(time = %tick_time.format("%H:%M"), "Cron scheduler tick");
|
||||
|
||||
// Find tasks due at this minute.
|
||||
let due = cc_tools::cron::pop_due_tasks(&tick_time).await;
|
||||
|
||||
for task in due {
|
||||
info!(id = %task.id, cron = %task.cron, "Firing cron task");
|
||||
|
||||
let client = client.clone();
|
||||
let tools = tools.clone();
|
||||
let tool_ctx = tool_ctx.clone();
|
||||
let query_config = query_config.clone();
|
||||
let cost_tracker = tool_ctx.cost_tracker.clone();
|
||||
let cancel_child = cancel.clone();
|
||||
let task_id = task.id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut messages = vec![Message::user(task.prompt.clone())];
|
||||
|
||||
let outcome = run_query_loop(
|
||||
client.as_ref(),
|
||||
&mut messages,
|
||||
&tools,
|
||||
&tool_ctx,
|
||||
&query_config,
|
||||
cost_tracker,
|
||||
None, // background — no UI event channel
|
||||
cancel_child,
|
||||
)
|
||||
.await;
|
||||
|
||||
match outcome {
|
||||
QueryOutcome::EndTurn { .. } => {
|
||||
info!(id = %task_id, "Cron task completed");
|
||||
}
|
||||
QueryOutcome::Error(e) => {
|
||||
error!(id = %task_id, error = %e, "Cron task failed");
|
||||
}
|
||||
QueryOutcome::MaxTokens { .. } => {
|
||||
info!(id = %task_id, "Cron task hit max tokens");
|
||||
}
|
||||
QueryOutcome::Cancelled => {
|
||||
debug!(id = %task_id, "Cron task cancelled");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
636
src-rust/crates/query/src/lib.rs
Normal file
636
src-rust/crates/query/src/lib.rs
Normal file
|
|
@ -0,0 +1,636 @@
|
|||
// cc-query: The core agentic query loop.
|
||||
//
|
||||
// This crate implements the main conversation loop that:
|
||||
// 1. Sends messages to the Anthropic API
|
||||
// 2. Processes streaming responses
|
||||
// 3. Detects tool-use requests and dispatches them
|
||||
// 4. Feeds tool results back to the model
|
||||
// 5. Handles auto-compact when the context window fills up
|
||||
// 6. Manages stop conditions (end_turn, max_turns, cancellation)
|
||||
|
||||
pub mod agent_tool;
|
||||
pub mod auto_dream;
|
||||
pub mod compact;
|
||||
pub mod coordinator;
|
||||
pub mod cron_scheduler;
|
||||
pub use agent_tool::AgentTool;
|
||||
pub use cron_scheduler::start_cron_scheduler;
|
||||
pub use compact::{
|
||||
AutoCompactState, TokenWarningState, auto_compact_if_needed, calculate_token_warning_state,
|
||||
compact_conversation, context_window_for_model, should_auto_compact,
|
||||
};
|
||||
|
||||
use cc_api::{
|
||||
ApiMessage, ApiToolDefinition, CreateMessageRequest, StreamAccumulator, StreamEvent,
|
||||
StreamHandler, SystemPrompt, ThinkingConfig,
|
||||
};
|
||||
use cc_core::config::Config;
|
||||
use cc_core::cost::CostTracker;
|
||||
use cc_core::error::ClaudeError;
|
||||
use cc_core::types::{ContentBlock, Message, ToolResultContent, UsageInfo};
|
||||
use cc_tools::{Tool, ToolContext, ToolResult};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Outcome of a single query-loop run.
|
||||
#[derive(Debug)]
|
||||
pub enum QueryOutcome {
|
||||
/// The model finished its turn (end_turn stop reason).
|
||||
EndTurn { message: Message, usage: UsageInfo },
|
||||
/// The model hit max_tokens.
|
||||
MaxTokens { partial_message: Message, usage: UsageInfo },
|
||||
/// The conversation was cancelled by the user.
|
||||
Cancelled,
|
||||
/// An unrecoverable error occurred.
|
||||
Error(ClaudeError),
|
||||
}
|
||||
|
||||
/// Configuration for a single query-loop invocation.
|
||||
#[derive(Clone)]
|
||||
pub struct QueryConfig {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub max_turns: u32,
|
||||
pub system_prompt: Option<String>,
|
||||
pub append_system_prompt: Option<String>,
|
||||
pub thinking_budget: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for QueryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: cc_core::constants::DEFAULT_MODEL.to_string(),
|
||||
max_tokens: cc_core::constants::DEFAULT_MAX_TOKENS,
|
||||
max_turns: cc_core::constants::MAX_TURNS_DEFAULT,
|
||||
system_prompt: None,
|
||||
append_system_prompt: None,
|
||||
thinking_budget: None,
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryConfig {
|
||||
pub fn from_config(cfg: &Config) -> Self {
|
||||
Self {
|
||||
model: cfg.effective_model().to_string(),
|
||||
max_tokens: cfg.effective_max_tokens(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Events emitted by the query loop for the TUI to render.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum QueryEvent {
|
||||
/// A stream event from the API.
|
||||
Stream(StreamEvent),
|
||||
/// A tool is about to be executed.
|
||||
ToolStart { tool_name: String, tool_id: String },
|
||||
/// A tool has finished executing.
|
||||
ToolEnd { tool_name: String, tool_id: String, result: String, is_error: bool },
|
||||
/// The model finished a turn.
|
||||
TurnComplete { turn: u32, stop_reason: String },
|
||||
/// An informational status message.
|
||||
Status(String),
|
||||
/// An error.
|
||||
Error(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Query loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Run the agentic query loop.
|
||||
///
|
||||
/// This sends the conversation to the API, handles tool calls in a loop, and
|
||||
/// returns when the model issues an end_turn or an error/limit is hit.
|
||||
pub async fn run_query_loop(
|
||||
client: &cc_api::AnthropicClient,
|
||||
messages: &mut Vec<Message>,
|
||||
tools: &[Box<dyn Tool>],
|
||||
tool_ctx: &ToolContext,
|
||||
config: &QueryConfig,
|
||||
cost_tracker: Arc<CostTracker>,
|
||||
event_tx: Option<mpsc::UnboundedSender<QueryEvent>>,
|
||||
cancel_token: tokio_util::sync::CancellationToken,
|
||||
) -> QueryOutcome {
|
||||
let mut turn = 0u32;
|
||||
let mut compact_state = compact::AutoCompactState::default();
|
||||
|
||||
loop {
|
||||
turn += 1;
|
||||
if turn > config.max_turns {
|
||||
info!(turns = turn, "Max turns reached");
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(QueryEvent::Status(format!(
|
||||
"Reached maximum turn limit ({})",
|
||||
config.max_turns
|
||||
)));
|
||||
}
|
||||
// Return the last assistant message if any
|
||||
let last_msg = messages
|
||||
.last()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| Message::assistant("Max turns reached."));
|
||||
return QueryOutcome::EndTurn {
|
||||
message: last_msg,
|
||||
usage: UsageInfo::default(),
|
||||
};
|
||||
}
|
||||
|
||||
// Check for cancellation
|
||||
if cancel_token.is_cancelled() {
|
||||
return QueryOutcome::Cancelled;
|
||||
}
|
||||
|
||||
// Build API request
|
||||
let api_messages: Vec<ApiMessage> = messages.iter().map(ApiMessage::from).collect();
|
||||
let api_tools: Vec<ApiToolDefinition> = tools
|
||||
.iter()
|
||||
.map(|t| ApiToolDefinition::from(&t.to_definition()))
|
||||
.collect();
|
||||
|
||||
let system = build_system_prompt(config);
|
||||
|
||||
let mut req_builder = CreateMessageRequest::builder(&config.model, config.max_tokens)
|
||||
.messages(api_messages)
|
||||
.system(system)
|
||||
.tools(api_tools);
|
||||
|
||||
// Only enable extended thinking if an explicit budget was provided.
|
||||
if let Some(budget) = config.thinking_budget {
|
||||
req_builder = req_builder.thinking(ThinkingConfig::enabled(budget));
|
||||
}
|
||||
|
||||
let request = req_builder.build();
|
||||
|
||||
// Create a stream handler that forwards to the event channel
|
||||
let handler: Arc<dyn StreamHandler> = if let Some(ref tx) = event_tx {
|
||||
let tx = tx.clone();
|
||||
Arc::new(ChannelStreamHandler { tx })
|
||||
} else {
|
||||
Arc::new(cc_api::streaming::NullStreamHandler)
|
||||
};
|
||||
|
||||
// Send to API
|
||||
debug!(turn, model = %config.model, "Sending API request");
|
||||
let mut stream_rx = match client.create_message_stream(request, handler).await {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
error!(error = %e, "API request failed");
|
||||
return QueryOutcome::Error(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate the streamed response
|
||||
let mut accumulator = StreamAccumulator::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
return QueryOutcome::Cancelled;
|
||||
}
|
||||
event = stream_rx.recv() => {
|
||||
match event {
|
||||
Some(evt) => {
|
||||
accumulator.on_event(&evt);
|
||||
match &evt {
|
||||
StreamEvent::Error { error_type, message } => {
|
||||
if error_type == "overloaded_error" {
|
||||
warn!("API overloaded, should retry");
|
||||
}
|
||||
error!(error_type, message, "Stream error");
|
||||
}
|
||||
StreamEvent::MessageStop => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
None => break, // Stream ended
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (assistant_msg, usage, stop_reason) = accumulator.finish();
|
||||
|
||||
// Track costs
|
||||
cost_tracker.add_usage(
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.cache_creation_input_tokens,
|
||||
usage.cache_read_input_tokens,
|
||||
);
|
||||
|
||||
// Append assistant message to conversation
|
||||
messages.push(assistant_msg.clone());
|
||||
|
||||
let stop = stop_reason.as_deref().unwrap_or("end_turn");
|
||||
|
||||
// Auto-compact: if context is near-full, summarise older messages now
|
||||
// (before the next turn's API call would fail with prompt-too-long).
|
||||
if stop == "end_turn" || stop == "tool_use" {
|
||||
if let Some(new_msgs) = compact::auto_compact_if_needed(
|
||||
client,
|
||||
messages,
|
||||
usage.input_tokens,
|
||||
&config.model,
|
||||
&mut compact_state,
|
||||
)
|
||||
.await
|
||||
{
|
||||
*messages = new_msgs;
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(QueryEvent::Status(
|
||||
"Context compacted to stay within limits.".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(QueryEvent::TurnComplete {
|
||||
turn,
|
||||
stop_reason: stop.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Helper closure for firing the Stop hook.
|
||||
macro_rules! fire_stop_hook {
|
||||
($msg:expr) => {{
|
||||
let stop_ctx = cc_core::hooks::HookContext {
|
||||
event: "Stop".to_string(),
|
||||
tool_name: None,
|
||||
tool_input: None,
|
||||
tool_output: Some($msg.get_all_text()),
|
||||
is_error: None,
|
||||
session_id: Some(tool_ctx.session_id.clone()),
|
||||
};
|
||||
cc_core::hooks::run_hooks(
|
||||
&tool_ctx.config.hooks,
|
||||
cc_core::config::HookEvent::Stop,
|
||||
&stop_ctx,
|
||||
&tool_ctx.working_dir,
|
||||
)
|
||||
.await;
|
||||
}};
|
||||
}
|
||||
|
||||
match stop {
|
||||
"end_turn" => {
|
||||
fire_stop_hook!(assistant_msg);
|
||||
return QueryOutcome::EndTurn {
|
||||
message: assistant_msg,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
"max_tokens" => {
|
||||
return QueryOutcome::MaxTokens {
|
||||
partial_message: assistant_msg,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
"tool_use" => {
|
||||
// Extract tool calls and execute them
|
||||
let tool_blocks = assistant_msg.get_tool_use_blocks();
|
||||
if tool_blocks.is_empty() {
|
||||
// Shouldn't happen but treat as end_turn
|
||||
return QueryOutcome::EndTurn {
|
||||
message: assistant_msg,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
let mut result_blocks: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
for block in tool_blocks {
|
||||
if let ContentBlock::ToolUse { id, name, input } = block {
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(QueryEvent::ToolStart {
|
||||
tool_name: name.clone(),
|
||||
tool_id: id.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Fire PreToolUse hooks (blocking hooks can cancel execution)
|
||||
let hooks = &tool_ctx.config.hooks;
|
||||
let hook_ctx = cc_core::hooks::HookContext {
|
||||
event: "PreToolUse".to_string(),
|
||||
tool_name: Some(name.clone()),
|
||||
tool_input: Some(input.clone()),
|
||||
tool_output: None,
|
||||
is_error: None,
|
||||
session_id: Some(tool_ctx.session_id.clone()),
|
||||
};
|
||||
let pre_outcome = cc_core::hooks::run_hooks(
|
||||
hooks,
|
||||
cc_core::config::HookEvent::PreToolUse,
|
||||
&hook_ctx,
|
||||
&tool_ctx.working_dir,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = if let cc_core::hooks::HookOutcome::Blocked(reason) = pre_outcome {
|
||||
warn!(tool = name, reason = %reason, "PreToolUse hook blocked execution");
|
||||
cc_tools::ToolResult::error(format!("Blocked by hook: {}", reason))
|
||||
} else {
|
||||
execute_tool(&name, &input, tools, tool_ctx).await
|
||||
};
|
||||
|
||||
// Fire PostToolUse hooks
|
||||
let post_ctx = cc_core::hooks::HookContext {
|
||||
event: "PostToolUse".to_string(),
|
||||
tool_name: Some(name.clone()),
|
||||
tool_input: Some(input.clone()),
|
||||
tool_output: Some(result.content.clone()),
|
||||
is_error: Some(result.is_error),
|
||||
session_id: Some(tool_ctx.session_id.clone()),
|
||||
};
|
||||
cc_core::hooks::run_hooks(
|
||||
hooks,
|
||||
cc_core::config::HookEvent::PostToolUse,
|
||||
&post_ctx,
|
||||
&tool_ctx.working_dir,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(QueryEvent::ToolEnd {
|
||||
tool_name: name.clone(),
|
||||
tool_id: id.clone(),
|
||||
result: result.content.clone(),
|
||||
is_error: result.is_error,
|
||||
});
|
||||
}
|
||||
|
||||
result_blocks.push(ContentBlock::ToolResult {
|
||||
tool_use_id: id.clone(),
|
||||
content: ToolResultContent::Text(result.content),
|
||||
is_error: if result.is_error { Some(true) } else { None },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Append tool results as a user message
|
||||
messages.push(Message::user_blocks(result_blocks));
|
||||
|
||||
// Continue the loop to send results back to the model
|
||||
continue;
|
||||
}
|
||||
"stop_sequence" => {
|
||||
fire_stop_hook!(assistant_msg);
|
||||
return QueryOutcome::EndTurn {
|
||||
message: assistant_msg,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
other => {
|
||||
warn!(stop_reason = other, "Unknown stop reason, treating as end_turn");
|
||||
fire_stop_hook!(assistant_msg);
|
||||
return QueryOutcome::EndTurn {
|
||||
message: assistant_msg,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a single tool invocation.
|
||||
async fn execute_tool(
|
||||
name: &str,
|
||||
input: &Value,
|
||||
tools: &[Box<dyn Tool>],
|
||||
ctx: &ToolContext,
|
||||
) -> ToolResult {
|
||||
let tool = tools.iter().find(|t| t.name() == name);
|
||||
|
||||
match tool {
|
||||
Some(tool) => {
|
||||
debug!(tool = name, "Executing tool");
|
||||
tool.execute(input.clone(), ctx).await
|
||||
}
|
||||
None => {
|
||||
warn!(tool = name, "Unknown tool requested");
|
||||
ToolResult::error(format!("Unknown tool: {}", name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the system prompt from config.
|
||||
///
|
||||
/// Delegates to `cc_core::system_prompt::build_system_prompt` so that all
|
||||
/// default content (capabilities, safety guidelines, dynamic-boundary marker,
|
||||
/// etc.) is assembled in one place. The `QueryConfig` fields map directly to
|
||||
/// `SystemPromptOptions`:
|
||||
///
|
||||
/// - `system_prompt` → `custom_system_prompt` (added to cacheable block)
|
||||
/// - `append_system_prompt` → `append_system_prompt` (added after boundary)
|
||||
fn build_system_prompt(config: &QueryConfig) -> SystemPrompt {
|
||||
use cc_core::system_prompt::{OutputStyle, SystemPromptOptions};
|
||||
|
||||
let opts = SystemPromptOptions {
|
||||
custom_system_prompt: config.system_prompt.clone(),
|
||||
append_system_prompt: config.append_system_prompt.clone(),
|
||||
// All other fields use sensible defaults:
|
||||
// - prefix: auto-detect from env
|
||||
// - output_style: Default (no suffix)
|
||||
// - working_directory: None (callers inject via append if needed)
|
||||
// - memory_content: empty (callers inject via append if needed)
|
||||
// - replace_system_prompt: false (additive mode)
|
||||
// - coordinator_mode: false
|
||||
output_style: OutputStyle::Default,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let text = cc_core::system_prompt::build_system_prompt(&opts);
|
||||
SystemPrompt::Text(text)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use cc_api::SystemPrompt;
|
||||
|
||||
fn make_config(sys: Option<&str>, append: Option<&str>) -> QueryConfig {
|
||||
QueryConfig {
|
||||
model: "claude-sonnet-4-6".to_string(),
|
||||
max_tokens: 4096,
|
||||
max_turns: 10,
|
||||
system_prompt: sys.map(String::from),
|
||||
append_system_prompt: append.map(String::from),
|
||||
thinking_budget: None,
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
|
||||
// ---- build_system_prompt tests ------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_default_when_empty() {
|
||||
// The default prompt (no custom system prompt set) should include the
|
||||
// Claude Code attribution and standard sections.
|
||||
let cfg = make_config(None, None);
|
||||
let prompt = build_system_prompt(&cfg);
|
||||
if let SystemPrompt::Text(text) = prompt {
|
||||
assert!(
|
||||
text.contains("Claude Code") || text.contains("Claude agent"),
|
||||
"Default prompt should contain attribution: {}",
|
||||
text
|
||||
);
|
||||
assert!(
|
||||
text.contains(cc_core::system_prompt::SYSTEM_PROMPT_DYNAMIC_BOUNDARY),
|
||||
"Default prompt must contain the dynamic boundary marker"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected SystemPrompt::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_with_custom() {
|
||||
// A custom system prompt is injected into the cacheable section as
|
||||
// <custom_instructions>; the default sections are still present.
|
||||
let cfg = make_config(Some("You are a code reviewer."), None);
|
||||
let prompt = build_system_prompt(&cfg);
|
||||
if let SystemPrompt::Text(text) = prompt {
|
||||
assert!(
|
||||
text.contains("You are a code reviewer."),
|
||||
"Custom prompt text should appear in the output"
|
||||
);
|
||||
assert!(
|
||||
text.contains("Claude Code") || text.contains("Claude agent"),
|
||||
"Default attribution should still be present"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected SystemPrompt::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_with_append() {
|
||||
// Appended text lands after the dynamic boundary.
|
||||
let cfg = make_config(Some("Base prompt."), Some("Additional context."));
|
||||
let prompt = build_system_prompt(&cfg);
|
||||
if let SystemPrompt::Text(text) = prompt {
|
||||
assert!(text.contains("Base prompt."));
|
||||
assert!(text.contains("Additional context."));
|
||||
// append_system_prompt appears after the boundary
|
||||
let boundary_pos = text
|
||||
.find(cc_core::system_prompt::SYSTEM_PROMPT_DYNAMIC_BOUNDARY)
|
||||
.expect("boundary must exist");
|
||||
let append_pos = text.find("Additional context.").unwrap();
|
||||
assert!(
|
||||
append_pos > boundary_pos,
|
||||
"Appended text must appear after the dynamic boundary"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected SystemPrompt::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_append_only() {
|
||||
// When only append is set, default sections are present plus the
|
||||
// appended text after the dynamic boundary.
|
||||
let cfg = make_config(None, Some("Appended text."));
|
||||
let prompt = build_system_prompt(&cfg);
|
||||
if let SystemPrompt::Text(text) = prompt {
|
||||
assert!(
|
||||
text.contains("Appended text."),
|
||||
"Appended text must appear in the prompt"
|
||||
);
|
||||
let boundary_pos = text
|
||||
.find(cc_core::system_prompt::SYSTEM_PROMPT_DYNAMIC_BOUNDARY)
|
||||
.expect("boundary must exist");
|
||||
let append_pos = text.find("Appended text.").unwrap();
|
||||
assert!(
|
||||
append_pos > boundary_pos,
|
||||
"Appended text must appear after the dynamic boundary"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected SystemPrompt::Text");
|
||||
}
|
||||
}
|
||||
|
||||
// ---- QueryConfig tests --------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_query_config_clone() {
|
||||
let cfg = make_config(Some("test"), Some("append"));
|
||||
let cloned = cfg.clone();
|
||||
assert_eq!(cloned.model, "claude-sonnet-4-6");
|
||||
assert_eq!(cloned.max_tokens, 4096);
|
||||
assert_eq!(cloned.system_prompt, Some("test".to_string()));
|
||||
}
|
||||
|
||||
// ---- QueryOutcome variant tests -----------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_query_outcome_debug() {
|
||||
// Ensure the enum variants can be created and debug-formatted
|
||||
let outcome = QueryOutcome::Cancelled;
|
||||
let s = format!("{:?}", outcome);
|
||||
assert!(s.contains("Cancelled"));
|
||||
|
||||
let err_outcome = QueryOutcome::Error(cc_core::error::ClaudeError::RateLimit);
|
||||
let s2 = format!("{:?}", err_outcome);
|
||||
assert!(s2.contains("Error"));
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream handler that forwards events to an unbounded channel.
|
||||
struct ChannelStreamHandler {
|
||||
tx: mpsc::UnboundedSender<QueryEvent>,
|
||||
}
|
||||
|
||||
impl StreamHandler for ChannelStreamHandler {
|
||||
fn on_event(&self, event: &StreamEvent) {
|
||||
let _ = self.tx.send(QueryEvent::Stream(event.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Single-shot query (non-looping, for simple one-off calls)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Run a single (non-agentic) query – no tool loop, just one API call.
|
||||
pub async fn run_single_query(
|
||||
client: &cc_api::AnthropicClient,
|
||||
messages: Vec<Message>,
|
||||
config: &QueryConfig,
|
||||
) -> Result<Message, ClaudeError> {
|
||||
let api_messages: Vec<ApiMessage> = messages.iter().map(ApiMessage::from).collect();
|
||||
let system = build_system_prompt(config);
|
||||
|
||||
let request = CreateMessageRequest::builder(&config.model, config.max_tokens)
|
||||
.messages(api_messages)
|
||||
.system(system)
|
||||
.build();
|
||||
|
||||
let handler: Arc<dyn StreamHandler> = Arc::new(cc_api::streaming::NullStreamHandler);
|
||||
|
||||
let mut rx = client.create_message_stream(request, handler).await?;
|
||||
let mut acc = StreamAccumulator::new();
|
||||
|
||||
while let Some(evt) = rx.recv().await {
|
||||
acc.on_event(&evt);
|
||||
if matches!(evt, StreamEvent::MessageStop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (msg, _usage, _stop) = acc.finish();
|
||||
Ok(msg)
|
||||
}
|
||||
33
src-rust/crates/tools/Cargo.toml
Normal file
33
src-rust/crates/tools/Cargo.toml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
[package]
|
||||
name = "cc-tools"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-mcp = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
glob = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
which = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
7
src-rust/crates/tools/src/agent_tool.rs
Normal file
7
src-rust/crates/tools/src/agent_tool.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
// AgentTool is defined in cc-query to avoid a circular dependency:
|
||||
// cc-tools → cc-query → cc-tools would be circular.
|
||||
//
|
||||
// The AgentTool implementation lives in crates/query/src/agent_tool.rs and is
|
||||
// re-exported from cc-query as `cc_query::AgentTool`.
|
||||
//
|
||||
// This file exists only as a placeholder to keep the directory tidy.
|
||||
79
src-rust/crates/tools/src/ask_user.rs
Normal file
79
src-rust/crates/tools/src/ask_user.rs
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
// AskUserQuestion tool: ask the human operator a question and wait for a response.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct AskUserQuestionTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AskUserInput {
|
||||
question: String,
|
||||
#[serde(default)]
|
||||
options: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for AskUserQuestionTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_ASK_USER
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Ask the user a question and wait for their response. Use this when you \
|
||||
need clarification, confirmation, or additional information from the user. \
|
||||
The question will be displayed and the user can type their answer."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::None
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the user"
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Optional list of choices for multiple-choice questions"
|
||||
}
|
||||
},
|
||||
"required": ["question"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: AskUserInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
debug!(question = %params.question, "Asking user");
|
||||
|
||||
// In non-interactive mode we cannot ask the user.
|
||||
if ctx.non_interactive {
|
||||
return ToolResult::error(
|
||||
"Cannot ask user questions in non-interactive mode".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// The actual prompt/response is handled at the TUI layer, which will
|
||||
// intercept this tool result and display the question. We return a
|
||||
// placeholder that the query loop replaces.
|
||||
let meta = json!({
|
||||
"question": params.question,
|
||||
"options": params.options,
|
||||
"type": "ask_user",
|
||||
});
|
||||
|
||||
ToolResult::success(format!("Question: {}", params.question))
|
||||
.with_metadata(meta)
|
||||
}
|
||||
}
|
||||
199
src-rust/crates/tools/src/bash.rs
Normal file
199
src-rust/crates/tools/src/bash.rs
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// Bash tool: execute shell commands with timeout and streaming output.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
pub struct BashTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BashInput {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
#[serde(default = "default_timeout")]
|
||||
timeout: u64,
|
||||
#[serde(default)]
|
||||
run_in_background: bool,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
120_000 // 2 minutes in ms
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BashTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_BASH
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Executes a given bash command and returns its output. The working directory \
|
||||
persists between commands, but shell state does not. Avoid using interactive \
|
||||
commands. Use this tool for running shell commands, scripts, git operations, \
|
||||
and system tasks."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::Execute
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Clear, concise description of what this command does"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Optional timeout in milliseconds (max 600000, default 120000)"
|
||||
},
|
||||
"run_in_background": {
|
||||
"type": "boolean",
|
||||
"description": "Set to true to run command in the background"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: BashInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
// Permission check
|
||||
let desc = params
|
||||
.description
|
||||
.as_deref()
|
||||
.unwrap_or(¶ms.command);
|
||||
if let Err(e) = ctx.check_permission(self.name(), desc, false) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
let timeout_ms = params.timeout.min(600_000);
|
||||
let timeout_dur = Duration::from_millis(timeout_ms);
|
||||
|
||||
// Determine shell
|
||||
let (shell, flag) = if cfg!(windows) {
|
||||
("cmd", "/C")
|
||||
} else {
|
||||
("bash", "-c")
|
||||
};
|
||||
|
||||
debug!(command = %params.command, "Executing bash command");
|
||||
|
||||
let mut child = match Command::new(shell)
|
||||
.arg(flag)
|
||||
.arg(¶ms.command)
|
||||
.current_dir(&ctx.working_dir)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null())
|
||||
.spawn()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => return ToolResult::error(format!("Failed to spawn command: {}", e)),
|
||||
};
|
||||
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
|
||||
// Collect output with a timeout
|
||||
let result = tokio::time::timeout(timeout_dur, async {
|
||||
let mut stdout_lines = Vec::new();
|
||||
let mut stderr_lines = Vec::new();
|
||||
|
||||
if let Some(stdout) = stdout {
|
||||
let reader = BufReader::new(stdout);
|
||||
let mut lines = reader.lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
stdout_lines.push(line);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
let reader = BufReader::new(stderr);
|
||||
let mut lines = reader.lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
stderr_lines.push(line);
|
||||
}
|
||||
}
|
||||
|
||||
let status = child.wait().await;
|
||||
|
||||
(stdout_lines, stderr_lines, status)
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok((stdout_lines, stderr_lines, status)) => {
|
||||
let exit_code = status
|
||||
.map(|s| s.code().unwrap_or(-1))
|
||||
.unwrap_or(-1);
|
||||
|
||||
let mut output = String::new();
|
||||
|
||||
if !stdout_lines.is_empty() {
|
||||
output.push_str(&stdout_lines.join("\n"));
|
||||
}
|
||||
|
||||
if !stderr_lines.is_empty() {
|
||||
if !output.is_empty() {
|
||||
output.push_str("\n");
|
||||
}
|
||||
output.push_str("STDERR:\n");
|
||||
output.push_str(&stderr_lines.join("\n"));
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
output = "(no output)".to_string();
|
||||
}
|
||||
|
||||
// Truncate very long output
|
||||
const MAX_OUTPUT_LEN: usize = 100_000;
|
||||
if output.len() > MAX_OUTPUT_LEN {
|
||||
let half = MAX_OUTPUT_LEN / 2;
|
||||
let start = &output[..half];
|
||||
let end = &output[output.len() - half..];
|
||||
output = format!(
|
||||
"{}\n\n... ({} characters truncated) ...\n\n{}",
|
||||
start,
|
||||
output.len() - MAX_OUTPUT_LEN,
|
||||
end
|
||||
);
|
||||
}
|
||||
|
||||
if exit_code != 0 {
|
||||
ToolResult::error(format!(
|
||||
"Command exited with code {}\n{}",
|
||||
exit_code, output
|
||||
))
|
||||
} else {
|
||||
ToolResult::success(output)
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout – try to kill the child
|
||||
let _ = child.kill().await;
|
||||
ToolResult::error(format!(
|
||||
"Command timed out after {}ms",
|
||||
timeout_ms
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
151
src-rust/crates/tools/src/brief.rs
Normal file
151
src-rust/crates/tools/src/brief.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
// BriefTool: send a formatted message to the user, optionally with file attachments.
|
||||
//
|
||||
// This is the model's way of proactively communicating status, completions, or
|
||||
// findings without being asked. The message is returned as a tool result and
|
||||
// the TUI renders it prominently.
|
||||
//
|
||||
// Status can be:
|
||||
// "normal" – reply to what the user just said
|
||||
// "proactive" – unsolicited update (task done, blocker, status ping)
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::path::Path;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct BriefTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BriefInput {
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
attachments: Vec<String>,
|
||||
#[serde(default = "default_status")]
|
||||
status: String,
|
||||
}
|
||||
|
||||
fn default_status() -> String { "normal".to_string() }
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AttachmentMeta {
|
||||
path: String,
|
||||
size: u64,
|
||||
is_image: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BriefTool {
|
||||
fn name(&self) -> &str { "Brief" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a formatted message to the user, optionally with file attachments. \
|
||||
Use status=\"proactive\" when surfacing something the user hasn't asked for \
|
||||
(task completion, a blocker, an unsolicited update). \
|
||||
Use status=\"normal\" when replying to something the user just said."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message to send. Supports Markdown."
|
||||
},
|
||||
"attachments": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Optional file paths to attach (images, diffs, logs)"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["normal", "proactive"],
|
||||
"description": "Use 'proactive' for unsolicited updates, 'normal' for direct replies"
|
||||
}
|
||||
},
|
||||
"required": ["message", "status"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: BriefInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
if params.message.trim().is_empty() {
|
||||
return ToolResult::error("Message cannot be empty.".to_string());
|
||||
}
|
||||
|
||||
// Resolve and validate attachments
|
||||
let mut resolved: Vec<AttachmentMeta> = Vec::new();
|
||||
let mut errors: Vec<String> = Vec::new();
|
||||
|
||||
for raw_path in ¶ms.attachments {
|
||||
let path = ctx.resolve_path(raw_path);
|
||||
match resolve_attachment(&path).await {
|
||||
Ok(meta) => resolved.push(meta),
|
||||
Err(e) => errors.push(format!("{}: {}", raw_path, e)),
|
||||
}
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
return ToolResult::error(format!(
|
||||
"Failed to resolve attachments:\n{}",
|
||||
errors.join("\n")
|
||||
));
|
||||
}
|
||||
|
||||
debug!(
|
||||
status = %params.status,
|
||||
attachments = resolved.len(),
|
||||
"Brief message"
|
||||
);
|
||||
|
||||
// Build result payload
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let mut result = json!({
|
||||
"message": params.message,
|
||||
"status": params.status,
|
||||
"sentAt": now,
|
||||
});
|
||||
|
||||
if !resolved.is_empty() {
|
||||
result["attachments"] = serde_json::to_value(&resolved).unwrap_or_default();
|
||||
}
|
||||
|
||||
ToolResult::success(params.message).with_metadata(result)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn resolve_attachment(path: &Path) -> Result<AttachmentMeta, String> {
|
||||
let meta = tokio::fs::metadata(path)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if !meta.is_file() {
|
||||
return Err("not a file".to_string());
|
||||
}
|
||||
|
||||
let size = meta.len();
|
||||
let is_image = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| matches!(e.to_lowercase().as_str(), "png" | "jpg" | "jpeg" | "gif" | "webp" | "svg"))
|
||||
.unwrap_or(false);
|
||||
|
||||
Ok(AttachmentMeta {
|
||||
path: path.display().to_string(),
|
||||
size,
|
||||
is_image,
|
||||
})
|
||||
}
|
||||
573
src-rust/crates/tools/src/bundled_skills.rs
Normal file
573
src-rust/crates/tools/src/bundled_skills.rs
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
//! Bundled skill definitions for the Skill tool.
|
||||
//!
|
||||
//! Each entry in `BUNDLED_SKILLS` mirrors one of the TypeScript
|
||||
//! `registerXxxSkill()` calls under `src/skills/bundled/`. Only publicly
|
||||
//! invocable, user-facing skills are included; internal or ANT-only skills
|
||||
//! (stuck, remember, verify) are omitted from the user-visible list but are
|
||||
//! still present as documentation stubs so callers can discover them.
|
||||
//!
|
||||
//! The `SkillTool` checks bundled skills *before* scanning disk directories,
|
||||
//! so bundled names take precedence over same-named `.md` files.
|
||||
|
||||
/// A single bundled skill definition.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BundledSkill {
|
||||
/// Primary name used to invoke the skill (e.g. `"simplify"`).
|
||||
pub name: &'static str,
|
||||
/// One-line description shown in `/skill list` output and to the model.
|
||||
pub description: &'static str,
|
||||
/// Additional names that map to this skill.
|
||||
pub aliases: &'static [&'static str],
|
||||
/// Optional guidance for the model about when to auto-invoke.
|
||||
pub when_to_use: Option<&'static str>,
|
||||
/// Placeholder shown next to the skill name in help text.
|
||||
pub argument_hint: Option<&'static str>,
|
||||
/// The prompt template. `$ARGUMENTS` is replaced at call time.
|
||||
/// `$ARGUMENTS_SUFFIX` expands to `": <args>"` when args are non-empty,
|
||||
/// or `""` otherwise.
|
||||
pub prompt_template: &'static str,
|
||||
/// If `Some`, only these tool names are available during the skill run.
|
||||
pub allowed_tools: Option<&'static [&'static str]>,
|
||||
/// Whether a human user can invoke this skill via `/skill <name>`.
|
||||
pub user_invocable: bool,
|
||||
}
|
||||
|
||||
/// All bundled skills.
|
||||
pub const BUNDLED_SKILLS: &[BundledSkill] = &[
|
||||
// -----------------------------------------------------------------------
|
||||
// simplify
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "simplify",
|
||||
description: "Review changed code for reuse, quality, and efficiency, then fix any issues found.",
|
||||
aliases: &[],
|
||||
when_to_use: Some("After writing code, when you want a quality review and cleanup pass."),
|
||||
argument_hint: None,
|
||||
prompt_template: r#"# Simplify: Code Review and Cleanup
|
||||
|
||||
Review all changed files for reuse, quality, and efficiency. Fix any issues found.
|
||||
|
||||
## Phase 1: Identify Changes
|
||||
|
||||
Run `git diff` (or `git diff HEAD` if there are staged changes) to see what changed.
|
||||
If there are no git changes, review the most recently modified files that were
|
||||
mentioned or edited earlier in this conversation.
|
||||
|
||||
## Phase 2: Launch Three Review Agents in Parallel
|
||||
|
||||
Use the Agent tool to launch all three agents concurrently in a single message.
|
||||
Pass each agent the full diff so it has complete context.
|
||||
|
||||
### Agent 1: Code Reuse Review
|
||||
|
||||
For each change:
|
||||
1. **Search for existing utilities and helpers** that could replace newly written code.
|
||||
2. **Flag any new function that duplicates existing functionality.**
|
||||
3. **Flag any inline logic that could use an existing utility** — hand-rolled string
|
||||
manipulation, manual path handling, custom environment checks, etc.
|
||||
|
||||
### Agent 2: Code Quality Review
|
||||
|
||||
Review the same changes for hacky patterns:
|
||||
1. **Redundant state** that duplicates existing state.
|
||||
2. **Parameter sprawl** — new parameters instead of restructuring.
|
||||
3. **Copy-paste with slight variation** that should be unified.
|
||||
4. **Leaky abstractions** — exposing internal details.
|
||||
5. **Stringly-typed code** where constants or enums already exist.
|
||||
6. **Unnecessary comments** narrating what code does (not why).
|
||||
|
||||
### Agent 3: Efficiency Review
|
||||
|
||||
Review the same changes for efficiency:
|
||||
1. **Unnecessary work** — redundant computations, duplicate reads.
|
||||
2. **Missed concurrency** — independent operations run sequentially.
|
||||
3. **Hot-path bloat** — blocking work added to startup or per-request paths.
|
||||
4. **Recurring no-op updates** — unconditional updates in polling loops.
|
||||
5. **Memory** — unbounded data structures, missing cleanup.
|
||||
|
||||
## Phase 3: Fix Issues
|
||||
|
||||
Wait for all three agents to complete. Aggregate findings and fix each issue.
|
||||
If a finding is a false positive, note it and move on.
|
||||
|
||||
When done, briefly summarize what was fixed (or confirm the code was already clean).
|
||||
$ARGUMENTS_SUFFIX"#,
|
||||
allowed_tools: None,
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// remember
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "remember",
|
||||
description: "Review auto-memory entries and propose promotions to CLAUDE.md, CLAUDE.local.md, or shared memory.",
|
||||
aliases: &["mem", "save"],
|
||||
when_to_use: Some("When the user wants to review, organise, or promote their auto-memory entries."),
|
||||
argument_hint: Some("[additional context]"),
|
||||
prompt_template: r#"# Memory Review
|
||||
|
||||
## Goal
|
||||
Review the user's memory landscape and produce a clear report of proposed changes,
|
||||
grouped by action type. Do NOT apply changes — present proposals for user approval.
|
||||
|
||||
## Steps
|
||||
|
||||
### 1. Gather all memory layers
|
||||
Read CLAUDE.md and CLAUDE.local.md from the project root (if they exist).
|
||||
Your auto-memory content is already in your system prompt — review it there.
|
||||
|
||||
### 2. Classify each auto-memory entry
|
||||
|
||||
| Destination | What belongs there |
|
||||
|---|---|
|
||||
| **CLAUDE.md** | Project conventions all contributors should follow |
|
||||
| **CLAUDE.local.md** | Personal instructions specific to this user |
|
||||
| **Stay in auto-memory** | Working notes, temporary context, uncertain patterns |
|
||||
|
||||
### 3. Identify cleanup opportunities
|
||||
- **Duplicates**: auto-memory entries already in CLAUDE.md → propose removing
|
||||
- **Outdated**: CLAUDE.md entries contradicted by newer auto-memory → propose updating
|
||||
- **Conflicts**: contradictions between layers → propose resolution
|
||||
|
||||
### 4. Present the report
|
||||
Output a structured report grouped by: Promotions, Cleanup, Ambiguous, No action needed.
|
||||
|
||||
## Rules
|
||||
- Present ALL proposals before making any changes
|
||||
- Do NOT modify files without explicit user approval
|
||||
- Ask about ambiguous entries — don't guess
|
||||
$ARGUMENTS_SUFFIX"#,
|
||||
allowed_tools: Some(&["Read", "Write", "Edit", "Glob"]),
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// debug
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "debug",
|
||||
description: "Enable debug logging for this session and help diagnose issues.",
|
||||
aliases: &["diagnose"],
|
||||
when_to_use: Some("When there is an error, bug, or unexpected behaviour to investigate."),
|
||||
argument_hint: Some("[issue description or error message]"),
|
||||
prompt_template: r#"# Debug Skill
|
||||
|
||||
Help the user debug an issue they are encountering.
|
||||
|
||||
## Issue Description
|
||||
|
||||
$ARGUMENTS
|
||||
|
||||
## Systematic Debugging Approach
|
||||
|
||||
1. **Reproduce** — Confirm the exact error / behaviour.
|
||||
2. **Locate** — Find the relevant code (read files, grep for error messages).
|
||||
3. **Hypothesize** — Form 2–3 hypotheses about the root cause.
|
||||
4. **Test** — Verify each hypothesis systematically.
|
||||
5. **Fix** — Implement the fix for the confirmed root cause.
|
||||
6. **Verify** — Confirm the fix resolves the issue.
|
||||
|
||||
## Settings Reference
|
||||
|
||||
Settings files are in:
|
||||
- User: ~/.claude/settings.json
|
||||
- Project: .claude/settings.json
|
||||
- Local: .claude/settings.local.json
|
||||
|
||||
Read the relevant files before making any changes."#,
|
||||
allowed_tools: Some(&["Read", "Grep", "Glob"]),
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// stuck
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "stuck",
|
||||
description: "Help get unstuck when you don't know how to proceed.",
|
||||
aliases: &["help-me", "unblock"],
|
||||
when_to_use: Some("When you are stuck, confused, or don't know how to proceed."),
|
||||
argument_hint: Some("[what you're trying to do]"),
|
||||
prompt_template: r#"The user is stuck$ARGUMENTS_SUFFIX. Help them get unstuck:
|
||||
|
||||
1. Clarify what they are trying to achieve (if unclear).
|
||||
2. Identify why they might be stuck (missing context, unclear requirements, technical blocker).
|
||||
3. Suggest 2–3 concrete next steps in order of likelihood of success.
|
||||
4. If a technical blocker: propose specific debugging steps or workarounds.
|
||||
5. Ask clarifying questions if needed.
|
||||
|
||||
Be direct and actionable. Focus on unblocking, not on explaining concepts."#,
|
||||
allowed_tools: None,
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// batch
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "batch",
|
||||
description: "Research and plan a large-scale change, then execute it in parallel across isolated worktree agents that each open a PR.",
|
||||
aliases: &[],
|
||||
when_to_use: Some("When the user wants to make a sweeping, mechanical change across many files that can be decomposed into independent parallel units."),
|
||||
argument_hint: Some("<instruction>"),
|
||||
prompt_template: r#"# Batch: Parallel Work Orchestration
|
||||
|
||||
You are orchestrating a large, parallelisable change across this codebase.
|
||||
|
||||
## User Instruction
|
||||
|
||||
$ARGUMENTS
|
||||
|
||||
## Phase 1: Research and Plan (Plan Mode)
|
||||
|
||||
Enter plan mode, then:
|
||||
|
||||
1. **Understand the scope.** Launch subagents to deeply research what this instruction
|
||||
touches. Find all files, patterns, and call sites that need to change.
|
||||
|
||||
2. **Decompose into independent units.** Break the work into 5–30 self-contained units.
|
||||
Each unit must be independently implementable in an isolated git worktree and
|
||||
mergeable on its own without depending on another unit's PR landing first.
|
||||
|
||||
3. **Determine the e2e test recipe.** Figure out how a worker can verify its change
|
||||
actually works end-to-end. If you cannot find a concrete path, ask the user.
|
||||
|
||||
4. **Write the plan.** Include: research summary, numbered work units, e2e recipe,
|
||||
and the exact worker instructions.
|
||||
|
||||
## Phase 2: Spawn Workers (After Plan Approval)
|
||||
|
||||
Spawn one background agent per work unit using the Agent tool with
|
||||
`isolation: "worktree"` and `run_in_background: true`. Launch them all in a single
|
||||
message block so they run in parallel. Each agent prompt must be fully self-contained.
|
||||
|
||||
After each agent finishes, parse the `PR: <url>` line from its result and render
|
||||
a status table. When all agents have reported, print a final summary."#,
|
||||
allowed_tools: None,
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// verify
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "verify",
|
||||
description: "Verify that code or behaviour is correct.",
|
||||
aliases: &["check", "validate"],
|
||||
when_to_use: Some("After implementing something, to verify it is correct."),
|
||||
argument_hint: Some("[what to verify]"),
|
||||
prompt_template: r#"# Verify: $ARGUMENTS
|
||||
|
||||
## Verification Steps
|
||||
|
||||
1. Read the relevant code / implementation.
|
||||
2. Check against requirements (if specified).
|
||||
3. Look for edge cases and error conditions.
|
||||
4. Run tests if available.
|
||||
5. Check for common pitfalls: null handling, error propagation, type safety.
|
||||
6. Report: what was verified, what passed, what failed or is uncertain."#,
|
||||
allowed_tools: None,
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// update-config
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "update-config",
|
||||
description: "Configure Claude Code settings (hooks, permissions, env vars, behaviours) via settings.json.",
|
||||
aliases: &["config-update", "settings"],
|
||||
when_to_use: Some("When the user wants to configure automated behaviours, permissions, or settings."),
|
||||
argument_hint: Some("<what to configure>"),
|
||||
prompt_template: r#"# Update Config Skill
|
||||
|
||||
Modify Claude Code configuration by updating settings.json files.
|
||||
|
||||
## Settings File Locations
|
||||
|
||||
| File | Scope | Use For |
|
||||
|------|-------|---------|
|
||||
| `~/.claude/settings.json` | Global | Personal preferences for all projects |
|
||||
| `.claude/settings.json` | Project | Team-wide hooks, permissions, plugins |
|
||||
| `.claude/settings.local.json` | Project (local) | Personal overrides for this project |
|
||||
|
||||
Settings load in order: user → project → local (later overrides earlier).
|
||||
|
||||
## CRITICAL: Read Before Write
|
||||
|
||||
Always read the existing settings file before making changes.
|
||||
Merge new settings with existing ones — never replace the entire file.
|
||||
|
||||
## Hook Events
|
||||
|
||||
PreToolUse, PostToolUse, PreCompact, PostCompact, Stop, Notification, SessionStart
|
||||
|
||||
## User Request
|
||||
|
||||
$ARGUMENTS"#,
|
||||
allowed_tools: Some(&["Read", "Write", "Edit", "Bash"]),
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// claude-api
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "claude-api",
|
||||
description: "Build apps with the Claude API or Anthropic SDK.",
|
||||
aliases: &["api", "anthropic-sdk"],
|
||||
when_to_use: Some("When the user wants to use the Claude API, Anthropic SDK, or build Claude-powered apps."),
|
||||
argument_hint: Some("[what to build]"),
|
||||
prompt_template: r#"# Build a Claude API Integration
|
||||
|
||||
## User Request
|
||||
|
||||
$ARGUMENTS
|
||||
|
||||
## Default Models
|
||||
|
||||
- Most capable: claude-opus-4-6
|
||||
- Balanced: claude-sonnet-4-6
|
||||
- Fast: claude-haiku-4-5-20251001
|
||||
|
||||
## SDK Quickstart
|
||||
|
||||
**Python**
|
||||
```python
|
||||
pip install anthropic
|
||||
import anthropic
|
||||
client = anthropic.Anthropic()
|
||||
```
|
||||
|
||||
**TypeScript / Node**
|
||||
```typescript
|
||||
npm install @anthropic-ai/sdk
|
||||
import Anthropic from '@anthropic-ai/sdk';
|
||||
const client = new Anthropic();
|
||||
```
|
||||
|
||||
## Key API Features
|
||||
|
||||
- Streaming (`stream_message`)
|
||||
- Tool use / function calling
|
||||
- Extended thinking
|
||||
- Prompt caching
|
||||
- Vision (image input)
|
||||
- Files API
|
||||
- Batch processing
|
||||
|
||||
Use async/await patterns. Follow SDK best practices."#,
|
||||
allowed_tools: Some(&["Read", "Grep", "Glob", "WebFetch"]),
|
||||
user_invocable: true,
|
||||
},
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// loop
|
||||
// -----------------------------------------------------------------------
|
||||
BundledSkill {
|
||||
name: "loop",
|
||||
description: "Run a prompt or slash command on a recurring interval.",
|
||||
aliases: &[],
|
||||
when_to_use: Some("When the user wants to run something repeatedly on a schedule."),
|
||||
argument_hint: Some("[interval] <command>"),
|
||||
prompt_template: r#"# /loop — schedule a recurring prompt
|
||||
|
||||
Parse the input below into `[interval] <prompt…>` and schedule it with CronCreate.
|
||||
|
||||
## Parsing (in priority order)
|
||||
|
||||
1. **Leading token**: if the first token matches `^\d+[smhd]$` (e.g. `5m`, `2h`), that
|
||||
is the interval; the rest is the prompt.
|
||||
2. **Trailing "every" clause**: if the input ends with `every <N><unit>` extract that
|
||||
as the interval and strip it from the prompt.
|
||||
3. **Default**: interval is `10m` and the entire input is the prompt.
|
||||
|
||||
If the resulting prompt is empty, show usage `/loop [interval] <prompt>` and stop.
|
||||
|
||||
## Interval → Cron
|
||||
|
||||
| Pattern | Cron | Notes |
|
||||
|---------|------|-------|
|
||||
| `Nm` (N ≤ 59) | `*/N * * * *` | every N minutes |
|
||||
| `Nh` (N ≤ 23) | `0 */N * * *` | every N hours |
|
||||
| `Nd` | `0 0 */N * *` | every N days at midnight |
|
||||
| `Ns` | round up to nearest minute | cron min granularity is 1 min |
|
||||
|
||||
## Action
|
||||
|
||||
1. Call CronCreate with the parsed cron expression and prompt.
|
||||
2. Confirm what was scheduled, including the cron expression and human-readable cadence.
|
||||
3. **Immediately execute the parsed prompt now** — don't wait for the first cron fire.
|
||||
|
||||
## Input
|
||||
|
||||
$ARGUMENTS"#,
|
||||
allowed_tools: Some(&["CronCreate", "CronList"]),
|
||||
user_invocable: true,
|
||||
},
|
||||
];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Find a bundled skill by name or alias (case-insensitive).
|
||||
pub fn find_bundled_skill(name: &str) -> Option<&'static BundledSkill> {
|
||||
let lower = name.to_lowercase();
|
||||
BUNDLED_SKILLS.iter().find(|s| {
|
||||
s.name == lower || s.aliases.iter().any(|a| *a == lower)
|
||||
})
|
||||
}
|
||||
|
||||
/// Return `(name, description)` pairs for all user-invocable bundled skills.
|
||||
pub fn user_invocable_skills() -> Vec<(&'static str, &'static str)> {
|
||||
BUNDLED_SKILLS
|
||||
.iter()
|
||||
.filter(|s| s.user_invocable)
|
||||
.map(|s| (s.name, s.description))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Expand a skill's prompt template, substituting `$ARGUMENTS` and
|
||||
/// `$ARGUMENTS_SUFFIX`.
|
||||
///
|
||||
/// - `$ARGUMENTS` → replaced by `args` verbatim (or `""` when empty)
|
||||
/// - `$ARGUMENTS_SUFFIX` → replaced by `": <args>"` when non-empty, else `""`
|
||||
pub fn expand_prompt(skill: &BundledSkill, args: &str) -> String {
|
||||
let suffix = if args.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(": {}", args)
|
||||
};
|
||||
|
||||
skill
|
||||
.prompt_template
|
||||
.replace("$ARGUMENTS_SUFFIX", &suffix)
|
||||
.replace("$ARGUMENTS", args)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn all_skills_have_non_empty_names() {
|
||||
for s in BUNDLED_SKILLS {
|
||||
assert!(!s.name.is_empty(), "skill has empty name");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_skills_have_non_empty_descriptions() {
|
||||
for s in BUNDLED_SKILLS {
|
||||
assert!(
|
||||
!s.description.is_empty(),
|
||||
"skill '{}' has empty description",
|
||||
s.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_skills_have_non_empty_prompt_templates() {
|
||||
for s in BUNDLED_SKILLS {
|
||||
assert!(
|
||||
!s.prompt_template.is_empty(),
|
||||
"skill '{}' has empty prompt_template",
|
||||
s.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_names_are_unique() {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for s in BUNDLED_SKILLS {
|
||||
assert!(
|
||||
seen.insert(s.name),
|
||||
"duplicate skill name: {}",
|
||||
s.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_by_primary_name() {
|
||||
let skill = find_bundled_skill("simplify");
|
||||
assert!(skill.is_some());
|
||||
assert_eq!(skill.unwrap().name, "simplify");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_by_alias() {
|
||||
let skill = find_bundled_skill("mem");
|
||||
assert!(skill.is_some());
|
||||
assert_eq!(skill.unwrap().name, "remember");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_case_insensitive() {
|
||||
assert!(find_bundled_skill("SIMPLIFY").is_some());
|
||||
assert!(find_bundled_skill("Debug").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_missing_returns_none() {
|
||||
assert!(find_bundled_skill("nonexistent-skill-xyz").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_prompt_substitutes_arguments() {
|
||||
let skill = find_bundled_skill("debug").unwrap();
|
||||
let expanded = expand_prompt(skill, "NullPointerException in Foo.java");
|
||||
assert!(expanded.contains("NullPointerException in Foo.java"));
|
||||
assert!(!expanded.contains("$ARGUMENTS"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_prompt_empty_args_no_residual_placeholder() {
|
||||
let skill = find_bundled_skill("simplify").unwrap();
|
||||
let expanded = expand_prompt(skill, "");
|
||||
assert!(!expanded.contains("$ARGUMENTS"));
|
||||
assert!(!expanded.contains("$ARGUMENTS_SUFFIX"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_prompt_suffix_non_empty() {
|
||||
let skill = find_bundled_skill("stuck").unwrap();
|
||||
let expanded = expand_prompt(skill, "trying to run tests");
|
||||
// Should contain ": trying to run tests" from $ARGUMENTS_SUFFIX
|
||||
assert!(expanded.contains(": trying to run tests"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_prompt_suffix_empty() {
|
||||
let skill = find_bundled_skill("stuck").unwrap();
|
||||
let expanded = expand_prompt(skill, "");
|
||||
// $ARGUMENTS_SUFFIX should be "" (not ": ")
|
||||
assert!(!expanded.contains(": "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_invocable_skills_non_empty() {
|
||||
let skills = user_invocable_skills();
|
||||
assert!(!skills.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_invocable_skills_all_marked_true() {
|
||||
for (name, _) in user_invocable_skills() {
|
||||
let skill = find_bundled_skill(name).unwrap();
|
||||
assert!(
|
||||
skill.user_invocable,
|
||||
"skill '{}' returned by user_invocable_skills() but user_invocable=false",
|
||||
name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
199
src-rust/crates/tools/src/config_tool.rs
Normal file
199
src-rust/crates/tools/src/config_tool.rs
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// ConfigTool: get or set Claude Code configuration settings at runtime.
|
||||
//
|
||||
// Reads from and persists to ~/.claude/settings.json.
|
||||
// Supported settings: model, max_tokens, verbose, permission_mode.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub struct ConfigTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConfigInput {
|
||||
setting: String,
|
||||
value: Option<Value>,
|
||||
}
|
||||
|
||||
static SUPPORTED_SETTINGS: &[(&str, &str)] = &[
|
||||
("model", "LLM model to use (e.g. 'claude-opus-4-6')"),
|
||||
("max_tokens", "Maximum output tokens per response"),
|
||||
("verbose", "Enable verbose logging (true/false)"),
|
||||
("permission_mode", "Permission mode: default | accept_edits | bypass_permissions | plan"),
|
||||
("auto_compact", "Auto-compact conversation when context fills (true/false)"),
|
||||
];
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ConfigTool {
|
||||
fn name(&self) -> &str { "Config" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Get or set Claude Code configuration settings. Omit 'value' to read the current value. \
|
||||
Supported settings: model, max_tokens, verbose, permission_mode, auto_compact. \
|
||||
Changes persist to ~/.claude/settings.json."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Write }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"setting": {
|
||||
"type": "string",
|
||||
"description": "Setting key (e.g. 'model', 'verbose', 'max_tokens', 'permission_mode')"
|
||||
},
|
||||
"value": {
|
||||
"description": "New value to set. Omit to read the current value."
|
||||
}
|
||||
},
|
||||
"required": ["setting"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: ConfigInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let key = params.setting.trim();
|
||||
|
||||
// List all supported settings
|
||||
if key == "list" || key == "help" {
|
||||
let lines: Vec<String> = SUPPORTED_SETTINGS
|
||||
.iter()
|
||||
.map(|(k, d)| format!(" {} — {}", k, d))
|
||||
.collect();
|
||||
return ToolResult::success(format!(
|
||||
"Supported settings:\n{}",
|
||||
lines.join("\n")
|
||||
));
|
||||
}
|
||||
|
||||
// Load current settings
|
||||
let mut settings = match cc_core::config::Settings::load().await {
|
||||
Ok(s) => s,
|
||||
Err(e) => return ToolResult::error(format!("Failed to load settings: {}", e)),
|
||||
};
|
||||
|
||||
if let Some(new_value) = params.value {
|
||||
// SET operation
|
||||
match key {
|
||||
"model" => {
|
||||
let s = match new_value.as_str() {
|
||||
Some(s) => s.to_string(),
|
||||
None => return ToolResult::error("'model' must be a string".to_string()),
|
||||
};
|
||||
settings.config.model = Some(s.clone());
|
||||
if let Err(e) = settings.save().await {
|
||||
return ToolResult::error(format!("Failed to save settings: {}", e));
|
||||
}
|
||||
ToolResult::success(format!("model = \"{}\"", s))
|
||||
}
|
||||
"max_tokens" => {
|
||||
let n = match new_value.as_u64() {
|
||||
Some(n) => n as u32,
|
||||
None => return ToolResult::error("'max_tokens' must be a positive integer".to_string()),
|
||||
};
|
||||
settings.config.max_tokens = Some(n);
|
||||
if let Err(e) = settings.save().await {
|
||||
return ToolResult::error(format!("Failed to save settings: {}", e));
|
||||
}
|
||||
ToolResult::success(format!("max_tokens = {}", n))
|
||||
}
|
||||
"verbose" => {
|
||||
let b = match new_value.as_bool() {
|
||||
Some(b) => b,
|
||||
None => return ToolResult::error("'verbose' must be true or false".to_string()),
|
||||
};
|
||||
settings.config.verbose = b;
|
||||
if let Err(e) = settings.save().await {
|
||||
return ToolResult::error(format!("Failed to save settings: {}", e));
|
||||
}
|
||||
ToolResult::success(format!("verbose = {}", b))
|
||||
}
|
||||
"auto_compact" => {
|
||||
let b = match new_value.as_bool() {
|
||||
Some(b) => b,
|
||||
None => return ToolResult::error("'auto_compact' must be true or false".to_string()),
|
||||
};
|
||||
settings.config.auto_compact = b;
|
||||
if let Err(e) = settings.save().await {
|
||||
return ToolResult::error(format!("Failed to save settings: {}", e));
|
||||
}
|
||||
ToolResult::success(format!("auto_compact = {}", b))
|
||||
}
|
||||
"permission_mode" => {
|
||||
use cc_core::config::PermissionMode;
|
||||
let s = match new_value.as_str() {
|
||||
Some(s) => s,
|
||||
None => return ToolResult::error("'permission_mode' must be a string".to_string()),
|
||||
};
|
||||
let mode = match s {
|
||||
"default" => PermissionMode::Default,
|
||||
"accept_edits" | "acceptEdits" => PermissionMode::AcceptEdits,
|
||||
"bypass_permissions" | "bypassPermissions" => {
|
||||
PermissionMode::BypassPermissions
|
||||
}
|
||||
"plan" => PermissionMode::Plan,
|
||||
_ => {
|
||||
return ToolResult::error(format!(
|
||||
"Unknown permission_mode '{}'. Use: default | accept_edits | bypass_permissions | plan",
|
||||
s
|
||||
))
|
||||
}
|
||||
};
|
||||
settings.config.permission_mode = mode;
|
||||
if let Err(e) = settings.save().await {
|
||||
return ToolResult::error(format!("Failed to save settings: {}", e));
|
||||
}
|
||||
ToolResult::success(format!("permission_mode = \"{}\"", s))
|
||||
}
|
||||
_ => ToolResult::error(format!(
|
||||
"Unknown setting '{}'. Use setting='list' to see all supported settings.",
|
||||
key
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
// GET operation
|
||||
match key {
|
||||
"model" => ToolResult::success(format!(
|
||||
"model = \"{}\"",
|
||||
settings.config.effective_model()
|
||||
)),
|
||||
"max_tokens" => ToolResult::success(format!(
|
||||
"max_tokens = {}",
|
||||
settings.config.effective_max_tokens()
|
||||
)),
|
||||
"verbose" => ToolResult::success(format!(
|
||||
"verbose = {}",
|
||||
settings.config.verbose
|
||||
)),
|
||||
"auto_compact" => ToolResult::success(format!(
|
||||
"auto_compact = {}",
|
||||
settings.config.auto_compact
|
||||
)),
|
||||
"permission_mode" => ToolResult::success(format!(
|
||||
"permission_mode = \"{}\"",
|
||||
permission_mode_str(&settings.config.permission_mode)
|
||||
)),
|
||||
_ => ToolResult::error(format!(
|
||||
"Unknown setting '{}'. Use setting='list' to see all supported settings.",
|
||||
key
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn permission_mode_str(mode: &cc_core::config::PermissionMode) -> &'static str {
|
||||
use cc_core::config::PermissionMode;
|
||||
match mode {
|
||||
PermissionMode::Default => "default",
|
||||
PermissionMode::AcceptEdits => "accept_edits",
|
||||
PermissionMode::BypassPermissions => "bypass_permissions",
|
||||
PermissionMode::Plan => "plan",
|
||||
}
|
||||
}
|
||||
435
src-rust/crates/tools/src/cron.rs
Normal file
435
src-rust/crates/tools/src/cron.rs
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
// Cron tools: schedule recurring and one-shot prompts.
|
||||
//
|
||||
// CronCreateTool – create a new scheduled task (cron expression)
|
||||
// CronDeleteTool – remove an existing scheduled task
|
||||
// CronListTool – list all scheduled tasks
|
||||
//
|
||||
// Scheduled tasks are stored in a global in-memory store (session-only).
|
||||
// Optionally persisted to `.claude/scheduled_tasks.json` (durable mode).
|
||||
//
|
||||
// Cron expression format: "M H DoM Mon DoW" (standard 5-field cron in local
|
||||
// time). For example:
|
||||
// "*/5 * * * *" = every 5 minutes
|
||||
// "30 14 * * 1" = every Monday at 14:30
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Datelike, Local, Timelike};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// In-memory store
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronTask {
|
||||
pub id: String,
|
||||
pub cron: String,
|
||||
pub prompt: String,
|
||||
pub recurring: bool,
|
||||
pub durable: bool,
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
static CRON_STORE: Lazy<Arc<RwLock<HashMap<String, CronTask>>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public scheduler API (used by cc-query cron_scheduler)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Check if a cron expression fires at the given minute-resolution datetime.
|
||||
pub fn cron_matches(expr: &str, dt: &DateTime<Local>) -> bool {
|
||||
let fields: Vec<&str> = expr.split_whitespace().collect();
|
||||
if fields.len() != 5 {
|
||||
return false;
|
||||
}
|
||||
let minute = dt.minute();
|
||||
let hour = dt.hour();
|
||||
let day = dt.day();
|
||||
let month = dt.month();
|
||||
let dow = dt.weekday().num_days_from_sunday(); // 0=Sun .. 6=Sat
|
||||
|
||||
cron_field_matches(fields[0], minute)
|
||||
&& cron_field_matches(fields[1], hour)
|
||||
&& cron_field_matches(fields[2], day)
|
||||
&& cron_field_matches(fields[3], month)
|
||||
&& cron_field_matches(fields[4], dow)
|
||||
}
|
||||
|
||||
fn cron_field_matches(field: &str, value: u32) -> bool {
|
||||
if field == "*" {
|
||||
return true;
|
||||
}
|
||||
// */N step
|
||||
if let Some(step_str) = field.strip_prefix("*/") {
|
||||
if let Ok(step) = step_str.parse::<u32>() {
|
||||
return step > 0 && value % step == 0;
|
||||
}
|
||||
}
|
||||
// Comma-separated list of values or ranges
|
||||
for part in field.split(',') {
|
||||
if cron_range_matches(part, value) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn cron_range_matches(part: &str, value: u32) -> bool {
|
||||
if let Some(dash) = part.find('-') {
|
||||
let lo: u32 = part[..dash].parse().unwrap_or(u32::MAX);
|
||||
let hi: u32 = part[dash + 1..].parse().unwrap_or(0);
|
||||
value >= lo && value <= hi
|
||||
} else {
|
||||
part.parse::<u32>()
|
||||
.map_or(false, |n| n == value || (n == 7 && value == 0)) // 7 = Sunday alias
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all tasks whose cron expression fires at `dt`.
|
||||
/// One-shot tasks (recurring=false) are removed from the store after being returned.
|
||||
pub async fn pop_due_tasks(dt: &DateTime<Local>) -> Vec<CronTask> {
|
||||
let mut store = CRON_STORE.write().await;
|
||||
let due: Vec<CronTask> = store
|
||||
.values()
|
||||
.filter(|t| cron_matches(&t.cron, dt))
|
||||
.cloned()
|
||||
.collect();
|
||||
for t in &due {
|
||||
if !t.recurring {
|
||||
store.remove(&t.id);
|
||||
}
|
||||
}
|
||||
due
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Simple cron expression parser (5-field)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate that a 5-field cron expression is syntactically correct.
|
||||
fn validate_cron(expr: &str) -> bool {
|
||||
let fields: Vec<&str> = expr.split_whitespace().collect();
|
||||
if fields.len() != 5 {
|
||||
return false;
|
||||
}
|
||||
// Check each field: ranges for M(0-59), H(0-23), DoM(1-31), Mon(1-12), DoW(0-7)
|
||||
let ranges = [(0u32, 59), (0, 23), (1, 31), (1, 12), (0, 7)];
|
||||
for (i, field) in fields.iter().enumerate() {
|
||||
if *field == "*" {
|
||||
continue;
|
||||
}
|
||||
// Handle */N (step)
|
||||
if let Some(step) = field.strip_prefix("*/") {
|
||||
if step.parse::<u32>().is_err() {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Handle N-M (range) or N
|
||||
let parts: Vec<&str> = field.split('-').collect();
|
||||
for part in &parts {
|
||||
match part.parse::<u32>() {
|
||||
Ok(n) => {
|
||||
if n < ranges[i].0 || n > ranges[i].1 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Err(_) => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Convert a cron expression to a human-readable description.
|
||||
fn cron_to_human(expr: &str) -> String {
|
||||
let fields: Vec<&str> = expr.split_whitespace().collect();
|
||||
if fields.len() != 5 {
|
||||
return expr.to_string();
|
||||
}
|
||||
|
||||
let (minute, hour, dom, month, dow) = (fields[0], fields[1], fields[2], fields[3], fields[4]);
|
||||
|
||||
if expr == "* * * * *" {
|
||||
return "every minute".to_string();
|
||||
}
|
||||
if minute.starts_with("*/") {
|
||||
let n = &minute[2..];
|
||||
return format!("every {} minutes", n);
|
||||
}
|
||||
if hour == "*" && dom == "*" && month == "*" && dow == "*" {
|
||||
return format!("at minute {} of every hour", minute);
|
||||
}
|
||||
if dom == "*" && month == "*" && dow == "*" {
|
||||
return format!("daily at {:0>2}:{:0>2}", hour, minute);
|
||||
}
|
||||
// Fallback: return the raw expression
|
||||
format!("cron({})", expr)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CronCreate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct CronCreateTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CronCreateInput {
|
||||
cron: String,
|
||||
prompt: String,
|
||||
#[serde(default = "default_true")]
|
||||
recurring: bool,
|
||||
#[serde(default)]
|
||||
durable: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool { true }
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CronCreateTool {
|
||||
fn name(&self) -> &str { "CronCreate" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Schedule a recurring or one-shot prompt using a standard 5-field cron expression \
|
||||
in local time: \"M H DoM Mon DoW\". Examples:\n\
|
||||
- \"*/5 * * * *\" = every 5 minutes\n\
|
||||
- \"30 14 * * 1\" = every Monday at 14:30\n\
|
||||
- \"0 9 15 * *\" = 15th of each month at 09:00\n\
|
||||
Use recurring=false for one-shot (fires once then auto-deletes).\n\
|
||||
Use durable=true to persist across sessions."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "5-field cron expression: M H DoM Mon DoW"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The prompt to run at each scheduled time"
|
||||
},
|
||||
"recurring": {
|
||||
"type": "boolean",
|
||||
"description": "true (default) = repeat on every match; false = fire once then delete"
|
||||
},
|
||||
"durable": {
|
||||
"type": "boolean",
|
||||
"description": "true = persist to .claude/scheduled_tasks.json; false (default) = session only"
|
||||
}
|
||||
},
|
||||
"required": ["cron", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: CronCreateInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
if !validate_cron(¶ms.cron) {
|
||||
return ToolResult::error(format!(
|
||||
"Invalid cron expression '{}'. Expected 5 fields: M H DoM Mon DoW.",
|
||||
params.cron
|
||||
));
|
||||
}
|
||||
|
||||
let mut store = CRON_STORE.write().await;
|
||||
if store.len() >= 50 {
|
||||
return ToolResult::error("Too many scheduled jobs (max 50). Cancel one first.".to_string());
|
||||
}
|
||||
|
||||
let id = Uuid::new_v4().to_string()[..8].to_string();
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
let task = CronTask {
|
||||
id: id.clone(),
|
||||
cron: params.cron.clone(),
|
||||
prompt: params.prompt.clone(),
|
||||
recurring: params.recurring,
|
||||
durable: params.durable,
|
||||
created_at: now,
|
||||
};
|
||||
|
||||
// Optionally persist to disk
|
||||
if params.durable {
|
||||
if let Err(e) = persist_tasks_to_disk(&store, ctx).await {
|
||||
debug!("Failed to persist cron task to disk: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
store.insert(id.clone(), task);
|
||||
let human = cron_to_human(¶ms.cron);
|
||||
|
||||
let where_note = if params.durable {
|
||||
"Persisted to .claude/scheduled_tasks.json"
|
||||
} else {
|
||||
"Session-only (dies when Claude exits)"
|
||||
};
|
||||
|
||||
let msg = if params.recurring {
|
||||
format!(
|
||||
"Scheduled recurring job {} ({}). {}",
|
||||
id, human, where_note
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Scheduled one-shot task {} ({}). {}. Will fire once then auto-delete.",
|
||||
id, human, where_note
|
||||
)
|
||||
};
|
||||
|
||||
ToolResult::success(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CronDelete
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct CronDeleteTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CronDeleteInput {
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CronDeleteTool {
|
||||
fn name(&self) -> &str { "CronDelete" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Cancel a scheduled cron task by its ID. Use CronList to find the ID."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The cron task ID to delete"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: CronDeleteInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let mut store = CRON_STORE.write().await;
|
||||
if store.remove(¶ms.id).is_some() {
|
||||
ToolResult::success(format!("Deleted cron task '{}'.", params.id))
|
||||
} else {
|
||||
ToolResult::error(format!("Cron task '{}' not found.", params.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CronList
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct CronListTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CronListTool {
|
||||
fn name(&self) -> &str { "CronList" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List all currently scheduled cron tasks."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, _input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let store = CRON_STORE.read().await;
|
||||
|
||||
if store.is_empty() {
|
||||
return ToolResult::success("No scheduled cron tasks.".to_string());
|
||||
}
|
||||
|
||||
let mut tasks: Vec<&CronTask> = store.values().collect();
|
||||
tasks.sort_by_key(|t| t.created_at);
|
||||
|
||||
let lines: Vec<String> = tasks
|
||||
.iter()
|
||||
.map(|t| {
|
||||
format!(
|
||||
"{} | {} | {} | recurring={} | durable={} | prompt: {}",
|
||||
t.id,
|
||||
t.cron,
|
||||
cron_to_human(&t.cron),
|
||||
t.recurring,
|
||||
t.durable,
|
||||
if t.prompt.len() > 60 {
|
||||
format!("{}…", &t.prompt[..60])
|
||||
} else {
|
||||
t.prompt.clone()
|
||||
}
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Scheduled tasks ({}):\n\n{}",
|
||||
tasks.len(),
|
||||
lines.join("\n")
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Persist all durable tasks to `.claude/scheduled_tasks.json`.
|
||||
async fn persist_tasks_to_disk(
|
||||
store: &HashMap<String, CronTask>,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<(), String> {
|
||||
let durable: Vec<&CronTask> = store.values().filter(|t| t.durable).collect();
|
||||
let json = serde_json::to_string_pretty(&durable)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let dir = ctx.working_dir.join(".claude");
|
||||
tokio::fs::create_dir_all(&dir)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
tokio::fs::write(dir.join("scheduled_tasks.json"), json)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
64
src-rust/crates/tools/src/enter_plan_mode.rs
Normal file
64
src-rust/crates/tools/src/enter_plan_mode.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
// EnterPlanMode tool: switch the session into planning (read-only) mode.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct EnterPlanModeTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct EnterPlanModeInput {
|
||||
#[serde(default)]
|
||||
reason: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for EnterPlanModeTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_ENTER_PLAN_MODE
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Enter plan mode. In plan mode, the assistant can only read files and \
|
||||
think, but cannot execute commands or write files. Use this to step back \
|
||||
and plan a complex change before implementing it."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::None
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Why you want to enter plan mode"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: EnterPlanModeInput = serde_json::from_value(input).unwrap_or(EnterPlanModeInput {
|
||||
reason: None,
|
||||
});
|
||||
|
||||
debug!(reason = ?params.reason, "Entering plan mode");
|
||||
|
||||
let msg = if let Some(reason) = ¶ms.reason {
|
||||
format!("Entered plan mode: {}", reason)
|
||||
} else {
|
||||
"Entered plan mode. Only read-only operations are allowed.".to_string()
|
||||
};
|
||||
|
||||
ToolResult::success(msg).with_metadata(json!({
|
||||
"type": "enter_plan_mode",
|
||||
"reason": params.reason,
|
||||
}))
|
||||
}
|
||||
}
|
||||
63
src-rust/crates/tools/src/exit_plan_mode.rs
Normal file
63
src-rust/crates/tools/src/exit_plan_mode.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// ExitPlanMode tool: leave planning mode and return to normal execution.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct ExitPlanModeTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ExitPlanModeInput {
|
||||
#[serde(default)]
|
||||
summary: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ExitPlanModeTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_EXIT_PLAN_MODE
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Exit plan mode and return to normal execution mode where all tools \
|
||||
are available. Optionally provide a summary of the plan."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::None
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Summary of the plan you developed"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: ExitPlanModeInput = serde_json::from_value(input).unwrap_or(ExitPlanModeInput {
|
||||
summary: None,
|
||||
});
|
||||
|
||||
debug!(summary = ?params.summary, "Exiting plan mode");
|
||||
|
||||
let msg = if let Some(summary) = ¶ms.summary {
|
||||
format!("Exited plan mode. Plan summary: {}", summary)
|
||||
} else {
|
||||
"Exited plan mode. All tools are now available.".to_string()
|
||||
};
|
||||
|
||||
ToolResult::success(msg).with_metadata(json!({
|
||||
"type": "exit_plan_mode",
|
||||
"summary": params.summary,
|
||||
}))
|
||||
}
|
||||
}
|
||||
152
src-rust/crates/tools/src/file_edit.rs
Normal file
152
src-rust/crates/tools/src/file_edit.rs
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
// FileEdit tool: exact string replacement with old/new strings (like sed but
|
||||
// deterministic). Mirrors the TypeScript Edit tool behaviour.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct FileEditTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FileEditInput {
|
||||
file_path: String,
|
||||
old_string: String,
|
||||
new_string: String,
|
||||
#[serde(default)]
|
||||
replace_all: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileEditTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_FILE_EDIT
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Performs exact string replacements in files. The edit will FAIL if \
|
||||
`old_string` is not unique in the file (unless `replace_all` is true). \
|
||||
You MUST read the file first before editing. Preserve the exact \
|
||||
indentation as it appears in the file."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::Write
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to modify"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "The text to replace (must be unique in the file unless replace_all is true)"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "The text to replace it with (must be different from old_string)"
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences of old_string (default false)"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: FileEditInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
// Validate old != new
|
||||
if params.old_string == params.new_string {
|
||||
return ToolResult::error(
|
||||
"old_string and new_string must be different".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let path = ctx.resolve_path(¶ms.file_path);
|
||||
debug!(path = %path.display(), "Editing file");
|
||||
|
||||
// Permission check
|
||||
if let Err(e) = ctx.check_permission(
|
||||
self.name(),
|
||||
&format!("Edit {}", path.display()),
|
||||
false,
|
||||
) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
// Read current content
|
||||
let content = match tokio::fs::read_to_string(&path).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ToolResult::error(format!(
|
||||
"Failed to read file {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Count occurrences
|
||||
let count = content.matches(¶ms.old_string).count();
|
||||
|
||||
if count == 0 {
|
||||
return ToolResult::error(format!(
|
||||
"old_string not found in {}. Make sure the string matches exactly, \
|
||||
including whitespace and indentation.",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
if count > 1 && !params.replace_all {
|
||||
return ToolResult::error(format!(
|
||||
"old_string appears {} times in {}. Either provide a larger string \
|
||||
with more surrounding context to make it unique, or set replace_all \
|
||||
to true to replace every occurrence.",
|
||||
count,
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Perform replacement
|
||||
let new_content = if params.replace_all {
|
||||
content.replace(¶ms.old_string, ¶ms.new_string)
|
||||
} else {
|
||||
// Replace only the first occurrence
|
||||
content.replacen(¶ms.old_string, ¶ms.new_string, 1)
|
||||
};
|
||||
|
||||
// Write back
|
||||
if let Err(e) = tokio::fs::write(&path, &new_content).await {
|
||||
return ToolResult::error(format!(
|
||||
"Failed to write file {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
));
|
||||
}
|
||||
|
||||
// Build a diff snippet for the response
|
||||
let replacements = if params.replace_all { count } else { 1 };
|
||||
let msg = format!(
|
||||
"Successfully edited {} ({} replacement{}).",
|
||||
path.display(),
|
||||
replacements,
|
||||
if replacements != 1 { "s" } else { "" }
|
||||
);
|
||||
|
||||
ToolResult::success(msg).with_metadata(json!({
|
||||
"file_path": path.display().to_string(),
|
||||
"replacements": replacements,
|
||||
}))
|
||||
}
|
||||
}
|
||||
161
src-rust/crates/tools/src/file_read.rs
Normal file
161
src-rust/crates/tools/src/file_read.rs
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
// FileRead tool: read files with optional line range, image support, PDF page ranges.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct FileReadTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FileReadInput {
|
||||
file_path: String,
|
||||
#[serde(default)]
|
||||
offset: Option<usize>,
|
||||
#[serde(default)]
|
||||
limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileReadTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_FILE_READ
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Reads a file from the local filesystem. You can access any file directly. \
|
||||
By default reads up to 2000 lines from the beginning. Results are returned \
|
||||
with line numbers starting at 1. This tool can read images (PNG, JPG) and \
|
||||
PDF files."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::ReadOnly
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to read"
|
||||
},
|
||||
"offset": {
|
||||
"type": "number",
|
||||
"description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once."
|
||||
},
|
||||
"limit": {
|
||||
"type": "number",
|
||||
"description": "The number of lines to read. Only provide if the file is too large to read at once."
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: FileReadInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let path = ctx.resolve_path(¶ms.file_path);
|
||||
debug!(path = %path.display(), "Reading file");
|
||||
|
||||
// Check if file exists
|
||||
if !path.exists() {
|
||||
return ToolResult::error(format!("File not found: {}", path.display()));
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
if path.is_dir() {
|
||||
return ToolResult::error(format!(
|
||||
"{} is a directory, not a file. Use Bash with `ls` to list directory contents.",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Detect binary / image files by extension
|
||||
let ext = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
let image_exts = ["png", "jpg", "jpeg", "gif", "bmp", "webp", "svg", "ico"];
|
||||
if image_exts.contains(&ext.as_str()) {
|
||||
return ToolResult::success(format!(
|
||||
"[Image file: {}. The image content has been captured for visual analysis.]",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
if ext == "pdf" {
|
||||
return ToolResult::success(format!(
|
||||
"[PDF file: {}. Use the `pages` parameter to read specific page ranges.]",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Read text file
|
||||
let content = match tokio::fs::read_to_string(&path).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
// Might be binary
|
||||
if e.kind() == std::io::ErrorKind::InvalidData {
|
||||
return ToolResult::error(format!(
|
||||
"File appears to be binary and cannot be displayed as text: {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
return ToolResult::error(format!("Failed to read file: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
if content.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"[File {} exists but is empty]",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let total_lines = lines.len();
|
||||
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
let limit = params.limit.unwrap_or(2000);
|
||||
|
||||
// Convert 1-based offset to 0-based index
|
||||
let start = if offset > 0 { offset - 1 } else { 0 };
|
||||
let end = (start + limit).min(total_lines);
|
||||
|
||||
if start >= total_lines {
|
||||
return ToolResult::error(format!(
|
||||
"Offset {} exceeds total line count {} in {}",
|
||||
offset,
|
||||
total_lines,
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
let width = format!("{}", end).len();
|
||||
|
||||
for (i, line) in lines[start..end].iter().enumerate() {
|
||||
let line_num = start + i + 1;
|
||||
output.push_str(&format!("{:>width$}\t{}\n", line_num, line, width = width));
|
||||
}
|
||||
|
||||
if end < total_lines {
|
||||
output.push_str(&format!(
|
||||
"\n... ({} more lines, {} total. Use offset/limit to read more.)\n",
|
||||
total_lines - end,
|
||||
total_lines
|
||||
));
|
||||
}
|
||||
|
||||
ToolResult::success(output)
|
||||
}
|
||||
}
|
||||
110
src-rust/crates/tools/src/file_write.rs
Normal file
110
src-rust/crates/tools/src/file_write.rs
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
// FileWrite tool: write/create files.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct FileWriteTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FileWriteInput {
|
||||
file_path: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileWriteTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_FILE_WRITE
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Writes a file to the local filesystem. This tool will overwrite the existing \
|
||||
file if there is one. Prefer the Edit tool for modifying existing files. \
|
||||
Only use this tool to create new files or for complete rewrites."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::Write
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: FileWriteInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let path = ctx.resolve_path(¶ms.file_path);
|
||||
debug!(path = %path.display(), "Writing file");
|
||||
|
||||
// Permission check
|
||||
if let Err(e) = ctx.check_permission(
|
||||
self.name(),
|
||||
&format!("Write {}", path.display()),
|
||||
false,
|
||||
) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
// Ensure parent directories exist
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.exists() {
|
||||
if let Err(e) = tokio::fs::create_dir_all(parent).await {
|
||||
return ToolResult::error(format!(
|
||||
"Failed to create directory {}: {}",
|
||||
parent.display(),
|
||||
e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let is_new = !path.exists();
|
||||
|
||||
// Write the file
|
||||
if let Err(e) = tokio::fs::write(&path, ¶ms.content).await {
|
||||
return ToolResult::error(format!(
|
||||
"Failed to write file {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
));
|
||||
}
|
||||
|
||||
let line_count = params.content.lines().count();
|
||||
let byte_count = params.content.len();
|
||||
|
||||
let action = if is_new { "Created" } else { "Wrote" };
|
||||
ToolResult::success(format!(
|
||||
"{} {} ({} lines, {} bytes)",
|
||||
action,
|
||||
path.display(),
|
||||
line_count,
|
||||
byte_count
|
||||
))
|
||||
.with_metadata(json!({
|
||||
"file_path": path.display().to_string(),
|
||||
"is_new": is_new,
|
||||
"lines": line_count,
|
||||
"bytes": byte_count,
|
||||
}))
|
||||
}
|
||||
}
|
||||
127
src-rust/crates/tools/src/glob_tool.rs
Normal file
127
src-rust/crates/tools/src/glob_tool.rs
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
// Glob tool: fast file pattern matching.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::path::PathBuf;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct GlobTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GlobInput {
|
||||
pattern: String,
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for GlobTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_GLOB
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fast file pattern matching tool that works with any codebase size. \
|
||||
Supports glob patterns like \"**/*.rs\" or \"src/**/*.ts\". Returns \
|
||||
matching file paths sorted by modification time. Use this tool when \
|
||||
you need to find files by name patterns."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::ReadOnly
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The glob pattern to match files against"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory to search in. Defaults to working directory."
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: GlobInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let base_dir = params
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| ctx.resolve_path(p))
|
||||
.unwrap_or_else(|| ctx.working_dir.clone());
|
||||
|
||||
debug!(pattern = %params.pattern, dir = %base_dir.display(), "Running glob");
|
||||
|
||||
if !base_dir.exists() || !base_dir.is_dir() {
|
||||
return ToolResult::error(format!(
|
||||
"Directory not found: {}",
|
||||
base_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Build the full glob pattern
|
||||
let full_pattern = base_dir.join(¶ms.pattern);
|
||||
let pattern_str = full_pattern.to_string_lossy().to_string();
|
||||
|
||||
// On Windows, normalize backslashes to forward slashes for the glob crate
|
||||
let pattern_str = pattern_str.replace('\\', "/");
|
||||
|
||||
let entries: Vec<PathBuf> = match glob::glob(&pattern_str) {
|
||||
Ok(paths) => paths.filter_map(|p| p.ok()).collect(),
|
||||
Err(e) => {
|
||||
return ToolResult::error(format!("Invalid glob pattern: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"No files matched pattern \"{}\" in {}",
|
||||
params.pattern,
|
||||
base_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Sort by modification time (most recent first) — fall back to name sort
|
||||
let mut entries_with_time: Vec<(PathBuf, std::time::SystemTime)> = entries
|
||||
.into_iter()
|
||||
.filter_map(|p| {
|
||||
let mtime = std::fs::metadata(&p).ok()?.modified().ok()?;
|
||||
Some((p, mtime))
|
||||
})
|
||||
.collect();
|
||||
|
||||
entries_with_time.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
let total = entries_with_time.len();
|
||||
let max_results = 250;
|
||||
let truncated = total > max_results;
|
||||
|
||||
let mut output = String::new();
|
||||
for (path, _) in entries_with_time.iter().take(max_results) {
|
||||
output.push_str(&path.display().to_string());
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
if truncated {
|
||||
output.push_str(&format!(
|
||||
"\n... and {} more files (showing first {})\n",
|
||||
total - max_results,
|
||||
max_results,
|
||||
));
|
||||
}
|
||||
|
||||
ToolResult::success(output)
|
||||
}
|
||||
}
|
||||
364
src-rust/crates/tools/src/grep_tool.rs
Normal file
364
src-rust/crates/tools/src/grep_tool.rs
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
// Grep tool: content search with ripgrep-style options.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use regex::RegexBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::path::PathBuf;
|
||||
use tracing::debug;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
pub struct GrepTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GrepInput {
|
||||
pattern: String,
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
#[serde(default, rename = "type")]
|
||||
file_type: Option<String>,
|
||||
#[serde(default)]
|
||||
glob: Option<String>,
|
||||
#[serde(default = "default_output_mode")]
|
||||
output_mode: String,
|
||||
#[serde(default)]
|
||||
context: Option<usize>,
|
||||
#[serde(default, rename = "-i")]
|
||||
case_insensitive: bool,
|
||||
#[serde(default, rename = "-n")]
|
||||
show_line_numbers: Option<bool>,
|
||||
#[serde(default)]
|
||||
head_limit: Option<usize>,
|
||||
#[serde(default)]
|
||||
multiline: bool,
|
||||
}
|
||||
|
||||
fn default_output_mode() -> String {
|
||||
"files_with_matches".to_string()
|
||||
}
|
||||
|
||||
/// Map file type shorthand to extensions (similar to ripgrep --type).
|
||||
fn extensions_for_type(t: &str) -> Vec<&'static str> {
|
||||
match t {
|
||||
"rust" | "rs" => vec!["rs"],
|
||||
"js" => vec!["js", "jsx", "mjs", "cjs"],
|
||||
"ts" => vec!["ts", "tsx", "mts", "cts"],
|
||||
"py" | "python" => vec!["py", "pyi"],
|
||||
"go" => vec!["go"],
|
||||
"java" => vec!["java"],
|
||||
"c" => vec!["c", "h"],
|
||||
"cpp" => vec!["cpp", "hpp", "cc", "hh", "cxx"],
|
||||
"rb" | "ruby" => vec!["rb"],
|
||||
"php" => vec!["php"],
|
||||
"swift" => vec!["swift"],
|
||||
"kt" | "kotlin" => vec!["kt", "kts"],
|
||||
"css" => vec!["css", "scss", "sass", "less"],
|
||||
"html" => vec!["html", "htm"],
|
||||
"json" => vec!["json"],
|
||||
"yaml" | "yml" => vec!["yaml", "yml"],
|
||||
"toml" => vec!["toml"],
|
||||
"xml" => vec!["xml"],
|
||||
"md" | "markdown" => vec!["md", "markdown"],
|
||||
"sh" | "shell" | "bash" => vec!["sh", "bash", "zsh"],
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for GrepTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_GREP
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"A powerful search tool built on regex. Supports full regex syntax. \
|
||||
Filter files with the `glob` parameter or `type` parameter. Output \
|
||||
modes: \"content\" shows matching lines, \"files_with_matches\" shows \
|
||||
only file paths (default), \"count\" shows match counts."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::ReadOnly
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The regular expression pattern to search for"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory to search in. Defaults to working directory."
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "File type to search (e.g. js, py, rust, go)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to filter files (e.g. \"*.js\")"
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_with_matches", "count"],
|
||||
"description": "Output mode (default: files_with_matches)"
|
||||
},
|
||||
"context": {
|
||||
"type": "number",
|
||||
"description": "Number of context lines before and after each match"
|
||||
},
|
||||
"-i": {
|
||||
"type": "boolean",
|
||||
"description": "Case insensitive search"
|
||||
},
|
||||
"-n": {
|
||||
"type": "boolean",
|
||||
"description": "Show line numbers (for content mode)"
|
||||
},
|
||||
"head_limit": {
|
||||
"type": "number",
|
||||
"description": "Limit output to first N entries (default 250)"
|
||||
},
|
||||
"multiline": {
|
||||
"type": "boolean",
|
||||
"description": "Enable multiline mode where . matches newlines"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: GrepInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let search_path = params
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| ctx.resolve_path(p))
|
||||
.unwrap_or_else(|| ctx.working_dir.clone());
|
||||
|
||||
debug!(pattern = %params.pattern, path = %search_path.display(), "Running grep");
|
||||
|
||||
// Compile regex
|
||||
let regex = match RegexBuilder::new(¶ms.pattern)
|
||||
.case_insensitive(params.case_insensitive)
|
||||
.dot_matches_new_line(params.multiline)
|
||||
.multi_line(params.multiline)
|
||||
.build()
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return ToolResult::error(format!("Invalid regex: {}", e)),
|
||||
};
|
||||
|
||||
let head_limit = params.head_limit.unwrap_or(250);
|
||||
let context_lines = params.context.unwrap_or(0);
|
||||
let show_line_numbers = params.show_line_numbers.unwrap_or(true);
|
||||
|
||||
// Collect candidate file extensions
|
||||
let type_exts: Vec<&str> = params
|
||||
.file_type
|
||||
.as_deref()
|
||||
.map(extensions_for_type)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Build glob matcher for filtering
|
||||
let glob_pattern = params.glob.as_deref();
|
||||
|
||||
// If the search path is a single file, just search it.
|
||||
if search_path.is_file() {
|
||||
return self.search_file(
|
||||
&search_path,
|
||||
®ex,
|
||||
¶ms.output_mode,
|
||||
context_lines,
|
||||
show_line_numbers,
|
||||
);
|
||||
}
|
||||
|
||||
// Walk directory tree
|
||||
let mut results: Vec<String> = Vec::new();
|
||||
let mut match_count = 0usize;
|
||||
|
||||
for entry in WalkDir::new(&search_path)
|
||||
.follow_links(true)
|
||||
.into_iter()
|
||||
.filter_entry(|e| {
|
||||
// Skip hidden directories
|
||||
let name = e.file_name().to_string_lossy();
|
||||
!name.starts_with('.')
|
||||
&& name != "node_modules"
|
||||
&& name != "target"
|
||||
&& name != "__pycache__"
|
||||
&& name != ".git"
|
||||
})
|
||||
{
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let path = entry.path();
|
||||
|
||||
// Type filter
|
||||
if !type_exts.is_empty() {
|
||||
let ext = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("");
|
||||
if !type_exts.contains(&ext) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Glob filter
|
||||
if let Some(pattern) = glob_pattern {
|
||||
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
|
||||
if let Ok(m) = glob::Pattern::new(pattern) {
|
||||
if !m.matches(name) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read file (skip binary)
|
||||
let content = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let mut file_matches: Vec<(usize, &str)> = Vec::new();
|
||||
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if regex.is_match(line) {
|
||||
file_matches.push((i, line));
|
||||
}
|
||||
}
|
||||
|
||||
if file_matches.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match params.output_mode.as_str() {
|
||||
"files_with_matches" => {
|
||||
results.push(path.display().to_string());
|
||||
match_count += 1;
|
||||
}
|
||||
"count" => {
|
||||
results.push(format!("{}:{}", path.display(), file_matches.len()));
|
||||
match_count += 1;
|
||||
}
|
||||
"content" => {
|
||||
for (line_idx, _) in &file_matches {
|
||||
let start = line_idx.saturating_sub(context_lines);
|
||||
let end = (*line_idx + context_lines + 1).min(lines.len());
|
||||
|
||||
for ci in start..end {
|
||||
let prefix = if show_line_numbers {
|
||||
format!("{}:{}:", path.display(), ci + 1)
|
||||
} else {
|
||||
format!("{}:", path.display())
|
||||
};
|
||||
results.push(format!("{}{}", prefix, lines[ci]));
|
||||
}
|
||||
|
||||
if context_lines > 0 {
|
||||
results.push("--".to_string());
|
||||
}
|
||||
|
||||
match_count += 1;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
results.push(path.display().to_string());
|
||||
match_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if match_count >= head_limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if results.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"No matches found for pattern \"{}\" in {}",
|
||||
params.pattern,
|
||||
search_path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let output = results.join("\n");
|
||||
ToolResult::success(output)
|
||||
}
|
||||
}
|
||||
|
||||
impl GrepTool {
|
||||
fn search_file(
|
||||
&self,
|
||||
path: &PathBuf,
|
||||
regex: ®ex::Regex,
|
||||
output_mode: &str,
|
||||
context_lines: usize,
|
||||
show_line_numbers: bool,
|
||||
) -> ToolResult {
|
||||
let content = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => return ToolResult::error(format!("Failed to read {}: {}", path.display(), e)),
|
||||
};
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let mut matching_lines: Vec<usize> = Vec::new();
|
||||
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if regex.is_match(line) {
|
||||
matching_lines.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
if matching_lines.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"No matches found in {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
match output_mode {
|
||||
"files_with_matches" => ToolResult::success(path.display().to_string()),
|
||||
"count" => ToolResult::success(format!(
|
||||
"{}:{}",
|
||||
path.display(),
|
||||
matching_lines.len()
|
||||
)),
|
||||
_ => {
|
||||
let mut results = Vec::new();
|
||||
for line_idx in &matching_lines {
|
||||
let start = line_idx.saturating_sub(context_lines);
|
||||
let end = (*line_idx + context_lines + 1).min(lines.len());
|
||||
for ci in start..end {
|
||||
if show_line_numbers {
|
||||
results.push(format!("{}:{}", ci + 1, lines[ci]));
|
||||
} else {
|
||||
results.push(lines[ci].to_string());
|
||||
}
|
||||
}
|
||||
if context_lines > 0 {
|
||||
results.push("--".to_string());
|
||||
}
|
||||
}
|
||||
ToolResult::success(results.join("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
451
src-rust/crates/tools/src/lib.rs
Normal file
451
src-rust/crates/tools/src/lib.rs
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
// cc-tools: All tool implementations for the Claude Code Rust port.
|
||||
//
|
||||
// Each tool maps to a capability the LLM can invoke: running shell commands,
|
||||
// reading/writing/editing files, searching codebases, fetching web pages, etc.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cc_core::config::PermissionMode;
|
||||
use cc_core::cost::CostTracker;
|
||||
use cc_core::permissions::{PermissionDecision, PermissionHandler, PermissionRequest};
|
||||
use cc_core::types::ToolDefinition;
|
||||
use serde_json::Value;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Sub-modules – each contains a full tool implementation.
|
||||
pub mod ask_user;
|
||||
pub mod bash;
|
||||
pub mod brief;
|
||||
pub mod config_tool;
|
||||
pub mod cron;
|
||||
pub mod enter_plan_mode;
|
||||
pub mod exit_plan_mode;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod glob_tool;
|
||||
pub mod grep_tool;
|
||||
pub mod mcp_resources;
|
||||
pub mod todo_write;
|
||||
pub mod notebook_edit;
|
||||
pub mod powershell;
|
||||
pub mod send_message;
|
||||
pub mod bundled_skills;
|
||||
pub mod skill_tool;
|
||||
pub mod sleep;
|
||||
pub mod tasks;
|
||||
pub mod tool_search;
|
||||
pub mod web_fetch;
|
||||
pub mod web_search;
|
||||
pub mod worktree;
|
||||
|
||||
// Re-exports for convenience.
|
||||
pub use ask_user::AskUserQuestionTool;
|
||||
pub use bash::BashTool;
|
||||
pub use brief::BriefTool;
|
||||
pub use config_tool::ConfigTool;
|
||||
pub use cron::{CronCreateTool, CronDeleteTool, CronListTool};
|
||||
pub use enter_plan_mode::EnterPlanModeTool;
|
||||
pub use exit_plan_mode::ExitPlanModeTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use glob_tool::GlobTool;
|
||||
pub use grep_tool::GrepTool;
|
||||
pub use mcp_resources::{ListMcpResourcesTool, ReadMcpResourceTool};
|
||||
pub use todo_write::TodoWriteTool;
|
||||
pub use notebook_edit::NotebookEditTool;
|
||||
pub use powershell::PowerShellTool;
|
||||
pub use send_message::{SendMessageTool, drain_inbox, peek_inbox};
|
||||
pub use skill_tool::SkillTool;
|
||||
pub use sleep::SleepTool;
|
||||
pub use tasks::{TaskCreateTool, TaskGetTool, TaskListTool, TaskOutputTool, TaskStopTool, TaskUpdateTool};
|
||||
pub use tool_search::ToolSearchTool;
|
||||
pub use web_fetch::WebFetchTool;
|
||||
pub use web_search::WebSearchTool;
|
||||
pub use worktree::{EnterWorktreeTool, ExitWorktreeTool};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Core trait & types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The result of executing a tool.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolResult {
|
||||
/// Content to send back to the model as the tool result.
|
||||
pub content: String,
|
||||
/// Whether this invocation was an error.
|
||||
pub is_error: bool,
|
||||
/// Optional structured metadata (for the TUI to render diffs, etc.).
|
||||
pub metadata: Option<Value>,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
pub fn success(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
is_error: false,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
is_error: true,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, meta: Value) -> Self {
|
||||
self.metadata = Some(meta);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Permission level required by a tool.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PermissionLevel {
|
||||
/// No permission needed (read-only, purely informational).
|
||||
None,
|
||||
/// Read-only access to the filesystem or network.
|
||||
ReadOnly,
|
||||
/// Write access to the filesystem.
|
||||
Write,
|
||||
/// Arbitrary command execution.
|
||||
Execute,
|
||||
/// Potentially dangerous (e.g., bypass sandbox).
|
||||
Dangerous,
|
||||
}
|
||||
|
||||
/// Shared context passed to every tool invocation.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolContext {
|
||||
pub working_dir: PathBuf,
|
||||
pub permission_mode: PermissionMode,
|
||||
pub permission_handler: Arc<dyn PermissionHandler>,
|
||||
pub cost_tracker: Arc<CostTracker>,
|
||||
pub session_id: String,
|
||||
/// If true, suppress interactive prompts (batch / CI mode).
|
||||
pub non_interactive: bool,
|
||||
/// Optional MCP manager for ListMcpResources / ReadMcpResource tools.
|
||||
pub mcp_manager: Option<Arc<cc_mcp::McpManager>>,
|
||||
/// Configured event hooks (PreToolUse, PostToolUse, etc.).
|
||||
pub config: cc_core::config::Config,
|
||||
}
|
||||
|
||||
impl ToolContext {
|
||||
/// Resolve a potentially relative path against the working directory.
|
||||
pub fn resolve_path(&self, path: &str) -> PathBuf {
|
||||
let p = PathBuf::from(path);
|
||||
if p.is_absolute() {
|
||||
p
|
||||
} else {
|
||||
self.working_dir.join(p)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check permissions for a tool invocation.
|
||||
pub fn check_permission(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
description: &str,
|
||||
is_read_only: bool,
|
||||
) -> Result<(), cc_core::error::ClaudeError> {
|
||||
let request = PermissionRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
description: description.to_string(),
|
||||
details: None,
|
||||
is_read_only,
|
||||
};
|
||||
let decision = self.permission_handler.request_permission(&request);
|
||||
match decision {
|
||||
PermissionDecision::Allow | PermissionDecision::AllowPermanently => Ok(()),
|
||||
_ => Err(cc_core::error::ClaudeError::PermissionDenied(format!(
|
||||
"Permission denied for tool '{}'",
|
||||
tool_name
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The trait every tool must implement.
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
/// Human-readable name (matches the constant in cc_core::constants).
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// One-line description shown to the LLM.
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// The permission level the tool requires.
|
||||
fn permission_level(&self) -> PermissionLevel;
|
||||
|
||||
/// JSON Schema describing the tool's input parameters.
|
||||
fn input_schema(&self) -> Value;
|
||||
|
||||
/// Execute the tool with the given JSON input.
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult;
|
||||
|
||||
/// Produce a `ToolDefinition` suitable for sending to the API.
|
||||
fn to_definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: self.name().to_string(),
|
||||
description: self.description().to_string(),
|
||||
input_schema: self.input_schema(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all built-in tools (excluding AgentTool, which lives in cc-query).
|
||||
pub fn all_tools() -> Vec<Box<dyn Tool>> {
|
||||
vec![
|
||||
Box::new(BashTool),
|
||||
Box::new(FileReadTool),
|
||||
Box::new(FileEditTool),
|
||||
Box::new(FileWriteTool),
|
||||
Box::new(GlobTool),
|
||||
Box::new(GrepTool),
|
||||
Box::new(WebFetchTool),
|
||||
Box::new(WebSearchTool),
|
||||
Box::new(NotebookEditTool),
|
||||
Box::new(TaskCreateTool),
|
||||
Box::new(TaskGetTool),
|
||||
Box::new(TaskUpdateTool),
|
||||
Box::new(TaskListTool),
|
||||
Box::new(TaskStopTool),
|
||||
Box::new(TaskOutputTool),
|
||||
Box::new(TodoWriteTool),
|
||||
Box::new(AskUserQuestionTool),
|
||||
Box::new(EnterPlanModeTool),
|
||||
Box::new(ExitPlanModeTool),
|
||||
Box::new(PowerShellTool),
|
||||
Box::new(SleepTool),
|
||||
Box::new(CronCreateTool),
|
||||
Box::new(CronDeleteTool),
|
||||
Box::new(CronListTool),
|
||||
Box::new(EnterWorktreeTool),
|
||||
Box::new(ExitWorktreeTool),
|
||||
Box::new(ListMcpResourcesTool),
|
||||
Box::new(ReadMcpResourceTool),
|
||||
Box::new(ToolSearchTool),
|
||||
Box::new(BriefTool),
|
||||
Box::new(ConfigTool),
|
||||
Box::new(SendMessageTool),
|
||||
Box::new(SkillTool),
|
||||
]
|
||||
}
|
||||
|
||||
/// Find a tool by name (case-sensitive).
|
||||
pub fn find_tool(name: &str) -> Option<Box<dyn Tool>> {
|
||||
all_tools().into_iter().find(|t| t.name() == name)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---- Tool registry tests ------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_all_tools_non_empty() {
|
||||
let tools = all_tools();
|
||||
assert!(!tools.is_empty(), "all_tools() must return at least one tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_tools_have_unique_names() {
|
||||
let tools = all_tools();
|
||||
let mut names = std::collections::HashSet::new();
|
||||
for tool in &tools {
|
||||
assert!(
|
||||
names.insert(tool.name().to_string()),
|
||||
"Duplicate tool name: {}",
|
||||
tool.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_tools_have_non_empty_descriptions() {
|
||||
for tool in all_tools() {
|
||||
assert!(
|
||||
!tool.description().is_empty(),
|
||||
"Tool '{}' has empty description",
|
||||
tool.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_tools_have_valid_input_schema() {
|
||||
for tool in all_tools() {
|
||||
let schema = tool.input_schema();
|
||||
assert!(
|
||||
schema.is_object(),
|
||||
"Tool '{}' input_schema must be a JSON object",
|
||||
tool.name()
|
||||
);
|
||||
assert!(
|
||||
schema.get("type").is_some() || schema.get("properties").is_some(),
|
||||
"Tool '{}' schema missing type or properties",
|
||||
tool.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_tool_found() {
|
||||
let tool = find_tool("Bash");
|
||||
assert!(tool.is_some(), "Should find the Bash tool");
|
||||
assert_eq!(tool.unwrap().name(), "Bash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_tool_not_found() {
|
||||
assert!(find_tool("NonExistentTool12345").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_tool_case_sensitive() {
|
||||
// Tool names are case-sensitive — "bash" should not match "Bash"
|
||||
assert!(find_tool("bash").is_none());
|
||||
assert!(find_tool("Bash").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_core_tools_present() {
|
||||
let expected = [
|
||||
"Bash", "Read", "Edit", "Write", "Glob", "Grep",
|
||||
"WebFetch", "WebSearch",
|
||||
"TodoWrite", "Skill",
|
||||
];
|
||||
for name in &expected {
|
||||
assert!(
|
||||
find_tool(name).is_some(),
|
||||
"Expected tool '{}' not found in all_tools()",
|
||||
name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ToolResult tests ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_success() {
|
||||
let r = ToolResult::success("done");
|
||||
assert!(!r.is_error);
|
||||
assert_eq!(r.content, "done");
|
||||
assert!(r.metadata.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_error() {
|
||||
let r = ToolResult::error("something went wrong");
|
||||
assert!(r.is_error);
|
||||
assert_eq!(r.content, "something went wrong");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_with_metadata() {
|
||||
let r = ToolResult::success("ok")
|
||||
.with_metadata(serde_json::json!({"file": "foo.rs", "lines": 10}));
|
||||
assert!(r.metadata.is_some());
|
||||
let meta = r.metadata.unwrap();
|
||||
assert_eq!(meta["file"], "foo.rs");
|
||||
}
|
||||
|
||||
// ---- ToolContext::resolve_path tests ------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_resolve_path_absolute() {
|
||||
use cc_core::config::Config;
|
||||
use cc_core::permissions::AutoPermissionHandler;
|
||||
|
||||
let handler = Arc::new(AutoPermissionHandler {
|
||||
mode: cc_core::config::PermissionMode::Default,
|
||||
});
|
||||
let ctx = ToolContext {
|
||||
working_dir: PathBuf::from("/workspace"),
|
||||
permission_mode: cc_core::config::PermissionMode::Default,
|
||||
permission_handler: handler,
|
||||
cost_tracker: cc_core::cost::CostTracker::new(),
|
||||
session_id: "test".to_string(),
|
||||
non_interactive: true,
|
||||
mcp_manager: None,
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
// Absolute paths pass through unchanged
|
||||
let resolved = ctx.resolve_path("/absolute/path/file.rs");
|
||||
assert_eq!(resolved, PathBuf::from("/absolute/path/file.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_path_relative() {
|
||||
use cc_core::config::Config;
|
||||
use cc_core::permissions::AutoPermissionHandler;
|
||||
|
||||
let handler = Arc::new(AutoPermissionHandler {
|
||||
mode: cc_core::config::PermissionMode::Default,
|
||||
});
|
||||
let ctx = ToolContext {
|
||||
working_dir: PathBuf::from("/workspace"),
|
||||
permission_mode: cc_core::config::PermissionMode::Default,
|
||||
permission_handler: handler,
|
||||
cost_tracker: cc_core::cost::CostTracker::new(),
|
||||
session_id: "test".to_string(),
|
||||
non_interactive: true,
|
||||
mcp_manager: None,
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
// Relative paths get joined with working_dir
|
||||
let resolved = ctx.resolve_path("src/main.rs");
|
||||
assert_eq!(resolved, PathBuf::from("/workspace/src/main.rs"));
|
||||
}
|
||||
|
||||
// ---- PermissionLevel tests ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_permission_level_order() {
|
||||
// Just verify the variants exist and are distinct
|
||||
assert_ne!(PermissionLevel::None, PermissionLevel::ReadOnly);
|
||||
assert_ne!(PermissionLevel::Write, PermissionLevel::Execute);
|
||||
assert_ne!(PermissionLevel::Execute, PermissionLevel::Dangerous);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bash_tool_permission_level() {
|
||||
assert_eq!(BashTool.permission_level(), PermissionLevel::Execute);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_read_permission_level() {
|
||||
assert_eq!(FileReadTool.permission_level(), PermissionLevel::ReadOnly);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_edit_permission_level() {
|
||||
assert_eq!(FileEditTool.permission_level(), PermissionLevel::Write);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_write_permission_level() {
|
||||
assert_eq!(FileWriteTool.permission_level(), PermissionLevel::Write);
|
||||
}
|
||||
|
||||
// ---- Tool to_definition tests ------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_tool_to_definition() {
|
||||
let def = BashTool.to_definition();
|
||||
assert_eq!(def.name, "Bash");
|
||||
assert!(!def.description.is_empty());
|
||||
assert!(def.input_schema.is_object());
|
||||
}
|
||||
}
|
||||
148
src-rust/crates/tools/src/mcp_resources.rs
Normal file
148
src-rust/crates/tools/src/mcp_resources.rs
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
// MCP resource tools: list and read resources from connected MCP servers.
|
||||
//
|
||||
// ListMcpResourcesTool – enumerate all resources available from MCP servers
|
||||
// ReadMcpResourceTool – read a specific resource by server name + URI
|
||||
//
|
||||
// These require an MCP manager to be configured in ToolContext.mcp_manager.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ListMcpResourcesTool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct ListMcpResourcesTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListMcpResourcesInput {
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ListMcpResourcesTool {
|
||||
fn name(&self) -> &str { "ListMcpResources" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List all resources available from connected MCP servers. \
|
||||
Optionally filter by server name. \
|
||||
Resources represent data that MCP servers expose (files, database records, etc.)."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::ReadOnly }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"server": {
|
||||
"type": "string",
|
||||
"description": "Optional server name to filter resources by"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: ListMcpResourcesInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let manager = match &ctx.mcp_manager {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
return ToolResult::error(
|
||||
"No MCP servers connected. Configure MCP servers in settings.".to_string(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let resources = manager.list_all_resources(params.server.as_deref()).await;
|
||||
|
||||
if resources.is_empty() {
|
||||
return ToolResult::success(
|
||||
"No resources found. MCP servers may still provide tools even if they have no resources."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let json_out = serde_json::to_string_pretty(&resources).unwrap_or_default();
|
||||
debug!(count = resources.len(), "Listed MCP resources");
|
||||
ToolResult::success(json_out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ReadMcpResourceTool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct ReadMcpResourceTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ReadMcpResourceInput {
|
||||
server: String,
|
||||
uri: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReadMcpResourceTool {
|
||||
fn name(&self) -> &str { "ReadMcpResource" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read a specific resource from an MCP server by URI. \
|
||||
Use ListMcpResources to discover available resource URIs."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::ReadOnly }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"server": {
|
||||
"type": "string",
|
||||
"description": "The MCP server name"
|
||||
},
|
||||
"uri": {
|
||||
"type": "string",
|
||||
"description": "The resource URI to read"
|
||||
}
|
||||
},
|
||||
"required": ["server", "uri"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: ReadMcpResourceInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let manager = match &ctx.mcp_manager {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
return ToolResult::error(
|
||||
"No MCP servers connected. Configure MCP servers in settings.".to_string(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(server = %params.server, uri = %params.uri, "Reading MCP resource");
|
||||
|
||||
match manager.read_resource(¶ms.server, ¶ms.uri).await {
|
||||
Ok(contents) => {
|
||||
let json_out = serde_json::to_string_pretty(&contents).unwrap_or_default();
|
||||
ToolResult::success(json_out)
|
||||
}
|
||||
Err(e) => ToolResult::error(format!(
|
||||
"Failed to read resource '{}' from server '{}': {}",
|
||||
params.uri, params.server, e
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
298
src-rust/crates/tools/src/notebook_edit.rs
Normal file
298
src-rust/crates/tools/src/notebook_edit.rs
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
// NotebookEditTool: edit Jupyter notebook cells (.ipynb files).
|
||||
//
|
||||
// Supports three edit modes:
|
||||
// - replace: modify an existing cell's source
|
||||
// - insert: add a new cell after a given cell (or at the start)
|
||||
// - delete: remove a cell
|
||||
//
|
||||
// Behaviour mirrors the TypeScript NotebookEditTool.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct NotebookEditTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NotebookEditInput {
|
||||
notebook_path: String,
|
||||
#[serde(default)]
|
||||
cell_id: Option<String>,
|
||||
#[serde(default)]
|
||||
new_source: Option<String>,
|
||||
#[serde(default = "default_cell_type")]
|
||||
cell_type: String,
|
||||
#[serde(default = "default_edit_mode")]
|
||||
edit_mode: String,
|
||||
}
|
||||
|
||||
fn default_cell_type() -> String {
|
||||
"code".to_string()
|
||||
}
|
||||
|
||||
fn default_edit_mode() -> String {
|
||||
"replace".to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for NotebookEditTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_NOTEBOOK_EDIT
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Edit cells in a Jupyter notebook (.ipynb file). Supports three edit modes:\n\
|
||||
- replace: modify an existing cell's source (requires cell_id)\n\
|
||||
- insert: add a new cell after a given cell (or at the start if no cell_id)\n\
|
||||
- delete: remove a cell (requires cell_id)\n\
|
||||
You MUST read the notebook file before editing."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::Write
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"notebook_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the .ipynb notebook file"
|
||||
},
|
||||
"cell_id": {
|
||||
"type": "string",
|
||||
"description": "Cell ID (UUID or 'cell-N' index). Required for replace/delete."
|
||||
},
|
||||
"new_source": {
|
||||
"type": "string",
|
||||
"description": "New cell content. Required for replace/insert."
|
||||
},
|
||||
"cell_type": {
|
||||
"type": "string",
|
||||
"enum": ["code", "markdown"],
|
||||
"description": "Cell type for insert operations (default: code)"
|
||||
},
|
||||
"edit_mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "insert", "delete"],
|
||||
"description": "Edit mode: replace, insert, or delete (default: replace)"
|
||||
}
|
||||
},
|
||||
"required": ["notebook_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: NotebookEditInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let path = ctx.resolve_path(¶ms.notebook_path);
|
||||
|
||||
// Validate extension
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("ipynb") {
|
||||
return ToolResult::error("File must have .ipynb extension".to_string());
|
||||
}
|
||||
|
||||
// Permission check
|
||||
if let Err(e) = ctx.check_permission(
|
||||
self.name(),
|
||||
&format!("Edit notebook {}", path.display()),
|
||||
false,
|
||||
) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
// Read notebook
|
||||
let content = match tokio::fs::read_to_string(&path).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => return ToolResult::error(format!("Failed to read notebook: {}", e)),
|
||||
};
|
||||
|
||||
let mut notebook: Value = match serde_json::from_str(&content) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return ToolResult::error(format!("Invalid notebook JSON: {}", e)),
|
||||
};
|
||||
|
||||
debug!(path = %path.display(), mode = %params.edit_mode, "Editing notebook");
|
||||
|
||||
let result = match params.edit_mode.as_str() {
|
||||
"replace" => {
|
||||
let cell_id = match ¶ms.cell_id {
|
||||
Some(id) => id.clone(),
|
||||
None => return ToolResult::error("cell_id is required for replace mode".to_string()),
|
||||
};
|
||||
let new_source = match ¶ms.new_source {
|
||||
Some(s) => s.clone(),
|
||||
None => return ToolResult::error("new_source is required for replace mode".to_string()),
|
||||
};
|
||||
replace_cell(&mut notebook, &cell_id, &new_source)
|
||||
}
|
||||
"insert" => {
|
||||
let new_source = match ¶ms.new_source {
|
||||
Some(s) => s.clone(),
|
||||
None => return ToolResult::error("new_source is required for insert mode".to_string()),
|
||||
};
|
||||
insert_cell(&mut notebook, params.cell_id.as_deref(), &new_source, ¶ms.cell_type)
|
||||
}
|
||||
"delete" => {
|
||||
let cell_id = match ¶ms.cell_id {
|
||||
Some(id) => id.clone(),
|
||||
None => return ToolResult::error("cell_id is required for delete mode".to_string()),
|
||||
};
|
||||
delete_cell(&mut notebook, &cell_id)
|
||||
}
|
||||
other => return ToolResult::error(format!("Unknown edit_mode: {}", other)),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(msg) => {
|
||||
// Write back
|
||||
let updated = match serde_json::to_string_pretty(¬ebook) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return ToolResult::error(format!("Failed to serialize notebook: {}", e)),
|
||||
};
|
||||
if let Err(e) = tokio::fs::write(&path, &updated).await {
|
||||
return ToolResult::error(format!("Failed to write notebook: {}", e));
|
||||
}
|
||||
ToolResult::success(msg)
|
||||
}
|
||||
Err(e) => ToolResult::error(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Notebook manipulation helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolve a cell index from "cell-N" notation or return `None` for UUID lookup.
|
||||
fn parse_cell_index(cell_id: &str) -> Option<usize> {
|
||||
cell_id
|
||||
.strip_prefix("cell-")
|
||||
.and_then(|n| n.parse::<usize>().ok())
|
||||
}
|
||||
|
||||
/// Find the position of a cell in the `cells` array by id or "cell-N".
|
||||
fn find_cell_index(cells: &[Value], cell_id: &str) -> Result<usize, String> {
|
||||
// Try "cell-N" index format first
|
||||
if let Some(idx) = parse_cell_index(cell_id) {
|
||||
if idx < cells.len() {
|
||||
return Ok(idx);
|
||||
}
|
||||
return Err(format!("Cell index {} is out of range (notebook has {} cells)", idx, cells.len()));
|
||||
}
|
||||
|
||||
// Try UUID match
|
||||
for (i, cell) in cells.iter().enumerate() {
|
||||
if let Some(id) = cell.get("id").and_then(|v| v.as_str()) {
|
||||
if id == cell_id {
|
||||
return Ok(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!("Cell '{}' not found", cell_id))
|
||||
}
|
||||
|
||||
/// Generate a simple random cell ID (8 hex chars, like nbformat ≥ 4.5).
|
||||
fn generate_cell_id() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.subsec_nanos())
|
||||
.unwrap_or(0);
|
||||
format!("{:08x}", nanos ^ 0xdeadbeef_u32)
|
||||
}
|
||||
|
||||
/// Build a new cell JSON object.
|
||||
fn make_cell(cell_type: &str, source: &str, cell_id: &str) -> Value {
|
||||
let source_lines: Vec<Value> = if source.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
let lines: Vec<&str> = source.split_inclusive('\n').collect();
|
||||
lines.iter().map(|l| Value::String(l.to_string())).collect()
|
||||
};
|
||||
|
||||
match cell_type {
|
||||
"markdown" => json!({
|
||||
"cell_type": "markdown",
|
||||
"id": cell_id,
|
||||
"metadata": {},
|
||||
"source": source_lines
|
||||
}),
|
||||
_ => json!({
|
||||
"cell_type": "code",
|
||||
"id": cell_id,
|
||||
"metadata": {},
|
||||
"source": source_lines,
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_cell(notebook: &mut Value, cell_id: &str, new_source: &str) -> Result<String, String> {
|
||||
let cells = notebook
|
||||
.get_mut("cells")
|
||||
.and_then(|c| c.as_array_mut())
|
||||
.ok_or_else(|| "Notebook has no 'cells' array".to_string())?;
|
||||
|
||||
let idx = find_cell_index(cells, cell_id)?;
|
||||
|
||||
let cell = &mut cells[idx];
|
||||
let source_lines: Vec<Value> = new_source
|
||||
.split_inclusive('\n')
|
||||
.map(|l| Value::String(l.to_string()))
|
||||
.collect();
|
||||
|
||||
cell["source"] = Value::Array(source_lines);
|
||||
|
||||
// Reset execution state for code cells
|
||||
if cell.get("cell_type").and_then(|t| t.as_str()) == Some("code") {
|
||||
cell["outputs"] = Value::Array(vec![]);
|
||||
cell["execution_count"] = Value::Null;
|
||||
}
|
||||
|
||||
Ok(format!("Replaced cell '{}' (index {})", cell_id, idx))
|
||||
}
|
||||
|
||||
fn insert_cell(
|
||||
notebook: &mut Value,
|
||||
after_cell_id: Option<&str>,
|
||||
new_source: &str,
|
||||
cell_type: &str,
|
||||
) -> Result<String, String> {
|
||||
let cells = notebook
|
||||
.get_mut("cells")
|
||||
.and_then(|c| c.as_array_mut())
|
||||
.ok_or_else(|| "Notebook has no 'cells' array".to_string())?;
|
||||
|
||||
let insert_at = if let Some(id) = after_cell_id {
|
||||
find_cell_index(cells, id)? + 1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let new_id = generate_cell_id();
|
||||
let cell = make_cell(cell_type, new_source, &new_id);
|
||||
|
||||
cells.insert(insert_at, cell);
|
||||
Ok(format!("Inserted {} cell '{}' at position {}", cell_type, new_id, insert_at))
|
||||
}
|
||||
|
||||
fn delete_cell(notebook: &mut Value, cell_id: &str) -> Result<String, String> {
|
||||
let cells = notebook
|
||||
.get_mut("cells")
|
||||
.and_then(|c| c.as_array_mut())
|
||||
.ok_or_else(|| "Notebook has no 'cells' array".to_string())?;
|
||||
|
||||
let idx = find_cell_index(cells, cell_id)?;
|
||||
cells.remove(idx);
|
||||
|
||||
Ok(format!("Deleted cell '{}' (was at index {})", cell_id, idx))
|
||||
}
|
||||
136
src-rust/crates/tools/src/powershell.rs
Normal file
136
src-rust/crates/tools/src/powershell.rs
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// PowerShell tool: execute PowerShell commands (Windows-native).
|
||||
//
|
||||
// On Windows, PowerShell provides richer scripting than cmd.exe.
|
||||
// On non-Windows platforms, attempts to use `pwsh` (PowerShell Core).
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct PowerShellTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PowerShellInput {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
#[serde(default = "default_timeout")]
|
||||
timeout: u64,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 { 120_000 }
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for PowerShellTool {
|
||||
fn name(&self) -> &str { "PowerShell" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a PowerShell command. Use for Windows-native operations, .NET APIs, \
|
||||
registry access, and Windows-specific system administration."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Execute }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string", "description": "The PowerShell command to execute" },
|
||||
"description": { "type": "string", "description": "Description of what this command does" },
|
||||
"timeout": { "type": "number", "description": "Timeout in ms (default 120000)" }
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: PowerShellInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let desc = params.description.as_deref().unwrap_or(¶ms.command);
|
||||
if let Err(e) = ctx.check_permission(self.name(), desc, false) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
// Determine the PowerShell executable
|
||||
let (exe, args) = if cfg!(windows) {
|
||||
("powershell", vec!["-NoProfile", "-NonInteractive", "-Command"])
|
||||
} else {
|
||||
// PowerShell Core on non-Windows
|
||||
("pwsh", vec!["-NoProfile", "-NonInteractive", "-Command"])
|
||||
};
|
||||
|
||||
debug!(command = %params.command, "Executing PowerShell command");
|
||||
|
||||
let timeout_ms = params.timeout.min(600_000);
|
||||
let timeout_dur = Duration::from_millis(timeout_ms);
|
||||
|
||||
let mut child = match Command::new(exe)
|
||||
.args(&args)
|
||||
.arg(¶ms.command)
|
||||
.current_dir(&ctx.working_dir)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null())
|
||||
.spawn()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => return ToolResult::error(format!("Failed to spawn PowerShell: {}", e)),
|
||||
};
|
||||
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
|
||||
let result = tokio::time::timeout(timeout_dur, async {
|
||||
let mut stdout_lines = Vec::new();
|
||||
let mut stderr_lines = Vec::new();
|
||||
|
||||
if let Some(out) = stdout {
|
||||
let mut lines = BufReader::new(out).lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
stdout_lines.push(line);
|
||||
}
|
||||
}
|
||||
if let Some(err) = stderr {
|
||||
let mut lines = BufReader::new(err).lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
stderr_lines.push(line);
|
||||
}
|
||||
}
|
||||
|
||||
let status = child.wait().await;
|
||||
(stdout_lines, stderr_lines, status)
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok((stdout_lines, stderr_lines, status)) => {
|
||||
let exit_code = status.map(|s| s.code().unwrap_or(-1)).unwrap_or(-1);
|
||||
let mut output = stdout_lines.join("\n");
|
||||
if !stderr_lines.is_empty() {
|
||||
if !output.is_empty() { output.push('\n'); }
|
||||
output.push_str("STDERR:\n");
|
||||
output.push_str(&stderr_lines.join("\n"));
|
||||
}
|
||||
if output.is_empty() { output = "(no output)".to_string(); }
|
||||
|
||||
if exit_code != 0 {
|
||||
ToolResult::error(format!("PowerShell exited with code {}\n{}", exit_code, output))
|
||||
} else {
|
||||
ToolResult::success(output)
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = child.kill().await;
|
||||
ToolResult::error(format!("PowerShell command timed out after {}ms", timeout_ms))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
149
src-rust/crates/tools/src/send_message.rs
Normal file
149
src-rust/crates/tools/src/send_message.rs
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// SendMessageTool: send a message to another agent or broadcast to all.
|
||||
//
|
||||
// In the TypeScript version this uses a complex mailbox/swarm system with
|
||||
// process-level sockets. The Rust port uses a simpler in-process DashMap
|
||||
// inbox that works for sub-agents spawned via AgentTool.
|
||||
//
|
||||
// Messages are stored keyed by recipient name. Other agents can check
|
||||
// their inbox by calling drain_inbox() or peek_inbox().
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// In-process inbox
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single message in the inbox.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentMessage {
|
||||
pub from: String,
|
||||
pub to: String,
|
||||
pub content: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Global inbox: recipient_id → queued messages.
|
||||
static INBOX: Lazy<DashMap<String, Vec<AgentMessage>>> = Lazy::new(DashMap::new);
|
||||
|
||||
/// Remove and return all messages queued for `recipient`.
|
||||
pub fn drain_inbox(recipient: &str) -> Vec<AgentMessage> {
|
||||
INBOX.remove(recipient).map(|(_, v)| v).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Read (without removing) all messages queued for `recipient`.
|
||||
pub fn peek_inbox(recipient: &str) -> Vec<AgentMessage> {
|
||||
INBOX.get(recipient).map(|v| v.clone()).unwrap_or_default()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct SendMessageTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SendMessageInput {
|
||||
/// Recipient name, or "*" for broadcast.
|
||||
to: String,
|
||||
/// Message body.
|
||||
message: String,
|
||||
/// Short preview text shown in the UI.
|
||||
#[serde(default)]
|
||||
summary: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SendMessageTool {
|
||||
fn name(&self) -> &str { "SendMessage" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a message to another agent by name, or broadcast to all active agents with to=\"*\". \
|
||||
Recipients accumulate messages in their inbox and can retrieve them. \
|
||||
Use this for coordination between concurrent sub-agents."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Recipient agent name or session ID. Use \"*\" to broadcast to all."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message content"
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "5–10 word preview for the UI (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["to", "message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: SendMessageInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
if params.message.is_empty() {
|
||||
return ToolResult::error("Message cannot be empty.".to_string());
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
let msg = AgentMessage {
|
||||
from: ctx.session_id.clone(),
|
||||
to: params.to.clone(),
|
||||
content: params.message.clone(),
|
||||
timestamp: now,
|
||||
};
|
||||
|
||||
let preview = params
|
||||
.summary
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| {
|
||||
let s = params.message.as_str();
|
||||
&s[..s.len().min(60)]
|
||||
});
|
||||
|
||||
if params.to == "*" {
|
||||
// Broadcast: deliver to every existing inbox key
|
||||
let recipients: Vec<String> = INBOX.iter().map(|e| e.key().clone()).collect();
|
||||
|
||||
if recipients.is_empty() {
|
||||
return ToolResult::success(
|
||||
"Broadcast queued (no active recipient inboxes yet).".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
for key in &recipients {
|
||||
INBOX.entry(key.clone()).or_default().push(msg.clone());
|
||||
}
|
||||
|
||||
return ToolResult::success(format!(
|
||||
"Broadcast to {} agent(s): {}",
|
||||
recipients.len(),
|
||||
preview
|
||||
));
|
||||
}
|
||||
|
||||
// Directed message
|
||||
INBOX.entry(params.to.clone()).or_default().push(msg);
|
||||
|
||||
ToolResult::success(format!("Message sent to '{}': {}", params.to, preview))
|
||||
}
|
||||
}
|
||||
227
src-rust/crates/tools/src/skill_tool.rs
Normal file
227
src-rust/crates/tools/src/skill_tool.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
// SkillTool: execute user-defined skill (prompt template) files programmatically.
|
||||
//
|
||||
// Skills are Markdown files stored in:
|
||||
// <project>/.claude/commands/<name>.md
|
||||
// ~/.claude/commands/<name>.md
|
||||
//
|
||||
// Bundled skills (defined in bundled_skills.rs) are checked first before the
|
||||
// disk directories, so they take precedence over same-named .md files.
|
||||
//
|
||||
// The model invokes this tool to expand a skill's prompt inline.
|
||||
// Supports $ARGUMENTS placeholder substitution.
|
||||
// Use skill="list" to discover available skills.
|
||||
|
||||
use crate::bundled_skills::{expand_prompt, find_bundled_skill, user_invocable_skills};
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::path::PathBuf;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct SkillTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SkillInput {
|
||||
skill: String,
|
||||
#[serde(default)]
|
||||
args: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillTool {
|
||||
fn name(&self) -> &str { "Skill" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a skill (custom prompt template) by name. \
|
||||
Skills are .md files in .claude/commands/ or ~/.claude/commands/. \
|
||||
Use skill=\"list\" to discover available skills. \
|
||||
The expanded skill prompt is returned for you to act on."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::ReadOnly }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill": {
|
||||
"type": "string",
|
||||
"description": "Skill name (without .md extension), or \"list\" to enumerate skills"
|
||||
},
|
||||
"args": {
|
||||
"type": "string",
|
||||
"description": "Arguments passed to the skill — replaces $ARGUMENTS in the template"
|
||||
}
|
||||
},
|
||||
"required": ["skill"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: SkillInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let dirs = skill_search_dirs(ctx);
|
||||
|
||||
if params.skill == "list" {
|
||||
return list_skills(&dirs).await;
|
||||
}
|
||||
|
||||
let skill_name = params.skill.trim_end_matches(".md");
|
||||
debug!(skill = skill_name, "Loading skill");
|
||||
|
||||
// Check bundled skills first — they take precedence over disk files.
|
||||
if let Some(bundled) = find_bundled_skill(skill_name) {
|
||||
let args = params.args.as_deref().unwrap_or("");
|
||||
let prompt = expand_prompt(bundled, args);
|
||||
let prompt = prompt.trim().to_string();
|
||||
if prompt.is_empty() {
|
||||
return ToolResult::error(format!(
|
||||
"Bundled skill '{}' expanded to empty content.",
|
||||
skill_name
|
||||
));
|
||||
}
|
||||
return ToolResult::success(prompt);
|
||||
}
|
||||
|
||||
let raw = match find_and_read_skill(skill_name, &dirs).await {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return ToolResult::error(format!(
|
||||
"Skill '{}' not found. Use skill=\"list\" to see available skills.",
|
||||
skill_name
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Strip YAML frontmatter if present (--- ... ---)
|
||||
let content = strip_frontmatter(&raw);
|
||||
|
||||
// Substitute $ARGUMENTS
|
||||
let prompt = if let Some(args) = ¶ms.args {
|
||||
content.replace("$ARGUMENTS", args)
|
||||
} else {
|
||||
content.replace("$ARGUMENTS", "")
|
||||
};
|
||||
|
||||
let prompt = prompt.trim().to_string();
|
||||
if prompt.is_empty() {
|
||||
return ToolResult::error(format!("Skill '{}' expanded to empty content.", skill_name));
|
||||
}
|
||||
|
||||
ToolResult::success(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn skill_search_dirs(ctx: &ToolContext) -> Vec<PathBuf> {
|
||||
let mut dirs = vec![
|
||||
ctx.working_dir.join(".claude").join("commands"),
|
||||
];
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
dirs.push(home.join(".claude").join("commands"));
|
||||
}
|
||||
dirs
|
||||
}
|
||||
|
||||
async fn list_skills(dirs: &[PathBuf]) -> ToolResult {
|
||||
// Start with the bundled skills.
|
||||
let mut lines: Vec<String> = Vec::new();
|
||||
let bundled = user_invocable_skills();
|
||||
for (name, desc) in &bundled {
|
||||
lines.push(format!(" {} — {} [bundled]", name, desc));
|
||||
}
|
||||
let bundled_names: Vec<&str> = bundled.iter().map(|(n, _)| *n).collect();
|
||||
|
||||
// Then add disk skills, skipping any that shadow a bundled name.
|
||||
let mut disk_skills: Vec<(String, PathBuf)> = Vec::new();
|
||||
for dir in dirs {
|
||||
match tokio::fs::read_dir(dir).await {
|
||||
Ok(mut entries) => {
|
||||
while let Ok(Some(entry)) = entries.next_entry().await {
|
||||
let path = entry.path();
|
||||
if path.extension().map_or(false, |e| e == "md") {
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
let name = stem.to_string();
|
||||
// Deduplicate — project-level shadows user-level;
|
||||
// bundled skills shadow everything.
|
||||
if !disk_skills.iter().any(|(n, _)| n == &name)
|
||||
&& !bundled_names.contains(&name.as_str())
|
||||
{
|
||||
disk_skills.push((name, path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {} // directory doesn't exist, skip
|
||||
}
|
||||
}
|
||||
|
||||
disk_skills.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, path) in &disk_skills {
|
||||
let desc = read_skill_description(path).await;
|
||||
lines.push(format!(" {} — {}", name, desc));
|
||||
}
|
||||
|
||||
let total = bundled.len() + disk_skills.len();
|
||||
if total == 0 {
|
||||
return ToolResult::success(
|
||||
"No skills found. Create .md files in .claude/commands/ to define skills.\n\
|
||||
Example: .claude/commands/review.md"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Available skills ({}):\n{}",
|
||||
total,
|
||||
lines.join("\n")
|
||||
))
|
||||
}
|
||||
|
||||
async fn find_and_read_skill(name: &str, dirs: &[PathBuf]) -> Option<String> {
|
||||
for dir in dirs {
|
||||
let path = dir.join(format!("{}.md", name));
|
||||
if let Ok(content) = tokio::fs::read_to_string(&path).await {
|
||||
return Some(content);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn read_skill_description(path: &std::path::Path) -> String {
|
||||
let Ok(content) = tokio::fs::read_to_string(path).await else {
|
||||
return "(no description)".to_string();
|
||||
};
|
||||
let body = strip_frontmatter(&content);
|
||||
// First non-empty, non-heading line
|
||||
for line in body.lines() {
|
||||
let t = line.trim().trim_start_matches('#').trim();
|
||||
if !t.is_empty() {
|
||||
let truncated = if t.len() > 80 { &t[..80] } else { t };
|
||||
return truncated.to_string();
|
||||
}
|
||||
}
|
||||
"(no description)".to_string()
|
||||
}
|
||||
|
||||
/// Remove YAML frontmatter delimited by `---` at the start of the file.
|
||||
fn strip_frontmatter(content: &str) -> String {
|
||||
if content.starts_with("---") {
|
||||
// Find closing ---
|
||||
let after_open = &content[3..];
|
||||
if let Some(close_pos) = after_open.find("\n---") {
|
||||
// Skip past the closing delimiter and any leading newline
|
||||
let rest = &after_open[close_pos + 4..];
|
||||
return rest.trim_start_matches('\n').to_string();
|
||||
}
|
||||
}
|
||||
content.to_string()
|
||||
}
|
||||
63
src-rust/crates/tools/src/sleep.rs
Normal file
63
src-rust/crates/tools/src/sleep.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// SleepTool: pause execution for a specified duration.
|
||||
//
|
||||
// Useful when the model needs to wait between operations (e.g., polling,
|
||||
// rate limiting, or waiting for external processes). Unlike `Bash(sleep ...)`,
|
||||
// this does not hold a shell process and can run concurrently with other tools.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct SleepTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SleepInput {
|
||||
/// Duration in milliseconds (capped at 300_000 = 5 minutes).
|
||||
#[serde(alias = "ms", alias = "duration_ms")]
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SleepTool {
|
||||
fn name(&self) -> &str { "Sleep" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Wait for a specified duration in milliseconds. \
|
||||
Use instead of Bash(sleep ...) — it doesn't hold a shell process \
|
||||
and can run concurrently with other tools. \
|
||||
The user can interrupt the sleep at any time."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ms": {
|
||||
"type": "number",
|
||||
"description": "Duration to sleep in milliseconds (max 300000 = 5 minutes)"
|
||||
}
|
||||
},
|
||||
"required": ["ms"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: SleepInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
// Cap at 5 minutes
|
||||
let duration_ms = params.ms.min(300_000);
|
||||
debug!(ms = duration_ms, "Sleeping");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(duration_ms)).await;
|
||||
|
||||
ToolResult::success(format!("Slept for {}ms.", duration_ms))
|
||||
}
|
||||
}
|
||||
503
src-rust/crates/tools/src/tasks.rs
Normal file
503
src-rust/crates/tools/src/tasks.rs
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
// Task management tools: TaskCreate, TaskGet, TaskUpdate, TaskList, TaskStop, TaskOutput.
|
||||
//
|
||||
// Implements a simple in-process task store backed by a global Arc<Mutex<HashMap>>.
|
||||
// Tasks have id, subject, description, status, owner, blocks/blocked-by dependencies,
|
||||
// and optional output.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Task store (global singleton)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TaskStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
Deleted,
|
||||
Running, // for background shell tasks
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
TaskStatus::Pending => "pending",
|
||||
TaskStatus::InProgress => "in_progress",
|
||||
TaskStatus::Completed => "completed",
|
||||
TaskStatus::Deleted => "deleted",
|
||||
TaskStatus::Running => "running",
|
||||
TaskStatus::Failed => "failed",
|
||||
};
|
||||
write!(f, "{}", s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub id: String,
|
||||
pub subject: String,
|
||||
pub description: String,
|
||||
pub status: TaskStatus,
|
||||
pub owner: Option<String>,
|
||||
/// IDs of tasks this task blocks (i.e., those tasks depend on this one completing).
|
||||
pub blocks: Vec<String>,
|
||||
/// IDs of tasks that must complete before this task can start.
|
||||
pub blocked_by: Vec<String>,
|
||||
pub metadata: Option<Value>,
|
||||
pub output: Option<String>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
fn new(subject: impl Into<String>, description: impl Into<String>) -> Self {
|
||||
let now = chrono::Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
subject: subject.into(),
|
||||
description: description.into(),
|
||||
status: TaskStatus::Pending,
|
||||
owner: None,
|
||||
blocks: vec![],
|
||||
blocked_by: vec![],
|
||||
metadata: None,
|
||||
output: None,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_summary_value(&self) -> Value {
|
||||
// Compute effective blocked_by (exclude completed tasks)
|
||||
let blocked_by = self.blocked_by.clone();
|
||||
json!({
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"status": self.status.to_string(),
|
||||
"owner": self.owner,
|
||||
"blocked_by": blocked_by,
|
||||
})
|
||||
}
|
||||
|
||||
fn to_full_value(&self) -> Value {
|
||||
json!({
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"description": self.description,
|
||||
"status": self.status.to_string(),
|
||||
"owner": self.owner,
|
||||
"blocks": self.blocks,
|
||||
"blocked_by": self.blocked_by,
|
||||
"metadata": self.metadata,
|
||||
"output": self.output,
|
||||
"created_at": self.created_at.to_rfc3339(),
|
||||
"updated_at": self.updated_at.to_rfc3339(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Global task store shared across all tool invocations.
|
||||
static TASK_STORE: Lazy<Arc<DashMap<String, Task>>> =
|
||||
Lazy::new(|| Arc::new(DashMap::new()));
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskCreate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskCreateTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TaskCreateInput {
|
||||
subject: String,
|
||||
description: String,
|
||||
#[serde(default)]
|
||||
metadata: Option<Value>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskCreateTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_CREATE }
|
||||
fn description(&self) -> &str { "Create a new task to track work items. Returns the task ID." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subject": { "type": "string", "description": "Brief title for the task" },
|
||||
"description": { "type": "string", "description": "Detailed description of what needs to be done" },
|
||||
"metadata": { "type": "object", "description": "Optional arbitrary metadata" }
|
||||
},
|
||||
"required": ["subject", "description"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TaskCreateInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let mut task = Task::new(¶ms.subject, ¶ms.description);
|
||||
task.metadata = params.metadata;
|
||||
let task_id = task.id.clone();
|
||||
|
||||
debug!(task_id = %task_id, subject = %params.subject, "Creating task");
|
||||
TASK_STORE.insert(task_id.clone(), task);
|
||||
|
||||
ToolResult::success(serde_json::to_string_pretty(&json!({
|
||||
"task_id": task_id,
|
||||
"subject": params.subject,
|
||||
})).unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskGet
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskGetTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TaskGetInput {
|
||||
#[serde(alias = "taskId")]
|
||||
task_id: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskGetTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_GET }
|
||||
fn description(&self) -> &str { "Get full details of a task by ID." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "description": "Task ID to retrieve" }
|
||||
},
|
||||
"required": ["task_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TaskGetInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
match TASK_STORE.get(¶ms.task_id) {
|
||||
Some(task) => ToolResult::success(
|
||||
serde_json::to_string_pretty(&task.to_full_value()).unwrap_or_default()
|
||||
),
|
||||
None => ToolResult::success(
|
||||
serde_json::to_string_pretty(&json!(null)).unwrap_or_default()
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskUpdate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskUpdateTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TaskUpdateInput {
|
||||
#[serde(alias = "taskId")]
|
||||
task_id: String,
|
||||
#[serde(default)]
|
||||
subject: Option<String>,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
#[serde(default)]
|
||||
owner: Option<String>,
|
||||
#[serde(default, rename = "addBlocks")]
|
||||
add_blocks: Option<Vec<String>>,
|
||||
#[serde(default, rename = "addBlockedBy")]
|
||||
add_blocked_by: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
metadata: Option<Value>,
|
||||
#[serde(default)]
|
||||
output: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskUpdateTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_UPDATE }
|
||||
fn description(&self) -> &str { "Update a task's properties (status, subject, description, etc.)." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "description": "Task ID to update" },
|
||||
"subject": { "type": "string" },
|
||||
"description": { "type": "string" },
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed", "deleted", "failed"]
|
||||
},
|
||||
"owner": { "type": "string" },
|
||||
"addBlocks": { "type": "array", "items": { "type": "string" } },
|
||||
"addBlockedBy": { "type": "array", "items": { "type": "string" } },
|
||||
"metadata": { "type": "object" },
|
||||
"output": { "type": "string" }
|
||||
},
|
||||
"required": ["task_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TaskUpdateInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let mut task = match TASK_STORE.get_mut(¶ms.task_id) {
|
||||
Some(t) => t,
|
||||
None => return ToolResult::error(format!("Task '{}' not found", params.task_id)),
|
||||
};
|
||||
|
||||
let mut updated_fields: Vec<&str> = vec![];
|
||||
|
||||
if let Some(subject) = ¶ms.subject {
|
||||
task.subject = subject.clone();
|
||||
updated_fields.push("subject");
|
||||
}
|
||||
if let Some(desc) = ¶ms.description {
|
||||
task.description = desc.clone();
|
||||
updated_fields.push("description");
|
||||
}
|
||||
if let Some(status_str) = ¶ms.status {
|
||||
task.status = match status_str.as_str() {
|
||||
"pending" => TaskStatus::Pending,
|
||||
"in_progress" | "in-progress" => TaskStatus::InProgress,
|
||||
"completed" => TaskStatus::Completed,
|
||||
"deleted" => TaskStatus::Deleted,
|
||||
"running" => TaskStatus::Running,
|
||||
"failed" => TaskStatus::Failed,
|
||||
other => return ToolResult::error(format!("Unknown status: {}", other)),
|
||||
};
|
||||
updated_fields.push("status");
|
||||
}
|
||||
if let Some(owner) = ¶ms.owner {
|
||||
task.owner = Some(owner.clone());
|
||||
updated_fields.push("owner");
|
||||
}
|
||||
if let Some(blocks) = ¶ms.add_blocks {
|
||||
for b in blocks {
|
||||
if !task.blocks.contains(b) {
|
||||
task.blocks.push(b.clone());
|
||||
}
|
||||
}
|
||||
updated_fields.push("blocks");
|
||||
}
|
||||
if let Some(blocked_by) = ¶ms.add_blocked_by {
|
||||
for b in blocked_by {
|
||||
if !task.blocked_by.contains(b) {
|
||||
task.blocked_by.push(b.clone());
|
||||
}
|
||||
}
|
||||
updated_fields.push("blocked_by");
|
||||
}
|
||||
if let Some(meta) = ¶ms.metadata {
|
||||
task.metadata = Some(meta.clone());
|
||||
updated_fields.push("metadata");
|
||||
}
|
||||
if let Some(out) = ¶ms.output {
|
||||
task.output = Some(out.clone());
|
||||
updated_fields.push("output");
|
||||
}
|
||||
|
||||
task.updated_at = chrono::Utc::now();
|
||||
|
||||
// Handle deletion
|
||||
let task_id = task.id.clone();
|
||||
let task_status = task.status.clone();
|
||||
drop(task); // release the lock
|
||||
|
||||
if task_status == TaskStatus::Deleted {
|
||||
TASK_STORE.remove(&task_id);
|
||||
}
|
||||
|
||||
ToolResult::success(serde_json::to_string_pretty(&json!({
|
||||
"success": true,
|
||||
"task_id": task_id,
|
||||
"updated_fields": updated_fields,
|
||||
})).unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskList
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskListTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskListTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_LIST }
|
||||
fn description(&self) -> &str { "List all active tasks (excluding deleted/completed)." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"include_completed": {
|
||||
"type": "boolean",
|
||||
"description": "Include completed tasks (default false)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let include_completed = input
|
||||
.get("include_completed")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let tasks: Vec<Value> = TASK_STORE
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let status = &entry.value().status;
|
||||
match status {
|
||||
TaskStatus::Deleted => false,
|
||||
TaskStatus::Completed => include_completed,
|
||||
_ => true,
|
||||
}
|
||||
})
|
||||
.map(|entry| entry.value().to_summary_value())
|
||||
.collect();
|
||||
|
||||
ToolResult::success(serde_json::to_string_pretty(&tasks).unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskStop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskStopTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TaskStopInput {
|
||||
#[serde(alias = "shell_id")]
|
||||
task_id: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskStopTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_STOP }
|
||||
fn description(&self) -> &str { "Stop a running background task." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Execute }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "description": "ID of the task to stop" }
|
||||
},
|
||||
"required": ["task_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TaskStopInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
match TASK_STORE.get_mut(¶ms.task_id) {
|
||||
Some(mut task) => {
|
||||
if task.status != TaskStatus::Running && task.status != TaskStatus::InProgress {
|
||||
return ToolResult::error(format!(
|
||||
"Task '{}' is not running (status: {})",
|
||||
params.task_id, task.status
|
||||
));
|
||||
}
|
||||
task.status = TaskStatus::Completed;
|
||||
task.updated_at = chrono::Utc::now();
|
||||
ToolResult::success(serde_json::to_string_pretty(&json!({
|
||||
"message": "Task stopped",
|
||||
"task_id": params.task_id,
|
||||
})).unwrap_or_default())
|
||||
}
|
||||
None => ToolResult::error(format!("Task '{}' not found", params.task_id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskOutput
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct TaskOutputTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TaskOutputInput {
|
||||
task_id: String,
|
||||
#[serde(default = "default_block")]
|
||||
block: bool,
|
||||
}
|
||||
|
||||
fn default_block() -> bool { true }
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskOutputTool {
|
||||
fn name(&self) -> &str { cc_core::constants::TOOL_NAME_TASK_OUTPUT }
|
||||
fn description(&self) -> &str { "Get the output of a task." }
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "description": "Task ID to get output for" },
|
||||
"block": { "type": "boolean", "description": "Wait for task to complete (default true)" }
|
||||
},
|
||||
"required": ["task_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TaskOutputInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
match TASK_STORE.get(¶ms.task_id) {
|
||||
Some(task) => {
|
||||
let retrieval_status = match &task.status {
|
||||
TaskStatus::Completed | TaskStatus::Failed => "success",
|
||||
TaskStatus::Running | TaskStatus::InProgress => {
|
||||
if params.block { "success" } else { "not_ready" }
|
||||
}
|
||||
_ => "success",
|
||||
};
|
||||
ToolResult::success(serde_json::to_string_pretty(&json!({
|
||||
"retrieval_status": retrieval_status,
|
||||
"task": task.to_full_value(),
|
||||
})).unwrap_or_default())
|
||||
}
|
||||
None => ToolResult::error(format!("Task '{}' not found", params.task_id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
127
src-rust/crates/tools/src/todo_write.rs
Normal file
127
src-rust/crates/tools/src/todo_write.rs
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
// TodoWrite tool: task / todo list management.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct TodoWriteTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TodoWriteInput {
|
||||
todos: Vec<TodoItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct TodoItem {
|
||||
id: String,
|
||||
content: String,
|
||||
status: TodoStatus,
|
||||
#[serde(default)]
|
||||
priority: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum TodoStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TodoStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TodoStatus::Pending => write!(f, "pending"),
|
||||
TodoStatus::InProgress => write!(f, "in_progress"),
|
||||
TodoStatus::Completed => write!(f, "completed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TodoWriteTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_TODO_WRITE
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Write and manage a todo/task list. Provide the complete list of todos \
|
||||
each time (this replaces the entire list). Use this to track progress \
|
||||
on multi-step tasks."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::None
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string" },
|
||||
"content": { "type": "string" },
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"]
|
||||
},
|
||||
"priority": { "type": "string" }
|
||||
},
|
||||
"required": ["id", "content", "status"]
|
||||
},
|
||||
"description": "The complete list of todo items"
|
||||
}
|
||||
},
|
||||
"required": ["todos"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: TodoWriteInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
debug!(count = params.todos.len(), "Writing todo list");
|
||||
|
||||
let total = params.todos.len();
|
||||
let completed = params
|
||||
.todos
|
||||
.iter()
|
||||
.filter(|t| matches!(t.status, TodoStatus::Completed))
|
||||
.count();
|
||||
let in_progress = params
|
||||
.todos
|
||||
.iter()
|
||||
.filter(|t| matches!(t.status, TodoStatus::InProgress))
|
||||
.count();
|
||||
let pending = total - completed - in_progress;
|
||||
|
||||
let mut output = format!(
|
||||
"Todo list updated ({} total: {} pending, {} in progress, {} completed)\n\n",
|
||||
total, pending, in_progress, completed
|
||||
);
|
||||
|
||||
for item in ¶ms.todos {
|
||||
let icon = match item.status {
|
||||
TodoStatus::Pending => "[ ]",
|
||||
TodoStatus::InProgress => "[~]",
|
||||
TodoStatus::Completed => "[x]",
|
||||
};
|
||||
output.push_str(&format!("{} {} ({})\n", icon, item.content, item.id));
|
||||
}
|
||||
|
||||
ToolResult::success(output).with_metadata(json!({
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"in_progress": in_progress,
|
||||
"pending": pending,
|
||||
}))
|
||||
}
|
||||
}
|
||||
201
src-rust/crates/tools/src/tool_search.rs
Normal file
201
src-rust/crates/tools/src/tool_search.rs
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
// ToolSearchTool: search for tools by name or keyword.
|
||||
//
|
||||
// This is used by the model to discover "deferred" tools that are not yet
|
||||
// loaded into context. In the Rust port there is no deferred-tool mechanism
|
||||
// (all tools are always available), but this tool still provides a useful
|
||||
// search interface for the model to discover available capabilities.
|
||||
//
|
||||
// Supports two query modes:
|
||||
// - "select:ToolName" → direct lookup by exact name
|
||||
// - "keyword search" → fuzzy name + description match with scoring
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub struct ToolSearchTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToolSearchInput {
|
||||
query: String,
|
||||
#[serde(default = "default_max")]
|
||||
max_results: usize,
|
||||
}
|
||||
|
||||
fn default_max() -> usize { 5 }
|
||||
|
||||
/// A minimal catalog entry describing one tool.
|
||||
#[derive(Debug, Clone)]
|
||||
struct ToolEntry {
|
||||
name: &'static str,
|
||||
description: &'static str,
|
||||
keywords: &'static [&'static str],
|
||||
}
|
||||
|
||||
/// Static catalog of all built-in tools with keywords for scoring.
|
||||
static TOOL_CATALOG: &[ToolEntry] = &[
|
||||
ToolEntry { name: "Bash", description: "Execute shell commands", keywords: &["shell", "run", "command", "exec", "terminal"] },
|
||||
ToolEntry { name: "Read", description: "Read file contents", keywords: &["file", "read", "cat", "content"] },
|
||||
ToolEntry { name: "Write", description: "Write or create files", keywords: &["file", "write", "create", "save"] },
|
||||
ToolEntry { name: "Edit", description: "Edit existing files with string replacement", keywords: &["file", "edit", "modify", "replace", "patch"] },
|
||||
ToolEntry { name: "Glob", description: "Find files by pattern", keywords: &["find", "pattern", "search", "files", "glob"] },
|
||||
ToolEntry { name: "Grep", description: "Search file contents with regex", keywords: &["search", "regex", "grep", "find", "content"] },
|
||||
ToolEntry { name: "WebFetch", description: "Fetch web page content", keywords: &["web", "fetch", "http", "url", "browser"] },
|
||||
ToolEntry { name: "WebSearch", description: "Search the web", keywords: &["web", "search", "internet", "query"] },
|
||||
ToolEntry { name: "NotebookEdit", description: "Edit Jupyter notebook cells", keywords: &["notebook", "jupyter", "ipynb", "cell"] },
|
||||
ToolEntry { name: "TodoWrite", description: "Manage todo list", keywords: &["todo", "task", "list", "write"] },
|
||||
ToolEntry { name: "AskUserQuestion", description: "Ask the user a question", keywords: &["ask", "question", "user", "input", "clarify"] },
|
||||
ToolEntry { name: "EnterPlanMode", description: "Enter planning mode", keywords: &["plan", "mode", "planning"] },
|
||||
ToolEntry { name: "ExitPlanMode", description: "Exit planning mode", keywords: &["plan", "exit", "mode"] },
|
||||
ToolEntry { name: "Sleep", description: "Wait for a duration", keywords: &["sleep", "wait", "delay", "pause"] },
|
||||
ToolEntry { name: "PowerShell", description: "Execute PowerShell commands", keywords: &["powershell", "windows", "ps", "command"] },
|
||||
ToolEntry { name: "CronCreate", description: "Schedule a recurring cron task", keywords: &["cron", "schedule", "recurring", "timer"] },
|
||||
ToolEntry { name: "CronDelete", description: "Cancel a scheduled cron task", keywords: &["cron", "delete", "cancel", "remove"] },
|
||||
ToolEntry { name: "CronList", description: "List all cron tasks", keywords: &["cron", "list", "scheduled", "tasks"] },
|
||||
ToolEntry { name: "EnterWorktree", description: "Create and enter a git worktree", keywords: &["worktree", "git", "branch", "isolate"] },
|
||||
ToolEntry { name: "ExitWorktree", description: "Exit the current git worktree", keywords: &["worktree", "git", "exit", "restore"] },
|
||||
ToolEntry { name: "TaskCreate", description: "Create a background task", keywords: &["task", "create", "background", "async"] },
|
||||
ToolEntry { name: "TaskGet", description: "Get task details", keywords: &["task", "get", "status", "details"] },
|
||||
ToolEntry { name: "TaskUpdate", description: "Update a task's status", keywords: &["task", "update", "status", "progress"] },
|
||||
ToolEntry { name: "TaskList", description: "List all tasks", keywords: &["task", "list", "all", "tasks"] },
|
||||
ToolEntry { name: "TaskStop", description: "Stop a running task", keywords: &["task", "stop", "kill", "cancel"] },
|
||||
ToolEntry { name: "TaskOutput", description: "Get task output/logs", keywords: &["task", "output", "logs", "result"] },
|
||||
ToolEntry { name: "ListMcpResources", description: "List MCP server resources", keywords: &["mcp", "resource", "list", "server"] },
|
||||
ToolEntry { name: "ReadMcpResource", description: "Read an MCP resource", keywords: &["mcp", "resource", "read", "server"] },
|
||||
ToolEntry { name: "Agent", description: "Launch a sub-agent for complex tasks", keywords: &["agent", "subagent", "task", "parallel", "delegate"] },
|
||||
ToolEntry { name: "Brief", description: "Send a formatted message to the user", keywords: &["brief", "message", "notify", "proactive", "status", "update"] },
|
||||
ToolEntry { name: "Config", description: "Get or set Claude Code configuration", keywords: &["config", "settings", "model", "verbose", "permission", "configure"] },
|
||||
ToolEntry { name: "SendMessage", description: "Send a message to another agent", keywords: &["send", "message", "agent", "broadcast", "communicate", "inbox"] },
|
||||
ToolEntry { name: "Skill", description: "Execute a skill prompt template", keywords: &["skill", "command", "template", "prompt", "slash", "custom"] },
|
||||
];
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ToolSearchTool {
|
||||
fn name(&self) -> &str { "ToolSearch" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search for available tools by name or keyword. Use 'select:ToolName' for direct \
|
||||
lookup or provide keywords for fuzzy search. Returns matching tool names and their \
|
||||
descriptions. Max 5 results by default."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::None }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Query: use 'select:ToolName' for direct selection, or keywords to search"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "number",
|
||||
"description": "Maximum results to return (default: 5)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: ToolSearchInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let query = params.query.trim();
|
||||
let max = params.max_results.min(20);
|
||||
|
||||
// select: prefix — direct lookup
|
||||
if let Some(names_str) = query.strip_prefix("select:").map(str::trim) {
|
||||
let requested: Vec<&str> = names_str.split(',').map(str::trim).collect();
|
||||
let mut found = Vec::new();
|
||||
let mut missing = Vec::new();
|
||||
|
||||
for name in requested {
|
||||
if let Some(entry) = TOOL_CATALOG.iter().find(|e| {
|
||||
e.name.eq_ignore_ascii_case(name)
|
||||
}) {
|
||||
found.push(format!("{}: {}", entry.name, entry.description));
|
||||
} else {
|
||||
missing.push(name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if found.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"No matching tools found for: {}",
|
||||
missing.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
let mut out = found.join("\n");
|
||||
if !missing.is_empty() {
|
||||
out.push_str(&format!("\n\nNot found: {}", missing.join(", ")));
|
||||
}
|
||||
return ToolResult::success(out);
|
||||
}
|
||||
|
||||
// Keyword search with scoring
|
||||
let q_lower = query.to_lowercase();
|
||||
let terms: Vec<&str> = q_lower.split_whitespace().collect();
|
||||
|
||||
let mut scored: Vec<(usize, &ToolEntry)> = TOOL_CATALOG
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let mut score = 0usize;
|
||||
let name_lower = entry.name.to_lowercase();
|
||||
let desc_lower = entry.description.to_lowercase();
|
||||
|
||||
for term in &terms {
|
||||
// Exact name match
|
||||
if name_lower == *term {
|
||||
score += 20;
|
||||
} else if name_lower.contains(term) {
|
||||
score += 10;
|
||||
}
|
||||
|
||||
// Description match
|
||||
if desc_lower.contains(term) {
|
||||
score += 5;
|
||||
}
|
||||
|
||||
// Keyword match
|
||||
for &kw in entry.keywords {
|
||||
if kw == *term {
|
||||
score += 8;
|
||||
} else if kw.contains(term) {
|
||||
score += 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if score > 0 { Some((score, entry)) } else { None }
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
scored.truncate(max);
|
||||
|
||||
if scored.is_empty() {
|
||||
return ToolResult::success(format!(
|
||||
"No tools found matching '{}'. Try broader keywords or use 'select:ToolName'.",
|
||||
query
|
||||
));
|
||||
}
|
||||
|
||||
let lines: Vec<String> = scored
|
||||
.iter()
|
||||
.map(|(_, e)| format!("{}: {}", e.name, e.description))
|
||||
.collect();
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Tools matching '{}':\n\n{}\n\nTotal tools available: {}",
|
||||
query,
|
||||
lines.join("\n"),
|
||||
TOOL_CATALOG.len()
|
||||
))
|
||||
}
|
||||
}
|
||||
236
src-rust/crates/tools/src/web_fetch.rs
Normal file
236
src-rust/crates/tools/src/web_fetch.rs
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
// WebFetch tool: HTTP GET with basic HTML-to-text conversion.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct WebFetchTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WebFetchInput {
|
||||
url: String,
|
||||
#[serde(default)]
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
/// Naively strip HTML tags and decode common entities.
|
||||
fn strip_html(html: &str) -> String {
|
||||
let mut result = String::with_capacity(html.len());
|
||||
let mut in_tag = false;
|
||||
let mut in_script = false;
|
||||
let mut in_style = false;
|
||||
|
||||
let lower = html.to_lowercase();
|
||||
let chars: Vec<char> = html.chars().collect();
|
||||
let lower_chars: Vec<char> = lower.chars().collect();
|
||||
let len = chars.len();
|
||||
let mut i = 0;
|
||||
|
||||
while i < len {
|
||||
if !in_tag && chars[i] == '<' {
|
||||
in_tag = true;
|
||||
// Check for script/style open/close
|
||||
let rest: String = lower_chars[i..].iter().take(20).collect();
|
||||
if rest.starts_with("<script") {
|
||||
in_script = true;
|
||||
} else if rest.starts_with("</script") {
|
||||
in_script = false;
|
||||
} else if rest.starts_with("<style") {
|
||||
in_style = true;
|
||||
} else if rest.starts_with("</style") {
|
||||
in_style = false;
|
||||
}
|
||||
// Block tags => newline
|
||||
let block_tags = [
|
||||
"<br", "<p ", "<p>", "</p>", "<div", "</div>", "<h1", "<h2", "<h3",
|
||||
"<h4", "<h5", "<h6", "</h1", "</h2", "</h3", "</h4", "</h5", "</h6",
|
||||
"<li", "</li", "<tr", "</tr", "<hr",
|
||||
];
|
||||
for tag in &block_tags {
|
||||
if rest.starts_with(tag) {
|
||||
result.push('\n');
|
||||
break;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_tag {
|
||||
if chars[i] == '>' {
|
||||
in_tag = false;
|
||||
}
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_script || in_style {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Decode basic entities
|
||||
if chars[i] == '&' {
|
||||
let rest: String = chars[i..].iter().take(10).collect();
|
||||
if rest.starts_with("&") {
|
||||
result.push('&');
|
||||
i += 5;
|
||||
} else if rest.starts_with("<") {
|
||||
result.push('<');
|
||||
i += 4;
|
||||
} else if rest.starts_with(">") {
|
||||
result.push('>');
|
||||
i += 4;
|
||||
} else if rest.starts_with(""") {
|
||||
result.push('"');
|
||||
i += 6;
|
||||
} else if rest.starts_with("'") || rest.starts_with("'") {
|
||||
result.push('\'');
|
||||
i += if rest.starts_with("'") { 5 } else { 6 };
|
||||
} else if rest.starts_with(" ") {
|
||||
result.push(' ');
|
||||
i += 6;
|
||||
} else {
|
||||
result.push('&');
|
||||
i += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push(chars[i]);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Collapse multiple blank lines
|
||||
let mut collapsed = String::new();
|
||||
let mut blank_count = 0;
|
||||
for line in result.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
blank_count += 1;
|
||||
if blank_count <= 2 {
|
||||
collapsed.push('\n');
|
||||
}
|
||||
} else {
|
||||
blank_count = 0;
|
||||
collapsed.push_str(trimmed);
|
||||
collapsed.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
collapsed.trim().to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WebFetchTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_WEB_FETCH
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fetches a web page URL and returns its content as text. HTML is \
|
||||
automatically converted to plain text. Use this for reading documentation, \
|
||||
APIs, and other web resources."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::ReadOnly
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Optional prompt for how to process the content"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: WebFetchInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
// Permission check
|
||||
if let Err(e) = ctx.check_permission(
|
||||
self.name(),
|
||||
&format!("Fetch {}", params.url),
|
||||
true, // read-only
|
||||
) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
debug!(url = %params.url, "Fetching web page");
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.redirect(reqwest::redirect::Policy::limited(10))
|
||||
.build();
|
||||
|
||||
let client = match client {
|
||||
Ok(c) => c,
|
||||
Err(e) => return ToolResult::error(format!("Failed to create HTTP client: {}", e)),
|
||||
};
|
||||
|
||||
let resp = match client.get(¶ms.url)
|
||||
.header("User-Agent", "Claude-Code/1.0")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return ToolResult::error(format!("Failed to fetch {}: {}", params.url, e)),
|
||||
};
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
return ToolResult::error(format!(
|
||||
"HTTP {} when fetching {}",
|
||||
status, params.url
|
||||
));
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let body = match resp.text().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => return ToolResult::error(format!("Failed to read response body: {}", e)),
|
||||
};
|
||||
|
||||
// Convert HTML to text if applicable
|
||||
let text = if content_type.contains("html") {
|
||||
strip_html(&body)
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
// Truncate very long content
|
||||
const MAX_LEN: usize = 100_000;
|
||||
let text = if text.len() > MAX_LEN {
|
||||
format!(
|
||||
"{}\n\n... (truncated, {} total characters)",
|
||||
&text[..MAX_LEN],
|
||||
text.len()
|
||||
)
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
ToolResult::success(text)
|
||||
}
|
||||
}
|
||||
227
src-rust/crates/tools/src/web_search.rs
Normal file
227
src-rust/crates/tools/src/web_search.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
// WebSearch tool: search the web using Brave Search API or fallback to DuckDuckGo.
|
||||
//
|
||||
// Mirrors the TypeScript WebSearch tool behaviour:
|
||||
// - Accepts a query string
|
||||
// - Returns a list of results with title, url, and snippet
|
||||
// - Falls back to DuckDuckGo if no search API key is configured
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
pub struct WebSearchTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WebSearchInput {
|
||||
query: String,
|
||||
#[serde(default = "default_num_results")]
|
||||
num_results: usize,
|
||||
}
|
||||
|
||||
fn default_num_results() -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WebSearchTool {
|
||||
fn name(&self) -> &str {
|
||||
cc_core::constants::TOOL_NAME_WEB_SEARCH
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search the web for information. Returns a list of relevant web pages with \
|
||||
titles, URLs, and snippets. Use this when you need current information \
|
||||
not available in your training data, or when searching for documentation, \
|
||||
examples, or news."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel {
|
||||
PermissionLevel::ReadOnly
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"num_results": {
|
||||
"type": "number",
|
||||
"description": "Number of results to return (default: 5, max: 10)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: WebSearchInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let num_results = params.num_results.min(10).max(1);
|
||||
debug!(query = %params.query, num_results, "Web search");
|
||||
|
||||
// Try Brave Search API first, then fall back to DuckDuckGo
|
||||
if let Some(api_key) = std::env::var("BRAVE_SEARCH_API_KEY").ok().filter(|k| !k.is_empty()) {
|
||||
search_brave(¶ms.query, num_results, &api_key).await
|
||||
} else {
|
||||
search_duckduckgo(¶ms.query, num_results).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search using the Brave Search API.
|
||||
async fn search_brave(query: &str, num_results: usize, api_key: &str) -> ToolResult {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"https://api.search.brave.com/res/v1/web/search?q={}&count={}",
|
||||
urlencoding_simple(query),
|
||||
num_results
|
||||
);
|
||||
|
||||
let resp = match client
|
||||
.get(&url)
|
||||
.header("Accept", "application/json")
|
||||
.header("Accept-Encoding", "gzip")
|
||||
.header("X-Subscription-Token", api_key)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return ToolResult::error(format!("Search request failed: {}", e)),
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status().as_u16();
|
||||
return ToolResult::error(format!("Brave Search API returned status {}", status));
|
||||
}
|
||||
|
||||
let data: Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => return ToolResult::error(format!("Failed to parse response: {}", e)),
|
||||
};
|
||||
|
||||
let results = format_brave_results(&data, num_results);
|
||||
ToolResult::success(results)
|
||||
}
|
||||
|
||||
fn format_brave_results(data: &Value, max: usize) -> String {
|
||||
let mut output = String::new();
|
||||
let web_results = data
|
||||
.get("web")
|
||||
.and_then(|w| w.get("results"))
|
||||
.and_then(|r| r.as_array());
|
||||
|
||||
if let Some(items) = web_results {
|
||||
for (i, item) in items.iter().take(max).enumerate() {
|
||||
let title = item.get("title").and_then(|t| t.as_str()).unwrap_or("(No title)");
|
||||
let url = item.get("url").and_then(|u| u.as_str()).unwrap_or("");
|
||||
let snippet = item.get("description").and_then(|s| s.as_str()).unwrap_or("");
|
||||
|
||||
output.push_str(&format!("{}. **{}**\n URL: {}\n {}\n\n", i + 1, title, url, snippet));
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
"No results found.".to_string()
|
||||
} else {
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback: DuckDuckGo Instant Answer API.
|
||||
/// Note: this doesn't return full search results, only instant answers.
|
||||
async fn search_duckduckgo(query: &str, num_results: usize) -> ToolResult {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"https://api.duckduckgo.com/?q={}&format=json&no_html=1&skip_disambig=1",
|
||||
urlencoding_simple(query)
|
||||
);
|
||||
|
||||
let resp = match client
|
||||
.get(&url)
|
||||
.header("User-Agent", "Claude Code/1.0")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return ToolResult::error(format!("Search request failed: {}", e)),
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status().as_u16();
|
||||
return ToolResult::error(format!("DuckDuckGo API returned status {}", status));
|
||||
}
|
||||
|
||||
let data: Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => return ToolResult::error(format!("Failed to parse response: {}", e)),
|
||||
};
|
||||
|
||||
let output = format_ddg_results(&data, num_results);
|
||||
ToolResult::success(output)
|
||||
}
|
||||
|
||||
fn format_ddg_results(data: &Value, max: usize) -> String {
|
||||
let mut output = String::new();
|
||||
let mut count = 0;
|
||||
|
||||
// Abstract (main answer)
|
||||
if let Some(abstract_text) = data.get("Abstract").and_then(|a| a.as_str()) {
|
||||
if !abstract_text.is_empty() {
|
||||
let source = data.get("AbstractSource").and_then(|s| s.as_str()).unwrap_or("");
|
||||
let url = data.get("AbstractURL").and_then(|u| u.as_str()).unwrap_or("");
|
||||
output.push_str(&format!("**{}**\n{}\nURL: {}\n\n", source, abstract_text, url));
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Related topics
|
||||
if let Some(topics) = data.get("RelatedTopics").and_then(|t| t.as_array()) {
|
||||
for topic in topics.iter().take(max.saturating_sub(count)) {
|
||||
if let Some(text) = topic.get("Text").and_then(|t| t.as_str()) {
|
||||
if !text.is_empty() {
|
||||
let url = topic.get("FirstURL").and_then(|u| u.as_str()).unwrap_or("");
|
||||
output.push_str(&format!("- {}\n {}\n\n", text, url));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
format!(
|
||||
"No instant answer found for '{}'. Try using the Brave Search API \
|
||||
by setting the BRAVE_SEARCH_API_KEY environment variable for full web search.",
|
||||
data.get("QuerySearchQuery")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("your query")
|
||||
)
|
||||
} else {
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal percent-encoding for URL query parameters.
|
||||
fn urlencoding_simple(s: &str) -> String {
|
||||
let mut encoded = String::new();
|
||||
for ch in s.chars() {
|
||||
match ch {
|
||||
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => {
|
||||
encoded.push(ch);
|
||||
}
|
||||
' ' => encoded.push('+'),
|
||||
_ => {
|
||||
for byte in ch.to_string().as_bytes() {
|
||||
encoded.push_str(&format!("%{:02X}", byte));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
encoded
|
||||
}
|
||||
351
src-rust/crates/tools/src/worktree.rs
Normal file
351
src-rust/crates/tools/src/worktree.rs
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
// Worktree tools: create and exit git worktrees for isolated work sessions.
|
||||
//
|
||||
// EnterWorktreeTool – create a new git worktree with an optional branch name,
|
||||
// switching the session's working directory to it.
|
||||
// ExitWorktreeTool – exit the current worktree, optionally removing it, and
|
||||
// restore the original working directory.
|
||||
//
|
||||
// These tools mirror the TypeScript EnterWorktreeTool / ExitWorktreeTool.
|
||||
|
||||
use crate::{PermissionLevel, Tool, ToolContext, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::debug;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Session-level state: only one active worktree per session.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorktreeSession {
|
||||
pub original_cwd: PathBuf,
|
||||
pub worktree_path: PathBuf,
|
||||
pub branch: Option<String>,
|
||||
pub original_head: Option<String>,
|
||||
}
|
||||
|
||||
static WORKTREE_SESSION: Lazy<Arc<RwLock<Option<WorktreeSession>>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(None)));
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EnterWorktreeTool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct EnterWorktreeTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct EnterWorktreeInput {
|
||||
/// Optional branch name. If omitted, a timestamped branch is created.
|
||||
#[serde(default)]
|
||||
branch: Option<String>,
|
||||
/// Sub-path under the repo root where the worktree will be created.
|
||||
/// Defaults to `.worktrees/<branch>`.
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for EnterWorktreeTool {
|
||||
fn name(&self) -> &str { "EnterWorktree" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Create a new git worktree and switch the session's working directory to it. \
|
||||
This gives you an isolated environment to experiment or work on a feature \
|
||||
without affecting the main working tree. \
|
||||
Use ExitWorktree to return to the original directory."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Write }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"branch": {
|
||||
"type": "string",
|
||||
"description": "Branch name to create. Defaults to a timestamped name like worktree-1234567890."
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional path for the worktree directory. Defaults to .worktrees/<branch>."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
|
||||
let params: EnterWorktreeInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
// Check if already in a worktree session
|
||||
{
|
||||
let session = WORKTREE_SESSION.read().await;
|
||||
if session.is_some() {
|
||||
return ToolResult::error(
|
||||
"Already in a worktree session. Call ExitWorktree first.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = ctx.check_permission(
|
||||
self.name(),
|
||||
"Create a git worktree",
|
||||
false,
|
||||
) {
|
||||
return ToolResult::error(e.to_string());
|
||||
}
|
||||
|
||||
// Determine branch name
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
let branch = params
|
||||
.branch
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("worktree-{}", ts));
|
||||
|
||||
// Determine worktree path
|
||||
let worktree_path = if let Some(p) = params.path {
|
||||
ctx.working_dir.join(p)
|
||||
} else {
|
||||
ctx.working_dir.join(".worktrees").join(&branch)
|
||||
};
|
||||
|
||||
// Get current HEAD for change tracking
|
||||
let head = run_git(&ctx.working_dir, &["rev-parse", "HEAD"]).await;
|
||||
let original_head = head.ok().map(|h| h.trim().to_string());
|
||||
|
||||
// Create the worktree
|
||||
let worktree_str = worktree_path.to_string_lossy().to_string();
|
||||
let result = run_git(
|
||||
&ctx.working_dir,
|
||||
&["worktree", "add", "-b", &branch, &worktree_str],
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(e) => ToolResult::error(format!("Failed to create worktree: {}", e)),
|
||||
Ok(_) => {
|
||||
debug!(
|
||||
branch = %branch,
|
||||
path = %worktree_path.display(),
|
||||
"Created worktree"
|
||||
);
|
||||
|
||||
// Save session state
|
||||
*WORKTREE_SESSION.write().await = Some(WorktreeSession {
|
||||
original_cwd: ctx.working_dir.clone(),
|
||||
worktree_path: worktree_path.clone(),
|
||||
branch: Some(branch.clone()),
|
||||
original_head,
|
||||
});
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Created worktree at {} on branch '{}'.\n\
|
||||
The working directory is now {}.\n\
|
||||
Use ExitWorktree to return to {}.",
|
||||
worktree_path.display(),
|
||||
branch,
|
||||
worktree_path.display(),
|
||||
ctx.working_dir.display(),
|
||||
))
|
||||
.with_metadata(json!({
|
||||
"worktree_path": worktree_path.to_string_lossy(),
|
||||
"branch": branch,
|
||||
"original_cwd": ctx.working_dir.to_string_lossy(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExitWorktreeTool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct ExitWorktreeTool;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ExitWorktreeInput {
|
||||
/// "keep" = leave the worktree on disk; "remove" = delete it.
|
||||
#[serde(default = "default_action")]
|
||||
action: String,
|
||||
/// Required if action=="remove" and there are uncommitted changes.
|
||||
#[serde(default)]
|
||||
discard_changes: bool,
|
||||
}
|
||||
|
||||
fn default_action() -> String { "keep".to_string() }
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ExitWorktreeTool {
|
||||
fn name(&self) -> &str { "ExitWorktree" }
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Exit the current worktree session created by EnterWorktree and restore the \
|
||||
original working directory. Use action='keep' to preserve the worktree on \
|
||||
disk, or action='remove' to delete it. Only operates on worktrees created \
|
||||
by EnterWorktree in this session."
|
||||
}
|
||||
|
||||
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Write }
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["keep", "remove"],
|
||||
"description": "\"keep\" leaves the worktree on disk; \"remove\" deletes it and its branch."
|
||||
},
|
||||
"discard_changes": {
|
||||
"type": "boolean",
|
||||
"description": "Set true when action=remove and the worktree has uncommitted/unmerged work to discard."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
|
||||
let params: ExitWorktreeInput = match serde_json::from_value(input) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
|
||||
};
|
||||
|
||||
let session_guard = WORKTREE_SESSION.read().await;
|
||||
let session = match &*session_guard {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
return ToolResult::error(
|
||||
"No-op: there is no active EnterWorktree session to exit. \
|
||||
This tool only operates on worktrees created by EnterWorktree \
|
||||
in the current session."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
};
|
||||
drop(session_guard);
|
||||
|
||||
let worktree_str = session.worktree_path.to_string_lossy().to_string();
|
||||
|
||||
// If action is "remove", check for uncommitted changes
|
||||
if params.action == "remove" && !params.discard_changes {
|
||||
let status = run_git(&session.worktree_path, &["status", "--porcelain"]).await;
|
||||
let changed_files = status
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.lines()
|
||||
.filter(|l| !l.trim().is_empty())
|
||||
.count();
|
||||
|
||||
let commit_count = if let Some(ref head) = session.original_head {
|
||||
let rev = run_git(
|
||||
&session.worktree_path,
|
||||
&["rev-list", "--count", &format!("{}..HEAD", head)],
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
rev.trim().parse::<usize>().unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if changed_files > 0 || commit_count > 0 {
|
||||
let mut parts = Vec::new();
|
||||
if changed_files > 0 {
|
||||
parts.push(format!("{} uncommitted file(s)", changed_files));
|
||||
}
|
||||
if commit_count > 0 {
|
||||
parts.push(format!("{} commit(s) on the worktree branch", commit_count));
|
||||
}
|
||||
return ToolResult::error(format!(
|
||||
"Worktree has {}. Removing will discard this work permanently. \
|
||||
Confirm with the user, then re-invoke with discard_changes=true — \
|
||||
or use action=\"keep\" to preserve the worktree.",
|
||||
parts.join(" and ")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Clear session state
|
||||
*WORKTREE_SESSION.write().await = None;
|
||||
|
||||
match params.action.as_str() {
|
||||
"keep" => {
|
||||
// Just remove the worktree from git's tracking list (prune),
|
||||
// but keep the directory on disk.
|
||||
let _ = run_git(
|
||||
&session.original_cwd,
|
||||
&["worktree", "lock", "--reason", "kept by ExitWorktree", &worktree_str],
|
||||
)
|
||||
.await;
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Exited worktree. Work preserved at {} on branch {}. \
|
||||
Session is now back in {}.",
|
||||
session.worktree_path.display(),
|
||||
session.branch.as_deref().unwrap_or("(unknown)"),
|
||||
session.original_cwd.display(),
|
||||
))
|
||||
}
|
||||
"remove" => {
|
||||
// Remove the worktree
|
||||
let _ = run_git(
|
||||
&session.original_cwd,
|
||||
&["worktree", "remove", "--force", &worktree_str],
|
||||
)
|
||||
.await;
|
||||
|
||||
// Delete the branch if we created it
|
||||
if let Some(ref branch) = session.branch {
|
||||
let _ = run_git(
|
||||
&session.original_cwd,
|
||||
&["branch", "-D", branch],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
ToolResult::success(format!(
|
||||
"Exited and removed worktree at {}. \
|
||||
Session is now back in {}.",
|
||||
session.worktree_path.display(),
|
||||
session.original_cwd.display(),
|
||||
))
|
||||
}
|
||||
other => ToolResult::error(format!(
|
||||
"Unknown action '{}'. Use 'keep' or 'remove'.",
|
||||
other
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_git(cwd: &std::path::Path, args: &[&str]) -> Result<String, String> {
|
||||
let output = tokio::process::Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
Err(String::from_utf8_lossy(&output.stderr).to_string())
|
||||
}
|
||||
}
|
||||
21
src-rust/crates/tui/Cargo.toml
Normal file
21
src-rust/crates/tui/Cargo.toml
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
[package]
|
||||
name = "cc-tui"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cc-core = { workspace = true }
|
||||
cc-api = { workspace = true }
|
||||
cc-tools = { workspace = true }
|
||||
cc-query = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
ratatui = { workspace = true }
|
||||
crossterm = { workspace = true }
|
||||
unicode-width = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
1748
src-rust/crates/tui/src/lib.rs
Normal file
1748
src-rust/crates/tui/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue