hello world

This commit is contained in:
kuberwastaken 2026-04-01 01:20:27 +05:30
commit c99507ca1e
84 changed files with 54252 additions and 0 deletions

View file

@ -0,0 +1,31 @@
[package]
name = "cc-core"
version.workspace = true
edition.workspace = true
[dependencies]
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
toml = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
dirs = { workspace = true }
once_cell = { workspace = true }
parking_lot = { workspace = true }
dashmap = { workspace = true }
indexmap = { workspace = true }
regex = { workspace = true }
base64 = { workspace = true }
sha2 = { workspace = true }
hex = { workspace = true }
url = { workspace = true }
schemars = { workspace = true }
reqwest = { workspace = true }
urlencoding = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }

View file

@ -0,0 +1,403 @@
//! Analytics and telemetry (OpenTelemetry-compatible counters)
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
/// Session-level metrics counters (mirrors TypeScript bootstrap state).
///
/// All counters use `AtomicU64` so they can be shared across threads without
/// a mutex. Cost is stored as integer millicents (cost_usd × 100_000) to
/// avoid floating-point atomic arithmetic.
#[derive(Debug, Default)]
pub struct SessionMetrics {
/// Total cost in units of 1/100_000 USD (i.e. millicents).
pub total_cost_usd_millicents: AtomicU64,
pub total_input_tokens: AtomicU64,
pub total_output_tokens: AtomicU64,
pub total_api_duration_ms: AtomicU64,
pub total_tool_duration_ms: AtomicU64,
pub total_lines_added: AtomicU64,
pub total_lines_removed: AtomicU64,
pub session_count: AtomicU64,
pub commit_count: AtomicU64,
pub pr_count: AtomicU64,
pub tool_use_count: AtomicU64,
}
impl SessionMetrics {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn add_cost(&self, usd: f64) {
let millicents = (usd * 100_000.0) as u64;
self.total_cost_usd_millicents
.fetch_add(millicents, Ordering::Relaxed);
}
pub fn total_cost_usd(&self) -> f64 {
self.total_cost_usd_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
}
pub fn add_tokens(&self, input: u32, output: u32) {
self.total_input_tokens
.fetch_add(input as u64, Ordering::Relaxed);
self.total_output_tokens
.fetch_add(output as u64, Ordering::Relaxed);
}
pub fn add_api_duration(&self, ms: u64) {
self.total_api_duration_ms.fetch_add(ms, Ordering::Relaxed);
}
pub fn add_tool_duration(&self, ms: u64) {
self.total_tool_duration_ms.fetch_add(ms, Ordering::Relaxed);
}
pub fn add_lines(&self, added: i64, removed: i64) {
if added > 0 {
self.total_lines_added
.fetch_add(added as u64, Ordering::Relaxed);
}
if removed > 0 {
self.total_lines_removed
.fetch_add(removed as u64, Ordering::Relaxed);
}
}
pub fn increment_commits(&self) {
self.commit_count.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_prs(&self) {
self.pr_count.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_tool_use(&self) {
self.tool_use_count.fetch_add(1, Ordering::Relaxed);
}
pub fn summary(&self) -> MetricsSummary {
MetricsSummary {
cost_usd: self.total_cost_usd(),
input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
api_duration_ms: self.total_api_duration_ms.load(Ordering::Relaxed),
tool_duration_ms: self.total_tool_duration_ms.load(Ordering::Relaxed),
lines_added: self.total_lines_added.load(Ordering::Relaxed),
lines_removed: self.total_lines_removed.load(Ordering::Relaxed),
commits: self.commit_count.load(Ordering::Relaxed),
prs: self.pr_count.load(Ordering::Relaxed),
tool_uses: self.tool_use_count.load(Ordering::Relaxed),
}
}
}
/// A point-in-time snapshot of session metrics.
#[derive(Debug, Clone)]
pub struct MetricsSummary {
pub cost_usd: f64,
pub input_tokens: u64,
pub output_tokens: u64,
pub api_duration_ms: u64,
pub tool_duration_ms: u64,
pub lines_added: u64,
pub lines_removed: u64,
pub commits: u64,
pub prs: u64,
pub tool_uses: u64,
}
impl MetricsSummary {
/// Format cost as a dollar amount string with appropriate precision.
pub fn format_cost(&self) -> String {
if self.cost_usd < 0.01 {
format!("${:.5}", self.cost_usd)
} else {
format!("${:.4}", self.cost_usd)
}
}
/// Format total token count with K/M suffix.
pub fn format_tokens(&self) -> String {
let total = self.input_tokens + self.output_tokens;
if total >= 1_000_000 {
format!("{:.1}M tok", total as f64 / 1_000_000.0)
} else if total >= 1_000 {
format!("{:.1}K tok", total as f64 / 1_000.0)
} else {
format!("{} tok", total)
}
}
}
/// Event types for first-party analytics (privacy-respecting — no PII).
#[derive(Debug, Clone)]
pub enum AnalyticsEvent {
SessionStarted {
model: String,
is_interactive: bool,
},
SessionEnded {
turn_count: u32,
cost_usd: f64,
duration_ms: u64,
had_errors: bool,
},
ToolUsed {
tool_name: String,
success: bool,
duration_ms: u64,
},
CommandExecuted {
command: String,
success: bool,
},
CompactionTriggered {
tokens_before: u32,
tokens_after: u32,
},
}
/// Analytics sink — currently logs via `tracing`; can be extended to push
/// events to a first-party endpoint.
pub struct Analytics {
enabled: bool,
session_id: String,
}
impl Analytics {
pub fn new(session_id: String, enabled: bool) -> Self {
Self {
enabled,
session_id,
}
}
pub fn track(&self, event: AnalyticsEvent) {
if !self.enabled {
return;
}
tracing::debug!(
session_id = %self.session_id,
event = ?event,
"analytics event"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_metrics_initial_zero() {
let m = SessionMetrics::new();
assert_eq!(m.total_cost_usd(), 0.0);
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 0);
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 0);
}
#[test]
fn test_add_cost_single() {
let m = SessionMetrics::new();
m.add_cost(0.01);
let cost = m.total_cost_usd();
// Allow small floating-point tolerance
assert!((cost - 0.01).abs() < 1e-9, "cost = {}", cost);
}
#[test]
fn test_add_cost_accumulates() {
let m = SessionMetrics::new();
m.add_cost(1.0);
m.add_cost(2.5);
let cost = m.total_cost_usd();
assert!((cost - 3.5).abs() < 1e-9, "cost = {}", cost);
}
#[test]
fn test_add_tokens() {
let m = SessionMetrics::new();
m.add_tokens(1000, 500);
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1000);
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 500);
}
#[test]
fn test_add_tokens_accumulates() {
let m = SessionMetrics::new();
m.add_tokens(1000, 500);
m.add_tokens(200, 100);
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1200);
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 600);
}
#[test]
fn test_add_lines_positive() {
let m = SessionMetrics::new();
m.add_lines(10, 5);
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 10);
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 5);
}
#[test]
fn test_add_lines_negative_ignored() {
let m = SessionMetrics::new();
m.add_lines(-3, -7);
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 0);
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 0);
}
#[test]
fn test_increment_commits_and_prs() {
let m = SessionMetrics::new();
m.increment_commits();
m.increment_commits();
m.increment_prs();
assert_eq!(m.commit_count.load(Ordering::Relaxed), 2);
assert_eq!(m.pr_count.load(Ordering::Relaxed), 1);
}
#[test]
fn test_increment_tool_use() {
let m = SessionMetrics::new();
for _ in 0..5 {
m.increment_tool_use();
}
assert_eq!(m.tool_use_count.load(Ordering::Relaxed), 5);
}
#[test]
fn test_summary_snapshot() {
let m = SessionMetrics::new();
m.add_cost(1.23456);
m.add_tokens(100, 50);
m.add_api_duration(300);
m.add_tool_duration(150);
m.add_lines(8, 3);
m.increment_commits();
m.increment_prs();
m.increment_tool_use();
let s = m.summary();
assert!((s.cost_usd - 1.23456).abs() < 1e-9);
assert_eq!(s.input_tokens, 100);
assert_eq!(s.output_tokens, 50);
assert_eq!(s.api_duration_ms, 300);
assert_eq!(s.tool_duration_ms, 150);
assert_eq!(s.lines_added, 8);
assert_eq!(s.lines_removed, 3);
assert_eq!(s.commits, 1);
assert_eq!(s.prs, 1);
assert_eq!(s.tool_uses, 1);
}
#[test]
fn test_format_cost_small() {
let s = MetricsSummary {
cost_usd: 0.001,
input_tokens: 0,
output_tokens: 0,
api_duration_ms: 0,
tool_duration_ms: 0,
lines_added: 0,
lines_removed: 0,
commits: 0,
prs: 0,
tool_uses: 0,
};
let formatted = s.format_cost();
assert!(formatted.starts_with('$'));
// Should have 5 decimal places for small cost
assert!(formatted.contains('.'));
}
#[test]
fn test_format_cost_large() {
let s = MetricsSummary {
cost_usd: 1.5,
input_tokens: 0,
output_tokens: 0,
api_duration_ms: 0,
tool_duration_ms: 0,
lines_added: 0,
lines_removed: 0,
commits: 0,
prs: 0,
tool_uses: 0,
};
assert_eq!(s.format_cost(), "$1.5000");
}
#[test]
fn test_format_tokens_exact() {
let s = MetricsSummary {
cost_usd: 0.0,
input_tokens: 500,
output_tokens: 300,
api_duration_ms: 0,
tool_duration_ms: 0,
lines_added: 0,
lines_removed: 0,
commits: 0,
prs: 0,
tool_uses: 0,
};
assert_eq!(s.format_tokens(), "800 tok");
}
#[test]
fn test_format_tokens_kilo() {
let s = MetricsSummary {
cost_usd: 0.0,
input_tokens: 5_000,
output_tokens: 3_000,
api_duration_ms: 0,
tool_duration_ms: 0,
lines_added: 0,
lines_removed: 0,
commits: 0,
prs: 0,
tool_uses: 0,
};
assert!(s.format_tokens().ends_with("K tok"));
}
#[test]
fn test_format_tokens_mega() {
let s = MetricsSummary {
cost_usd: 0.0,
input_tokens: 1_500_000,
output_tokens: 500_000,
api_duration_ms: 0,
tool_duration_ms: 0,
lines_added: 0,
lines_removed: 0,
commits: 0,
prs: 0,
tool_uses: 0,
};
assert!(s.format_tokens().ends_with("M tok"));
}
#[test]
fn test_analytics_track_disabled_no_panic() {
let a = Analytics::new("test-session".to_string(), false);
// Should not panic even though disabled
a.track(AnalyticsEvent::SessionStarted {
model: "claude-opus-4-6".to_string(),
is_interactive: true,
});
}
#[test]
fn test_analytics_track_enabled_no_panic() {
let a = Analytics::new("test-session-2".to_string(), true);
a.track(AnalyticsEvent::ToolUsed {
tool_name: "Bash".to_string(),
success: true,
duration_ms: 42,
});
}
}

View file

@ -0,0 +1,423 @@
//! Configurable keyboard shortcuts system
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
/// All keybinding contexts
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub enum KeyContext {
Global,
Chat,
Autocomplete,
Confirmation,
Help,
Transcript,
HistorySearch,
Task,
ThemePicker,
Settings,
Tabs,
Attachments,
Footer,
MessageSelector,
DiffDialog,
ModelPicker,
Select,
Plugin,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedKeystroke {
pub key: String, // normalized key name
pub ctrl: bool,
pub alt: bool,
pub shift: bool,
pub meta: bool,
}
pub type Chord = Vec<ParsedKeystroke>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedBinding {
pub chord: Chord,
pub action: Option<String>, // None = unbound
pub context: KeyContext,
}
/// Parse a keystroke string like "ctrl+shift+enter" into ParsedKeystroke
pub fn parse_keystroke(s: &str) -> Option<ParsedKeystroke> {
let s = s.trim().to_lowercase();
let mut ctrl = false;
let mut alt = false;
let mut shift = false;
let mut meta = false;
let mut key_parts: Vec<&str> = Vec::new();
for part in s.split('+') {
let part = part.trim();
match part {
"ctrl" | "control" => ctrl = true,
"alt" | "opt" | "option" => alt = true,
"shift" => shift = true,
"meta" | "cmd" | "command" | "super" | "win" => meta = true,
_ => key_parts.push(part),
}
}
if key_parts.is_empty() {
return None;
}
let key = normalize_key(key_parts.join("+").as_str());
Some(ParsedKeystroke {
key,
ctrl,
alt,
shift,
meta,
})
}
fn normalize_key(k: &str) -> String {
match k {
"esc" | "escape" => "escape".to_string(),
"return" | "enter" => "enter".to_string(),
"del" | "delete" => "delete".to_string(),
"backspace" | "bs" => "backspace".to_string(),
"space" | " " => "space".to_string(),
"up" => "up".to_string(),
"down" => "down".to_string(),
"left" => "left".to_string(),
"right" => "right".to_string(),
"pageup" | "pgup" => "pageup".to_string(),
"pagedown" | "pgdn" | "pgdown" => "pagedown".to_string(),
"home" => "home".to_string(),
"end" => "end".to_string(),
"tab" => "tab".to_string(),
k => k.to_string(),
}
}
/// Parse a chord (space-separated keystrokes like "ctrl+k ctrl+d")
pub fn parse_chord(s: &str) -> Option<Chord> {
let keystrokes: Vec<ParsedKeystroke> =
s.split_whitespace().filter_map(parse_keystroke).collect();
if keystrokes.is_empty() {
None
} else {
Some(keystrokes)
}
}
/// Keys that cannot be rebound
pub const NON_REBINDABLE: &[&str] = &["ctrl+c", "ctrl+d", "ctrl+m"];
/// Default keybindings
pub fn default_bindings() -> Vec<ParsedBinding> {
let defaults: &[(&str, &str, KeyContext)] = &[
// Global
("ctrl+c", "interrupt", KeyContext::Global),
("ctrl+d", "exit", KeyContext::Global),
("ctrl+l", "redraw", KeyContext::Global),
("ctrl+r", "historySearch", KeyContext::Global),
// Chat
("enter", "submit", KeyContext::Chat),
("up", "historyPrev", KeyContext::Chat),
("down", "historyNext", KeyContext::Chat),
("shift+tab", "cycleMode", KeyContext::Chat),
("pageup", "scrollUp", KeyContext::Chat),
("pagedown", "scrollDown", KeyContext::Chat),
// Confirmation
("y", "yes", KeyContext::Confirmation),
("enter", "yes", KeyContext::Confirmation),
("n", "no", KeyContext::Confirmation),
("escape", "no", KeyContext::Confirmation),
("up", "prevOption", KeyContext::Confirmation),
("down", "nextOption", KeyContext::Confirmation),
// Help
("escape", "close", KeyContext::Help),
("q", "close", KeyContext::Help),
// HistorySearch
("enter", "select", KeyContext::HistorySearch),
("escape", "cancel", KeyContext::HistorySearch),
("up", "prevResult", KeyContext::HistorySearch),
("down", "nextResult", KeyContext::HistorySearch),
];
defaults
.iter()
.filter_map(|(chord_str, action, context)| {
parse_chord(chord_str).map(|chord| ParsedBinding {
chord,
action: Some(action.to_string()),
context: context.clone(),
})
})
.collect()
}
/// User keybindings loaded from ~/.claude/keybindings.json
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct UserKeybindings {
pub bindings: Vec<UserBinding>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserBinding {
pub chord: String, // e.g. "ctrl+k ctrl+d"
pub action: Option<String>, // None = unbound
pub context: Option<String>,
}
impl UserKeybindings {
pub fn load(config_dir: &Path) -> Self {
let path = config_dir.join("keybindings.json");
if let Ok(content) = std::fs::read_to_string(&path) {
serde_json::from_str(&content).unwrap_or_default()
} else {
Self::default()
}
}
pub fn save(&self, config_dir: &Path) -> anyhow::Result<()> {
let path = config_dir.join("keybindings.json");
let json = serde_json::to_string_pretty(self)?;
std::fs::write(path, json)?;
Ok(())
}
}
/// Resolved keybindings (defaults merged with user overrides)
pub struct KeybindingResolver {
bindings: Vec<ParsedBinding>,
pending_chord: Vec<ParsedKeystroke>,
}
impl KeybindingResolver {
pub fn new(user: &UserKeybindings) -> Self {
let mut bindings = default_bindings();
// Apply user overrides (user bindings win, last match wins)
for user_binding in &user.bindings {
if let Some(chord) = parse_chord(&user_binding.chord) {
let context = user_binding
.context
.as_deref()
.and_then(|c| serde_json::from_str(&format!("\"{}\"", c)).ok())
.unwrap_or(KeyContext::Global);
bindings.push(ParsedBinding {
chord,
action: user_binding.action.clone(),
context,
});
}
}
Self {
bindings,
pending_chord: Vec::new(),
}
}
/// Process a keystroke, returns action if binding matches
pub fn process(
&mut self,
keystroke: ParsedKeystroke,
context: &KeyContext,
) -> KeybindingResult {
self.pending_chord.push(keystroke);
// Find matching bindings in current context + Global
let matches: Vec<&ParsedBinding> = self
.bindings
.iter()
.filter(|b| &b.context == context || b.context == KeyContext::Global)
.filter(|b| b.chord.starts_with(self.pending_chord.as_slice()))
.collect();
if matches.is_empty() {
self.pending_chord.clear();
return KeybindingResult::NoMatch;
}
let exact: Vec<&ParsedBinding> = matches
.iter()
.copied()
.filter(|b| b.chord.len() == self.pending_chord.len())
.collect();
if !exact.is_empty() {
// Last match wins (user overrides)
let binding = exact.last().unwrap();
self.pending_chord.clear();
return match &binding.action {
Some(action) => KeybindingResult::Action(action.clone()),
None => KeybindingResult::Unbound,
};
}
// Chord in progress
KeybindingResult::Pending
}
pub fn cancel_chord(&mut self) {
self.pending_chord.clear();
}
}
impl PartialEq for ParsedKeystroke {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
&& self.ctrl == other.ctrl
&& self.alt == other.alt
&& self.shift == other.shift
&& self.meta == other.meta
}
}
#[derive(Debug, Clone)]
pub enum KeybindingResult {
Action(String),
Unbound,
Pending,
NoMatch,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_keystroke_simple() {
let ks = parse_keystroke("enter").unwrap();
assert_eq!(ks.key, "enter");
assert!(!ks.ctrl);
assert!(!ks.alt);
assert!(!ks.shift);
assert!(!ks.meta);
}
#[test]
fn test_parse_keystroke_ctrl_c() {
let ks = parse_keystroke("ctrl+c").unwrap();
assert_eq!(ks.key, "c");
assert!(ks.ctrl);
assert!(!ks.alt);
}
#[test]
fn test_parse_keystroke_ctrl_shift_enter() {
let ks = parse_keystroke("ctrl+shift+enter").unwrap();
assert_eq!(ks.key, "enter");
assert!(ks.ctrl);
assert!(ks.shift);
assert!(!ks.alt);
}
#[test]
fn test_parse_keystroke_normalizes_esc() {
let ks = parse_keystroke("esc").unwrap();
assert_eq!(ks.key, "escape");
}
#[test]
fn test_parse_keystroke_normalizes_return() {
let ks = parse_keystroke("return").unwrap();
assert_eq!(ks.key, "enter");
}
#[test]
fn test_parse_keystroke_empty_returns_none() {
assert!(parse_keystroke("ctrl+").is_none());
assert!(parse_keystroke("").is_none());
}
#[test]
fn test_parse_chord_single() {
let chord = parse_chord("ctrl+c").unwrap();
assert_eq!(chord.len(), 1);
assert_eq!(chord[0].key, "c");
assert!(chord[0].ctrl);
}
#[test]
fn test_parse_chord_multi() {
let chord = parse_chord("ctrl+k ctrl+d").unwrap();
assert_eq!(chord.len(), 2);
assert_eq!(chord[0].key, "k");
assert_eq!(chord[1].key, "d");
assert!(chord[0].ctrl);
assert!(chord[1].ctrl);
}
#[test]
fn test_parse_chord_empty_returns_none() {
assert!(parse_chord("").is_none());
}
#[test]
fn test_default_bindings_not_empty() {
let bindings = default_bindings();
assert!(!bindings.is_empty());
}
#[test]
fn test_default_bindings_contains_ctrl_c() {
let bindings = default_bindings();
let ctrl_c = bindings.iter().find(|b| {
b.chord.len() == 1
&& b.chord[0].ctrl
&& b.chord[0].key == "c"
&& b.context == KeyContext::Global
});
assert!(ctrl_c.is_some());
assert_eq!(ctrl_c.unwrap().action.as_deref(), Some("interrupt"));
}
#[test]
fn test_resolver_simple_action() {
let user = UserKeybindings::default();
let mut resolver = KeybindingResolver::new(&user);
let ks = parse_keystroke("ctrl+c").unwrap();
let result = resolver.process(ks, &KeyContext::Global);
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "interrupt"));
}
#[test]
fn test_resolver_no_match() {
let user = UserKeybindings::default();
let mut resolver = KeybindingResolver::new(&user);
// ctrl+z has no default binding
let ks = parse_keystroke("ctrl+z").unwrap();
let result = resolver.process(ks, &KeyContext::Chat);
assert!(matches!(result, KeybindingResult::NoMatch));
}
#[test]
fn test_resolver_context_match_global_from_chat() {
let user = UserKeybindings::default();
let mut resolver = KeybindingResolver::new(&user);
// ctrl+l is Global, should match even when context is Chat
let ks = parse_keystroke("ctrl+l").unwrap();
let result = resolver.process(ks, &KeyContext::Chat);
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "redraw"));
}
#[test]
fn test_keystroke_equality() {
let ks1 = parse_keystroke("ctrl+enter").unwrap();
let ks2 = parse_keystroke("ctrl+enter").unwrap();
let ks3 = parse_keystroke("shift+enter").unwrap();
assert_eq!(ks1, ks2);
assert_ne!(ks1, ks3);
}
#[test]
fn test_user_keybindings_default_empty() {
let user = UserKeybindings::default();
assert!(user.bindings.is_empty());
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,294 @@
//! Language Server Protocol client stub.
//!
//! The full LSP implementation is provided by plugins; this module defines
//! the integration interface that the rest of the codebase uses to query
//! diagnostics, register servers, and format output.
use serde::{Deserialize, Serialize};
/// Configuration for a single LSP server process.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspServerConfig {
/// Display name, e.g. "rust-analyzer"
pub name: String,
/// Path or name of the server binary, e.g. "rust-analyzer"
pub command: String,
/// Command-line arguments passed to the server binary
pub args: Vec<String>,
/// Glob patterns that activate this server, e.g. `["*.rs", "*.toml"]`
pub file_patterns: Vec<String>,
/// Optional server-specific initialization options (passed in LSP `initialize`)
pub initialization_options: Option<serde_json::Value>,
}
/// A single diagnostic emitted by an LSP server.
#[derive(Debug, Clone)]
pub struct LspDiagnostic {
/// Workspace-relative or absolute file path
pub file: String,
/// 1-based line number
pub line: u32,
/// 1-based column number
pub column: u32,
pub severity: DiagnosticSeverity,
pub message: String,
/// The LSP server that produced this diagnostic (e.g. "rust-analyzer")
pub source: Option<String>,
/// Diagnostic code (e.g. "E0308"), if provided by the server
pub code: Option<String>,
}
/// Severity level of a diagnostic, matching the LSP spec.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum DiagnosticSeverity {
Error = 1,
Warning = 2,
Information = 3,
Hint = 4,
}
impl DiagnosticSeverity {
pub fn as_str(&self) -> &'static str {
match self {
Self::Error => "error",
Self::Warning => "warning",
Self::Information => "info",
Self::Hint => "hint",
}
}
}
/// LSP manager stub.
///
/// In the full implementation this will own LSP server processes and route
/// JSON-RPC messages. For now it is a registry that tracks configured
/// servers and returns empty diagnostic lists — the plugin system is
/// responsible for wiring up real communication.
pub struct LspManager {
servers: Vec<LspServerConfig>,
}
impl LspManager {
pub fn new() -> Self {
Self {
servers: Vec::new(),
}
}
/// Register an LSP server configuration.
pub fn register_server(&mut self, config: LspServerConfig) {
self.servers.push(config);
}
/// Return all registered server configurations.
pub fn servers(&self) -> &[LspServerConfig] {
&self.servers
}
/// Look up a server configuration by name.
pub fn server_by_name(&self, name: &str) -> Option<&LspServerConfig> {
self.servers.iter().find(|s| s.name == name)
}
/// Get diagnostics for a file.
///
/// This stub always returns an empty list. When an LSP plugin connects it
/// will replace this path with real RPC calls.
pub async fn get_diagnostics(&self, _file: &str) -> Vec<LspDiagnostic> {
Vec::new()
}
/// Format a slice of diagnostics into a human-readable multi-line string
/// suitable for inclusion in tool output or TUI display.
pub fn format_diagnostics(diagnostics: &[LspDiagnostic]) -> String {
if diagnostics.is_empty() {
return "No diagnostics.".to_string();
}
diagnostics
.iter()
.map(|d| {
format!(
"[{}] {}:{}:{} - {}{}{}",
d.severity.as_str().to_uppercase(),
d.file,
d.line,
d.column,
d.message,
d.source
.as_deref()
.map(|s| format!(" ({})", s))
.unwrap_or_default(),
d.code
.as_deref()
.map(|c| format!(" [{}]", c))
.unwrap_or_default(),
)
})
.collect::<Vec<_>>()
.join("\n")
}
}
impl Default for LspManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(name: &str) -> LspServerConfig {
LspServerConfig {
name: name.to_string(),
command: name.to_string(),
args: vec![],
file_patterns: vec!["*.rs".to_string()],
initialization_options: None,
}
}
fn make_diagnostic(
file: &str,
line: u32,
col: u32,
severity: DiagnosticSeverity,
message: &str,
) -> LspDiagnostic {
LspDiagnostic {
file: file.to_string(),
line,
column: col,
severity,
message: message.to_string(),
source: None,
code: None,
}
}
#[test]
fn test_new_manager_empty() {
let mgr = LspManager::new();
assert!(mgr.servers().is_empty());
}
#[test]
fn test_register_server() {
let mut mgr = LspManager::new();
mgr.register_server(make_config("rust-analyzer"));
assert_eq!(mgr.servers().len(), 1);
assert_eq!(mgr.servers()[0].name, "rust-analyzer");
}
#[test]
fn test_register_multiple_servers() {
let mut mgr = LspManager::new();
mgr.register_server(make_config("rust-analyzer"));
mgr.register_server(make_config("pyright"));
assert_eq!(mgr.servers().len(), 2);
}
#[test]
fn test_server_by_name_found() {
let mut mgr = LspManager::new();
mgr.register_server(make_config("rust-analyzer"));
mgr.register_server(make_config("pyright"));
let s = mgr.server_by_name("pyright");
assert!(s.is_some());
assert_eq!(s.unwrap().name, "pyright");
}
#[test]
fn test_server_by_name_not_found() {
let mgr = LspManager::new();
assert!(mgr.server_by_name("missing").is_none());
}
#[tokio::test]
async fn test_get_diagnostics_stub_empty() {
let mgr = LspManager::new();
let diags = mgr.get_diagnostics("src/main.rs").await;
assert!(diags.is_empty());
}
#[test]
fn test_format_diagnostics_empty() {
let result = LspManager::format_diagnostics(&[]);
assert_eq!(result, "No diagnostics.");
}
#[test]
fn test_format_diagnostics_single_error() {
let diags = vec![make_diagnostic(
"src/lib.rs",
10,
5,
DiagnosticSeverity::Error,
"type mismatch",
)];
let result = LspManager::format_diagnostics(&diags);
assert!(result.contains("[ERROR]"));
assert!(result.contains("src/lib.rs"));
assert!(result.contains("10:5"));
assert!(result.contains("type mismatch"));
}
#[test]
fn test_format_diagnostics_multiple() {
let diags = vec![
make_diagnostic("a.rs", 1, 1, DiagnosticSeverity::Error, "err1"),
make_diagnostic("b.rs", 2, 3, DiagnosticSeverity::Warning, "warn1"),
];
let result = LspManager::format_diagnostics(&diags);
let lines: Vec<&str> = result.lines().collect();
assert_eq!(lines.len(), 2);
assert!(lines[0].contains("[ERROR]"));
assert!(lines[1].contains("[WARNING]"));
}
#[test]
fn test_format_diagnostics_with_source_and_code() {
let mut d = make_diagnostic(
"main.rs",
5,
1,
DiagnosticSeverity::Error,
"mismatched types",
);
d.source = Some("rust-analyzer".to_string());
d.code = Some("E0308".to_string());
let result = LspManager::format_diagnostics(&[d]);
assert!(result.contains("(rust-analyzer)"), "result = {}", result);
assert!(result.contains("[E0308]"), "result = {}", result);
}
#[test]
fn test_diagnostic_severity_ordering() {
assert!(DiagnosticSeverity::Error < DiagnosticSeverity::Warning);
assert!(DiagnosticSeverity::Warning < DiagnosticSeverity::Information);
assert!(DiagnosticSeverity::Information < DiagnosticSeverity::Hint);
}
#[test]
fn test_diagnostic_severity_as_str() {
assert_eq!(DiagnosticSeverity::Error.as_str(), "error");
assert_eq!(DiagnosticSeverity::Warning.as_str(), "warning");
assert_eq!(DiagnosticSeverity::Information.as_str(), "info");
assert_eq!(DiagnosticSeverity::Hint.as_str(), "hint");
}
#[test]
fn test_lsp_server_config_serialization() {
let cfg = make_config("rust-analyzer");
let json = serde_json::to_string(&cfg).unwrap();
let back: LspServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.name, "rust-analyzer");
}
#[test]
fn test_default_trait() {
let mgr = LspManager::default();
assert!(mgr.servers().is_empty());
}
}

View file

@ -0,0 +1,881 @@
//! Memory directory (memdir) system.
//!
//! Provides persistent, file-based memory across sessions. Mirrors the
//! TypeScript modules under `src/memdir/`:
//! - `memoryScan.ts` → `scan_memory_dir`, `parse_frontmatter_quick`, `format_memory_manifest`
//! - `memoryAge.ts` → `memory_age_days`, `memory_freshness_text`, `memory_freshness_note`
//! - `memdir.ts` → `build_memory_prompt_content`, `load_memory_index`, `ensure_memory_dir_exists`
//! - `paths.ts` → `auto_memory_path`, `is_auto_memory_enabled`
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
// ---------------------------------------------------------------------------
// Memory type taxonomy
// ---------------------------------------------------------------------------
/// The four canonical memory types.
/// Matches the TypeScript `MemoryType` union in `memoryTypes.ts`.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MemoryType {
/// Information about the user's role, goals, and preferences.
User,
/// Guidance the user has given about how to approach work.
Feedback,
/// Information about ongoing work, goals, or incidents in the project.
Project,
/// Pointers to where information lives in external systems.
Reference,
}
impl MemoryType {
/// Parse a raw frontmatter value into a `MemoryType`.
/// Returns `None` for missing or unrecognised values (legacy files degrade gracefully).
pub fn parse(raw: &str) -> Option<Self> {
match raw.trim() {
"user" => Some(Self::User),
"feedback" => Some(Self::Feedback),
"project" => Some(Self::Project),
"reference" => Some(Self::Reference),
_ => None,
}
}
/// Display as a lowercase string.
pub fn as_str(&self) -> &'static str {
match self {
Self::User => "user",
Self::Feedback => "feedback",
Self::Project => "project",
Self::Reference => "reference",
}
}
}
// ---------------------------------------------------------------------------
// Memory file metadata and content
// ---------------------------------------------------------------------------
/// Scanned metadata for a single memory file (without the full body).
/// Mirrors `MemoryHeader` in `memoryScan.ts`.
#[derive(Debug, Clone)]
pub struct MemoryFileMeta {
/// Filename relative to the memory directory (e.g. `user_role.md`).
pub filename: String,
/// Absolute path to the file.
pub path: PathBuf,
/// `name:` frontmatter field.
pub name: Option<String>,
/// `description:` frontmatter field (used for relevance scoring).
pub description: Option<String>,
/// `type:` frontmatter field.
pub memory_type: Option<MemoryType>,
/// File modification time in seconds since UNIX epoch.
pub modified_secs: u64,
}
/// A fully-loaded memory file including its body.
#[derive(Debug, Clone)]
pub struct MemoryFile {
pub meta: MemoryFileMeta,
pub content: String,
}
// ---------------------------------------------------------------------------
// Directory scanning
// ---------------------------------------------------------------------------
/// Maximum number of memory files kept after sorting.
/// Matches `MAX_MEMORY_FILES` in `memoryScan.ts`.
const MAX_MEMORY_FILES: usize = 200;
/// Number of lines scanned for frontmatter.
/// Matches `FRONTMATTER_MAX_LINES` in `memoryScan.ts`.
const FRONTMATTER_MAX_LINES: usize = 30;
/// Scan a memory directory, returning metadata for all `.md` files
/// (excluding `MEMORY.md`), sorted newest-first, capped at `MAX_MEMORY_FILES`.
///
/// This is a synchronous scan used during system-prompt assembly.
/// Mirrors `scanMemoryFiles` in `memoryScan.ts` (async version; this is the
/// sync equivalent used at prompt-build time).
pub fn scan_memory_dir(dir: &Path) -> Vec<MemoryFileMeta> {
let mut files: Vec<MemoryFileMeta> = Vec::new();
if !dir.exists() {
return files;
}
// Walk recursively using `walkdir`-style manual recursion to stay
// dependency-free (only std).
collect_md_files(dir, dir, &mut files);
// Sort newest-first.
files.sort_by(|a, b| b.modified_secs.cmp(&a.modified_secs));
files.truncate(MAX_MEMORY_FILES);
files
}
/// Recursively collect `.md` files (excluding `MEMORY.md`) from `current_dir`.
fn collect_md_files(base: &Path, current_dir: &Path, out: &mut Vec<MemoryFileMeta>) {
let Ok(entries) = std::fs::read_dir(current_dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
collect_md_files(base, &path, out);
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
let file_name = path.file_name().map(|n| n.to_string_lossy().into_owned()).unwrap_or_default();
if file_name == "MEMORY.md" {
continue;
}
let modified_secs = entry
.metadata()
.and_then(|m| m.modified())
.map(|t| t.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs())
.unwrap_or(0);
let (name, description, memory_type) =
if let Ok(content) = std::fs::read_to_string(&path) {
parse_frontmatter_quick(&content)
} else {
(None, None, None)
};
// Relative path from the memory dir root.
let relative = path
.strip_prefix(base)
.map(|p| p.to_string_lossy().into_owned())
.unwrap_or_else(|_| file_name.clone());
out.push(MemoryFileMeta {
filename: relative,
path,
name,
description,
memory_type,
modified_secs,
});
}
}
}
/// Parse YAML frontmatter from the first `FRONTMATTER_MAX_LINES` lines without
/// a full YAML parser. Returns `(name, description, memory_type)`.
///
/// Mirrors `parseFrontmatter` usage in `memoryScan.ts`.
pub fn parse_frontmatter_quick(
content: &str,
) -> (Option<String>, Option<String>, Option<MemoryType>) {
let mut name = None;
let mut description = None;
let mut memory_type = None;
let lines: Vec<&str> = content.lines().take(FRONTMATTER_MAX_LINES).collect();
// Frontmatter must start with `---`
if lines.first().map(|l| l.trim() != "---").unwrap_or(true) {
return (name, description, memory_type);
}
for line in &lines[1..] {
if line.trim() == "---" {
break;
}
if let Some(rest) = line.strip_prefix("name:") {
name = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
} else if let Some(rest) = line.strip_prefix("description:") {
description = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
} else if let Some(rest) = line.strip_prefix("type:") {
memory_type = MemoryType::parse(rest.trim().trim_matches('"').trim_matches('\''));
}
}
(name, description, memory_type)
}
/// Format memory headers as a text manifest: one entry per file with
/// `[type] filename (iso-timestamp): description`.
///
/// Mirrors `formatMemoryManifest` in `memoryScan.ts`.
pub fn format_memory_manifest(memories: &[MemoryFileMeta]) -> String {
memories
.iter()
.map(|m| {
let tag = m
.memory_type
.as_ref()
.map(|t| format!("[{}] ", t.as_str()))
.unwrap_or_default();
// Convert modified_secs to an ISO-8601-like timestamp.
let ts = format_unix_secs_iso(m.modified_secs);
match &m.description {
Some(desc) => format!("- {}{} ({}): {}", tag, m.filename, ts, desc),
None => format!("- {}{} ({})", tag, m.filename, ts),
}
})
.collect::<Vec<_>>()
.join("\n")
}
/// Minimal ISO-8601 formatter for a Unix timestamp (no external deps).
fn format_unix_secs_iso(secs: u64) -> String {
// We use a very lightweight implementation to avoid pulling in chrono here
// (chrono is already a workspace dep but we want this module to stay lean).
// Accuracy to the day is sufficient for memory manifests.
let days_since_epoch = secs / 86400;
// Julian Day Number for 1970-01-01 is 2440588.
let jdn = days_since_epoch as u32 + 2440588;
let (y, m, d) = jdn_to_ymd(jdn);
let hh = (secs % 86400) / 3600;
let mm = (secs % 3600) / 60;
let ss = secs % 60;
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, hh, mm, ss)
}
/// Convert a Julian Day Number to (year, month, day).
fn jdn_to_ymd(jdn: u32) -> (u32, u32, u32) {
let a = jdn + 32044;
let b = (4 * a + 3) / 146097;
let c = a - (146097 * b) / 4;
let d = (4 * c + 3) / 1461;
let e = c - (1461 * d) / 4;
let m = (5 * e + 2) / 153;
let day = e - (153 * m + 2) / 5 + 1;
let month = m + 3 - 12 * (m / 10);
let year = 100 * b + d - 4800 + m / 10;
(year, month, day)
}
// ---------------------------------------------------------------------------
// Memory age / freshness
// ---------------------------------------------------------------------------
/// Days elapsed since `modified_secs`. Floor-rounded; clamped to 0 for
/// future mtimes (clock skew).
///
/// Mirrors `memoryAgeDays` in `memoryAge.ts`.
pub fn memory_age_days(modified_secs: u64) -> u64 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
(now.saturating_sub(modified_secs)) / 86400
}
/// Human-readable age string. Models are poor at date arithmetic — a raw
/// ISO timestamp does not trigger staleness reasoning the way "47 days ago" does.
///
/// Mirrors `memoryAge` in `memoryAge.ts`.
pub fn memory_age(modified_secs: u64) -> String {
let d = memory_age_days(modified_secs);
match d {
0 => "today".to_string(),
1 => "yesterday".to_string(),
n => format!("{} days ago", n),
}
}
/// Plain-text staleness caveat for memories > 1 day old.
/// Returns an empty string for fresh memories (today / yesterday).
///
/// Mirrors `memoryFreshnessText` in `memoryAge.ts`.
pub fn memory_freshness_text(modified_secs: u64) -> String {
let d = memory_age_days(modified_secs);
if d <= 1 {
return String::new();
}
format!(
"This memory is {} days old. \
Memories are point-in-time observations, not live state \
claims about code behavior or file:line citations may be outdated. \
Verify against current code before asserting as fact.",
d
)
}
/// Per-memory staleness note wrapped in `<system-reminder>` tags.
/// Returns an empty string for memories ≤ 1 day old.
///
/// Mirrors `memoryFreshnessNote` in `memoryAge.ts`.
pub fn memory_freshness_note(modified_secs: u64) -> String {
let text = memory_freshness_text(modified_secs);
if text.is_empty() {
return String::new();
}
format!("<system-reminder>{}</system-reminder>\n", text)
}
// ---------------------------------------------------------------------------
// Path resolution
// ---------------------------------------------------------------------------
/// Entrypoint filename within the memory directory.
pub const MEMORY_ENTRYPOINT: &str = "MEMORY.md";
/// Maximum number of lines loaded from `MEMORY.md`.
/// Matches `MAX_ENTRYPOINT_LINES` in `memdir.ts`.
pub const MAX_ENTRYPOINT_LINES: usize = 200;
/// Maximum bytes loaded from `MEMORY.md`.
/// Matches `MAX_ENTRYPOINT_BYTES` in `memdir.ts`.
pub const MAX_ENTRYPOINT_BYTES: usize = 25_000;
/// Compute the auto-memory directory path for a project root.
///
/// Resolution order (mirrors `getAutoMemPath` in `paths.ts`):
/// 1. `CLAUDE_COWORK_MEMORY_PATH_OVERRIDE` env var (full-path override).
/// 2. `<CLAUDE_CODE_REMOTE_MEMORY_DIR>/projects/<sanitized-root>/memory/`
/// when `CLAUDE_CODE_REMOTE_MEMORY_DIR` is set.
/// 3. `~/.claude/projects/<sanitized-root>/memory/` (default).
pub fn auto_memory_path(project_root: &Path) -> PathBuf {
// 1. Cowork full-path override.
if let Ok(override_path) = std::env::var("CLAUDE_COWORK_MEMORY_PATH_OVERRIDE") {
if !override_path.is_empty() {
return PathBuf::from(override_path);
}
}
// 2. Determine the memory base directory.
let memory_base = std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".claude")
});
// 3. Sanitize the project root into a safe directory name.
let sanitized = sanitize_path_component(&project_root.to_string_lossy());
memory_base.join("projects").join(sanitized).join("memory")
}
/// Sanitize an arbitrary string into a directory-name-safe component.
/// Matches `sanitizePath` used inside `getAutoMemPath` in `paths.ts`.
pub fn sanitize_path_component(s: &str) -> String {
let sanitized: String = s
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
c
} else {
'_'
}
})
.collect();
sanitized.trim_matches('_').to_string()
}
/// Whether the auto-memory system is enabled for this session.
///
/// Priority chain (mirrors `isAutoMemoryEnabled` in `paths.ts`):
/// 1. `CLAUDE_CODE_DISABLE_AUTO_MEMORY` — truthy → OFF, falsy (but defined) → ON.
/// 2. `CLAUDE_CODE_SIMPLE` (--bare) → OFF.
/// 3. Remote mode without `CLAUDE_CODE_REMOTE_MEMORY_DIR` → OFF.
/// 4. `settings_enabled` parameter (from settings.json `autoMemoryEnabled` field).
/// 5. Default: enabled.
pub fn is_auto_memory_enabled(settings_enabled: Option<bool>) -> bool {
if let Ok(val) = std::env::var("CLAUDE_CODE_DISABLE_AUTO_MEMORY") {
// Truthy values (non-empty, non-"0", non-"false") disable memory.
match val.to_lowercase().as_str() {
"" | "0" | "false" | "no" | "off" => return true, // defined-falsy → ON
_ => return false, // truthy → OFF
}
}
if std::env::var("CLAUDE_CODE_SIMPLE").is_ok() {
return false;
}
if std::env::var("CLAUDE_CODE_REMOTE").is_ok()
&& std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR").is_err()
{
return false;
}
settings_enabled.unwrap_or(true)
}
// ---------------------------------------------------------------------------
// Index loading and truncation
// ---------------------------------------------------------------------------
/// Result of loading and (optionally) truncating the `MEMORY.md` entrypoint.
#[derive(Debug, Clone)]
pub struct EntrypointTruncation {
pub content: String,
pub line_count: usize,
pub byte_count: usize,
pub was_line_truncated: bool,
pub was_byte_truncated: bool,
}
/// Truncate `MEMORY.md` content to `MAX_ENTRYPOINT_LINES` lines and
/// `MAX_ENTRYPOINT_BYTES` bytes, appending a warning when either cap fires.
///
/// Mirrors `truncateEntrypointContent` in `memdir.ts`.
pub fn truncate_entrypoint_content(raw: &str) -> EntrypointTruncation {
let trimmed = raw.trim();
let content_lines: Vec<&str> = trimmed.lines().collect();
let line_count = content_lines.len();
let byte_count = trimmed.len();
let was_line_truncated = line_count > MAX_ENTRYPOINT_LINES;
let was_byte_truncated = byte_count > MAX_ENTRYPOINT_BYTES;
if !was_line_truncated && !was_byte_truncated {
return EntrypointTruncation {
content: trimmed.to_string(),
line_count,
byte_count,
was_line_truncated: false,
was_byte_truncated: false,
};
}
let mut truncated = if was_line_truncated {
content_lines[..MAX_ENTRYPOINT_LINES].join("\n")
} else {
trimmed.to_string()
};
if truncated.len() > MAX_ENTRYPOINT_BYTES {
let cut_at = truncated[..MAX_ENTRYPOINT_BYTES]
.rfind('\n')
.unwrap_or(MAX_ENTRYPOINT_BYTES);
truncated.truncate(cut_at);
}
let reason = match (was_line_truncated, was_byte_truncated) {
(true, false) => format!("{} lines (limit: {})", line_count, MAX_ENTRYPOINT_LINES),
(false, true) => format!(
"{} bytes (limit: {}) — index entries are too long",
byte_count, MAX_ENTRYPOINT_BYTES
),
_ => format!(
"{} lines and {} bytes",
line_count, byte_count
),
};
truncated.push_str(&format!(
"\n\n> WARNING: {} is {}. Only part of it was loaded. \
Keep index entries to one line under ~200 chars; move detail into topic files.",
MEMORY_ENTRYPOINT, reason
));
EntrypointTruncation {
content: truncated,
line_count,
byte_count,
was_line_truncated,
was_byte_truncated,
}
}
/// Load and truncate the `MEMORY.md` index from `memory_dir`.
/// Returns `None` when the file does not exist or is empty.
///
/// Mirrors the entrypoint-reading path in `buildMemoryPrompt` / `loadMemoryPrompt`.
pub fn load_memory_index(memory_dir: &Path) -> Option<EntrypointTruncation> {
let index_path = memory_dir.join(MEMORY_ENTRYPOINT);
if !index_path.exists() {
return None;
}
let raw = std::fs::read_to_string(&index_path).ok()?;
if raw.trim().is_empty() {
return None;
}
Some(truncate_entrypoint_content(&raw))
}
// ---------------------------------------------------------------------------
// System-prompt memory content builder
// ---------------------------------------------------------------------------
/// Build the memory content string to inject into the system prompt's
/// `<memory>` block.
///
/// Always includes the `MEMORY.md` index when it exists.
/// Called during `build_system_prompt` → `SystemPromptOptions::memory_content`.
pub fn build_memory_prompt_content(memory_dir: &Path) -> String {
let mut parts: Vec<String> = Vec::new();
if let Some(index) = load_memory_index(memory_dir) {
parts.push(format!("## Memory Index (MEMORY.md)\n{}", index.content));
}
parts.join("\n\n")
}
/// Ensure the memory directory exists, creating it (and any parents) if needed.
/// Errors are silently swallowed (the Write tool will surface them if needed).
///
/// Mirrors `ensureMemoryDirExists` in `memdir.ts`.
pub fn ensure_memory_dir_exists(memory_dir: &Path) {
if let Err(e) = std::fs::create_dir_all(memory_dir) {
// Log at debug level so --debug shows why, but don't abort.
tracing::debug!(
dir = %memory_dir.display(),
error = %e,
"ensureMemoryDirExists failed"
);
}
}
// ---------------------------------------------------------------------------
// Simple relevance search (no LLM side-query)
// ---------------------------------------------------------------------------
/// Find and load the most relevant memory files for a query using a
/// lightweight TF-IDF-style keyword score.
///
/// The full Sonnet side-query (`findRelevantMemories` in TypeScript) lives
/// in `cc-query`; this function provides a cheaper fallback for contexts
/// where an API call is not available.
pub fn find_relevant_memories_simple(
memory_dir: &Path,
query: &str,
max_files: usize,
) -> Vec<MemoryFile> {
let metas = scan_memory_dir(memory_dir);
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
if query_words.is_empty() {
return Vec::new();
}
let mut scored: Vec<(f32, MemoryFileMeta)> = metas
.into_iter()
.filter_map(|meta| {
let desc = meta.description.as_deref().unwrap_or("").to_lowercase();
let name = meta.name.as_deref().unwrap_or("").to_lowercase();
let filename = meta.filename.to_lowercase();
let score: f32 = query_words
.iter()
.map(|w| {
let in_name = if name.contains(*w) { 2.0_f32 } else { 0.0 };
let in_desc = if desc.contains(*w) { 1.0_f32 } else { 0.0 };
let in_file = if filename.contains(*w) { 0.5_f32 } else { 0.0 };
in_name + in_desc + in_file
})
.sum();
if score > 0.0 { Some((score, meta)) } else { None }
})
.collect();
// Sort highest score first.
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(max_files)
.filter_map(|(_, meta)| {
let content = std::fs::read_to_string(&meta.path).ok()?;
Some(MemoryFile { meta, content })
})
.collect()
}
// ---------------------------------------------------------------------------
// Team memory helpers
// ---------------------------------------------------------------------------
/// Return the team-memory sub-directory path.
/// Mirrors `getTeamMemPath` in `teamMemPaths.ts`.
pub fn team_memory_path(auto_memory_dir: &Path) -> PathBuf {
auto_memory_dir.join("team")
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write as IoWrite;
// Helpers ----------------------------------------------------------------
fn make_temp_dir() -> tempfile::TempDir {
tempfile::tempdir().expect("tempdir")
}
fn write_file(dir: &Path, name: &str, content: &str) {
let path = dir.join(name);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).unwrap();
}
let mut f = std::fs::File::create(&path).unwrap();
f.write_all(content.as_bytes()).unwrap();
}
// ---- parse_frontmatter_quick -------------------------------------------
#[test]
fn test_parse_frontmatter_full() {
let content = "---\nname: My Memory\ndescription: A test description\ntype: feedback\n---\n\nBody text.";
let (name, desc, mt) = parse_frontmatter_quick(content);
assert_eq!(name.as_deref(), Some("My Memory"));
assert_eq!(desc.as_deref(), Some("A test description"));
assert_eq!(mt, Some(MemoryType::Feedback));
}
#[test]
fn test_parse_frontmatter_no_frontmatter() {
let content = "Just plain text.";
let (name, desc, mt) = parse_frontmatter_quick(content);
assert!(name.is_none());
assert!(desc.is_none());
assert!(mt.is_none());
}
#[test]
fn test_parse_frontmatter_quoted_values() {
let content = "---\nname: \"Quoted Name\"\ndescription: 'Single quoted'\ntype: user\n---";
let (name, desc, mt) = parse_frontmatter_quick(content);
assert_eq!(name.as_deref(), Some("Quoted Name"));
assert_eq!(desc.as_deref(), Some("Single quoted"));
assert_eq!(mt, Some(MemoryType::User));
}
#[test]
fn test_parse_frontmatter_unknown_type() {
let content = "---\ntype: unknown_type\n---";
let (_, _, mt) = parse_frontmatter_quick(content);
assert!(mt.is_none());
}
// ---- memory_age_days ---------------------------------------------------
#[test]
fn test_memory_age_today() {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
assert_eq!(memory_age_days(now_secs), 0);
}
#[test]
fn test_memory_age_one_day_ago() {
let yesterday = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.saturating_sub(86_400);
assert_eq!(memory_age_days(yesterday), 1);
}
#[test]
fn test_memory_age_future_clamps_to_zero() {
let far_future = u64::MAX;
assert_eq!(memory_age_days(far_future), 0);
}
// ---- memory_freshness_text ---------------------------------------------
#[test]
fn test_freshness_text_fresh() {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
assert!(memory_freshness_text(now).is_empty());
}
#[test]
fn test_freshness_text_stale() {
let old = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.saturating_sub(10 * 86_400); // 10 days ago
let text = memory_freshness_text(old);
assert!(text.contains("10 days old"));
assert!(text.contains("point-in-time"));
}
// ---- memory_freshness_note ---------------------------------------------
#[test]
fn test_freshness_note_fresh_is_empty() {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
assert!(memory_freshness_note(now).is_empty());
}
#[test]
fn test_freshness_note_stale_has_tags() {
let old = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.saturating_sub(5 * 86_400);
let note = memory_freshness_note(old);
assert!(note.contains("<system-reminder>"));
assert!(note.contains("</system-reminder>"));
}
// ---- truncate_entrypoint_content ---------------------------------------
#[test]
fn test_truncate_no_truncation_needed() {
let content = "line1\nline2\nline3";
let result = truncate_entrypoint_content(content);
assert!(!result.was_line_truncated);
assert!(!result.was_byte_truncated);
assert_eq!(result.content, content);
}
#[test]
fn test_truncate_line_limit() {
let content = (0..=MAX_ENTRYPOINT_LINES)
.map(|i| format!("line {}", i))
.collect::<Vec<_>>()
.join("\n");
let result = truncate_entrypoint_content(&content);
assert!(result.was_line_truncated);
assert!(result.content.contains("WARNING"));
}
// ---- sanitize_path_component -------------------------------------------
#[test]
fn test_sanitize_path_component() {
assert_eq!(sanitize_path_component("/home/user/project"), "_home_user_project");
assert_eq!(sanitize_path_component("normal-name_123"), "normal-name_123");
assert_eq!(sanitize_path_component("C:\\Users\\foo"), "C__Users_foo");
}
// ---- load_memory_index -------------------------------------------------
#[test]
fn test_load_memory_index_nonexistent() {
let dir = make_temp_dir();
assert!(load_memory_index(dir.path()).is_none());
}
#[test]
fn test_load_memory_index_empty() {
let dir = make_temp_dir();
write_file(dir.path(), "MEMORY.md", " ");
assert!(load_memory_index(dir.path()).is_none());
}
#[test]
fn test_load_memory_index_with_content() {
let dir = make_temp_dir();
write_file(dir.path(), "MEMORY.md", "- [test.md](test.md) — something");
let result = load_memory_index(dir.path()).unwrap();
assert!(result.content.contains("test.md"));
}
// ---- scan_memory_dir ---------------------------------------------------
#[test]
fn test_scan_excludes_memory_md() {
let dir = make_temp_dir();
write_file(dir.path(), "MEMORY.md", "# index");
write_file(dir.path(), "user_role.md", "---\nname: Role\n---");
let metas = scan_memory_dir(dir.path());
assert_eq!(metas.len(), 1);
assert_eq!(metas[0].filename, "user_role.md");
}
#[test]
fn test_scan_empty_dir() {
let dir = make_temp_dir();
assert!(scan_memory_dir(dir.path()).is_empty());
}
#[test]
fn test_scan_nonexistent_dir() {
let path = PathBuf::from("/tmp/nonexistent_memory_dir_cc_rust_test_xyz");
assert!(scan_memory_dir(&path).is_empty());
}
// ---- format_memory_manifest --------------------------------------------
#[test]
fn test_format_memory_manifest_with_description() {
let meta = MemoryFileMeta {
filename: "user_role.md".to_string(),
path: PathBuf::from("user_role.md"),
name: Some("User Role".to_string()),
description: Some("The user is a data scientist".to_string()),
memory_type: Some(MemoryType::User),
modified_secs: 0,
};
let manifest = format_memory_manifest(&[meta]);
assert!(manifest.contains("[user]"));
assert!(manifest.contains("user_role.md"));
assert!(manifest.contains("data scientist"));
}
#[test]
fn test_format_memory_manifest_no_description() {
let meta = MemoryFileMeta {
filename: "ref.md".to_string(),
path: PathBuf::from("ref.md"),
name: None,
description: None,
memory_type: None,
modified_secs: 0,
};
let manifest = format_memory_manifest(&[meta]);
assert!(manifest.contains("ref.md"));
// No description separator colon
assert!(!manifest.contains("ref.md ("));
}
// ---- MemoryType --------------------------------------------------------
#[test]
fn test_memory_type_roundtrip() {
for (s, expected) in [
("user", MemoryType::User),
("feedback", MemoryType::Feedback),
("project", MemoryType::Project),
("reference", MemoryType::Reference),
] {
let parsed = MemoryType::parse(s).unwrap();
assert_eq!(parsed, expected);
assert_eq!(parsed.as_str(), s);
}
}
#[test]
fn test_memory_type_unknown_returns_none() {
assert!(MemoryType::parse("bogus").is_none());
}
// ---- is_auto_memory_enabled -------------------------------------------
#[test]
fn test_auto_memory_enabled_default() {
// No env vars set for this test, settings None → should be enabled.
// We can't guarantee the test environment is clean, so just check it
// returns a bool without panicking.
let _ = is_auto_memory_enabled(None);
}
#[test]
fn test_auto_memory_disabled_by_setting() {
// If settings explicitly disable it and no env override, returns false.
// We only test the settings-path without touching process env.
// Simulate: env vars not set, settings says false.
// We can't unset env vars reliably in tests, so just ensure the
// function handles Some(false) without panicking.
// (The full env-var paths are integration-tested separately.)
let _ = is_auto_memory_enabled(Some(false));
}
}

View file

@ -0,0 +1,474 @@
//! Settings migration framework
//! Runs on startup to upgrade settings.json from older versions.
//!
//! Migrations are derived from the TypeScript originals:
//! - src/migrations/migrateFennecToOpus.ts
//! - src/migrations/migrateLegacyOpusToCurrent.ts
//! - src/migrations/migrateSonnet45ToSonnet46.ts
//! - src/migrations/migrateAutoUpdatesToSettings.ts
//! - (and several others without separate TS source files)
//!
//! Each migration is idempotent: it only touches fields it recognises and
//! only writes when it actually changes something.
use serde_json::Value;
/// A single migration function.
/// Returns `true` if the settings object was modified.
pub type MigrationFn = fn(&mut Value) -> bool;
/// All migrations in the order they must be applied.
pub const MIGRATIONS: &[(&str, MigrationFn)] = &[
("migrate_fennec_to_opus", migrate_fennec_to_opus),
("migrate_legacy_opus_to_current", migrate_legacy_opus_to_current),
("migrate_opus_to_opus_1m", migrate_opus_to_opus_1m),
("migrate_sonnet_1m_to_sonnet_45", migrate_sonnet_1m_to_sonnet_45),
("migrate_sonnet_45_to_sonnet_46", migrate_sonnet_45_to_sonnet_46),
(
"migrate_bypass_permissions_to_settings",
migrate_bypass_permissions_to_settings,
),
(
"migrate_repl_bridge_to_remote_control",
migrate_repl_bridge_to_remote_control,
),
("migrate_enable_all_mcp_servers", migrate_enable_all_mcp_servers),
("migrate_auto_updates", migrate_auto_updates),
("reset_auto_mode_opt_in", reset_auto_mode_opt_in),
("reset_pro_to_opus_default", reset_pro_to_opus_default),
];
/// Apply every pending migration to a settings `Value` (must be a JSON object).
/// Returns `true` when at least one migration changed the settings.
pub fn run_migrations(settings: &mut Value) -> bool {
let mut changed = false;
for (name, migration) in MIGRATIONS {
if migration(settings) {
tracing::info!("Applied settings migration: {}", name);
changed = true;
}
}
changed
}
// ---------------------------------------------------------------------------
// Model-name migrations
// ---------------------------------------------------------------------------
/// Fennec was an internal alias; map to the current Opus line.
/// Source: migrateFennecToOpus.ts
fn migrate_fennec_to_opus(settings: &mut Value) -> bool {
// fennec-latest[1m] → opus[1m], fennec-latest → opus
// fennec-fast-latest / opus-4-5-fast → opus[1m] (fast-mode alias)
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
Some(m) => m.to_string(),
None => return false,
};
if model.starts_with("fennec-latest[1m]") {
settings["model"] = Value::String("opus[1m]".to_string());
return true;
}
if model.starts_with("fennec-latest") {
settings["model"] = Value::String("opus".to_string());
return true;
}
if model.starts_with("fennec-fast-latest") || model.starts_with("opus-4-5-fast") {
settings["model"] = Value::String("opus[1m]".to_string());
settings["fastMode"] = Value::Bool(true);
return true;
}
false
}
/// Migrate explicit Opus 4.0/4.1 strings to the `opus` alias.
/// Source: migrateLegacyOpusToCurrent.ts
fn migrate_legacy_opus_to_current(settings: &mut Value) -> bool {
const LEGACY_OPUS: &[&str] = &[
"claude-opus-4-20250514",
"claude-opus-4-1-20250805",
"claude-opus-4-0",
"claude-opus-4-1",
];
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
Some(m) => m.to_string(),
None => return false,
};
if LEGACY_OPUS.contains(&model.as_str()) {
settings["model"] = Value::String("opus".to_string());
return true;
}
false
}
/// Rename the old explicit `claude-opus-4-0` model string (pre-alias era).
fn migrate_opus_to_opus_1m(settings: &mut Value) -> bool {
rename_model(settings, "claude-opus-4-0", "claude-opus-4-5-20251001")
}
/// Migrate the old Sonnet 1m string to the Sonnet 4.5 release ID.
fn migrate_sonnet_1m_to_sonnet_45(settings: &mut Value) -> bool {
rename_model(
settings,
"claude-sonnet-4-0-1m",
"claude-sonnet-4-5-20251015",
)
}
/// Migrate Sonnet 4.5 explicit IDs to `sonnet` (which resolves to 4.6).
/// Source: migrateSonnet45ToSonnet46.ts
fn migrate_sonnet_45_to_sonnet_46(settings: &mut Value) -> bool {
const SONNET_45_IDS: &[&str] = &[
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5-20250929[1m]",
"sonnet-4-5-20250929",
"sonnet-4-5-20250929[1m]",
// Also handle the model strings used in the older Rust migrations table:
"claude-sonnet-4-5-20251015",
"claude-sonnet-4-5",
];
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
Some(m) => m.to_string(),
None => return false,
};
if SONNET_45_IDS.contains(&model.as_str()) {
let has_1m = model.ends_with("[1m]");
let new_model = if has_1m { "sonnet[1m]" } else { "sonnet" };
settings["model"] = Value::String(new_model.to_string());
return true;
}
false
}
/// Rename `from` to `to` in the `model`, `defaultModel`, and `mainLoopModel`
/// fields. Returns `true` if any field was changed.
fn rename_model(settings: &mut Value, from: &str, to: &str) -> bool {
let mut changed = false;
for key in &["model", "defaultModel", "mainLoopModel"] {
if let Some(val) = settings.get_mut(*key) {
if val.as_str() == Some(from) {
*val = Value::String(to.to_string());
changed = true;
}
}
}
changed
}
// ---------------------------------------------------------------------------
// Config-structure migrations
// ---------------------------------------------------------------------------
/// Move `bypassPermissionsAccepted` boolean into `permissionMode`.
fn migrate_bypass_permissions_to_settings(settings: &mut Value) -> bool {
let old = match settings.get("bypassPermissionsAccepted").cloned() {
Some(v) => v,
None => return false,
};
if settings.get("permissionMode").is_none() && old.as_bool().unwrap_or(false) {
settings["permissionMode"] = Value::String("bypass".to_string());
}
if let Some(obj) = settings.as_object_mut() {
obj.remove("bypassPermissionsAccepted");
}
true
}
/// Rename `replBridgeEnabled` → `remoteControlAtStartup`.
fn migrate_repl_bridge_to_remote_control(settings: &mut Value) -> bool {
let old = match settings.get("replBridgeEnabled").cloned() {
Some(v) => v,
None => return false,
};
if settings.get("remoteControlAtStartup").is_none() {
settings["remoteControlAtStartup"] = old;
}
if let Some(obj) = settings.as_object_mut() {
obj.remove("replBridgeEnabled");
}
true
}
/// Rename `enableAllProjectMcpServers` → `mcpAutoApprove`.
fn migrate_enable_all_mcp_servers(settings: &mut Value) -> bool {
let old = match settings.get("enableAllProjectMcpServers").cloned() {
Some(v) => v,
None => return false,
};
if settings.get("mcpAutoApprove").is_none() {
settings["mcpAutoApprove"] = old;
}
if let Some(obj) = settings.as_object_mut() {
obj.remove("enableAllProjectMcpServers");
}
true
}
/// Migrate `autoUpdatesEnabled` → `autoUpdates`.
/// Source: migrateAutoUpdatesToSettings.ts
/// The TS version also writes an env-var to settings.json; here we keep the
/// simpler structural rename and leave env-var injection to the caller.
fn migrate_auto_updates(settings: &mut Value) -> bool {
let old = match settings.get("autoUpdatesEnabled").cloned() {
Some(v) => v,
None => return false,
};
if settings.get("autoUpdates").is_none() {
settings["autoUpdates"] = old;
}
if let Some(obj) = settings.as_object_mut() {
obj.remove("autoUpdatesEnabled");
}
true
}
/// Clear an old sentinel value for the auto-mode opt-in flag.
fn reset_auto_mode_opt_in(settings: &mut Value) -> bool {
if let Some(val) = settings.get("autoModeOptIn") {
if val.as_str() == Some("default_offer_2024") {
settings["autoModeOptIn"] = Value::Null;
return true;
}
}
false
}
/// Reset users who were auto-defaulted to Opus back to Sonnet 4.6.
/// Only resets when `modelSetByUser` is not explicitly `true`.
fn reset_pro_to_opus_default(settings: &mut Value) -> bool {
if let Some(val) = settings.get("model") {
if val.as_str() == Some("claude-opus-4-5-20251001") {
let set_by_user = settings
.get("modelSetByUser")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if !set_by_user {
settings["model"] = Value::String("claude-sonnet-4-6".to_string());
return true;
}
}
}
false
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn settings(model: &str) -> Value {
json!({ "model": model })
}
// ---- rename_model -------------------------------------------------------
#[test]
fn rename_model_changes_matching_field() {
let mut s = settings("old-model");
assert!(rename_model(&mut s, "old-model", "new-model"));
assert_eq!(s["model"].as_str(), Some("new-model"));
}
#[test]
fn rename_model_no_change_when_different() {
let mut s = settings("something-else");
assert!(!rename_model(&mut s, "old-model", "new-model"));
assert_eq!(s["model"].as_str(), Some("something-else"));
}
#[test]
fn rename_model_covers_all_keys() {
let mut s = json!({
"model": "claude-foo",
"defaultModel": "claude-foo",
"mainLoopModel": "claude-foo",
});
assert!(rename_model(&mut s, "claude-foo", "claude-bar"));
assert_eq!(s["model"].as_str(), Some("claude-bar"));
assert_eq!(s["defaultModel"].as_str(), Some("claude-bar"));
assert_eq!(s["mainLoopModel"].as_str(), Some("claude-bar"));
}
// ---- migrate_fennec_to_opus ---------------------------------------------
#[test]
fn fennec_latest_1m_maps_to_opus_1m() {
let mut s = settings("fennec-latest[1m]");
assert!(migrate_fennec_to_opus(&mut s));
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
}
#[test]
fn fennec_latest_maps_to_opus() {
let mut s = settings("fennec-latest");
assert!(migrate_fennec_to_opus(&mut s));
assert_eq!(s["model"].as_str(), Some("opus"));
}
#[test]
fn fennec_fast_maps_to_opus_1m_with_fast_mode() {
let mut s = settings("fennec-fast-latest");
assert!(migrate_fennec_to_opus(&mut s));
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
assert_eq!(s["fastMode"].as_bool(), Some(true));
}
#[test]
fn opus_4_5_fast_maps_to_opus_1m_with_fast_mode() {
let mut s = settings("opus-4-5-fast");
assert!(migrate_fennec_to_opus(&mut s));
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
assert_eq!(s["fastMode"].as_bool(), Some(true));
}
#[test]
fn fennec_no_match_returns_false() {
let mut s = settings("claude-sonnet-4-6");
assert!(!migrate_fennec_to_opus(&mut s));
}
// ---- migrate_legacy_opus_to_current ------------------------------------
#[test]
fn legacy_opus_4_0_maps_to_opus() {
let mut s = settings("claude-opus-4-0");
assert!(migrate_legacy_opus_to_current(&mut s));
assert_eq!(s["model"].as_str(), Some("opus"));
}
#[test]
fn legacy_opus_4_1_maps_to_opus() {
let mut s = settings("claude-opus-4-1-20250805");
assert!(migrate_legacy_opus_to_current(&mut s));
assert_eq!(s["model"].as_str(), Some("opus"));
}
// ---- migrate_sonnet_45_to_sonnet_46 ------------------------------------
#[test]
fn sonnet_45_explicit_id_maps_to_sonnet() {
let mut s = settings("claude-sonnet-4-5-20250929");
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
assert_eq!(s["model"].as_str(), Some("sonnet"));
}
#[test]
fn sonnet_45_1m_maps_to_sonnet_1m() {
let mut s = settings("claude-sonnet-4-5-20250929[1m]");
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
assert_eq!(s["model"].as_str(), Some("sonnet[1m]"));
}
#[test]
fn sonnet_46_is_untouched() {
let mut s = settings("claude-sonnet-4-6");
assert!(!migrate_sonnet_45_to_sonnet_46(&mut s));
}
// ---- struct migrations -------------------------------------------------
#[test]
fn bypass_permissions_migrates_and_removes_old_key() {
let mut s = json!({ "bypassPermissionsAccepted": true });
assert!(migrate_bypass_permissions_to_settings(&mut s));
assert!(s.get("bypassPermissionsAccepted").is_none());
assert_eq!(s["permissionMode"].as_str(), Some("bypass"));
}
#[test]
fn bypass_permissions_false_does_not_set_mode() {
let mut s = json!({ "bypassPermissionsAccepted": false });
assert!(migrate_bypass_permissions_to_settings(&mut s));
assert!(s.get("permissionMode").is_none());
}
#[test]
fn repl_bridge_renames_field() {
let mut s = json!({ "replBridgeEnabled": true });
assert!(migrate_repl_bridge_to_remote_control(&mut s));
assert!(s.get("replBridgeEnabled").is_none());
assert_eq!(s["remoteControlAtStartup"].as_bool(), Some(true));
}
#[test]
fn enable_all_mcp_renames_field() {
let mut s = json!({ "enableAllProjectMcpServers": true });
assert!(migrate_enable_all_mcp_servers(&mut s));
assert!(s.get("enableAllProjectMcpServers").is_none());
assert_eq!(s["mcpAutoApprove"].as_bool(), Some(true));
}
#[test]
fn auto_updates_renames_field() {
let mut s = json!({ "autoUpdatesEnabled": false });
assert!(migrate_auto_updates(&mut s));
assert!(s.get("autoUpdatesEnabled").is_none());
assert_eq!(s["autoUpdates"].as_bool(), Some(false));
}
#[test]
fn reset_auto_mode_clears_sentinel() {
let mut s = json!({ "autoModeOptIn": "default_offer_2024" });
assert!(reset_auto_mode_opt_in(&mut s));
assert!(s["autoModeOptIn"].is_null());
}
#[test]
fn reset_auto_mode_leaves_other_values() {
let mut s = json!({ "autoModeOptIn": "user_opted_in" });
assert!(!reset_auto_mode_opt_in(&mut s));
assert_eq!(s["autoModeOptIn"].as_str(), Some("user_opted_in"));
}
#[test]
fn reset_pro_opus_default_resets_when_not_user_set() {
let mut s = json!({ "model": "claude-opus-4-5-20251001" });
assert!(reset_pro_to_opus_default(&mut s));
assert_eq!(s["model"].as_str(), Some("claude-sonnet-4-6"));
}
#[test]
fn reset_pro_opus_default_preserves_when_user_set() {
let mut s = json!({ "model": "claude-opus-4-5-20251001", "modelSetByUser": true });
assert!(!reset_pro_to_opus_default(&mut s));
assert_eq!(s["model"].as_str(), Some("claude-opus-4-5-20251001"));
}
// ---- run_migrations integration ----------------------------------------
#[test]
fn run_migrations_applies_chain() {
// A Sonnet 4.5 model should end up as "sonnet" after the full chain.
let mut s = json!({ "model": "claude-sonnet-4-5-20250929" });
let changed = run_migrations(&mut s);
assert!(changed);
assert_eq!(s["model"].as_str(), Some("sonnet"));
}
#[test]
fn run_migrations_returns_false_when_nothing_changes() {
let mut s = json!({ "model": "claude-sonnet-4-6", "someOtherKey": 42 });
assert!(!run_migrations(&mut s));
}
#[test]
fn run_migrations_handles_empty_object() {
let mut s = json!({});
// No model fields, no sentinel values → nothing to do.
assert!(!run_migrations(&mut s));
}
}

View file

@ -0,0 +1,364 @@
//! OAuth configuration for multiple environments.
//!
//! This module mirrors the TypeScript `src/constants/oauth.ts` and
//! `src/services/oauth/crypto.ts` constants. It is intentionally
//! *configuration-only* — no live network I/O except for the optional
//! `fetch_oauth_profile` helper at the bottom.
use serde::{Deserialize, Serialize};
// ---------------------------------------------------------------------------
// Scope constants (mirrors constants/oauth.ts)
// ---------------------------------------------------------------------------
/// The Claude.ai inference scope — required for Bearer-auth API calls.
pub const CLAUDE_AI_INFERENCE_SCOPE: &str = "user:inference";
/// The profile scope — required to read account / subscription data.
pub const CLAUDE_AI_PROFILE_SCOPE: &str = "user:profile";
/// Console scope — used when creating an API key via the Console flow.
pub const CONSOLE_SCOPE: &str = "org:create_api_key";
/// All Claude.ai OAuth scopes (mirrors `CLAUDE_AI_OAUTH_SCOPES`).
pub const CLAUDE_AI_OAUTH_SCOPES: &[&str] = &[
CLAUDE_AI_PROFILE_SCOPE,
CLAUDE_AI_INFERENCE_SCOPE,
"user:sessions:claude_code",
"user:mcp_servers",
"user:file_upload",
];
/// Console OAuth scopes (mirrors `CONSOLE_OAUTH_SCOPES`).
pub const CONSOLE_OAUTH_SCOPES: &[&str] = &[CONSOLE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
/// Union of all scopes used during login (mirrors `ALL_OAUTH_SCOPES`).
/// Requesting all at once lets a single login satisfy both Console and
/// claude.ai auth paths.
pub const ALL_OAUTH_SCOPES: &[&str] = &[
CONSOLE_SCOPE,
CLAUDE_AI_PROFILE_SCOPE,
CLAUDE_AI_INFERENCE_SCOPE,
"user:sessions:claude_code",
"user:mcp_servers",
"user:file_upload",
];
/// Minimum scopes required for basic operation.
pub const MINIMUM_SCOPES: &[&str] = &[CLAUDE_AI_INFERENCE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
// ---------------------------------------------------------------------------
// OAuthConfig struct
// ---------------------------------------------------------------------------
/// Full OAuth configuration for a deployment environment.
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub base_api_url: &'static str,
pub console_authorize_url: &'static str,
pub claude_ai_authorize_url: &'static str,
/// The raw claude.ai web origin (separate from the authorize URL which
/// may bounce through claude.com for attribution).
pub claude_ai_origin: &'static str,
pub token_url: &'static str,
pub api_key_url: &'static str,
pub roles_url: &'static str,
pub console_success_url: &'static str,
pub claudeai_success_url: &'static str,
pub manual_redirect_url: &'static str,
pub client_id: &'static str,
pub oauth_file_suffix: &'static str,
pub mcp_proxy_url: &'static str,
pub mcp_proxy_path: &'static str,
}
// ---------------------------------------------------------------------------
// Production config (mirrors PROD_OAUTH_CONFIG in oauth.ts)
// ---------------------------------------------------------------------------
pub const PROD_OAUTH: OAuthConfig = OAuthConfig {
base_api_url: "https://api.anthropic.com",
// Routes through claude.com/cai/* for attribution, 307s to claude.ai in
// two hops — same behaviour as the TypeScript client.
console_authorize_url: "https://platform.claude.com/oauth/authorize",
claude_ai_authorize_url: "https://claude.com/cai/oauth/authorize",
claude_ai_origin: "https://claude.ai",
token_url: "https://platform.claude.com/v1/oauth/token",
api_key_url: "https://api.anthropic.com/api/oauth/claude_cli/create_api_key",
roles_url: "https://api.anthropic.com/api/oauth/claude_cli/roles",
console_success_url: "https://platform.claude.com/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
claudeai_success_url: "https://platform.claude.com/oauth/code/success?app=claude-code",
manual_redirect_url: "https://platform.claude.com/oauth/code/callback",
client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e",
oauth_file_suffix: "",
mcp_proxy_url: "https://mcp-proxy.anthropic.com",
mcp_proxy_path: "/v1/mcp/{server_id}",
};
// ---------------------------------------------------------------------------
// Staging config (mirrors STAGING_OAUTH_CONFIG — ant builds only)
// ---------------------------------------------------------------------------
pub const STAGING_OAUTH: OAuthConfig = OAuthConfig {
base_api_url: "https://api-staging.anthropic.com",
console_authorize_url: "https://platform.staging.ant.dev/oauth/authorize",
claude_ai_authorize_url: "https://claude-ai.staging.ant.dev/oauth/authorize",
claude_ai_origin: "https://claude-ai.staging.ant.dev",
token_url: "https://platform.staging.ant.dev/v1/oauth/token",
api_key_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/create_api_key",
roles_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/roles",
console_success_url: "https://platform.staging.ant.dev/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
claudeai_success_url: "https://platform.staging.ant.dev/oauth/code/success?app=claude-code",
manual_redirect_url: "https://platform.staging.ant.dev/oauth/code/callback",
client_id: "22422756-60c9-4084-8eb7-27705fd5cf9a",
oauth_file_suffix: "-staging-oauth",
mcp_proxy_url: "https://mcp-proxy-staging.anthropic.com",
mcp_proxy_path: "/v1/mcp/{server_id}",
};
/// Client-ID Metadata Document URL for MCP OAuth (CIMD / SEP-991).
pub const MCP_CLIENT_METADATA_URL: &str =
"https://claude.ai/oauth/claude-code-client-metadata";
// ---------------------------------------------------------------------------
// Config selection
// ---------------------------------------------------------------------------
/// Return the OAuth config appropriate for the current environment.
///
/// Selection logic mirrors `getOauthConfigType()` in `constants/oauth.ts`:
/// - `USER_TYPE=ant` + `USE_STAGING_OAUTH=true` → staging
/// - anything else → production
///
/// Note: the `local` variant from the TypeScript code is intentionally
/// omitted here — local dev servers are not needed in the Rust port yet.
pub fn get_oauth_config() -> &'static OAuthConfig {
let user_type = std::env::var("USER_TYPE").unwrap_or_default();
if user_type == "ant" {
let use_staging = std::env::var("USE_STAGING_OAUTH")
.map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
if use_staging {
return &STAGING_OAUTH;
}
}
&PROD_OAUTH
}
// ---------------------------------------------------------------------------
// PKCE helpers (mirrors src/services/oauth/crypto.ts)
// ---------------------------------------------------------------------------
/// PKCE code-challenge / code-verifier helpers.
pub mod pkce {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use sha2::{Digest, Sha256};
/// Generate a cryptographically random code verifier (43128 chars of
/// Base64url characters, as required by RFC 7636).
///
/// Uses `getrandom` via the `rand` crate's OS RNG through the `uuid`
/// crate's v4 generator — both already in-tree. Falls back to a
/// time+pid mix if the OS RNG is unavailable.
pub fn generate_code_verifier() -> String {
// 32 random bytes → 43-char Base64url string (same as the TS impl).
let bytes = random_bytes_32();
URL_SAFE_NO_PAD.encode(bytes)
}
/// Compute `BASE64URL(SHA256(verifier))` — the S256 code challenge.
pub fn code_challenge(verifier: &str) -> String {
let hash = Sha256::digest(verifier.as_bytes());
URL_SAFE_NO_PAD.encode(hash)
}
/// Generate a random state parameter (16 Base64url chars).
pub fn generate_state() -> String {
let bytes = random_bytes_32();
let encoded = URL_SAFE_NO_PAD.encode(bytes);
// Take first 43 chars for a compact state parameter
encoded.chars().take(43).collect()
}
// ------------------------------------------------------------------
// Internal: produce 32 random bytes.
// We derive them from a UUID v4 (which already pulls from the OS RNG
// via the `uuid` crate) so we don't need to add a new `rand` dep.
// ------------------------------------------------------------------
fn random_bytes_32() -> [u8; 32] {
// Two UUID v4 values give us 32 bytes of OS-backed randomness.
let u1 = uuid::Uuid::new_v4();
let u2 = uuid::Uuid::new_v4();
let mut out = [0u8; 32];
out[..16].copy_from_slice(u1.as_bytes());
out[16..].copy_from_slice(u2.as_bytes());
out
}
}
// ---------------------------------------------------------------------------
// Token and profile types
// ---------------------------------------------------------------------------
/// Raw OAuth token response from the token endpoint.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
/// Slim profile fetched after token exchange.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OAuthProfile {
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub account_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub subscription_tier: Option<String>,
}
/// Fetch the OAuth profile using an access token.
///
/// Returns a default (all-`None`) profile on any non-success response so
/// callers can treat a profile fetch failure as non-fatal.
pub async fn fetch_oauth_profile(
access_token: &str,
api_base: &str,
) -> anyhow::Result<OAuthProfile> {
let client = reqwest::Client::new();
let url = format!("{}/api/auth/oauth/profile", api_base.trim_end_matches('/'));
let resp = client
.get(&url)
.bearer_auth(access_token)
.timeout(std::time::Duration::from_secs(10))
.send()
.await?;
if resp.status().is_success() {
let profile: OAuthProfile = resp.json().await.unwrap_or_default();
Ok(profile)
} else {
// Non-fatal: return an empty profile so the caller can continue.
Ok(OAuthProfile::default())
}
}
// ---------------------------------------------------------------------------
// Auth URL builder
// ---------------------------------------------------------------------------
/// Build the OAuth authorization URL (mirrors `buildAuthUrl` in client.ts).
pub fn build_auth_url(
code_challenge: &str,
state: &str,
port: u16,
is_manual: bool,
login_with_claude_ai: bool,
inference_only: bool,
) -> String {
let cfg = get_oauth_config();
let base = if login_with_claude_ai {
cfg.claude_ai_authorize_url
} else {
cfg.console_authorize_url
};
let redirect_uri = if is_manual {
cfg.manual_redirect_url.to_string()
} else {
format!("http://localhost:{}/callback", port)
};
let scopes: Vec<&str> = if inference_only {
vec![CLAUDE_AI_INFERENCE_SCOPE]
} else {
ALL_OAUTH_SCOPES.to_vec()
};
let scope_str = scopes.join(" ");
format!(
"{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
base,
urlencoding::encode(cfg.client_id),
urlencoding::encode(&redirect_uri),
urlencoding::encode(&scope_str),
urlencoding::encode(code_challenge),
urlencoding::encode(state),
)
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prod_config_urls_are_https() {
assert!(PROD_OAUTH.token_url.starts_with("https://"));
assert!(PROD_OAUTH.api_key_url.starts_with("https://"));
assert!(PROD_OAUTH.claude_ai_authorize_url.starts_with("https://"));
}
#[test]
fn test_staging_config_urls_are_https() {
assert!(STAGING_OAUTH.token_url.starts_with("https://"));
assert!(STAGING_OAUTH.api_key_url.starts_with("https://"));
}
#[test]
fn test_pkce_code_challenge_is_base64url() {
let verifier = pkce::generate_code_verifier();
assert!(!verifier.is_empty());
// Base64url characters only (no +, /, =)
assert!(!verifier.contains('+'));
assert!(!verifier.contains('/'));
assert!(!verifier.contains('='));
let challenge = pkce::code_challenge(&verifier);
assert!(!challenge.is_empty());
assert!(!challenge.contains('+'));
assert!(!challenge.contains('/'));
assert!(!challenge.contains('='));
}
#[test]
fn test_verifier_length_meets_rfc7636_minimum() {
let verifier = pkce::generate_code_verifier();
// RFC 7636 §4.1: code_verifier length ∈ [43, 128]
assert!(
verifier.len() >= 43,
"verifier too short: {} chars",
verifier.len()
);
assert!(verifier.len() <= 128, "verifier too long: {} chars", verifier.len());
}
#[test]
fn test_all_oauth_scopes_contains_inference() {
assert!(ALL_OAUTH_SCOPES.contains(&CLAUDE_AI_INFERENCE_SCOPE));
}
#[test]
fn test_build_auth_url_contains_required_params() {
let url = build_auth_url("challenge123", "state456", 8080, false, true, false);
assert!(url.contains("challenge123"));
assert!(url.contains("state456"));
assert!(url.contains("S256"));
assert!(url.contains("localhost"));
}
}

View file

@ -0,0 +1,347 @@
//! Output style system — customises how Claude responds to the user.
//!
//! Styles are applied by injecting `OutputStyleDef::prompt` into the system
//! prompt. Built-in styles are defined in code; users can add their own by
//! placing `.md` or `.json` files in:
//! - Global: `~/.claude/output-styles/`
//! - Project: `.claude/output-styles/`
//!
//! Markdown style files have a simple structure:
//! Line 1: `# <Label>` (heading becomes the label)
//! Line 2: short description
//! Remainder: the prompt text injected into the system prompt
use serde::{Deserialize, Serialize};
use std::path::Path;
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// A single output style definition.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OutputStyleDef {
/// Machine-readable identifier (e.g. `"concise"`).
pub name: String,
/// Human-readable label shown in picker UI (e.g. `"Concise"`).
pub label: String,
/// One-line description.
pub description: String,
/// Text injected into the system prompt when this style is active.
/// Empty string for the default style (no extra injection).
pub prompt: String,
}
impl OutputStyleDef {
// ---- Built-in styles ---------------------------------------------------
pub fn builtin_default() -> Self {
Self {
name: "default".to_string(),
label: "Default".to_string(),
description: "Standard Claude Code responses.".to_string(),
prompt: String::new(),
}
}
pub fn builtin_concise() -> Self {
Self {
name: "concise".to_string(),
label: "Concise".to_string(),
description: "Short, direct responses with minimal explanation.".to_string(),
prompt: "Be maximally concise. Skip preamble, summaries, and filler. \
Lead with the answer."
.to_string(),
}
}
pub fn builtin_explanatory() -> Self {
Self {
name: "explanatory".to_string(),
label: "Explanatory".to_string(),
description: "Thorough explanations with reasoning and alternatives.".to_string(),
prompt: "When explaining code or concepts, be thorough and educational. \
Include reasoning, alternatives considered, and potential pitfalls. \
Err on the side of over-explaining."
.to_string(),
}
}
pub fn builtin_learning() -> Self {
Self {
name: "learning".to_string(),
label: "Learning".to_string(),
description: "Pedagogical mode — explains patterns and decisions.".to_string(),
prompt: "This user is learning. Explain concepts as you implement them. \
Point out patterns, best practices, and why you made each decision. \
Use analogies when helpful."
.to_string(),
}
}
}
// ---------------------------------------------------------------------------
// Built-ins
// ---------------------------------------------------------------------------
/// Return all built-in output styles in display order.
pub fn builtin_styles() -> Vec<OutputStyleDef> {
vec![
OutputStyleDef::builtin_default(),
OutputStyleDef::builtin_concise(),
OutputStyleDef::builtin_explanatory(),
OutputStyleDef::builtin_learning(),
]
}
// ---------------------------------------------------------------------------
// Loading from disk
// ---------------------------------------------------------------------------
/// Load user-defined output styles from a directory.
///
/// Supported file formats:
/// - `.md` — Markdown: `# Label\ndescription\n\nprompt text…`
/// - `.json` — JSON: `{ "name": "…", "label": "…", "description": "…", "prompt": "…" }`
///
/// Files that cannot be parsed are silently skipped.
pub fn load_output_styles_dir(styles_dir: &Path) -> Vec<OutputStyleDef> {
if !styles_dir.exists() {
return Vec::new();
}
let entries = match std::fs::read_dir(styles_dir) {
Ok(e) => e,
Err(_) => return Vec::new(),
};
let mut styles = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
if ext == "md" || ext == "json" {
if let Some(style) = load_style_file(&path) {
styles.push(style);
}
}
}
// Sort alphabetically so the list is deterministic.
styles.sort_by(|a, b| a.name.cmp(&b.name));
styles
}
fn load_style_file(path: &Path) -> Option<OutputStyleDef> {
let content = std::fs::read_to_string(path).ok()?;
let stem = path.file_stem()?.to_string_lossy().into_owned();
if path.extension().and_then(|e| e.to_str()) == Some("json") {
// Try deserialising directly; fall back to inserting the stem as name.
let mut def: OutputStyleDef = serde_json::from_str(&content).ok()?;
if def.name.is_empty() {
def.name = stem;
}
return Some(def);
}
// Markdown format:
// Line 1: # Label (optional leading `#` and whitespace)
// Line 2: description (short, plain text)
// Lines 3+: prompt text (everything after the blank / second line)
let mut lines = content.lines();
let raw_label = lines.next().unwrap_or("").trim().to_string();
let label = raw_label.trim_start_matches('#').trim().to_string();
let label = if label.is_empty() { stem.clone() } else { label };
let description = lines
.next()
.map(|l| l.trim().to_string())
.unwrap_or_default();
// Collect remaining lines as the prompt, trimming leading blank lines.
let prompt_lines: Vec<&str> = lines.collect();
let prompt = prompt_lines.join("\n").trim().to_string();
Some(OutputStyleDef {
name: stem,
label,
description,
prompt,
})
}
// ---------------------------------------------------------------------------
// Aggregated access
// ---------------------------------------------------------------------------
/// Return all styles available for `config_dir`:
/// built-ins first, then styles from `<config_dir>/output-styles/`.
///
/// `config_dir` is typically `~/.claude`.
pub fn all_styles(config_dir: &Path) -> Vec<OutputStyleDef> {
let mut styles = builtin_styles();
let user_dir = config_dir.join("output-styles");
styles.extend(load_output_styles_dir(&user_dir));
styles
}
/// Find a style by its `name` field.
pub fn find_style<'a>(styles: &'a [OutputStyleDef], name: &str) -> Option<&'a OutputStyleDef> {
styles.iter().find(|s| s.name == name)
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write as IoWrite;
use tempfile::TempDir;
// ---- builtin_styles ----------------------------------------------------
#[test]
fn builtin_styles_non_empty() {
assert!(!builtin_styles().is_empty());
}
#[test]
fn builtin_styles_have_unique_names() {
let styles = builtin_styles();
let mut seen = std::collections::HashSet::new();
for s in &styles {
assert!(seen.insert(&s.name), "duplicate style name: {}", s.name);
}
}
#[test]
fn builtin_default_has_empty_prompt() {
let def = OutputStyleDef::builtin_default();
assert!(def.prompt.is_empty());
}
#[test]
fn builtin_non_default_have_prompts() {
for s in builtin_styles() {
if s.name != "default" {
assert!(
!s.prompt.is_empty(),
"style '{}' should have a non-empty prompt",
s.name
);
}
}
}
// ---- find_style --------------------------------------------------------
#[test]
fn find_style_by_name() {
let styles = builtin_styles();
let found = find_style(&styles, "concise");
assert!(found.is_some());
assert_eq!(found.unwrap().name, "concise");
}
#[test]
fn find_style_missing() {
let styles = builtin_styles();
assert!(find_style(&styles, "nonexistent-xyz").is_none());
}
// ---- load_output_styles_dir (markdown) ---------------------------------
fn write_file(dir: &TempDir, name: &str, content: &str) {
let path = dir.path().join(name);
let mut f = std::fs::File::create(path).unwrap();
f.write_all(content.as_bytes()).unwrap();
}
#[test]
fn load_markdown_style() {
let dir = TempDir::new().unwrap();
write_file(
&dir,
"terse.md",
"# Terse\nVery short answers.\n\nOne sentence per response.",
);
let styles = load_output_styles_dir(dir.path());
assert_eq!(styles.len(), 1);
let s = &styles[0];
assert_eq!(s.name, "terse");
assert_eq!(s.label, "Terse");
assert_eq!(s.description, "Very short answers.");
assert_eq!(s.prompt, "One sentence per response.");
}
#[test]
fn load_json_style() {
let dir = TempDir::new().unwrap();
write_file(
&dir,
"formal.json",
r#"{"name":"formal","label":"Formal","description":"Formal tone.","prompt":"Use formal language."}"#,
);
let styles = load_output_styles_dir(dir.path());
assert_eq!(styles.len(), 1);
let s = &styles[0];
assert_eq!(s.name, "formal");
assert_eq!(s.label, "Formal");
assert_eq!(s.prompt, "Use formal language.");
}
#[test]
fn load_skips_unknown_extensions() {
let dir = TempDir::new().unwrap();
write_file(&dir, "ignore.txt", "should be skipped");
let styles = load_output_styles_dir(dir.path());
assert!(styles.is_empty());
}
#[test]
fn load_non_existent_dir_returns_empty() {
use std::path::PathBuf;
let styles = load_output_styles_dir(&PathBuf::from("/nonexistent/path/xyz"));
assert!(styles.is_empty());
}
#[test]
fn load_multiple_styles_sorted() {
let dir = TempDir::new().unwrap();
write_file(&dir, "zebra.md", "# Zebra\nZ style.\n\nZ prompt.");
write_file(&dir, "apple.md", "# Apple\nA style.\n\nA prompt.");
let styles = load_output_styles_dir(dir.path());
assert_eq!(styles[0].name, "apple");
assert_eq!(styles[1].name, "zebra");
}
// ---- all_styles --------------------------------------------------------
#[test]
fn all_styles_includes_builtins() {
let dir = TempDir::new().unwrap();
// no output-styles subdir → only built-ins
let styles = all_styles(dir.path());
assert!(styles.iter().any(|s| s.name == "default"));
assert!(styles.iter().any(|s| s.name == "concise"));
}
#[test]
fn all_styles_merges_user_styles() {
let dir = TempDir::new().unwrap();
let output_styles_dir = dir.path().join("output-styles");
std::fs::create_dir_all(&output_styles_dir).unwrap();
// Write a user style file.
let mut f = std::fs::File::create(output_styles_dir.join("pirate.md")).unwrap();
f.write_all(b"# Pirate\nSpeak like a pirate.\n\nArrr matey!").unwrap();
let styles = all_styles(dir.path());
assert!(styles.iter().any(|s| s.name == "pirate"));
// Built-ins still present.
assert!(styles.iter().any(|s| s.name == "default"));
}
}

View file

@ -0,0 +1,526 @@
//! Modular system prompt assembly with caching support.
//!
//! Mirrors the TypeScript `systemPromptSections.ts` / `prompts.ts` architecture:
//! cacheable (static) sections are placed before `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`;
//! volatile, session-specific sections follow it.
use serde::{Deserialize, Serialize};
use std::sync::{Mutex, OnceLock};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Dynamic boundary marker
// ---------------------------------------------------------------------------
/// Marker that splits the cached vs dynamic parts of the system prompt.
/// Everything before this marker can be prompt-cached by the API.
/// Matches the TypeScript constant `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`.
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
// ---------------------------------------------------------------------------
// Section cache (mirrors bootstrap/state.ts systemPromptSectionCache)
// ---------------------------------------------------------------------------
fn section_cache() -> &'static Mutex<HashMap<String, Option<String>>> {
static CACHE: OnceLock<Mutex<HashMap<String, Option<String>>>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
/// Clear all cached system prompt sections (called on /clear and /compact).
pub fn clear_system_prompt_sections() {
if let Ok(mut cache) = section_cache().lock() {
cache.clear();
}
}
/// A single named section of the system prompt.
#[derive(Debug, Clone)]
pub struct SystemPromptSection {
/// Identifier used for cache lookups and invalidation.
pub tag: &'static str,
/// Computed content (None means the section is absent/disabled).
pub content: Option<String>,
/// If true the section is volatile and must not be prompt-cached.
pub cache_break: bool,
}
impl SystemPromptSection {
/// Create a memoizable (cacheable) section.
pub fn cached(tag: &'static str, content: impl Into<String>) -> Self {
Self { tag, content: Some(content.into()), cache_break: false }
}
/// Create a volatile section that re-evaluates every turn.
/// Passing `None` for content means the section is absent this turn.
pub fn uncached(tag: &'static str, content: Option<impl Into<String>>) -> Self {
Self {
tag,
content: content.map(|c| c.into()),
cache_break: true,
}
}
}
// ---------------------------------------------------------------------------
// Output style
// ---------------------------------------------------------------------------
/// Output styles that affect the system prompt.
/// Serialised as lowercase strings to match settings.json.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputStyle {
#[default]
Default,
Explanatory,
Learning,
Concise,
Formal,
Casual,
}
impl OutputStyle {
/// Returns the system-prompt suffix for this style, or `None` for Default.
pub fn prompt_suffix(self) -> Option<&'static str> {
match self {
OutputStyle::Explanatory => Some(
"When explaining code or concepts, be thorough and educational. \
Include reasoning, alternatives considered, and potential pitfalls. \
Err on the side of over-explaining.",
),
OutputStyle::Learning => Some(
"This user is learning. Explain concepts as you implement them. \
Point out patterns, best practices, and why you made each decision. \
Use analogies when helpful.",
),
OutputStyle::Concise => Some(
"Be maximally concise. Skip preamble, summaries, and filler. \
Lead with the answer. One sentence is better than three.",
),
OutputStyle::Formal => Some(
"Maintain a formal, professional tone. Use precise technical language.",
),
OutputStyle::Casual => Some("Use a casual, conversational tone."),
OutputStyle::Default => None,
}
}
/// Parse from a string (case-insensitive).
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"explanatory" => Self::Explanatory,
"learning" => Self::Learning,
"concise" => Self::Concise,
"formal" => Self::Formal,
"casual" => Self::Casual,
_ => Self::Default,
}
}
}
// ---------------------------------------------------------------------------
// System prompt prefix variants
// ---------------------------------------------------------------------------
/// Which entrypoint context Claude Code is running in.
/// Determines the opening attribution line of the system prompt.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SystemPromptPrefix {
/// Standard interactive CLI session.
Cli,
/// Running as a sub-agent spawned by the Claude Agent SDK.
Sdk,
/// The CLI preset running within the Agent SDK
/// (non-interactive + append_system_prompt set).
SdkPreset,
/// Running on Vertex AI.
Vertex,
/// Running on AWS Bedrock.
Bedrock,
/// Remote / headless CCR session.
Remote,
}
impl SystemPromptPrefix {
/// Detect from environment variables, mirroring `getCLISyspromptPrefix`.
pub fn detect(is_non_interactive: bool, has_append_system_prompt: bool) -> Self {
// Vertex: always uses the default "Claude Code" prefix.
if std::env::var("ANTHROPIC_VERTEX_PROJECT_ID").is_ok()
|| std::env::var("CLOUD_ML_PROJECT_ID").is_ok()
{
return Self::Vertex;
}
if std::env::var("AWS_BEDROCK_MODEL_ID").is_ok() {
return Self::Bedrock;
}
if std::env::var("CLAUDE_CODE_REMOTE").is_ok() {
return Self::Remote;
}
// Non-interactive mode maps to SDK variants (matches TS getCLISyspromptPrefix).
if is_non_interactive {
if has_append_system_prompt {
return Self::SdkPreset;
}
return Self::Sdk;
}
Self::Cli
}
/// The opening attribution string for this prefix variant.
pub fn attribution_text(self) -> &'static str {
match self {
Self::Cli | Self::Vertex | Self::Bedrock | Self::Remote => {
"You are Claude Code, Anthropic's official CLI for Claude."
}
Self::SdkPreset => {
"You are Claude Code, Anthropic's official CLI for Claude, \
running within the Claude Agent SDK."
}
Self::Sdk => {
"You are a Claude agent, built on Anthropic's Claude Agent SDK."
}
}
}
}
// ---------------------------------------------------------------------------
// Build options
// ---------------------------------------------------------------------------
/// All options controlling what goes into the assembled system prompt.
#[derive(Debug, Clone, Default)]
pub struct SystemPromptOptions {
/// Override auto-detected prefix.
pub prefix: Option<SystemPromptPrefix>,
/// Whether the session is non-interactive (SDK / pipe mode).
pub is_non_interactive: bool,
/// Whether --append-system-prompt is set (affects prefix detection).
pub has_append_system_prompt: bool,
/// Output style to inject.
pub output_style: OutputStyle,
/// Absolute path to the working directory (injected as dynamic section).
pub working_directory: Option<String>,
/// Pre-built memory content from memdir (injected as dynamic section).
pub memory_content: String,
/// Custom system prompt (--system-prompt flag or settings).
pub custom_system_prompt: Option<String>,
/// Additional text appended after everything else (--append-system-prompt).
pub append_system_prompt: Option<String>,
/// If true and `custom_system_prompt` is set, the entire default prompt is
/// replaced — only the custom text + dynamic boundary are emitted.
pub replace_system_prompt: bool,
/// Inject the coordinator-mode section.
pub coordinator_mode: bool,
}
// ---------------------------------------------------------------------------
// Main assembly function
// ---------------------------------------------------------------------------
/// Build the complete system prompt string.
///
/// The returned string contains `SYSTEM_PROMPT_DYNAMIC_BOUNDARY` as an
/// internal marker. Callers (e.g. `buildSystemPromptBlocks` in cc-query)
/// split on this marker to determine which portions are eligible for
/// Anthropic prompt-caching.
pub fn build_system_prompt(opts: &SystemPromptOptions) -> String {
// Replace mode: skip all default sections.
if opts.replace_system_prompt {
if let Some(custom) = &opts.custom_system_prompt {
return format!("{}\n\n{}", custom, SYSTEM_PROMPT_DYNAMIC_BOUNDARY);
}
}
let prefix = opts
.prefix
.unwrap_or_else(|| {
SystemPromptPrefix::detect(
opts.is_non_interactive,
opts.has_append_system_prompt,
)
});
let mut parts: Vec<String> = Vec::new();
// ------------------------------------------------------------------ //
// CACHEABLE sections (before the dynamic boundary) //
// ------------------------------------------------------------------ //
// 1. Attribution header
parts.push(prefix.attribution_text().to_string());
// 2. Core capabilities
parts.push(CORE_CAPABILITIES.to_string());
// 3. Tool use guidelines
parts.push(TOOL_USE_GUIDELINES.to_string());
// 4. Executing actions with care
parts.push(ACTIONS_SECTION.to_string());
// 5. Safety guidelines
parts.push(SAFETY_GUIDELINES.to_string());
// 6. Cyber-risk instruction (owned by safeguards — do not edit)
parts.push(CYBER_RISK_INSTRUCTION.to_string());
// 7. Output style (cacheable when non-Default; its content is stable)
if let Some(style_text) = opts.output_style.prompt_suffix() {
parts.push(format!("\n## Output Style\n{}", style_text));
}
// 8. Coordinator mode (cacheable: content is constant)
if opts.coordinator_mode {
parts.push(COORDINATOR_SYSTEM_PROMPT.to_string());
}
// 9. Custom system prompt addition (appended to cacheable block)
if let Some(custom) = &opts.custom_system_prompt {
parts.push(format!(
"\n<custom_instructions>\n{}\n</custom_instructions>",
custom
));
}
// Dynamic boundary marker
parts.push(SYSTEM_PROMPT_DYNAMIC_BOUNDARY.to_string());
// ------------------------------------------------------------------ //
// DYNAMIC / UNCACHEABLE sections (after the boundary) //
// ------------------------------------------------------------------ //
// 10. Working directory
if let Some(cwd) = &opts.working_directory {
parts.push(format!("\n<working_directory>{}</working_directory>", cwd));
}
// 11. Memory injection (from memdir)
if !opts.memory_content.is_empty() {
parts.push(format!(
"\n<memory>\n{}\n</memory>",
opts.memory_content
));
}
// 12. Appended system prompt (--append-system-prompt)
if let Some(append) = &opts.append_system_prompt {
parts.push(format!("\n{}", append));
}
parts.join("\n")
}
// ---------------------------------------------------------------------------
// Static system prompt sections
// ---------------------------------------------------------------------------
const CORE_CAPABILITIES: &str = r#"
## Capabilities
You have access to powerful tools for software engineering tasks:
- **Read/Write files**: Read any file, write new files, edit existing files with precise diffs
- **Execute commands**: Run bash commands, PowerShell scripts, background processes
- **Search**: Glob patterns, regex grep, web search, file content search
- **Web**: Fetch URLs, search the internet
- **Agents**: Spawn parallel sub-agents for complex multi-step work
- **Memory**: Persistent notes across sessions via the memory system
- **MCP servers**: Connect to external tools and APIs via Model Context Protocol
- **Jupyter notebooks**: Read and edit notebook cells
## How to approach tasks
1. **Understand before acting**: Read relevant files before making changes
2. **Minimal changes**: Only modify what's needed. Don't refactor unrequested code.
3. **Verify**: Check your work with tests or by reading the result
4. **Communicate blockers**: If stuck, ask the user rather than guessing
"#;
const TOOL_USE_GUIDELINES: &str = r#"
## Tool use guidelines
- Use dedicated tools (Read, Edit, Glob, Grep) instead of bash equivalents
- For searches, prefer Grep over `grep`; prefer Glob over `find`
- Parallelize independent tool calls in a single response
- For file edits: always read the file first, then make targeted edits
- Bash commands timeout after 2 minutes; use background mode for long operations
"#;
const ACTIONS_SECTION: &str = r#"
## Executing actions with care
Carefully consider the reversibility and blast radius of actions. For actions
that are hard to reverse, affect shared systems, or could be risky or
destructive, check with the user before proceeding. Authorization stands for
the scope specified, not beyond. Match the scope of your actions to what was
actually requested.
"#;
const SAFETY_GUIDELINES: &str = r#"
## Safety guidelines
- Never delete files without explicit user confirmation
- Don't modify protected files (.gitconfig, .bashrc, .zshrc, .mcp.json, .claude.json)
- Be careful with destructive operations (rm -rf, DROP TABLE, etc.)
- Don't commit secrets, credentials, or API keys
- For ambiguous destructive actions, ask before proceeding
"#;
/// Security / cyber-risk instruction.
/// Owned by the safeguards team — do not modify the content here.
const CYBER_RISK_INSTRUCTION: &str = r#"
## Security
You are authorized to assist with security research, CTF challenges, penetration testing
with explicit authorization, defensive security, and educational security content. Do not
assist with creating malware, unauthorized access, denial-of-service attacks, or any
destructive security techniques without clear legitimate purpose.
"#;
const COORDINATOR_SYSTEM_PROMPT: &str = r#"
## Coordinator Mode
You are operating as an orchestrator. Spawn parallel worker agents using the Agent tool.
Each worker prompt must be fully self-contained. Synthesize findings before delegating
follow-up work. Use TaskCreate/TaskUpdate to track parallel work.
"#;
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn default_opts() -> SystemPromptOptions {
SystemPromptOptions::default()
}
#[test]
fn test_default_prompt_contains_boundary() {
let prompt = build_system_prompt(&default_opts());
assert!(
prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY),
"System prompt must contain the dynamic boundary marker"
);
}
#[test]
fn test_default_prompt_contains_attribution() {
let prompt = build_system_prompt(&default_opts());
assert!(prompt.contains("Claude Code"), "Default prompt must contain attribution");
}
#[test]
fn test_replace_system_prompt() {
let opts = SystemPromptOptions {
custom_system_prompt: Some("Custom only.".to_string()),
replace_system_prompt: true,
..Default::default()
};
let prompt = build_system_prompt(&opts);
assert!(prompt.starts_with("Custom only."));
assert!(!prompt.contains("Capabilities"));
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
}
#[test]
fn test_working_directory_in_dynamic_section() {
let opts = SystemPromptOptions {
working_directory: Some("/home/user/project".to_string()),
..Default::default()
};
let prompt = build_system_prompt(&opts);
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
let cwd_pos = prompt.find("/home/user/project").unwrap();
assert!(
cwd_pos > boundary_pos,
"Working directory must appear after the dynamic boundary"
);
}
#[test]
fn test_memory_content_in_dynamic_section() {
let opts = SystemPromptOptions {
memory_content: "- [test.md](test.md) — a test memory".to_string(),
..Default::default()
};
let prompt = build_system_prompt(&opts);
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
let mem_pos = prompt.find("test.md").unwrap();
assert!(
mem_pos > boundary_pos,
"Memory content must appear after the dynamic boundary"
);
}
#[test]
fn test_output_style_concise() {
let opts = SystemPromptOptions {
output_style: OutputStyle::Concise,
..Default::default()
};
let prompt = build_system_prompt(&opts);
assert!(prompt.contains("maximally concise"));
}
#[test]
fn test_output_style_default_has_no_suffix() {
let opts = SystemPromptOptions {
output_style: OutputStyle::Default,
..Default::default()
};
let prompt = build_system_prompt(&opts);
// None of the style suffixes should appear
assert!(!prompt.contains("maximally concise"));
assert!(!prompt.contains("This user is learning"));
}
#[test]
fn test_coordinator_mode_section() {
let opts = SystemPromptOptions {
coordinator_mode: true,
..Default::default()
};
let prompt = build_system_prompt(&opts);
assert!(prompt.contains("Coordinator Mode"));
assert!(prompt.contains("orchestrator"));
}
#[test]
fn test_output_style_from_str() {
assert_eq!(OutputStyle::from_str("concise"), OutputStyle::Concise);
assert_eq!(OutputStyle::from_str("FORMAL"), OutputStyle::Formal);
assert_eq!(OutputStyle::from_str("unknown"), OutputStyle::Default);
}
#[test]
fn test_sdk_prefix_non_interactive_no_append() {
let prefix = SystemPromptPrefix::detect(true, false);
assert_eq!(prefix, SystemPromptPrefix::Sdk);
assert!(prefix.attribution_text().contains("Claude agent"));
}
#[test]
fn test_sdk_preset_prefix_non_interactive_with_append() {
let prefix = SystemPromptPrefix::detect(true, true);
assert_eq!(prefix, SystemPromptPrefix::SdkPreset);
assert!(prefix.attribution_text().contains("Claude Agent SDK"));
}
#[test]
fn test_clear_section_cache() {
// Populate cache then clear it — should not panic.
{
let mut cache = section_cache().lock().unwrap();
cache.insert("test_section".to_string(), Some("content".to_string()));
}
clear_system_prompt_sections();
let cache = section_cache().lock().unwrap();
assert!(cache.is_empty());
}
}

View file

@ -0,0 +1,682 @@
//! Team memory synchronization with claude.ai API.
//!
//! Implements delta push (only changed files) with ETag-based optimistic
//! concurrency and greedy bin-packing of changed entries into batches that
//! fit within the server's PUT body limit.
//!
//! Pull is server-wins: remote content overwrites local files unconditionally.
use std::collections::HashMap;
use std::path::PathBuf;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, Digest};
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
/// Maximum bytes per local file accepted for sync (250 KB)
const MAX_FILE_SIZE_BYTES: usize = 250 * 1024;
/// Maximum serialized bytes per PUT request body (200 KB)
const MAX_PUT_BODY_BYTES: usize = 200 * 1024;
// ---------------------------------------------------------------------------
// Data types
// ---------------------------------------------------------------------------
/// Persisted per-repo sync state (stored alongside local team-memory files).
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SyncState {
/// ETag returned by the last successful GET or PUT.
pub last_known_etag: Option<String>,
/// Per-key server-side checksums (`"sha256:<hex>"`).
/// Used to diff local vs remote without re-uploading unchanged entries.
pub server_checksums: HashMap<String, String>,
/// Server-enforced max_entries from a prior 413 response.
pub server_max_entries: Option<usize>,
}
/// A single team-memory entry (one markdown file).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TeamMemoryEntry {
/// Relative file path (forward-slash separated, e.g. `"MEMORY.md"`).
pub key: String,
/// UTF-8 file content (typically Markdown).
pub content: String,
/// `"sha256:<hex>"` of the content.
pub checksum: String,
}
/// Server response shape for GET `/api/claude_code/team_memory`.
#[derive(Debug, Serialize, Deserialize)]
pub struct TeamMemoryData {
pub entries: Vec<TeamMemoryEntry>,
pub etag: Option<String>,
}
// ---------------------------------------------------------------------------
// Checksum helper
// ---------------------------------------------------------------------------
/// Compute `"sha256:<lowercase hex>"` of a string.
pub fn content_checksum(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
format!("sha256:{}", hex::encode(hasher.finalize()))
}
// ---------------------------------------------------------------------------
// Path security validation
// ---------------------------------------------------------------------------
/// Reject paths that could escape the team-memory directory.
///
/// Checks performed (mirroring the TypeScript `securePath` validation):
/// - No null bytes
/// - No URL-encoded traversal sequences (`%2e`, `%2f`, case-insensitive)
/// - No backslashes
/// - Not an absolute path (Unix `/` or Windows `C:` style)
/// - No `..` components
pub fn validate_memory_path(path: &str) -> Result<()> {
if path.contains('\0') {
anyhow::bail!("Path contains null bytes: {:?}", path);
}
let lower = path.to_ascii_lowercase();
if lower.contains("%2e") || lower.contains("%2f") {
anyhow::bail!("Path contains URL-encoded traversal sequences: {:?}", path);
}
if path.contains('\\') {
anyhow::bail!("Path contains backslashes: {:?}", path);
}
if path.starts_with('/') {
anyhow::bail!("Absolute Unix paths not allowed: {:?}", path);
}
// Windows-style absolute path: e.g. "C:" or "c:"
if path.len() >= 2 {
let mut chars = path.chars();
let first = chars.next().unwrap();
if first.is_ascii_alphabetic() && chars.next() == Some(':') {
anyhow::bail!("Absolute Windows paths not allowed: {:?}", path);
}
}
if path.split('/').any(|component| component == "..") {
anyhow::bail!("Path traversal not allowed: {:?}", path);
}
Ok(())
}
// ---------------------------------------------------------------------------
// TeamMemorySync
// ---------------------------------------------------------------------------
/// Drives pull and push against the claude.ai team-memory API.
pub struct TeamMemorySync {
/// Base URL of the API, e.g. `"https://claude.ai"`.
api_base: String,
/// Repo identifier sent as a query parameter.
repo: String,
/// Bearer token for authentication.
token: String,
/// Local directory that mirrors the server's key namespace.
team_dir: PathBuf,
}
impl TeamMemorySync {
pub fn new(api_base: String, repo: String, token: String, team_dir: PathBuf) -> Self {
Self { api_base, repo, token, team_dir }
}
// -----------------------------------------------------------------------
// Pull
// -----------------------------------------------------------------------
/// Pull all entries from the server. Server wins: overwrites local files.
///
/// Updates `state.last_known_etag` and `state.server_checksums` on success.
/// Returns `Ok(())` on HTTP 404 (no remote data yet).
pub async fn pull(&self, state: &mut SyncState) -> Result<()> {
let client = reqwest::Client::new();
let url = format!(
"{}/api/claude_code/team_memory?repo={}",
self.api_base,
urlencoding::encode(&self.repo),
);
let response = client
.get(&url)
.bearer_auth(&self.token)
.send()
.await
.context("team memory pull: HTTP request failed")?;
let http_status = response.status();
if http_status.as_u16() == 404 {
return Ok(()); // No remote data yet
}
if !http_status.is_success() {
anyhow::bail!("team memory pull failed with status {}", http_status);
}
// Capture ETag before consuming the response body
if let Some(etag) = response
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
{
state.last_known_etag = Some(etag.to_string());
}
let data: TeamMemoryData = response
.json()
.await
.context("team memory pull: failed to parse response JSON")?;
state.server_checksums.clear();
for entry in &data.entries {
validate_memory_path(&entry.key)
.with_context(|| format!("server returned unsafe path: {:?}", entry.key))?;
state
.server_checksums
.insert(entry.key.clone(), entry.checksum.clone());
let local_path = self.team_dir.join(&entry.key);
if let Some(parent) = local_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.with_context(|| format!("create_dir_all for {:?}", parent))?;
}
if entry.content.len() <= MAX_FILE_SIZE_BYTES {
tokio::fs::write(&local_path, &entry.content)
.await
.with_context(|| format!("writing {:?}", local_path))?;
}
// Files exceeding MAX_FILE_SIZE_BYTES are silently skipped (same behaviour as push)
}
Ok(())
}
// -----------------------------------------------------------------------
// Push
// -----------------------------------------------------------------------
/// Push local changes to the server using delta upload.
///
/// Only entries whose local checksum differs from `state.server_checksums`
/// are uploaded. Changed entries are packed into batches ≤ `MAX_PUT_BODY_BYTES`.
pub async fn push(&self, state: &mut SyncState) -> Result<()> {
let local_entries = self
.scan_local_files()
.await
.context("team memory push: scanning local files")?;
// Delta: entries where local hash ≠ last-known server hash
let changed: Vec<TeamMemoryEntry> = local_entries
.into_iter()
.filter(|entry| {
state
.server_checksums
.get(&entry.key)
.map(|s| s.as_str())
!= Some(&entry.checksum)
})
.collect();
if changed.is_empty() {
return Ok(());
}
let batches = self.pack_batches(changed);
for batch in batches {
self.upload_batch(batch, state)
.await
.context("team memory push: uploading batch")?;
}
Ok(())
}
// -----------------------------------------------------------------------
// Internals
// -----------------------------------------------------------------------
/// Greedy bin-packing: pack entries into batches that each serialise to
/// ≤ `MAX_PUT_BODY_BYTES`. Entries that individually exceed the limit go
/// into singleton batches (server will reject them with 413, but that is
/// the caller's problem).
fn pack_batches(&self, entries: Vec<TeamMemoryEntry>) -> Vec<Vec<TeamMemoryEntry>> {
let mut batches: Vec<Vec<TeamMemoryEntry>> = Vec::new();
let mut current: Vec<TeamMemoryEntry> = Vec::new();
let mut current_size: usize = 0;
for entry in entries {
// Rough size estimate: key + content + JSON envelope overhead
let entry_size = entry.key.len() + entry.content.len() + 100;
if entry_size > MAX_PUT_BODY_BYTES {
// Oversized entry goes solo
if !current.is_empty() {
batches.push(std::mem::take(&mut current));
current_size = 0;
}
batches.push(vec![entry]);
continue;
}
if current_size + entry_size > MAX_PUT_BODY_BYTES && !current.is_empty() {
batches.push(std::mem::take(&mut current));
current_size = 0;
}
current_size += entry_size;
current.push(entry);
}
if !current.is_empty() {
batches.push(current);
}
batches
}
async fn upload_batch(
&self,
batch: Vec<TeamMemoryEntry>,
state: &mut SyncState,
) -> Result<()> {
let client = reqwest::Client::new();
let url = format!(
"{}/api/claude_code/team_memory?repo={}",
self.api_base,
urlencoding::encode(&self.repo),
);
let body = serde_json::json!({ "entries": batch });
let mut req = client
.put(&url)
.bearer_auth(&self.token)
.json(&body);
if let Some(etag) = &state.last_known_etag {
req = req.header("If-Match", etag);
}
let response = req
.send()
.await
.context("team memory: PUT request failed")?;
let status = response.status().as_u16();
match status {
200 | 201 | 204 => {
if let Some(etag) = response
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
{
state.last_known_etag = Some(etag.to_string());
}
// Update local checksum map to reflect uploaded state
for entry in &batch {
state
.server_checksums
.insert(entry.key.clone(), entry.checksum.clone());
}
Ok(())
}
412 => anyhow::bail!("Conflict (412 Precondition Failed): ETag mismatch, retry needed"),
413 => anyhow::bail!("Payload too large (413)"),
401 | 403 => anyhow::bail!("Authentication error ({})", status),
_ => anyhow::bail!("Upload failed with status {}", status),
}
}
/// Recursively scan `team_dir` for `.md` files, returning entries sorted by key.
async fn scan_local_files(&self) -> Result<Vec<TeamMemoryEntry>> {
let mut entries = Vec::new();
if !self.team_dir.exists() {
return Ok(entries);
}
// Iterative DFS using an explicit stack to avoid deep recursion
let mut stack = vec![self.team_dir.clone()];
while let Some(dir) = stack.pop() {
let mut read_dir = tokio::fs::read_dir(&dir)
.await
.with_context(|| format!("read_dir {:?}", dir))?;
while let Some(entry) = read_dir.next_entry().await? {
let path = entry.path();
if path.is_dir() {
stack.push(path);
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
let content = tokio::fs::read_to_string(&path)
.await
.with_context(|| format!("reading {:?}", path))?;
if content.len() > MAX_FILE_SIZE_BYTES {
continue; // Skip files that are too large
}
let key = path
.strip_prefix(&self.team_dir)
.unwrap()
.to_string_lossy()
.replace('\\', "/");
let checksum = content_checksum(&content);
entries.push(TeamMemoryEntry { key, content, checksum });
}
}
}
entries.sort_by(|a, b| a.key.cmp(&b.key));
Ok(entries)
}
}
// ---------------------------------------------------------------------------
// Secret scanner
// ---------------------------------------------------------------------------
/// A pattern matched during secret scanning.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SecretMatch {
/// Short label identifying the secret type, e.g. `"Anthropic API key"`.
pub label: String,
}
/// Scan `content` for common high-confidence secret patterns.
///
/// Returns one [`SecretMatch`] per distinct pattern that fired. The actual
/// matched text is intentionally **not** returned to avoid logging credentials.
pub fn scan_for_secrets(content: &str) -> Vec<SecretMatch> {
// Each tuple: (regex source, human-readable label)
// Patterns ordered by likelihood of appearing in dev-team memory content.
const PATTERNS: &[(&str, &str)] = &[
// Cloud providers
(r"(?:A3T[A-Z0-9]|AKIA|ASIA|ABIA|ACCA)[A-Z2-7]{16}", "AWS access key"),
(r"AIza[\w-]{35}", "GCP API key"),
// AI APIs
(r"sk-ant-api03-[a-zA-Z0-9_\-]{93}AA", "Anthropic API key"),
(r"sk-ant-admin01-[a-zA-Z0-9_\-]{93}AA", "Anthropic admin API key"),
(r"sk-[a-zA-Z0-9]{20}T3BlbkFJ[a-zA-Z0-9]{20}", "OpenAI API key"),
// Version control
(r"ghp_[0-9a-zA-Z]{36}", "GitHub personal access token"),
(r"github_pat_\w{82}", "GitHub fine-grained PAT"),
(r"(?:ghu|ghs)_[0-9a-zA-Z]{36}", "GitHub app token"),
(r"gho_[0-9a-zA-Z]{36}", "GitHub OAuth token"),
(r"glpat-[\w-]{20}", "GitLab PAT"),
// Communication
(r"xoxb-[0-9]{10,13}-[0-9]{10,13}[a-zA-Z0-9-]*", "Slack bot token"),
// Crypto / private keys
(r"-----BEGIN[ A-Z0-9_-]{0,100}PRIVATE KEY", "Private key"),
// Payments
(r"(?:sk|rk)_(?:test|live|prod)_[a-zA-Z0-9]{10,99}", "Stripe secret key"),
// NPM
(r"npm_[a-zA-Z0-9]{36}", "NPM access token"),
];
let mut findings: Vec<SecretMatch> = Vec::new();
for (pattern, label) in PATTERNS {
// Lazily compile; the fn is not hot enough to warrant a static cache here
if let Ok(re) = regex::Regex::new(pattern) {
if re.is_match(content) {
findings.push(SecretMatch { label: label.to_string() });
}
}
}
findings
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
// --- content_checksum ---
#[test]
fn test_checksum_format() {
let cs = content_checksum("hello");
assert!(cs.starts_with("sha256:"), "checksum should start with sha256:");
assert_eq!(cs.len(), "sha256:".len() + 64, "sha256 hex is 64 chars");
}
#[test]
fn test_checksum_deterministic() {
assert_eq!(content_checksum("foo"), content_checksum("foo"));
}
#[test]
fn test_checksum_distinct() {
assert_ne!(content_checksum("foo"), content_checksum("bar"));
}
// --- validate_memory_path ---
#[test]
fn test_valid_paths_accepted() {
let ok_paths = [
"MEMORY.md",
"sub/dir/file.md",
"sub/dir/another-file.md",
"a.md",
];
for p in &ok_paths {
assert!(validate_memory_path(p).is_ok(), "should accept: {}", p);
}
}
#[test]
fn test_null_byte_rejected() {
assert!(validate_memory_path("foo\0bar").is_err());
}
#[test]
fn test_url_encoded_dot_rejected() {
assert!(validate_memory_path("%2e%2e/secret").is_err());
}
#[test]
fn test_url_encoded_slash_rejected() {
assert!(validate_memory_path("foo%2Fbar").is_err());
}
#[test]
fn test_backslash_rejected() {
assert!(validate_memory_path("foo\\bar").is_err());
}
#[test]
fn test_absolute_unix_rejected() {
assert!(validate_memory_path("/etc/passwd").is_err());
}
#[test]
fn test_absolute_windows_rejected() {
assert!(validate_memory_path("C:foo").is_err());
}
#[test]
fn test_dotdot_rejected() {
assert!(validate_memory_path("../secret").is_err());
assert!(validate_memory_path("a/../../secret").is_err());
}
// --- pack_batches ---
fn make_sync() -> TeamMemorySync {
TeamMemorySync::new(
"https://example.com".to_string(),
"owner/repo".to_string(),
"token123".to_string(),
PathBuf::from("/tmp/team"),
)
}
fn entry(key: &str, size: usize) -> TeamMemoryEntry {
let content = "x".repeat(size);
let checksum = content_checksum(&content);
TeamMemoryEntry { key: key.to_string(), content, checksum }
}
#[test]
fn test_pack_batches_empty() {
let sync = make_sync();
let batches = sync.pack_batches(vec![]);
assert!(batches.is_empty());
}
#[test]
fn test_pack_batches_single_entry() {
let sync = make_sync();
let batches = sync.pack_batches(vec![entry("a.md", 100)]);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 1);
}
#[test]
fn test_pack_batches_oversized_solo() {
let sync = make_sync();
// Entry > MAX_PUT_BODY_BYTES → goes solo
let big = entry("big.md", MAX_PUT_BODY_BYTES + 1);
let small = entry("small.md", 100);
let batches = sync.pack_batches(vec![big, small]);
// big is solo, small may be in a separate batch
assert!(batches.len() >= 2);
assert_eq!(batches[0].len(), 1, "oversized entry is solo");
}
#[test]
fn test_pack_batches_groups_small_entries() {
let sync = make_sync();
// Many small entries that each fit in one batch
let entries: Vec<_> = (0..5).map(|i| entry(&format!("{i}.md"), 1024)).collect();
let batches = sync.pack_batches(entries);
// All 5 should fit in one batch (5 * ~1124 bytes << 200KB)
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 5);
}
// --- scan_for_secrets ---
#[test]
fn test_no_secrets_clean() {
let findings = scan_for_secrets("# Team notes\n\nSome markdown content here.");
assert!(findings.is_empty());
}
#[test]
fn test_detects_github_pat() {
let content = format!("token: ghp_{}", "A".repeat(36));
let findings = scan_for_secrets(&content);
assert!(
findings.iter().any(|m| m.label.contains("GitHub")),
"should detect GitHub PAT"
);
}
#[test]
fn test_detects_aws_key() {
let content = "key=AKIAIOSFODNN7EXAMPLE";
let findings = scan_for_secrets(content);
assert!(
findings.iter().any(|m| m.label.contains("AWS")),
"should detect AWS key"
);
}
#[test]
fn test_detects_private_key() {
let content = "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n";
let findings = scan_for_secrets(content);
assert!(
findings.iter().any(|m| m.label.contains("Private key")),
"should detect private key"
);
}
// --- scan_local_files (integration-style) ---
#[tokio::test]
async fn test_scan_local_files_empty_dir() {
let tmp = TempDir::new().unwrap();
let sync = TeamMemorySync::new(
"https://example.com".to_string(),
"r".to_string(),
"t".to_string(),
tmp.path().to_path_buf(),
);
let entries = sync.scan_local_files().await.unwrap();
assert!(entries.is_empty());
}
#[tokio::test]
async fn test_scan_local_files_finds_md() {
let tmp = TempDir::new().unwrap();
tokio::fs::write(tmp.path().join("MEMORY.md"), "# Memory").await.unwrap();
tokio::fs::write(tmp.path().join("ignore.txt"), "not md").await.unwrap();
let sync = TeamMemorySync::new(
"https://example.com".to_string(),
"r".to_string(),
"t".to_string(),
tmp.path().to_path_buf(),
);
let entries = sync.scan_local_files().await.unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].key, "MEMORY.md");
}
#[tokio::test]
async fn test_scan_local_files_sorted() {
let tmp = TempDir::new().unwrap();
tokio::fs::write(tmp.path().join("z.md"), "z").await.unwrap();
tokio::fs::write(tmp.path().join("a.md"), "a").await.unwrap();
tokio::fs::write(tmp.path().join("m.md"), "m").await.unwrap();
let sync = TeamMemorySync::new(
"https://example.com".to_string(),
"r".to_string(),
"t".to_string(),
tmp.path().to_path_buf(),
);
let entries = sync.scan_local_files().await.unwrap();
let keys: Vec<_> = entries.iter().map(|e| e.key.as_str()).collect();
assert_eq!(keys, vec!["a.md", "m.md", "z.md"]);
}
#[tokio::test]
async fn test_scan_local_files_checksums_match() {
let tmp = TempDir::new().unwrap();
let content = "# Hello world";
tokio::fs::write(tmp.path().join("MEMORY.md"), content).await.unwrap();
let sync = TeamMemorySync::new(
"https://example.com".to_string(),
"r".to_string(),
"t".to_string(),
tmp.path().to_path_buf(),
);
let entries = sync.scan_local_files().await.unwrap();
assert_eq!(entries[0].checksum, content_checksum(content));
}
}

View file

@ -0,0 +1,192 @@
//! Voice mode availability checks
use crate::oauth::OAuthTokens;
#[derive(Debug, Clone, PartialEq)]
pub enum VoiceAvailability {
Available,
/// Not authenticated via first-party OAuth
RequiresOAuth,
/// OAuth token missing required scopes
MissingScopes {
required: Vec<String>,
have: Vec<String>,
},
/// Feature disabled by kill-switch environment variable
Disabled,
/// Feature flag not enabled in this build
NotEnabled,
}
/// Scopes required for voice mode to function
const VOICE_REQUIRED_SCOPES: &[&str] = &["user:inference", "user:profile"];
/// Environment variable that disables voice mode when set (any value)
const KILL_SWITCH_ENV: &str = "CLAUDE_CODE_VOICE_DISABLED";
/// Check whether voice mode is available given the current OAuth tokens.
///
/// Pass `None` when the user is not authenticated via OAuth (API-key-only auth).
pub fn check_voice_availability(tokens: Option<&OAuthTokens>) -> VoiceAvailability {
// Check kill switch first — always wins
if std::env::var(KILL_SWITCH_ENV).is_ok() {
return VoiceAvailability::Disabled;
}
// Voice requires first-party OAuth; API key alone is not sufficient
let tokens = match tokens {
Some(t) => t,
None => return VoiceAvailability::RequiresOAuth,
};
// OAuthTokens stores scopes as Vec<String>
let have_scopes: &[String] = &tokens.scopes;
let missing: Vec<String> = VOICE_REQUIRED_SCOPES
.iter()
.filter(|&&required| !have_scopes.iter().any(|h| h == required))
.map(|s| s.to_string())
.collect();
if !missing.is_empty() {
return VoiceAvailability::MissingScopes {
required: VOICE_REQUIRED_SCOPES
.iter()
.map(|s| s.to_string())
.collect(),
have: have_scopes.to_vec(),
};
}
VoiceAvailability::Available
}
impl VoiceAvailability {
/// Returns `true` when voice mode can be started.
pub fn is_available(&self) -> bool {
matches!(self, VoiceAvailability::Available)
}
/// Returns a human-readable error message when voice is not available,
/// or `None` when it is.
pub fn error_message(&self) -> Option<String> {
match self {
VoiceAvailability::Available => None,
VoiceAvailability::RequiresOAuth => Some(
"Voice mode requires OAuth authentication. Run /login to authenticate.".to_string(),
),
VoiceAvailability::MissingScopes { required, have } => Some(format!(
"Voice mode requires scopes: {}. Your token has: {}",
required.join(", "),
if have.is_empty() {
"none".to_string()
} else {
have.join(", ")
}
)),
VoiceAvailability::Disabled => Some("Voice mode is currently disabled.".to_string()),
VoiceAvailability::NotEnabled => {
Some("Voice mode is not enabled in this build.".to_string())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tokens_with_scopes(scopes: Vec<&str>) -> OAuthTokens {
OAuthTokens {
access_token: "test_token".to_string(),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
..Default::default()
}
}
#[test]
fn test_no_tokens_requires_oauth() {
let result = check_voice_availability(None);
assert_eq!(result, VoiceAvailability::RequiresOAuth);
assert!(!result.is_available());
assert!(result.error_message().is_some());
}
#[test]
fn test_available_with_all_scopes() {
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
let result = check_voice_availability(Some(&tokens));
assert_eq!(result, VoiceAvailability::Available);
assert!(result.is_available());
assert!(result.error_message().is_none());
}
#[test]
fn test_missing_one_scope() {
let tokens = tokens_with_scopes(vec!["user:inference"]);
let result = check_voice_availability(Some(&tokens));
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
assert!(!result.is_available());
let msg = result.error_message().unwrap();
assert!(msg.contains("user:profile"));
}
#[test]
fn test_missing_all_scopes() {
let tokens = tokens_with_scopes(vec!["org:create_api_key"]);
let result = check_voice_availability(Some(&tokens));
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
assert!(!result.is_available());
}
#[test]
fn test_empty_scopes_missing() {
let tokens = tokens_with_scopes(vec![]);
let result = check_voice_availability(Some(&tokens));
assert!(
matches!(result, VoiceAvailability::MissingScopes { ref have, .. } if have.is_empty())
);
let msg = result.error_message().unwrap();
assert!(msg.contains("none"));
}
#[test]
fn test_kill_switch_disables_voice() {
// Temporarily set the kill-switch env var
std::env::set_var(KILL_SWITCH_ENV, "1");
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
let result = check_voice_availability(Some(&tokens));
std::env::remove_var(KILL_SWITCH_ENV);
assert_eq!(result, VoiceAvailability::Disabled);
assert!(!result.is_available());
}
#[test]
fn test_kill_switch_beats_no_auth() {
std::env::set_var(KILL_SWITCH_ENV, "true");
let result = check_voice_availability(None);
std::env::remove_var(KILL_SWITCH_ENV);
// Kill switch wins — returns Disabled, not RequiresOAuth
assert_eq!(result, VoiceAvailability::Disabled);
}
#[test]
fn test_not_enabled_error_message() {
let v = VoiceAvailability::NotEnabled;
assert!(!v.is_available());
assert!(v.error_message().unwrap().contains("not enabled"));
}
#[test]
fn test_extra_scopes_still_available() {
// Having more scopes than required is fine
let tokens = tokens_with_scopes(vec![
"user:inference",
"user:profile",
"org:create_api_key",
"user:file_upload",
]);
let result = check_voice_availability(Some(&tokens));
assert_eq!(result, VoiceAvailability::Available);
}
}