From 5ed6cdae02b27415be9553127a6fd9fff7a2f332 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Tue, 14 Oct 2025 20:44:19 -0700 Subject: [PATCH 01/23] initial static cost walk --- clarity/src/vm/costs/analysis.rs | 857 +++++++++++++++++++++++++++++++ clarity/src/vm/costs/mod.rs | 1 + 2 files changed, 858 insertions(+) create mode 100644 clarity/src/vm/costs/analysis.rs diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs new file mode 100644 index 0000000000..879f2ba44c --- /dev/null +++ b/clarity/src/vm/costs/analysis.rs @@ -0,0 +1,857 @@ +// Static cost analysis for Clarity expressions + +use std::collections::HashMap; + +use crate::vm::Value; +use clarity_types::representations::ContractName; +use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; + +use crate::vm::ast::parser::v2::parse; +use crate::vm::costs::cost_functions::{linear, CostValues}; +use crate::vm::costs::costs_3::Costs3; +use crate::vm::costs::ExecutionCost; +use crate::vm::errors::InterpreterResult; +use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolicExpressionType}; + +// TODO: +// contract-call? - get source from database +// type-checking +// lookups +// unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) + +const STRING_COST_BASE: u64 = 36; +const STRING_COST_MULTIPLIER: u64 = 3; + +/// Functions where string arguments have zero cost because the function +/// cost includes their processing +const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; + +#[derive(Debug, Clone)] +pub enum ExprNode { + If, + Match, + Unwrap, + Ok, + Err, + GT, + LT, + GE, + LE, + EQ, + Add, + Sub, + Mul, + Div, + Function(ClarityName), + AtomValue(Value), + Atom(ClarityName), + SugaredContractIdentifier(ContractName), + SugaredFieldIdentifier(ContractName, ClarityName), + FieldIdentifier(TraitIdentifier), + TraitReference(ClarityName), + // User function arguments + UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) +} + +#[derive(Debug, Clone)] +pub struct CostAnalysisNode { + pub expr: ExprNode, + pub cost: StaticCost, + pub children: Vec, + pub branching: bool, +} + +impl CostAnalysisNode { + pub fn new( + expr: ExprNode, + cost: StaticCost, + children: Vec, + branching: bool, + ) -> Self { + Self { + expr, + cost, + children, + branching, + } + } + + pub fn leaf(expr: ExprNode, cost: StaticCost) -> Self { + Self { + expr, + cost, + children: vec![], + branching: false, + } + } +} + +#[derive(Debug, Clone)] +pub struct StaticCost { + pub min: ExecutionCost, + pub max: ExecutionCost, +} + +impl StaticCost { + pub const ZERO: StaticCost = StaticCost { + min: ExecutionCost::ZERO, + max: ExecutionCost::ZERO, + }; +} + +#[derive(Debug, Clone)] +pub struct UserArgumentsContext { + /// Map from argument name to argument type + pub arguments: HashMap, +} + +impl UserArgumentsContext { + pub fn new() -> Self { + Self { + arguments: HashMap::new(), + } + } + + pub fn add_argument(&mut self, name: ClarityName, arg_type: ClarityName) { + self.arguments.insert(name, arg_type); + } + + pub fn is_user_argument(&self, name: &ClarityName) -> bool { + self.arguments.contains_key(name) + } + + pub fn get_argument_type(&self, name: &ClarityName) -> Option<&ClarityName> { + self.arguments.get(name) + } +} + +/// A type to track summed execution costs for different paths +/// This allows us to compute min and max costs across different execution paths +#[derive(Debug, Clone)] +pub struct SummingExecutionCost { + pub costs: Vec, +} + +impl SummingExecutionCost { + pub fn new() -> Self { + Self { costs: Vec::new() } + } + + pub fn from_single(cost: ExecutionCost) -> Self { + Self { costs: vec![cost] } + } + + pub fn add_cost(&mut self, cost: ExecutionCost) { + self.costs.push(cost); + } + + pub fn add_summing(&mut self, other: &SummingExecutionCost) { + self.costs.extend(other.costs.clone()); + } + + /// minimum cost across all paths + pub fn min(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.min(cost.runtime), + write_length: acc.write_length.min(cost.write_length), + write_count: acc.write_count.min(cost.write_count), + read_length: acc.read_length.min(cost.read_length), + read_count: acc.read_count.min(cost.read_count), + }) + } + } + + /// maximum cost across all paths + pub fn max(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.max(cost.runtime), + write_length: acc.write_length.max(cost.write_length), + write_count: acc.write_count.max(cost.write_count), + read_length: acc.read_length.max(cost.read_length), + read_count: acc.read_count.max(cost.read_count), + }) + } + } + + pub fn add_all(&self) -> ExecutionCost { + self.costs + .iter() + .fold(ExecutionCost::ZERO, |mut acc, cost| { + let _ = acc.add(cost); + acc + }) + } +} + +/// Parse Clarity source code and calculate its static execution cost +/// +/// theoretically you could inspect the tree at any node to get the spot cost +pub fn static_cost(source: &str) -> Result { + let pre_expressions = parse(source).map_err(|e| format!("Parse error: {:?}", e))?; + + if pre_expressions.is_empty() { + return Err("No expressions found".to_string()); + } + + // TODO what happens if multiple expressions are selected? + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_analysis_tree = build_cost_analysis_tree(pre_expr, &user_args)?; + + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); + Ok(summing_cost.into()) +} + +fn build_cost_analysis_tree( + expr: &PreSymbolicExpression, + user_args: &UserArgumentsContext, +) -> Result { + match &expr.pre_expr { + PreSymbolicExpressionType::List(list) => { + if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { + if function_name.as_str() == "define-public" + || function_name.as_str() == "define-private" + || function_name.as_str() == "define-read-only" + { + return build_function_definition_cost_analysis_tree(list, user_args); + } + } + build_listlike_cost_analysis_tree(list, "list", user_args) + } + PreSymbolicExpressionType::AtomValue(value) => { + let cost = calculate_value_cost(value)?; + Ok(CostAnalysisNode::leaf( + ExprNode::AtomValue(value.clone()), + cost, + )) + } + PreSymbolicExpressionType::Atom(name) => { + let expr_node = parse_atom_expression(name, user_args)?; + Ok(CostAnalysisNode::leaf(expr_node, StaticCost::ZERO)) + } + PreSymbolicExpressionType::Tuple(tuple) => { + build_listlike_cost_analysis_tree(tuple, "tuple", user_args) + } + PreSymbolicExpressionType::SugaredContractIdentifier(contract_name) => { + Ok(CostAnalysisNode::leaf( + ExprNode::SugaredContractIdentifier(contract_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )) + } + PreSymbolicExpressionType::SugaredFieldIdentifier(contract_name, field_name) => { + Ok(CostAnalysisNode::leaf( + ExprNode::SugaredFieldIdentifier(contract_name.clone(), field_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )) + } + PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(CostAnalysisNode::leaf( + ExprNode::FieldIdentifier(field_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )), + PreSymbolicExpressionType::TraitReference(trait_name) => Ok(CostAnalysisNode::leaf( + ExprNode::TraitReference(trait_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )), + // Comments and placeholders should be filtered out during traversal + PreSymbolicExpressionType::Comment(_comment) => { + Err("hit an irrelevant comment expr type".to_string()) + } + PreSymbolicExpressionType::Placeholder(_placeholder) => { + Err("hit an irrelevant placeholder expr type".to_string()) + } + } +} + +/// Parse an atom expression into an ExprNode +fn parse_atom_expression( + name: &ClarityName, + user_args: &UserArgumentsContext, +) -> Result { + // Check if this atom is a user-defined function argument + if user_args.is_user_argument(name) { + if let Some(arg_type) = user_args.get_argument_type(name) { + Ok(ExprNode::UserArgument(name.clone(), arg_type.clone())) + } else { + Ok(ExprNode::Atom(name.clone())) + } + } else { + Ok(ExprNode::Atom(name.clone())) + } +} + +/// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) +fn build_function_definition_cost_analysis_tree( + list: &[PreSymbolicExpression], + _user_args: &UserArgumentsContext, +) -> Result { + let define_type = list[0] + .match_atom() + .ok_or("Expected atom for define type")?; + let signature = list[1] + .match_list() + .ok_or("Expected list for function signature")?; + let body = &list[2]; + + let mut children = Vec::new(); + let mut function_user_args = UserArgumentsContext::new(); + + // Process function arguments: (a u64) + for arg_expr in signature.iter().skip(1) { + if let Some(arg_list) = arg_expr.match_list() { + if arg_list.len() == 2 { + let arg_name = arg_list[0] + .match_atom() + .ok_or("Expected atom for argument name")?; + + let arg_type = match &arg_list[1].pre_expr { + PreSymbolicExpressionType::Atom(type_name) => type_name.clone(), + PreSymbolicExpressionType::AtomValue(value) => { + ClarityName::from(value.to_string().as_str()) + } + _ => return Err("Argument type must be an atom or atom value".to_string()), + }; + + // Add to function's user arguments context + function_user_args.add_argument(arg_name.clone(), arg_type.clone()); + + // Create UserArgument node + children.push(CostAnalysisNode::leaf( + ExprNode::UserArgument(arg_name.clone(), arg_type), + StaticCost::ZERO, + )); + } + } + } + + // Process the function body with the function's user arguments context + let body_tree = build_cost_analysis_tree(body, &function_user_args)?; + children.push(body_tree); + + // Create the function definition node with zero cost (function definitions themselves don't have execution cost) + Ok(CostAnalysisNode::new( + ExprNode::Function(define_type.clone()), + StaticCost::ZERO, + children, + false, + )) +} + +/// Helper function to build expression trees for both lists and tuples +fn build_listlike_cost_analysis_tree( + items: &[PreSymbolicExpression], + container_type: &str, + user_args: &UserArgumentsContext, +) -> Result { + let function_name = match &items[0].pre_expr { + PreSymbolicExpressionType::Atom(name) => name, + _ => { + return Err(format!( + "First element of {} must be an atom (function name)", + container_type + )); + } + }; + + let args = &items[1..]; + let mut children = Vec::new(); + + // Build children for all arguments, skipping comments and placeholders + for arg in args { + match &arg.pre_expr { + PreSymbolicExpressionType::Comment(_) | PreSymbolicExpressionType::Placeholder(_) => { + // Skip comments and placeholders + continue; + } + _ => { + children.push(build_cost_analysis_tree(arg, user_args)?); + } + } + } + + let branching = is_branching_function(function_name); + let expr_node = map_function_to_expr_node(function_name.as_str()); + let cost = calculate_function_cost_from_name(function_name.as_str(), children.len() as u64)?; + + // Handle special cases for string arguments to functions that include their processing cost + if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { + for child in &mut children { + if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child.expr { + child.cost = StaticCost::ZERO; + } + } + } + + Ok(CostAnalysisNode::new(expr_node, cost, children, branching)) +} + +/// Maps function names to their corresponding ExprNode variants +fn map_function_to_expr_node(function_name: &str) -> ExprNode { + match function_name { + "if" => ExprNode::If, + "match" => ExprNode::Match, + "unwrap!" | "unwrap-err!" | "unwrap-panic" | "unwrap-err-panic" => ExprNode::Unwrap, + "ok" => ExprNode::Ok, + "err" => ExprNode::Err, + ">" => ExprNode::GT, + "<" => ExprNode::LT, + ">=" => ExprNode::GE, + "<=" => ExprNode::LE, + "=" | "is-eq" | "eq" => ExprNode::EQ, + "+" | "add" => ExprNode::Add, + "-" | "sub" => ExprNode::Sub, + "*" | "mul" => ExprNode::Mul, + "/" | "div" => ExprNode::Div, + _ => ExprNode::Function(ClarityName::from(function_name)), + } +} + +/// Determine if a function name represents a branching function +fn is_branching_function(function_name: &ClarityName) -> bool { + match function_name.as_str() { + "if" | "match" => true, + "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and + // unwrap-err traverse both branches regardless of result, so until this is + // fixed in clarity we'll set this to false + _ => false, + } +} + +/// Calculate the cost for a string based on its length +fn string_cost(length: usize) -> StaticCost { + let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); + let execution_cost = ExecutionCost::runtime(cost); + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + } +} + +/// Calculate cost for a value (used for literal values) +fn calculate_value_cost(value: &Value) -> Result { + match value { + Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { + Ok(string_cost(data.data.len())) + } + Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { + Ok(string_cost(data.data.len())) + } + _ => Ok(StaticCost::ZERO), + } +} + +fn calculate_function_cost_from_name( + function_name: &str, + arg_count: u64, +) -> Result { + let cost_function = match get_cost_function_for_name(function_name) { + Some(cost_fn) => cost_fn, + None => { + // TODO: zero cost for now + return Ok(StaticCost::ZERO); + } + }; + + let cost = get_costs(cost_function, arg_count)?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) +} + +/// Convert a function name string to its corresponding cost function +fn get_cost_function_for_name(name: &str) -> Option InterpreterResult> { + // Map function names to their cost functions using the existing enum structure + match name { + "+" | "add" => Some(Costs3::cost_add), + "-" | "sub" => Some(Costs3::cost_sub), + "*" | "mul" => Some(Costs3::cost_mul), + "/" | "div" => Some(Costs3::cost_div), + "mod" => Some(Costs3::cost_mod), + "pow" => Some(Costs3::cost_pow), + "sqrti" => Some(Costs3::cost_sqrti), + "log2" => Some(Costs3::cost_log2), + "to-int" | "to-uint" | "int-cast" => Some(Costs3::cost_int_cast), + "is-eq" | "=" | "eq" => Some(Costs3::cost_eq), + ">=" | "geq" => Some(Costs3::cost_geq), + "<=" | "leq" => Some(Costs3::cost_leq), + ">" | "ge" => Some(Costs3::cost_ge), + "<" | "le" => Some(Costs3::cost_le), + "xor" => Some(Costs3::cost_xor), + "not" => Some(Costs3::cost_not), + "and" => Some(Costs3::cost_and), + "or" => Some(Costs3::cost_or), + "concat" => Some(Costs3::cost_concat), + "len" => Some(Costs3::cost_len), + "as-max-len?" => Some(Costs3::cost_as_max_len), + "list" => Some(Costs3::cost_list_cons), + "element-at" | "element-at?" => Some(Costs3::cost_element_at), + "index-of" | "index-of?" => Some(Costs3::cost_index_of), + "fold" => Some(Costs3::cost_fold), + "map" => Some(Costs3::cost_map), + "filter" => Some(Costs3::cost_filter), + "append" => Some(Costs3::cost_append), + "tuple-get" => Some(Costs3::cost_tuple_get), + "tuple-merge" => Some(Costs3::cost_tuple_merge), + "tuple" => Some(Costs3::cost_tuple_cons), + "some" => Some(Costs3::cost_some_cons), + "ok" => Some(Costs3::cost_ok_cons), + "err" => Some(Costs3::cost_err_cons), + "default-to" => Some(Costs3::cost_default_to), + "unwrap!" => Some(Costs3::cost_unwrap_ret), + "unwrap-err!" => Some(Costs3::cost_unwrap_err_or_ret), + "is-ok" => Some(Costs3::cost_is_okay), + "is-none" => Some(Costs3::cost_is_none), + "is-err" => Some(Costs3::cost_is_err), + "is-some" => Some(Costs3::cost_is_some), + "unwrap-panic" => Some(Costs3::cost_unwrap), + "unwrap-err-panic" => Some(Costs3::cost_unwrap_err), + "try!" => Some(Costs3::cost_try_ret), + "if" => Some(Costs3::cost_if), + "match" => Some(Costs3::cost_match), + "begin" => Some(Costs3::cost_begin), + "let" => Some(Costs3::cost_let), + "asserts!" => Some(Costs3::cost_asserts), + "hash160" => Some(Costs3::cost_hash160), + "sha256" => Some(Costs3::cost_sha256), + "sha512" => Some(Costs3::cost_sha512), + "sha512/256" => Some(Costs3::cost_sha512t256), + "keccak256" => Some(Costs3::cost_keccak256), + "secp256k1-recover?" => Some(Costs3::cost_secp256k1recover), + "secp256k1-verify" => Some(Costs3::cost_secp256k1verify), + "print" => Some(Costs3::cost_print), + "contract-call?" => Some(Costs3::cost_contract_call), + "contract-of" => Some(Costs3::cost_contract_of), + "principal-of?" => Some(Costs3::cost_principal_of), + "at-block" => Some(Costs3::cost_at_block), + "load-contract" => Some(Costs3::cost_load_contract), + "create-map" => Some(Costs3::cost_create_map), + "create-var" => Some(Costs3::cost_create_var), + "create-non-fungible-token" => Some(Costs3::cost_create_nft), + "create-fungible-token" => Some(Costs3::cost_create_ft), + "map-get?" => Some(Costs3::cost_fetch_entry), + "map-set!" => Some(Costs3::cost_set_entry), + "var-get" => Some(Costs3::cost_fetch_var), + "var-set!" => Some(Costs3::cost_set_var), + "contract-storage" => Some(Costs3::cost_contract_storage), + "get-block-info?" => Some(Costs3::cost_block_info), + "get-burn-block-info?" => Some(Costs3::cost_burn_block_info), + "stx-get-balance" => Some(Costs3::cost_stx_balance), + "stx-transfer?" => Some(Costs3::cost_stx_transfer), + "stx-transfer-memo?" => Some(Costs3::cost_stx_transfer_memo), + "stx-account" => Some(Costs3::cost_stx_account), + "ft-mint?" => Some(Costs3::cost_ft_mint), + "ft-transfer?" => Some(Costs3::cost_ft_transfer), + "ft-get-balance" => Some(Costs3::cost_ft_balance), + "ft-get-supply" => Some(Costs3::cost_ft_get_supply), + "ft-burn?" => Some(Costs3::cost_ft_burn), + "nft-mint?" => Some(Costs3::cost_nft_mint), + "nft-transfer?" => Some(Costs3::cost_nft_transfer), + "nft-get-owner?" => Some(Costs3::cost_nft_owner), + "nft-burn?" => Some(Costs3::cost_nft_burn), + "buff-to-int-le?" => Some(Costs3::cost_buff_to_int_le), + "buff-to-uint-le?" => Some(Costs3::cost_buff_to_uint_le), + "buff-to-int-be?" => Some(Costs3::cost_buff_to_int_be), + "buff-to-uint-be?" => Some(Costs3::cost_buff_to_uint_be), + "to-consensus-buff?" => Some(Costs3::cost_to_consensus_buff), + "from-consensus-buff?" => Some(Costs3::cost_from_consensus_buff), + "is-standard?" => Some(Costs3::cost_is_standard), + "principal-destruct" => Some(Costs3::cost_principal_destruct), + "principal-construct?" => Some(Costs3::cost_principal_construct), + "as-contract" => Some(Costs3::cost_as_contract), + "string-to-int?" => Some(Costs3::cost_string_to_int), + "string-to-uint?" => Some(Costs3::cost_string_to_uint), + "int-to-ascii" => Some(Costs3::cost_int_to_ascii), + "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), + _ => None, // TODO + } +} + +/// Calculate total cost using SummingExecutionCost to handle branching properly +fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); + + for child in &node.children { + let child_summing = calculate_total_cost_with_summing(child); + summing_cost.add_summing(&child_summing); + } + + summing_cost +} + +fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::new(); + + if node.branching { + match &node.expr { + ExprNode::If | ExprNode::Match => { + // TODO match? + if node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&node.children[0]); + let condition_total = condition_cost.add_all(); + + // Add the root cost + condition cost to each branch + let mut root_and_condition = node.cost.min.clone(); + let _ = root_and_condition.add(&condition_total); + + for child_cost_node in node.children.iter().skip(1) { + let branch_cost = calculate_total_cost_with_summing(child_cost_node); + let branch_total = branch_cost.add_all(); + + let mut path_cost = root_and_condition.clone(); + let _ = path_cost.add(&branch_total); + + summing_cost.add_cost(path_cost); + } + } + } + _ => { + // For other branching functions, fall back to sequential processing + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + } + } else { + // For non-branching, add all costs sequentially + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + + summing_cost +} + +impl From for StaticCost { + fn from(summing: SummingExecutionCost) -> Self { + StaticCost { + min: summing.min(), + max: summing.max(), + } + } +} + +/// Helper: calculate min & max costs for a given cost function +/// This is likely tooo simplistic but for now it'll do +fn get_costs( + cost_fn: fn(u64) -> InterpreterResult, + arg_count: u64, +) -> Result { + let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(cost) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_constant() { + let source = "u2"; + let cost = static_cost(source).unwrap(); + assert_eq!(cost.min.runtime, 0); + assert_eq!(cost.max.runtime, 0); + } + + #[test] + fn test_simple_addition() { + let source = "(+ u1 u2)"; + let cost = static_cost(source).unwrap(); + + // Min: linear(2, 11, 125) = 11*2 + 125 = 147 + assert_eq!(cost.min.runtime, 147); + assert_eq!(cost.max.runtime, 147); + } + + #[test] + fn test_arithmetic() { + let source = "(- u4 (+ u1 u2))"; + let cost = static_cost(source).unwrap(); + assert_eq!(cost.min.runtime, 147 + 147); + assert_eq!(cost.max.runtime, 147 + 147); + } + + #[test] + fn test_nested_operations() { + let source = "(* (+ u1 u2) (- u3 u4))"; + let cost = static_cost(source).unwrap(); + // multiplication: 13*2 + 125 = 151 + assert_eq!(cost.min.runtime, 151 + 147 + 147); + assert_eq!(cost.max.runtime, 151 + 147 + 147); + } + + #[test] + fn test_string_concat_min_max() { + let source = "(concat \"hello\" \"world\")"; + let cost = static_cost(source).unwrap(); + + // For concat with 2 arguments: + // linear(2, 37, 220) = 37*2 + 220 = 294 + assert_eq!(cost.min.runtime, 294); + assert_eq!(cost.max.runtime, 294); + } + + #[test] + fn test_string_len_min_max() { + let source = "(len \"hello\")"; + let cost = static_cost(source).unwrap(); + + // cost: 429 (constant) - len doesn't depend on string size + assert_eq!(cost.min.runtime, 429); + assert_eq!(cost.max.runtime, 429); + } + + #[test] + fn test_branching() { + let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; + let cost = static_cost(source).unwrap(); + // min: 147 raw string + // max: 294 (concat) + + // ok = 199 + // if = 168 + // ge = (linear(n, 7, 128))) + let base_cost = 168 + ((2 * 7) + 128) + 199; + assert_eq!(cost.min.runtime, base_cost + 147); + assert_eq!(cost.max.runtime, base_cost + 294); + } + + // ---- ExprTreee building specific tests + #[test] + fn test_build_cost_analysis_tree_if_expression() { + let source = "(if (> 3 0) (ok true) (ok false))"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + + // Root should be an If node with branching=true + assert!(matches!(cost_tree.expr, ExprNode::If)); + assert!(cost_tree.branching); + assert_eq!(cost_tree.children.len(), 3); + + let gt_node = &cost_tree.children[0]; + assert!(matches!(gt_node.expr, ExprNode::GT)); + assert_eq!(gt_node.children.len(), 2); + + let left_val = >_node.children[0]; + let right_val = >_node.children[1]; + assert!(matches!(left_val.expr, ExprNode::AtomValue(_))); + assert!(matches!(right_val.expr, ExprNode::AtomValue(_))); + + let ok_true_node = &cost_tree.children[1]; + assert!(matches!(ok_true_node.expr, ExprNode::Ok)); + assert_eq!(ok_true_node.children.len(), 1); + + let ok_false_node = &cost_tree.children[2]; + assert!(matches!(ok_false_node.expr, ExprNode::Ok)); + assert_eq!(ok_false_node.children.len(), 1); + } + + #[test] + fn test_build_cost_analysis_tree_arithmetic() { + let source = "(+ (* 2 3) (- 5 1))"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + + assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(!cost_tree.branching); + assert_eq!(cost_tree.children.len(), 2); + + let mul_node = &cost_tree.children[0]; + assert!(matches!(mul_node.expr, ExprNode::Mul)); + assert_eq!(mul_node.children.len(), 2); + + let sub_node = &cost_tree.children[1]; + assert!(matches!(sub_node.expr, ExprNode::Sub)); + assert_eq!(sub_node.children.len(), 2); + } + + #[test] + fn test_build_cost_analysis_tree_with_comments() { + let source = "(+ 1 ;; this is a comment\n 2)"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + + assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(!cost_tree.branching); + assert_eq!(cost_tree.children.len(), 2); + + for child in &cost_tree.children { + assert!(matches!(child.expr, ExprNode::AtomValue(_))); + } + } + + #[test] + fn test_function_with_multiple_arguments() { + let src = r#"(define-public (add-two (x u64) (y u64)) (+ x y))"#; + let pre_expressions = parse(src).unwrap(); + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + + // Should have 3 children: UserArgument for (x u64), UserArgument for (y u64), and the body (+ x y) + assert_eq!(cost_tree.children.len(), 3); + + // First child should be UserArgument for (x u64) + let user_arg_x = &cost_tree.children[0]; + assert!(matches!(user_arg_x.expr, ExprNode::UserArgument(_, _))); + if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { + assert_eq!(arg_name.as_str(), "x"); + assert_eq!(arg_type.as_str(), "u64"); + } + + // Second child should be UserArgument for (y u64) + let user_arg_y = &cost_tree.children[1]; + assert!(matches!(user_arg_y.expr, ExprNode::UserArgument(_, _))); + if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { + assert_eq!(arg_name.as_str(), "y"); + assert_eq!(arg_type.as_str(), "u64"); + } + + // Third child should be the function body (+ x y) + let body = &cost_tree.children[2]; + assert!(matches!(body.expr, ExprNode::Add)); + assert_eq!(body.children.len(), 2); + + // Both arguments in the body should be UserArguments + let arg_x_ref = &body.children[0]; + let arg_y_ref = &body.children[1]; + assert!(matches!(arg_x_ref.expr, ExprNode::UserArgument(_, _))); + assert!(matches!(arg_y_ref.expr, ExprNode::UserArgument(_, _))); + + if let ExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { + assert_eq!(name.as_str(), "x"); + assert_eq!(arg_type.as_str(), "u64"); + } + if let ExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { + assert_eq!(name.as_str(), "y"); + assert_eq!(arg_type.as_str(), "u64"); + } + } +} diff --git a/clarity/src/vm/costs/mod.rs b/clarity/src/vm/costs/mod.rs index 4e006c890c..ea055ddea5 100644 --- a/clarity/src/vm/costs/mod.rs +++ b/clarity/src/vm/costs/mod.rs @@ -42,6 +42,7 @@ use crate::vm::types::{ FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, }; use crate::vm::{CallStack, ClarityName, Environment, LocalContext, SymbolicExpression, Value}; +pub mod analysis; pub mod constants; pub mod cost_functions; #[allow(unused_variables)] From 7b6ff35e3c8b43e0616b613035c417b3cad61471 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Tue, 14 Oct 2025 22:35:32 -0700 Subject: [PATCH 02/23] use SymbolicExpression and change to using NativeFunctions --- clarity/src/vm/costs/analysis.rs | 236 ++++++++++++++++--------------- 1 file changed, 124 insertions(+), 112 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 879f2ba44c..97120a76fa 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use crate::vm::Value; -use clarity_types::representations::ContractName; use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; use crate::vm::ast::parser::v2::parse; @@ -11,7 +10,9 @@ use crate::vm::costs::cost_functions::{linear, CostValues}; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; use crate::vm::errors::InterpreterResult; +use crate::vm::functions::NativeFunctions; use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolicExpressionType}; +use crate::vm::ClarityVersion; // TODO: // contract-call? - get source from database @@ -27,35 +28,23 @@ const STRING_COST_MULTIPLIER: u64 = 3; const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; #[derive(Debug, Clone)] -pub enum ExprNode { - If, - Match, - Unwrap, - Ok, - Err, - GT, - LT, - GE, - LE, - EQ, - Add, - Sub, - Mul, - Div, - Function(ClarityName), +pub enum CostExprNode { + // Native Clarity functions + NativeFunction(NativeFunctions), + // Non-native expressions AtomValue(Value), Atom(ClarityName), - SugaredContractIdentifier(ContractName), - SugaredFieldIdentifier(ContractName, ClarityName), FieldIdentifier(TraitIdentifier), TraitReference(ClarityName), // User function arguments UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) + // User-defined functions + UserFunction(ClarityName), } #[derive(Debug, Clone)] pub struct CostAnalysisNode { - pub expr: ExprNode, + pub expr: CostExprNode, pub cost: StaticCost, pub children: Vec, pub branching: bool, @@ -63,7 +52,7 @@ pub struct CostAnalysisNode { impl CostAnalysisNode { pub fn new( - expr: ExprNode, + expr: CostExprNode, cost: StaticCost, children: Vec, branching: bool, @@ -76,7 +65,7 @@ impl CostAnalysisNode { } } - pub fn leaf(expr: ExprNode, cost: StaticCost) -> Self { + pub fn leaf(expr: CostExprNode, cost: StaticCost) -> Self { Self { expr, cost, @@ -196,7 +185,7 @@ impl SummingExecutionCost { /// Parse Clarity source code and calculate its static execution cost /// /// theoretically you could inspect the tree at any node to get the spot cost -pub fn static_cost(source: &str) -> Result { +pub fn static_cost(source: &str, clarity_version: &ClarityVersion) -> Result { let pre_expressions = parse(source).map_err(|e| format!("Parse error: {:?}", e))?; if pre_expressions.is_empty() { @@ -206,7 +195,7 @@ pub fn static_cost(source: &str) -> Result { // TODO what happens if multiple expressions are selected? let pre_expr = &pre_expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_analysis_tree = build_cost_analysis_tree(pre_expr, &user_args)?; + let cost_analysis_tree = build_cost_analysis_tree(&pre_expr, &user_args, clarity_version)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) @@ -215,6 +204,7 @@ pub fn static_cost(source: &str) -> Result { fn build_cost_analysis_tree( expr: &PreSymbolicExpression, user_args: &UserArgumentsContext, + clarity_version: &ClarityVersion, ) -> Result { match &expr.pre_expr { PreSymbolicExpressionType::List(list) => { @@ -223,15 +213,19 @@ fn build_cost_analysis_tree( || function_name.as_str() == "define-private" || function_name.as_str() == "define-read-only" { - return build_function_definition_cost_analysis_tree(list, user_args); + return build_function_definition_cost_analysis_tree( + list, + user_args, + clarity_version, + ); } } - build_listlike_cost_analysis_tree(list, "list", user_args) + build_listlike_cost_analysis_tree(list, "list", user_args, clarity_version) } PreSymbolicExpressionType::AtomValue(value) => { let cost = calculate_value_cost(value)?; Ok(CostAnalysisNode::leaf( - ExprNode::AtomValue(value.clone()), + CostExprNode::AtomValue(value.clone()), cost, )) } @@ -240,30 +234,23 @@ fn build_cost_analysis_tree( Ok(CostAnalysisNode::leaf(expr_node, StaticCost::ZERO)) } PreSymbolicExpressionType::Tuple(tuple) => { - build_listlike_cost_analysis_tree(tuple, "tuple", user_args) + build_listlike_cost_analysis_tree(tuple, "tuple", user_args, clarity_version) } - PreSymbolicExpressionType::SugaredContractIdentifier(contract_name) => { + PreSymbolicExpressionType::SugaredContractIdentifier(_contract_name) => { Ok(CostAnalysisNode::leaf( - ExprNode::SugaredContractIdentifier(contract_name.clone()), - // TODO: Look up source - StaticCost::ZERO, - )) - } - PreSymbolicExpressionType::SugaredFieldIdentifier(contract_name, field_name) => { - Ok(CostAnalysisNode::leaf( - ExprNode::SugaredFieldIdentifier(contract_name.clone(), field_name.clone()), - // TODO: Look up source + CostExprNode::Atom(ClarityName::from("contract-identifier")), StaticCost::ZERO, )) } + PreSymbolicExpressionType::SugaredFieldIdentifier(_contract_name, field_name) => Ok( + CostAnalysisNode::leaf(CostExprNode::Atom(field_name.clone()), StaticCost::ZERO), + ), PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(CostAnalysisNode::leaf( - ExprNode::FieldIdentifier(field_name.clone()), - // TODO: Look up source + CostExprNode::FieldIdentifier(field_name.clone()), StaticCost::ZERO, )), PreSymbolicExpressionType::TraitReference(trait_name) => Ok(CostAnalysisNode::leaf( - ExprNode::TraitReference(trait_name.clone()), - // TODO: Look up source + CostExprNode::TraitReference(trait_name.clone()), StaticCost::ZERO, )), // Comments and placeholders should be filtered out during traversal @@ -280,16 +267,16 @@ fn build_cost_analysis_tree( fn parse_atom_expression( name: &ClarityName, user_args: &UserArgumentsContext, -) -> Result { +) -> Result { // Check if this atom is a user-defined function argument if user_args.is_user_argument(name) { if let Some(arg_type) = user_args.get_argument_type(name) { - Ok(ExprNode::UserArgument(name.clone(), arg_type.clone())) + Ok(CostExprNode::UserArgument(name.clone(), arg_type.clone())) } else { - Ok(ExprNode::Atom(name.clone())) + Ok(CostExprNode::Atom(name.clone())) } } else { - Ok(ExprNode::Atom(name.clone())) + Ok(CostExprNode::Atom(name.clone())) } } @@ -297,6 +284,7 @@ fn parse_atom_expression( fn build_function_definition_cost_analysis_tree( list: &[PreSymbolicExpression], _user_args: &UserArgumentsContext, + clarity_version: &ClarityVersion, ) -> Result { let define_type = list[0] .match_atom() @@ -330,7 +318,7 @@ fn build_function_definition_cost_analysis_tree( // Create UserArgument node children.push(CostAnalysisNode::leaf( - ExprNode::UserArgument(arg_name.clone(), arg_type), + CostExprNode::UserArgument(arg_name.clone(), arg_type), StaticCost::ZERO, )); } @@ -338,12 +326,12 @@ fn build_function_definition_cost_analysis_tree( } // Process the function body with the function's user arguments context - let body_tree = build_cost_analysis_tree(body, &function_user_args)?; + let body_tree = build_cost_analysis_tree(body, &function_user_args, clarity_version)?; children.push(body_tree); // Create the function definition node with zero cost (function definitions themselves don't have execution cost) Ok(CostAnalysisNode::new( - ExprNode::Function(define_type.clone()), + CostExprNode::UserFunction(define_type.clone()), StaticCost::ZERO, children, false, @@ -355,6 +343,7 @@ fn build_listlike_cost_analysis_tree( items: &[PreSymbolicExpression], container_type: &str, user_args: &UserArgumentsContext, + clarity_version: &ClarityVersion, ) -> Result { let function_name = match &items[0].pre_expr { PreSymbolicExpressionType::Atom(name) => name, @@ -377,19 +366,28 @@ fn build_listlike_cost_analysis_tree( continue; } _ => { - children.push(build_cost_analysis_tree(arg, user_args)?); + children.push(build_cost_analysis_tree(arg, user_args, clarity_version)?); } } } + // Try to lookup the function as a native function first + let expr_node = if let Some(native_function) = + NativeFunctions::lookup_by_name_at_version(function_name.as_str(), clarity_version) + { + CostExprNode::NativeFunction(native_function) + } else { + // If not a native function, treat as user-defined function + CostExprNode::UserFunction(function_name.clone()) + }; + let branching = is_branching_function(function_name); - let expr_node = map_function_to_expr_node(function_name.as_str()); let cost = calculate_function_cost_from_name(function_name.as_str(), children.len() as u64)?; // Handle special cases for string arguments to functions that include their processing cost if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { for child in &mut children { - if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child.expr { + if let CostExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child.expr { child.cost = StaticCost::ZERO; } } @@ -398,26 +396,8 @@ fn build_listlike_cost_analysis_tree( Ok(CostAnalysisNode::new(expr_node, cost, children, branching)) } -/// Maps function names to their corresponding ExprNode variants -fn map_function_to_expr_node(function_name: &str) -> ExprNode { - match function_name { - "if" => ExprNode::If, - "match" => ExprNode::Match, - "unwrap!" | "unwrap-err!" | "unwrap-panic" | "unwrap-err-panic" => ExprNode::Unwrap, - "ok" => ExprNode::Ok, - "err" => ExprNode::Err, - ">" => ExprNode::GT, - "<" => ExprNode::LT, - ">=" => ExprNode::GE, - "<=" => ExprNode::LE, - "=" | "is-eq" | "eq" => ExprNode::EQ, - "+" | "add" => ExprNode::Add, - "-" | "sub" => ExprNode::Sub, - "*" | "mul" => ExprNode::Mul, - "/" | "div" => ExprNode::Div, - _ => ExprNode::Function(ClarityName::from(function_name)), - } -} +/// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version +/// directly in build_listlike_cost_analysis_tree /// Determine if a function name represents a branching function fn is_branching_function(function_name: &ClarityName) -> bool { @@ -597,7 +577,8 @@ fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecut if node.branching { match &node.expr { - ExprNode::If | ExprNode::Match => { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => { // TODO match? if node.children.len() >= 2 { let condition_cost = calculate_total_cost_with_summing(&node.children[0]); @@ -668,16 +649,16 @@ mod tests { #[test] fn test_constant() { - let source = "u2"; - let cost = static_cost(source).unwrap(); + let source = "42"; + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); assert_eq!(cost.min.runtime, 0); assert_eq!(cost.max.runtime, 0); } #[test] fn test_simple_addition() { - let source = "(+ u1 u2)"; - let cost = static_cost(source).unwrap(); + let source = "(+ 1 2)"; + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); // Min: linear(2, 11, 125) = 11*2 + 125 = 147 assert_eq!(cost.min.runtime, 147); @@ -687,7 +668,7 @@ mod tests { #[test] fn test_arithmetic() { let source = "(- u4 (+ u1 u2))"; - let cost = static_cost(source).unwrap(); + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); assert_eq!(cost.min.runtime, 147 + 147); assert_eq!(cost.max.runtime, 147 + 147); } @@ -695,7 +676,7 @@ mod tests { #[test] fn test_nested_operations() { let source = "(* (+ u1 u2) (- u3 u4))"; - let cost = static_cost(source).unwrap(); + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); // multiplication: 13*2 + 125 = 151 assert_eq!(cost.min.runtime, 151 + 147 + 147); assert_eq!(cost.max.runtime, 151 + 147 + 147); @@ -703,8 +684,8 @@ mod tests { #[test] fn test_string_concat_min_max() { - let source = "(concat \"hello\" \"world\")"; - let cost = static_cost(source).unwrap(); + let source = r#"(concat "hello" "world")"#; + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); // For concat with 2 arguments: // linear(2, 37, 220) = 37*2 + 220 = 294 @@ -714,8 +695,8 @@ mod tests { #[test] fn test_string_len_min_max() { - let source = "(len \"hello\")"; - let cost = static_cost(source).unwrap(); + let source = r#"(len "hello")"#; + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); // cost: 429 (constant) - len doesn't depend on string size assert_eq!(cost.min.runtime, 429); @@ -725,7 +706,7 @@ mod tests { #[test] fn test_branching() { let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; - let cost = static_cost(source).unwrap(); + let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); // min: 147 raw string // max: 294 (concat) @@ -744,28 +725,41 @@ mod tests { let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + let cost_tree = + build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); // Root should be an If node with branching=true - assert!(matches!(cost_tree.expr, ExprNode::If)); + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::If) + )); assert!(cost_tree.branching); assert_eq!(cost_tree.children.len(), 3); let gt_node = &cost_tree.children[0]; - assert!(matches!(gt_node.expr, ExprNode::GT)); + assert!(matches!( + gt_node.expr, + CostExprNode::NativeFunction(NativeFunctions::CmpGreater) + )); assert_eq!(gt_node.children.len(), 2); let left_val = >_node.children[0]; let right_val = >_node.children[1]; - assert!(matches!(left_val.expr, ExprNode::AtomValue(_))); - assert!(matches!(right_val.expr, ExprNode::AtomValue(_))); + assert!(matches!(left_val.expr, CostExprNode::AtomValue(_))); + assert!(matches!(right_val.expr, CostExprNode::AtomValue(_))); let ok_true_node = &cost_tree.children[1]; - assert!(matches!(ok_true_node.expr, ExprNode::Ok)); + assert!(matches!( + ok_true_node.expr, + CostExprNode::NativeFunction(NativeFunctions::ConsOkay) + )); assert_eq!(ok_true_node.children.len(), 1); let ok_false_node = &cost_tree.children[2]; - assert!(matches!(ok_false_node.expr, ExprNode::Ok)); + assert!(matches!( + ok_false_node.expr, + CostExprNode::NativeFunction(NativeFunctions::ConsOkay) + )); assert_eq!(ok_false_node.children.len(), 1); } @@ -775,18 +769,28 @@ mod tests { let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + let cost_tree = + build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); - assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); assert!(!cost_tree.branching); assert_eq!(cost_tree.children.len(), 2); let mul_node = &cost_tree.children[0]; - assert!(matches!(mul_node.expr, ExprNode::Mul)); + assert!(matches!( + mul_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Multiply) + )); assert_eq!(mul_node.children.len(), 2); let sub_node = &cost_tree.children[1]; - assert!(matches!(sub_node.expr, ExprNode::Sub)); + assert!(matches!( + sub_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Subtract) + )); assert_eq!(sub_node.children.len(), 2); } @@ -796,14 +800,18 @@ mod tests { let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + let cost_tree = + build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); - assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); assert!(!cost_tree.branching); assert_eq!(cost_tree.children.len(), 2); for child in &cost_tree.children { - assert!(matches!(child.expr, ExprNode::AtomValue(_))); + assert!(matches!(child.expr, CostExprNode::AtomValue(_))); } } @@ -813,43 +821,47 @@ mod tests { let pre_expressions = parse(src).unwrap(); let pre_expr = &pre_expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + let cost_tree = + build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); // Should have 3 children: UserArgument for (x u64), UserArgument for (y u64), and the body (+ x y) assert_eq!(cost_tree.children.len(), 3); // First child should be UserArgument for (x u64) let user_arg_x = &cost_tree.children[0]; - assert!(matches!(user_arg_x.expr, ExprNode::UserArgument(_, _))); - if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { + assert!(matches!(user_arg_x.expr, CostExprNode::UserArgument(_, _))); + if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { assert_eq!(arg_name.as_str(), "x"); assert_eq!(arg_type.as_str(), "u64"); } // Second child should be UserArgument for (y u64) let user_arg_y = &cost_tree.children[1]; - assert!(matches!(user_arg_y.expr, ExprNode::UserArgument(_, _))); - if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { + assert!(matches!(user_arg_y.expr, CostExprNode::UserArgument(_, _))); + if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { assert_eq!(arg_name.as_str(), "y"); assert_eq!(arg_type.as_str(), "u64"); } // Third child should be the function body (+ x y) - let body = &cost_tree.children[2]; - assert!(matches!(body.expr, ExprNode::Add)); - assert_eq!(body.children.len(), 2); + let body_node = &cost_tree.children[2]; + assert!(matches!( + body_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); + assert_eq!(body_node.children.len(), 2); // Both arguments in the body should be UserArguments - let arg_x_ref = &body.children[0]; - let arg_y_ref = &body.children[1]; - assert!(matches!(arg_x_ref.expr, ExprNode::UserArgument(_, _))); - assert!(matches!(arg_y_ref.expr, ExprNode::UserArgument(_, _))); + let arg_x_ref = &body_node.children[0]; + let arg_y_ref = &body_node.children[1]; + assert!(matches!(arg_x_ref.expr, CostExprNode::UserArgument(_, _))); + assert!(matches!(arg_y_ref.expr, CostExprNode::UserArgument(_, _))); - if let ExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { + if let CostExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { assert_eq!(name.as_str(), "x"); assert_eq!(arg_type.as_str(), "u64"); } - if let ExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { + if let CostExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { assert_eq!(name.as_str(), "y"); assert_eq!(arg_type.as_str(), "u64"); } From 9d5bba62f5525a0be32a35eab31e97219f0a1dcb Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Wed, 15 Oct 2025 10:46:40 -0700 Subject: [PATCH 03/23] remove branching attribute --- clarity/src/vm/costs/analysis.rs | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 97120a76fa..282ef65ee9 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -47,21 +47,14 @@ pub struct CostAnalysisNode { pub expr: CostExprNode, pub cost: StaticCost, pub children: Vec, - pub branching: bool, } impl CostAnalysisNode { - pub fn new( - expr: CostExprNode, - cost: StaticCost, - children: Vec, - branching: bool, - ) -> Self { + pub fn new(expr: CostExprNode, cost: StaticCost, children: Vec) -> Self { Self { expr, cost, children, - branching, } } @@ -70,7 +63,6 @@ impl CostAnalysisNode { expr, cost, children: vec![], - branching: false, } } } @@ -334,7 +326,6 @@ fn build_function_definition_cost_analysis_tree( CostExprNode::UserFunction(define_type.clone()), StaticCost::ZERO, children, - false, )) } @@ -381,7 +372,6 @@ fn build_listlike_cost_analysis_tree( CostExprNode::UserFunction(function_name.clone()) }; - let branching = is_branching_function(function_name); let cost = calculate_function_cost_from_name(function_name.as_str(), children.len() as u64)?; // Handle special cases for string arguments to functions that include their processing cost @@ -393,7 +383,7 @@ fn build_listlike_cost_analysis_tree( } } - Ok(CostAnalysisNode::new(expr_node, cost, children, branching)) + Ok(CostAnalysisNode::new(expr_node, cost, children)) } /// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version @@ -410,6 +400,17 @@ fn is_branching_function(function_name: &ClarityName) -> bool { } } +/// Helper function to determine if a node represents a branching operation +/// This is used in tests and cost calculation +fn is_node_branching(node: &CostAnalysisNode) -> bool { + match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => true, + CostExprNode::UserFunction(name) => is_branching_function(name), + _ => false, + } +} + /// Calculate the cost for a string based on its length fn string_cost(length: usize) -> StaticCost { let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); @@ -575,7 +576,10 @@ fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutio fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { let mut summing_cost = SummingExecutionCost::new(); - if node.branching { + // Check if this is a branching function by examining the node's expression + let is_branching = is_node_branching(node); + + if is_branching { match &node.expr { CostExprNode::NativeFunction(NativeFunctions::If) | CostExprNode::NativeFunction(NativeFunctions::Match) => { @@ -733,7 +737,7 @@ mod tests { cost_tree.expr, CostExprNode::NativeFunction(NativeFunctions::If) )); - assert!(cost_tree.branching); + assert!(is_node_branching(&cost_tree)); assert_eq!(cost_tree.children.len(), 3); let gt_node = &cost_tree.children[0]; @@ -776,7 +780,7 @@ mod tests { cost_tree.expr, CostExprNode::NativeFunction(NativeFunctions::Add) )); - assert!(!cost_tree.branching); + assert!(!is_node_branching(&cost_tree)); assert_eq!(cost_tree.children.len(), 2); let mul_node = &cost_tree.children[0]; @@ -807,7 +811,7 @@ mod tests { cost_tree.expr, CostExprNode::NativeFunction(NativeFunctions::Add) )); - assert!(!cost_tree.branching); + assert!(!is_node_branching(&cost_tree)); assert_eq!(cost_tree.children.len(), 2); for child in &cost_tree.children { From 54d0e97619e2e3555c7170ecc6d19e97078c8ee6 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Wed, 15 Oct 2025 11:07:13 -0700 Subject: [PATCH 04/23] actually parse into SE instead of PSE --- clarity/src/vm/costs/analysis.rs | 163 +++++++++++++++++-------------- 1 file changed, 88 insertions(+), 75 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 282ef65ee9..2cda4233ee 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -5,14 +5,16 @@ use std::collections::HashMap; use crate::vm::Value; use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; -use crate::vm::ast::parser::v2::parse; +use crate::vm::ast::build_ast; use crate::vm::costs::cost_functions::{linear, CostValues}; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; use crate::vm::errors::InterpreterResult; use crate::vm::functions::NativeFunctions; -use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolicExpressionType}; +use crate::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; +use crate::vm::types::QualifiedContractIdentifier; use crate::vm::ClarityVersion; +use stacks_common::types::StacksEpochId; // TODO: // contract-call? - get source from database @@ -178,28 +180,39 @@ impl SummingExecutionCost { /// /// theoretically you could inspect the tree at any node to get the spot cost pub fn static_cost(source: &str, clarity_version: &ClarityVersion) -> Result { - let pre_expressions = parse(source).map_err(|e| format!("Parse error: {:?}", e))?; - - if pre_expressions.is_empty() { + let contract_identifier = QualifiedContractIdentifier::transient(); + let mut cost_tracker = (); + let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version + + let ast = build_ast( + &contract_identifier, + source, + &mut cost_tracker, + *clarity_version, + epoch, + ) + .map_err(|e| format!("Parse error: {:?}", e))?; + + if ast.expressions.is_empty() { return Err("No expressions found".to_string()); } // TODO what happens if multiple expressions are selected? - let pre_expr = &pre_expressions[0]; + let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); - let cost_analysis_tree = build_cost_analysis_tree(&pre_expr, &user_args, clarity_version)?; + let cost_analysis_tree = build_cost_analysis_tree(expr, &user_args, clarity_version)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) } fn build_cost_analysis_tree( - expr: &PreSymbolicExpression, + expr: &SymbolicExpression, user_args: &UserArgumentsContext, clarity_version: &ClarityVersion, ) -> Result { - match &expr.pre_expr { - PreSymbolicExpressionType::List(list) => { + match &expr.expr { + SymbolicExpressionType::List(list) => { if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { if function_name.as_str() == "define-public" || function_name.as_str() == "define-private" @@ -214,44 +227,34 @@ fn build_cost_analysis_tree( } build_listlike_cost_analysis_tree(list, "list", user_args, clarity_version) } - PreSymbolicExpressionType::AtomValue(value) => { + SymbolicExpressionType::AtomValue(value) => { + let cost = calculate_value_cost(value)?; + Ok(CostAnalysisNode::leaf( + CostExprNode::AtomValue(value.clone()), + cost, + )) + } + SymbolicExpressionType::LiteralValue(value) => { let cost = calculate_value_cost(value)?; Ok(CostAnalysisNode::leaf( CostExprNode::AtomValue(value.clone()), cost, )) } - PreSymbolicExpressionType::Atom(name) => { + SymbolicExpressionType::Atom(name) => { let expr_node = parse_atom_expression(name, user_args)?; Ok(CostAnalysisNode::leaf(expr_node, StaticCost::ZERO)) } - PreSymbolicExpressionType::Tuple(tuple) => { - build_listlike_cost_analysis_tree(tuple, "tuple", user_args, clarity_version) - } - PreSymbolicExpressionType::SugaredContractIdentifier(_contract_name) => { + SymbolicExpressionType::Field(field_identifier) => Ok(CostAnalysisNode::leaf( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + )), + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => { Ok(CostAnalysisNode::leaf( - CostExprNode::Atom(ClarityName::from("contract-identifier")), + CostExprNode::TraitReference(trait_name.clone()), StaticCost::ZERO, )) } - PreSymbolicExpressionType::SugaredFieldIdentifier(_contract_name, field_name) => Ok( - CostAnalysisNode::leaf(CostExprNode::Atom(field_name.clone()), StaticCost::ZERO), - ), - PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(CostAnalysisNode::leaf( - CostExprNode::FieldIdentifier(field_name.clone()), - StaticCost::ZERO, - )), - PreSymbolicExpressionType::TraitReference(trait_name) => Ok(CostAnalysisNode::leaf( - CostExprNode::TraitReference(trait_name.clone()), - StaticCost::ZERO, - )), - // Comments and placeholders should be filtered out during traversal - PreSymbolicExpressionType::Comment(_comment) => { - Err("hit an irrelevant comment expr type".to_string()) - } - PreSymbolicExpressionType::Placeholder(_placeholder) => { - Err("hit an irrelevant placeholder expr type".to_string()) - } } } @@ -274,7 +277,7 @@ fn parse_atom_expression( /// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) fn build_function_definition_cost_analysis_tree( - list: &[PreSymbolicExpression], + list: &[SymbolicExpression], _user_args: &UserArgumentsContext, clarity_version: &ClarityVersion, ) -> Result { @@ -297,9 +300,12 @@ fn build_function_definition_cost_analysis_tree( .match_atom() .ok_or("Expected atom for argument name")?; - let arg_type = match &arg_list[1].pre_expr { - PreSymbolicExpressionType::Atom(type_name) => type_name.clone(), - PreSymbolicExpressionType::AtomValue(value) => { + let arg_type = match &arg_list[1].expr { + SymbolicExpressionType::Atom(type_name) => type_name.clone(), + SymbolicExpressionType::AtomValue(value) => { + ClarityName::from(value.to_string().as_str()) + } + SymbolicExpressionType::LiteralValue(value) => { ClarityName::from(value.to_string().as_str()) } _ => return Err("Argument type must be an atom or atom value".to_string()), @@ -331,13 +337,13 @@ fn build_function_definition_cost_analysis_tree( /// Helper function to build expression trees for both lists and tuples fn build_listlike_cost_analysis_tree( - items: &[PreSymbolicExpression], + items: &[SymbolicExpression], container_type: &str, user_args: &UserArgumentsContext, clarity_version: &ClarityVersion, ) -> Result { - let function_name = match &items[0].pre_expr { - PreSymbolicExpressionType::Atom(name) => name, + let function_name = match &items[0].expr { + SymbolicExpressionType::Atom(name) => name, _ => { return Err(format!( "First element of {} must be an atom (function name)", @@ -349,17 +355,9 @@ fn build_listlike_cost_analysis_tree( let args = &items[1..]; let mut children = Vec::new(); - // Build children for all arguments, skipping comments and placeholders + // Build children for all arguments for arg in args { - match &arg.pre_expr { - PreSymbolicExpressionType::Comment(_) | PreSymbolicExpressionType::Placeholder(_) => { - // Skip comments and placeholders - continue; - } - _ => { - children.push(build_cost_analysis_tree(arg, user_args, clarity_version)?); - } - } + children.push(build_cost_analysis_tree(arg, user_args, clarity_version)?); } // Try to lookup the function as a native function first @@ -649,11 +647,26 @@ fn get_costs( #[cfg(test)] mod tests { + use super::*; + fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST { + let contract_identifier = QualifiedContractIdentifier::transient(); + let mut cost_tracker = (); + let ast = build_ast( + &contract_identifier, + src, + &mut cost_tracker, + ClarityVersion::Clarity1, + StacksEpochId::latest(), + ) + .unwrap(); + ast + } + #[test] fn test_constant() { - let source = "42"; + let source = "9001"; let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); assert_eq!(cost.min.runtime, 0); assert_eq!(cost.max.runtime, 0); @@ -725,12 +738,12 @@ mod tests { // ---- ExprTreee building specific tests #[test] fn test_build_cost_analysis_tree_if_expression() { - let source = "(if (> 3 0) (ok true) (ok false))"; - let pre_expressions = parse(source).unwrap(); - let pre_expr = &pre_expressions[0]; + let src = "(if (> 3 0) (ok true) (ok false))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_tree = - build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); // Root should be an If node with branching=true assert!(matches!( @@ -769,12 +782,12 @@ mod tests { #[test] fn test_build_cost_analysis_tree_arithmetic() { - let source = "(+ (* 2 3) (- 5 1))"; - let pre_expressions = parse(source).unwrap(); - let pre_expr = &pre_expressions[0]; + let src = "(+ (* 2 3) (- 5 1))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_tree = - build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); assert!(matches!( cost_tree.expr, @@ -800,12 +813,12 @@ mod tests { #[test] fn test_build_cost_analysis_tree_with_comments() { - let source = "(+ 1 ;; this is a comment\n 2)"; - let pre_expressions = parse(source).unwrap(); - let pre_expr = &pre_expressions[0]; + let src = ";; This is a comment\n(+ 5 ;; another comment\n7)"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_tree = - build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); assert!(matches!( cost_tree.expr, @@ -821,22 +834,22 @@ mod tests { #[test] fn test_function_with_multiple_arguments() { - let src = r#"(define-public (add-two (x u64) (y u64)) (+ x y))"#; - let pre_expressions = parse(src).unwrap(); - let pre_expr = &pre_expressions[0]; + let src = r#"(define-public (add-two (x uint) (y uint)) (+ x y))"#; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_tree = - build_cost_analysis_tree(pre_expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); - // Should have 3 children: UserArgument for (x u64), UserArgument for (y u64), and the body (+ x y) + // Should have 3 children: UserArgument for (x uint), UserArgument for (y uint), and the body (+ x y) assert_eq!(cost_tree.children.len(), 3); - // First child should be UserArgument for (x u64) + // First child should be UserArgument for (x uint) let user_arg_x = &cost_tree.children[0]; assert!(matches!(user_arg_x.expr, CostExprNode::UserArgument(_, _))); if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { assert_eq!(arg_name.as_str(), "x"); - assert_eq!(arg_type.as_str(), "u64"); + assert_eq!(arg_type.as_str(), "uint"); } // Second child should be UserArgument for (y u64) @@ -844,7 +857,7 @@ mod tests { assert!(matches!(user_arg_y.expr, CostExprNode::UserArgument(_, _))); if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { assert_eq!(arg_name.as_str(), "y"); - assert_eq!(arg_type.as_str(), "u64"); + assert_eq!(arg_type.as_str(), "uint"); } // Third child should be the function body (+ x y) @@ -863,11 +876,11 @@ mod tests { if let CostExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { assert_eq!(name.as_str(), "x"); - assert_eq!(arg_type.as_str(), "u64"); + assert_eq!(arg_type.as_str(), "uint"); } if let CostExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { assert_eq!(name.as_str(), "y"); - assert_eq!(arg_type.as_str(), "u64"); + assert_eq!(arg_type.as_str(), "uint"); } } } From 06cdea6c2b1fd628b39f3ce66f54b4e37a4f0f71 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Mon, 20 Oct 2025 14:04:03 -0700 Subject: [PATCH 05/23] use NativeFunctions and attempt cost lookup --- clarity/src/vm/costs/analysis.rs | 416 +++++++++++++++++++------------ clarity/src/vm/tests/analysis.rs | 178 +++++++++++++ clarity/src/vm/tests/mod.rs | 2 + 3 files changed, 434 insertions(+), 162 deletions(-) create mode 100644 clarity/src/vm/tests/analysis.rs diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 2cda4233ee..d0729c2d9a 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; -use crate::vm::Value; use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; +use stacks_common::types::StacksEpochId; use crate::vm::ast::build_ast; use crate::vm::costs::cost_functions::{linear, CostValues}; @@ -13,8 +13,7 @@ use crate::vm::errors::InterpreterResult; use crate::vm::functions::NativeFunctions; use crate::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; use crate::vm::types::QualifiedContractIdentifier; -use crate::vm::ClarityVersion; -use stacks_common::types::StacksEpochId; +use crate::vm::{ClarityVersion, Value}; // TODO: // contract-call? - get source from database @@ -176,14 +175,13 @@ impl SummingExecutionCost { } } -/// Parse Clarity source code and calculate its static execution cost -/// -/// theoretically you could inspect the tree at any node to get the spot cost -pub fn static_cost(source: &str, clarity_version: &ClarityVersion) -> Result { +fn make_ast( + source: &str, + epoch: StacksEpochId, + clarity_version: &ClarityVersion, +) -> Result { let contract_identifier = QualifiedContractIdentifier::transient(); let mut cost_tracker = (); - let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version - let ast = build_ast( &contract_identifier, source, @@ -192,23 +190,59 @@ pub fn static_cost(source: &str, clarity_version: &ClarityVersion) -> Result, + clarity_version: &ClarityVersion, +) -> Result { + let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version + let ast = make_ast(source, epoch, clarity_version)?; + let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); - let cost_analysis_tree = build_cost_analysis_tree(expr, &user_args, clarity_version)?; + let expr = &exprs[0]; + let cost_analysis_tree = + build_cost_analysis_tree(&expr, &user_args, cost_map, clarity_version)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) } +/// Parse Clarity source code and calculate its static execution cost for the specified function +pub fn static_cost( + source: &str, + clarity_version: &ClarityVersion, +) -> Result, String> { + let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version + let ast = make_ast(source, epoch, clarity_version)?; + + if ast.expressions.is_empty() { + return Err("No expressions found".to_string()); + } + let exprs = &ast.expressions; + let user_args = UserArgumentsContext::new(); + let mut costs = HashMap::new(); + for expr in exprs { + let cost_analysis_tree = + build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?; + + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); + costs.insert( + expr.match_atom() + .map(|name| name.to_string()) + .unwrap_or_default(), + summing_cost.into(), + ); + } + Ok(costs) +} fn build_cost_analysis_tree( expr: &SymbolicExpression, user_args: &UserArgumentsContext, + cost_map: &HashMap, clarity_version: &ClarityVersion, ) -> Result { match &expr.expr { @@ -221,11 +255,12 @@ fn build_cost_analysis_tree( return build_function_definition_cost_analysis_tree( list, user_args, + cost_map, clarity_version, ); } } - build_listlike_cost_analysis_tree(list, "list", user_args, clarity_version) + build_listlike_cost_analysis_tree(list, user_args, cost_map, clarity_version) } SymbolicExpressionType::AtomValue(value) => { let cost = calculate_value_cost(value)?; @@ -279,6 +314,7 @@ fn parse_atom_expression( fn build_function_definition_cost_analysis_tree( list: &[SymbolicExpression], _user_args: &UserArgumentsContext, + cost_map: &HashMap, clarity_version: &ClarityVersion, ) -> Result { let define_type = list[0] @@ -324,7 +360,7 @@ fn build_function_definition_cost_analysis_tree( } // Process the function body with the function's user arguments context - let body_tree = build_cost_analysis_tree(body, &function_user_args, clarity_version)?; + let body_tree = build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version)?; children.push(body_tree); // Create the function definition node with zero cost (function definitions themselves don't have execution cost) @@ -335,43 +371,51 @@ fn build_function_definition_cost_analysis_tree( )) } +fn get_function_name(expr: &SymbolicExpression) -> Result { + match &expr.expr { + SymbolicExpressionType::Atom(name) => Ok(name.clone()), + _ => Err("First element must be an atom (function name)".to_string()), + } +} + /// Helper function to build expression trees for both lists and tuples fn build_listlike_cost_analysis_tree( - items: &[SymbolicExpression], - container_type: &str, + exprs: &[SymbolicExpression], user_args: &UserArgumentsContext, + cost_map: &HashMap, clarity_version: &ClarityVersion, ) -> Result { - let function_name = match &items[0].expr { - SymbolicExpressionType::Atom(name) => name, - _ => { - return Err(format!( - "First element of {} must be an atom (function name)", - container_type - )); - } - }; - - let args = &items[1..]; let mut children = Vec::new(); - // Build children for all arguments - for arg in args { - children.push(build_cost_analysis_tree(arg, user_args, clarity_version)?); + // Build children for all exprs + for expr in exprs[1..].iter() { + children.push(build_cost_analysis_tree( + expr, + user_args, + cost_map, + clarity_version, + )?); } + let function_name = get_function_name(&exprs[0])?; // Try to lookup the function as a native function first - let expr_node = if let Some(native_function) = + let (expr_node, cost) = if let Some(native_function) = NativeFunctions::lookup_by_name_at_version(function_name.as_str(), clarity_version) { - CostExprNode::NativeFunction(native_function) + CostExprNode::NativeFunction(native_function); + let cost = calculate_function_cost_from_native_function( + native_function, + children.len() as u64, + clarity_version, + )?; + (CostExprNode::NativeFunction(native_function), cost) } else { - // If not a native function, treat as user-defined function - CostExprNode::UserFunction(function_name.clone()) + // If not a native function, treat as user-defined function and look it up + let expr_node = CostExprNode::UserFunction(function_name.clone()); + let cost = calculate_function_cost(function_name.to_string(), cost_map)?; + (expr_node, cost) }; - let cost = calculate_function_cost_from_name(function_name.as_str(), children.len() as u64)?; - // Handle special cases for string arguments to functions that include their processing cost if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { for child in &mut children { @@ -384,6 +428,15 @@ fn build_listlike_cost_analysis_tree( Ok(CostAnalysisNode::new(expr_node, cost, children)) } +// this is a bit tricky, we need to ensure the previously defined function is +// within the cost_map already or we need to find it and compute the cost first +fn calculate_function_cost( + function_name: String, + cost_map: &HashMap, +) -> Result { + let cost = cost_map.get(&function_name).unwrap_or(&StaticCost::ZERO); + Ok(cost.clone()) +} /// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version /// directly in build_listlike_cost_analysis_tree @@ -432,11 +485,12 @@ fn calculate_value_cost(value: &Value) -> Result { } } -fn calculate_function_cost_from_name( - function_name: &str, +fn calculate_function_cost_from_native_function( + native_function: NativeFunctions, arg_count: u64, + clarity_version: &ClarityVersion, ) -> Result { - let cost_function = match get_cost_function_for_name(function_name) { + let cost_function = match get_cost_function_for_native(native_function, clarity_version) { Some(cost_fn) => cost_fn, None => { // TODO: zero cost for now @@ -451,111 +505,132 @@ fn calculate_function_cost_from_name( }) } -/// Convert a function name string to its corresponding cost function -fn get_cost_function_for_name(name: &str) -> Option InterpreterResult> { - // Map function names to their cost functions using the existing enum structure - match name { - "+" | "add" => Some(Costs3::cost_add), - "-" | "sub" => Some(Costs3::cost_sub), - "*" | "mul" => Some(Costs3::cost_mul), - "/" | "div" => Some(Costs3::cost_div), - "mod" => Some(Costs3::cost_mod), - "pow" => Some(Costs3::cost_pow), - "sqrti" => Some(Costs3::cost_sqrti), - "log2" => Some(Costs3::cost_log2), - "to-int" | "to-uint" | "int-cast" => Some(Costs3::cost_int_cast), - "is-eq" | "=" | "eq" => Some(Costs3::cost_eq), - ">=" | "geq" => Some(Costs3::cost_geq), - "<=" | "leq" => Some(Costs3::cost_leq), - ">" | "ge" => Some(Costs3::cost_ge), - "<" | "le" => Some(Costs3::cost_le), - "xor" => Some(Costs3::cost_xor), - "not" => Some(Costs3::cost_not), - "and" => Some(Costs3::cost_and), - "or" => Some(Costs3::cost_or), - "concat" => Some(Costs3::cost_concat), - "len" => Some(Costs3::cost_len), - "as-max-len?" => Some(Costs3::cost_as_max_len), - "list" => Some(Costs3::cost_list_cons), - "element-at" | "element-at?" => Some(Costs3::cost_element_at), - "index-of" | "index-of?" => Some(Costs3::cost_index_of), - "fold" => Some(Costs3::cost_fold), - "map" => Some(Costs3::cost_map), - "filter" => Some(Costs3::cost_filter), - "append" => Some(Costs3::cost_append), - "tuple-get" => Some(Costs3::cost_tuple_get), - "tuple-merge" => Some(Costs3::cost_tuple_merge), - "tuple" => Some(Costs3::cost_tuple_cons), - "some" => Some(Costs3::cost_some_cons), - "ok" => Some(Costs3::cost_ok_cons), - "err" => Some(Costs3::cost_err_cons), - "default-to" => Some(Costs3::cost_default_to), - "unwrap!" => Some(Costs3::cost_unwrap_ret), - "unwrap-err!" => Some(Costs3::cost_unwrap_err_or_ret), - "is-ok" => Some(Costs3::cost_is_okay), - "is-none" => Some(Costs3::cost_is_none), - "is-err" => Some(Costs3::cost_is_err), - "is-some" => Some(Costs3::cost_is_some), - "unwrap-panic" => Some(Costs3::cost_unwrap), - "unwrap-err-panic" => Some(Costs3::cost_unwrap_err), - "try!" => Some(Costs3::cost_try_ret), - "if" => Some(Costs3::cost_if), - "match" => Some(Costs3::cost_match), - "begin" => Some(Costs3::cost_begin), - "let" => Some(Costs3::cost_let), - "asserts!" => Some(Costs3::cost_asserts), - "hash160" => Some(Costs3::cost_hash160), - "sha256" => Some(Costs3::cost_sha256), - "sha512" => Some(Costs3::cost_sha512), - "sha512/256" => Some(Costs3::cost_sha512t256), - "keccak256" => Some(Costs3::cost_keccak256), - "secp256k1-recover?" => Some(Costs3::cost_secp256k1recover), - "secp256k1-verify" => Some(Costs3::cost_secp256k1verify), - "print" => Some(Costs3::cost_print), - "contract-call?" => Some(Costs3::cost_contract_call), - "contract-of" => Some(Costs3::cost_contract_of), - "principal-of?" => Some(Costs3::cost_principal_of), - "at-block" => Some(Costs3::cost_at_block), - "load-contract" => Some(Costs3::cost_load_contract), - "create-map" => Some(Costs3::cost_create_map), - "create-var" => Some(Costs3::cost_create_var), - "create-non-fungible-token" => Some(Costs3::cost_create_nft), - "create-fungible-token" => Some(Costs3::cost_create_ft), - "map-get?" => Some(Costs3::cost_fetch_entry), - "map-set!" => Some(Costs3::cost_set_entry), - "var-get" => Some(Costs3::cost_fetch_var), - "var-set!" => Some(Costs3::cost_set_var), - "contract-storage" => Some(Costs3::cost_contract_storage), - "get-block-info?" => Some(Costs3::cost_block_info), - "get-burn-block-info?" => Some(Costs3::cost_burn_block_info), - "stx-get-balance" => Some(Costs3::cost_stx_balance), - "stx-transfer?" => Some(Costs3::cost_stx_transfer), - "stx-transfer-memo?" => Some(Costs3::cost_stx_transfer_memo), - "stx-account" => Some(Costs3::cost_stx_account), - "ft-mint?" => Some(Costs3::cost_ft_mint), - "ft-transfer?" => Some(Costs3::cost_ft_transfer), - "ft-get-balance" => Some(Costs3::cost_ft_balance), - "ft-get-supply" => Some(Costs3::cost_ft_get_supply), - "ft-burn?" => Some(Costs3::cost_ft_burn), - "nft-mint?" => Some(Costs3::cost_nft_mint), - "nft-transfer?" => Some(Costs3::cost_nft_transfer), - "nft-get-owner?" => Some(Costs3::cost_nft_owner), - "nft-burn?" => Some(Costs3::cost_nft_burn), - "buff-to-int-le?" => Some(Costs3::cost_buff_to_int_le), - "buff-to-uint-le?" => Some(Costs3::cost_buff_to_uint_le), - "buff-to-int-be?" => Some(Costs3::cost_buff_to_int_be), - "buff-to-uint-be?" => Some(Costs3::cost_buff_to_uint_be), - "to-consensus-buff?" => Some(Costs3::cost_to_consensus_buff), - "from-consensus-buff?" => Some(Costs3::cost_from_consensus_buff), - "is-standard?" => Some(Costs3::cost_is_standard), - "principal-destruct" => Some(Costs3::cost_principal_destruct), - "principal-construct?" => Some(Costs3::cost_principal_construct), - "as-contract" => Some(Costs3::cost_as_contract), - "string-to-int?" => Some(Costs3::cost_string_to_int), - "string-to-uint?" => Some(Costs3::cost_string_to_uint), - "int-to-ascii" => Some(Costs3::cost_int_to_ascii), - "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), - _ => None, // TODO +/// Convert a NativeFunctions enum variant to its corresponding cost function +/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in +fn get_cost_function_for_native( + function: NativeFunctions, + _clarity_version: &ClarityVersion, +) -> Option InterpreterResult> { + use crate::vm::functions::NativeFunctions::*; + + // Map NativeFunctions enum variants to their cost functions + match function { + Add => Some(Costs3::cost_add), + Subtract => Some(Costs3::cost_sub), + Multiply => Some(Costs3::cost_mul), + Divide => Some(Costs3::cost_div), + Modulo => Some(Costs3::cost_mod), + Power => Some(Costs3::cost_pow), + Sqrti => Some(Costs3::cost_sqrti), + Log2 => Some(Costs3::cost_log2), + ToInt | ToUInt => Some(Costs3::cost_int_cast), + Equals => Some(Costs3::cost_eq), + CmpGeq => Some(Costs3::cost_geq), + CmpLeq => Some(Costs3::cost_leq), + CmpGreater => Some(Costs3::cost_ge), + CmpLess => Some(Costs3::cost_le), + BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor), + Not | BitwiseNot => Some(Costs3::cost_not), + And | BitwiseAnd => Some(Costs3::cost_and), + Or | BitwiseOr => Some(Costs3::cost_or), + Concat => Some(Costs3::cost_concat), + Len => Some(Costs3::cost_len), + AsMaxLen => Some(Costs3::cost_as_max_len), + ListCons => Some(Costs3::cost_list_cons), + ElementAt | ElementAtAlias => Some(Costs3::cost_element_at), + IndexOf | IndexOfAlias => Some(Costs3::cost_index_of), + Fold => Some(Costs3::cost_fold), + Map => Some(Costs3::cost_map), + Filter => Some(Costs3::cost_filter), + Append => Some(Costs3::cost_append), + TupleGet => Some(Costs3::cost_tuple_get), + TupleMerge => Some(Costs3::cost_tuple_merge), + TupleCons => Some(Costs3::cost_tuple_cons), + ConsSome => Some(Costs3::cost_some_cons), + ConsOkay => Some(Costs3::cost_ok_cons), + ConsError => Some(Costs3::cost_err_cons), + DefaultTo => Some(Costs3::cost_default_to), + UnwrapRet => Some(Costs3::cost_unwrap_ret), + UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret), + IsOkay => Some(Costs3::cost_is_okay), + IsNone => Some(Costs3::cost_is_none), + IsErr => Some(Costs3::cost_is_err), + IsSome => Some(Costs3::cost_is_some), + Unwrap => Some(Costs3::cost_unwrap), + UnwrapErr => Some(Costs3::cost_unwrap_err), + TryRet => Some(Costs3::cost_try_ret), + If => Some(Costs3::cost_if), + Match => Some(Costs3::cost_match), + Begin => Some(Costs3::cost_begin), + Let => Some(Costs3::cost_let), + Asserts => Some(Costs3::cost_asserts), + Hash160 => Some(Costs3::cost_hash160), + Sha256 => Some(Costs3::cost_sha256), + Sha512 => Some(Costs3::cost_sha512), + Sha512Trunc256 => Some(Costs3::cost_sha512t256), + Keccak256 => Some(Costs3::cost_keccak256), + Secp256k1Recover => Some(Costs3::cost_secp256k1recover), + Secp256k1Verify => Some(Costs3::cost_secp256k1verify), + Print => Some(Costs3::cost_print), + ContractCall => Some(Costs3::cost_contract_call), + ContractOf => Some(Costs3::cost_contract_of), + PrincipalOf => Some(Costs3::cost_principal_of), + AtBlock => Some(Costs3::cost_at_block), + CreateMap => Some(Costs3::cost_create_map), + CreateVar => Some(Costs3::cost_create_var), + CreateNonFungibleToken => Some(Costs3::cost_create_nft), + CreateFungibleToken => Some(Costs3::cost_create_ft), + FetchEntry => Some(Costs3::cost_fetch_entry), + SetEntry => Some(Costs3::cost_set_entry), + FetchVar => Some(Costs3::cost_fetch_var), + SetVar => Some(Costs3::cost_set_var), + ContractStorage => Some(Costs3::cost_contract_storage), + GetBlockInfo => Some(Costs3::cost_block_info), + GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), + GetStxBalance => Some(Costs3::cost_stx_balance), + StxTransfer => Some(Costs3::cost_stx_transfer), + StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), + StxGetAccount => Some(Costs3::cost_stx_account), + MintToken => Some(Costs3::cost_ft_mint), + TransferToken => Some(Costs3::cost_ft_transfer), + GetTokenBalance => Some(Costs3::cost_ft_balance), + GetTokenSupply => Some(Costs3::cost_ft_get_supply), + BurnToken => Some(Costs3::cost_ft_burn), + MintAsset => Some(Costs3::cost_nft_mint), + TransferAsset => Some(Costs3::cost_nft_transfer), + GetAssetOwner => Some(Costs3::cost_nft_owner), + BurnAsset => Some(Costs3::cost_nft_burn), + BuffToIntLe => Some(Costs3::cost_buff_to_int_le), + BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le), + BuffToIntBe => Some(Costs3::cost_buff_to_int_be), + BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be), + ToConsensusBuff => Some(Costs3::cost_to_consensus_buff), + FromConsensusBuff => Some(Costs3::cost_from_consensus_buff), + IsStandard => Some(Costs3::cost_is_standard), + PrincipalDestruct => Some(Costs3::cost_principal_destruct), + PrincipalConstruct => Some(Costs3::cost_principal_construct), + AsContract | AsContractSafe => Some(Costs3::cost_as_contract), + StringToInt => Some(Costs3::cost_string_to_int), + StringToUInt => Some(Costs3::cost_string_to_uint), + IntToAscii => Some(Costs3::cost_int_to_ascii), + IntToUtf8 => Some(Costs3::cost_int_to_utf8), + BitwiseLShift => Some(Costs3::cost_bitwise_left_shift), + BitwiseRShift => Some(Costs3::cost_bitwise_right_shift), + Slice => Some(Costs3::cost_slice), + ReplaceAt => Some(Costs3::cost_replace_at), + GetStacksBlockInfo => Some(Costs3::cost_block_info), + GetTenureInfo => Some(Costs3::cost_burn_block_info), // XXX ??? + ContractHash => Some(Costs3::cost_contract_hash), + ToAscii => Some(Costs3::cost_to_ascii), + RestrictAssets => None, // TODO: add cost function + AllowanceWithStx => None, // TODO: add cost function + AllowanceWithFt => None, // TODO: add cost function + AllowanceWithNft => None, // TODO: add cost function + AllowanceWithStacking => None, // TODO: add cost function + AllowanceAll => None, // TODO: add cost function + InsertEntry => None, // TODO: add cost function + DeleteEntry => None, // TODO: add cost function + StxBurn => None, // TODO: add cost function } } @@ -650,6 +725,14 @@ mod tests { use super::*; + fn static_cost_native_test( + source: &str, + clarity_version: &ClarityVersion, + ) -> Result { + let cost_map = HashMap::new(); + static_cost_native(source, &cost_map, clarity_version) + } + fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST { let contract_identifier = QualifiedContractIdentifier::transient(); let mut cost_tracker = (); @@ -657,7 +740,7 @@ mod tests { &contract_identifier, src, &mut cost_tracker, - ClarityVersion::Clarity1, + ClarityVersion::Clarity3, StacksEpochId::latest(), ) .unwrap(); @@ -667,7 +750,7 @@ mod tests { #[test] fn test_constant() { let source = "9001"; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); assert_eq!(cost.min.runtime, 0); assert_eq!(cost.max.runtime, 0); } @@ -675,7 +758,7 @@ mod tests { #[test] fn test_simple_addition() { let source = "(+ 1 2)"; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); // Min: linear(2, 11, 125) = 11*2 + 125 = 147 assert_eq!(cost.min.runtime, 147); @@ -685,7 +768,7 @@ mod tests { #[test] fn test_arithmetic() { let source = "(- u4 (+ u1 u2))"; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); assert_eq!(cost.min.runtime, 147 + 147); assert_eq!(cost.max.runtime, 147 + 147); } @@ -693,7 +776,7 @@ mod tests { #[test] fn test_nested_operations() { let source = "(* (+ u1 u2) (- u3 u4))"; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); // multiplication: 13*2 + 125 = 151 assert_eq!(cost.min.runtime, 151 + 147 + 147); assert_eq!(cost.max.runtime, 151 + 147 + 147); @@ -702,7 +785,7 @@ mod tests { #[test] fn test_string_concat_min_max() { let source = r#"(concat "hello" "world")"#; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); // For concat with 2 arguments: // linear(2, 37, 220) = 37*2 + 220 = 294 @@ -713,7 +796,7 @@ mod tests { #[test] fn test_string_len_min_max() { let source = r#"(len "hello")"#; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); // cost: 429 (constant) - len doesn't depend on string size assert_eq!(cost.min.runtime, 429); @@ -723,7 +806,7 @@ mod tests { #[test] fn test_branching() { let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; - let cost = static_cost(source, &ClarityVersion::Clarity1).unwrap(); + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); // min: 147 raw string // max: 294 (concat) @@ -742,10 +825,12 @@ mod tests { let ast = build_test_ast(src); let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests let cost_tree = - build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) + .unwrap(); - // Root should be an If node with branching=true + // Root should be an If node assert!(matches!( cost_tree.expr, CostExprNode::NativeFunction(NativeFunctions::If) @@ -760,12 +845,13 @@ mod tests { )); assert_eq!(gt_node.children.len(), 2); + // The comparison node has 3 children: the function name, left operand, right operand let left_val = >_node.children[0]; let right_val = >_node.children[1]; assert!(matches!(left_val.expr, CostExprNode::AtomValue(_))); assert!(matches!(right_val.expr, CostExprNode::AtomValue(_))); - let ok_true_node = &cost_tree.children[1]; + let ok_true_node = &cost_tree.children[2]; assert!(matches!( ok_true_node.expr, CostExprNode::NativeFunction(NativeFunctions::ConsOkay) @@ -786,8 +872,10 @@ mod tests { let ast = build_test_ast(src); let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests let cost_tree = - build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) + .unwrap(); assert!(matches!( cost_tree.expr, @@ -817,8 +905,10 @@ mod tests { let ast = build_test_ast(src); let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests let cost_tree = - build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) + .unwrap(); assert!(matches!( cost_tree.expr, @@ -838,8 +928,10 @@ mod tests { let ast = build_test_ast(src); let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests let cost_tree = - build_cost_analysis_tree(expr, &user_args, &ClarityVersion::Clarity1).unwrap(); + build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) + .unwrap(); // Should have 3 children: UserArgument for (x uint), UserArgument for (y uint), and the body (+ x y) assert_eq!(cost_tree.children.len(), 3); diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs new file mode 100644 index 0000000000..aa212b77fa --- /dev/null +++ b/clarity/src/vm/tests/analysis.rs @@ -0,0 +1,178 @@ +// Copyright (C) 2013-2020 Blockstack PBC, a public benefit corporation +// Copyright (C) 2020 Stacks Open Internet Foundation +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +use rstest::rstest; +use stacks_common::types::StacksEpochId; + +use crate::vm::contexts::OwnedEnvironment; +use crate::vm::costs::analysis::static_cost; +use crate::vm::costs::ExecutionCost; +use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator}; +use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; +use crate::vm::ClarityVersion; + +const SIMPLE_TRAIT_SRC: &str = r#"(define-trait mytrait ( + (somefunc (uint uint) (response uint uint)) +)) +"#; + +#[rstest] +#[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] +fn test_simple_trait_implementation_costs( + #[case] version: ClarityVersion, + #[case] epoch: StacksEpochId, + mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, +) { + // Simple trait implementation - very brief function that basically does nothing + let simple_impl = r#"(impl-trait .mytrait.mytrait) + (define-public (somefunc (a uint) (b uint)) + (ok (+ a b)) + )"#; + + // Set up environment with cost tracking - use regular environment but try to get actual costs + let mut owned_env = tl_env_factory.get_env(epoch); + + // Get static cost analysis + let static_cost = static_cost(simple_impl, &version).unwrap(); + // Deploy and execute the contract to get dynamic costs + let contract_id = QualifiedContractIdentifier::local("simple-impl").unwrap(); + owned_env + .initialize_versioned_contract(contract_id.clone(), version, simple_impl, None) + .unwrap(); + + let dynamic_cost = execute_contract_function_and_get_cost( + &mut owned_env, + &contract_id, + "somefunc", + &[4, 5], + version, + ); + println!("dynamic_cost: {:?}", dynamic_cost); + println!("static_cost: {:?}", static_cost); + + let key = static_cost.keys().nth(1).unwrap(); + let cost = static_cost.get(key).unwrap(); + assert!(dynamic_cost.runtime >= cost.min.runtime); + assert!(dynamic_cost.runtime <= cost.max.runtime); +} + +/// Helper function to execute a contract function and return the execution cost +fn execute_contract_function_and_get_cost( + env: &mut OwnedEnvironment, + contract_id: &QualifiedContractIdentifier, + function_name: &str, + args: &[u64], + version: ClarityVersion, +) -> ExecutionCost { + // Start with a fresh cost tracker + let initial_cost = env.get_cost_total(); + + // Create a dummy sender + let sender = PrincipalData::parse_qualified_contract_principal( + "ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender", + ) + .unwrap(); + + // Build function call string + let arg_str = args + .iter() + .map(|a| format!("u{}", a)) + .collect::>() + .join(" "); + let function_call = format!("({} {})", function_name, arg_str); + + // Parse the function call into a symbolic expression + let ast = crate::vm::ast::parse( + &QualifiedContractIdentifier::transient(), + &function_call, + version, + StacksEpochId::Epoch21, + ) + .expect("Failed to parse function call"); + + if !ast.is_empty() { + let _result = env.execute_transaction( + sender, + None, + contract_id.clone(), + &function_call, + &ast[0..1], + ); + } + + // Get the cost after execution + let final_cost = env.get_cost_total(); + + // Return the difference + ExecutionCost { + write_length: final_cost.write_length - initial_cost.write_length, + write_count: final_cost.write_count - initial_cost.write_count, + read_length: final_cost.read_length - initial_cost.read_length, + read_count: final_cost.read_count - initial_cost.read_count, + runtime: final_cost.runtime - initial_cost.runtime, + } +} + +#[rstest] +#[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] +fn test_complex_trait_implementation_costs( + #[case] version: ClarityVersion, + #[case] epoch: StacksEpochId, + mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, +) { + // Complex trait implementation with expensive operations but no external calls + let complex_impl = r#"(define-public (somefunc (a uint) (b uint)) + (begin + ;; do something expensive + ;; emit events + (print a) + (print b) + (print "doing complex calculation") + (let ((result (* a b))) + (print result) + (ok (+ result (/ (+ a b) u2))) + ) + ) +)"#; + + let mut owned_env = tl_env_factory.get_env(epoch); + + let static_cost_result = static_cost(complex_impl, &version); + match static_cost_result { + Ok(static_cost) => { + let contract_id = QualifiedContractIdentifier::local("complex-impl").unwrap(); + owned_env + .initialize_versioned_contract(contract_id.clone(), version, complex_impl, None) + .unwrap(); + + let dynamic_cost = execute_contract_function_and_get_cost( + &mut owned_env, + &contract_id, + "somefunc", + &[7, 8], + version, + ); + + let key = static_cost.keys().nth(1).unwrap(); + let cost = static_cost.get(key).unwrap(); + assert!(dynamic_cost.runtime >= cost.min.runtime); + assert!(dynamic_cost.runtime <= cost.max.runtime); + } + Err(e) => { + println!("Static cost analysis failed: {}", e); + } + } +} diff --git a/clarity/src/vm/tests/mod.rs b/clarity/src/vm/tests/mod.rs index 3d4408abec..8a30cc13de 100644 --- a/clarity/src/vm/tests/mod.rs +++ b/clarity/src/vm/tests/mod.rs @@ -24,6 +24,8 @@ use crate::vm::contexts::OwnedEnvironment; pub use crate::vm::database::BurnStateDB; use crate::vm::database::MemoryBackingStore; +#[cfg(test)] +mod analysis; mod assets; mod contracts; #[cfg(test)] From 64dbb3e8708b1f05433488839cfd353484e6ccb7 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Mon, 20 Oct 2025 14:21:38 -0700 Subject: [PATCH 06/23] fix NativeFunction matching --- clarity/src/vm/costs/analysis.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index d0729c2d9a..78ae2c16de 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -576,15 +576,15 @@ fn get_cost_function_for_native( ContractOf => Some(Costs3::cost_contract_of), PrincipalOf => Some(Costs3::cost_principal_of), AtBlock => Some(Costs3::cost_at_block), - CreateMap => Some(Costs3::cost_create_map), - CreateVar => Some(Costs3::cost_create_var), - CreateNonFungibleToken => Some(Costs3::cost_create_nft), - CreateFungibleToken => Some(Costs3::cost_create_ft), + // CreateMap => Some(Costs3::cost_create_map), + // CreateVar => Some(Costs3::cost_create_var), + // CreateNonFungibleToken => Some(Costs3::cost_create_nft), + // CreateFungibleToken => Some(Costs3::cost_create_ft), FetchEntry => Some(Costs3::cost_fetch_entry), SetEntry => Some(Costs3::cost_set_entry), FetchVar => Some(Costs3::cost_fetch_var), SetVar => Some(Costs3::cost_set_var), - ContractStorage => Some(Costs3::cost_contract_storage), + // ContractStorage => Some(Costs3::cost_contract_storage), GetBlockInfo => Some(Costs3::cost_block_info), GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), GetStxBalance => Some(Costs3::cost_stx_balance), @@ -619,18 +619,18 @@ fn get_cost_function_for_native( Slice => Some(Costs3::cost_slice), ReplaceAt => Some(Costs3::cost_replace_at), GetStacksBlockInfo => Some(Costs3::cost_block_info), - GetTenureInfo => Some(Costs3::cost_burn_block_info), // XXX ??? + GetTenureInfo => Some(Costs3::cost_block_info), ContractHash => Some(Costs3::cost_contract_hash), ToAscii => Some(Costs3::cost_to_ascii), + InsertEntry => Some(Costs3::cost_set_entry), + DeleteEntry => Some(Costs3::cost_set_entry), + StxBurn => Some(Costs3::cost_stx_transfer), RestrictAssets => None, // TODO: add cost function AllowanceWithStx => None, // TODO: add cost function AllowanceWithFt => None, // TODO: add cost function AllowanceWithNft => None, // TODO: add cost function AllowanceWithStacking => None, // TODO: add cost function AllowanceAll => None, // TODO: add cost function - InsertEntry => None, // TODO: add cost function - DeleteEntry => None, // TODO: add cost function - StxBurn => None, // TODO: add cost function } } From 61abc75c6f227c9ab66355a657fef727b65b084e Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 21 Oct 2025 10:27:44 -0700 Subject: [PATCH 07/23] properly build function definition top-level map --- clarity/src/vm/costs/analysis.rs | 113 ++++++++++++++++++++----------- clarity/src/vm/tests/analysis.rs | 42 +++++++++++- 2 files changed, 112 insertions(+), 43 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 78ae2c16de..cb8a60cf3d 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -20,6 +20,8 @@ use crate::vm::{ClarityVersion, Value}; // type-checking // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) +// possibly use ContractContext? Enviornment? we need to use this somehow to +// provide full view of a contract, rather than passing in source const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; @@ -28,6 +30,15 @@ const STRING_COST_MULTIPLIER: u64 = 3; /// cost includes their processing const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; +/// Function definition keywords in Clarity +const FUNCTION_DEFINITION_KEYWORDS: &[&str] = + &["define-public", "define-private", "define-read-only"]; + +/// Check if a function name is a function definition keyword +fn is_function_definition(function_name: &str) -> bool { + FUNCTION_DEFINITION_KEYWORDS.contains(&function_name) +} + #[derive(Debug, Clone)] pub enum CostExprNode { // Native Clarity functions @@ -204,7 +215,7 @@ fn static_cost_native( let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); let expr = &exprs[0]; - let cost_analysis_tree = + let (_, cost_analysis_tree) = build_cost_analysis_tree(&expr, &user_args, cost_map, clarity_version)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); @@ -225,7 +236,7 @@ pub fn static_cost( let user_args = UserArgumentsContext::new(); let mut costs = HashMap::new(); for expr in exprs { - let cost_analysis_tree = + let (_, cost_analysis_tree) = build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); @@ -239,57 +250,71 @@ pub fn static_cost( Ok(costs) } -fn build_cost_analysis_tree( +pub fn build_cost_analysis_tree( expr: &SymbolicExpression, user_args: &UserArgumentsContext, cost_map: &HashMap, clarity_version: &ClarityVersion, -) -> Result { +) -> Result<(Option, CostAnalysisNode), String> { match &expr.expr { SymbolicExpressionType::List(list) => { if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { - if function_name.as_str() == "define-public" - || function_name.as_str() == "define-private" - || function_name.as_str() == "define-read-only" - { - return build_function_definition_cost_analysis_tree( + if is_function_definition(function_name.as_str()) { + let (returned_function_name, cost_analysis_tree) = + build_function_definition_cost_analysis_tree( + list, + user_args, + cost_map, + clarity_version, + )?; + Ok((Some(returned_function_name), cost_analysis_tree)) + } else { + let cost_analysis_tree = build_listlike_cost_analysis_tree( list, user_args, cost_map, clarity_version, - ); + )?; + Ok((None, cost_analysis_tree)) } + } else { + let cost_analysis_tree = + build_listlike_cost_analysis_tree(list, user_args, cost_map, clarity_version)?; + Ok((None, cost_analysis_tree)) } - build_listlike_cost_analysis_tree(list, user_args, cost_map, clarity_version) } SymbolicExpressionType::AtomValue(value) => { let cost = calculate_value_cost(value)?; - Ok(CostAnalysisNode::leaf( - CostExprNode::AtomValue(value.clone()), - cost, + Ok(( + None, + CostAnalysisNode::leaf(CostExprNode::AtomValue(value.clone()), cost), )) } SymbolicExpressionType::LiteralValue(value) => { let cost = calculate_value_cost(value)?; - Ok(CostAnalysisNode::leaf( - CostExprNode::AtomValue(value.clone()), - cost, + Ok(( + None, + CostAnalysisNode::leaf(CostExprNode::AtomValue(value.clone()), cost), )) } SymbolicExpressionType::Atom(name) => { let expr_node = parse_atom_expression(name, user_args)?; - Ok(CostAnalysisNode::leaf(expr_node, StaticCost::ZERO)) + Ok((None, CostAnalysisNode::leaf(expr_node, StaticCost::ZERO))) } - SymbolicExpressionType::Field(field_identifier) => Ok(CostAnalysisNode::leaf( - CostExprNode::FieldIdentifier(field_identifier.clone()), - StaticCost::ZERO, + SymbolicExpressionType::Field(field_identifier) => Ok(( + None, + CostAnalysisNode::leaf( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + ), )), - SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => { - Ok(CostAnalysisNode::leaf( + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => Ok(( + None, + CostAnalysisNode::leaf( CostExprNode::TraitReference(trait_name.clone()), StaticCost::ZERO, - )) - } + ), + )), } } @@ -316,13 +341,14 @@ fn build_function_definition_cost_analysis_tree( _user_args: &UserArgumentsContext, cost_map: &HashMap, clarity_version: &ClarityVersion, -) -> Result { +) -> Result<(String, CostAnalysisNode), String> { let define_type = list[0] .match_atom() .ok_or("Expected atom for define type")?; let signature = list[1] .match_list() .ok_or("Expected list for function signature")?; + println!("signature: {:?}", signature); let body = &list[2]; let mut children = Vec::new(); @@ -360,14 +386,23 @@ fn build_function_definition_cost_analysis_tree( } // Process the function body with the function's user arguments context - let body_tree = build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version)?; + let (_, body_tree) = + build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version)?; children.push(body_tree); + // Get the function name from the signature + let function_name = signature[0] + .match_atom() + .ok_or("Expected atom for function name")?; + // Create the function definition node with zero cost (function definitions themselves don't have execution cost) - Ok(CostAnalysisNode::new( - CostExprNode::UserFunction(define_type.clone()), - StaticCost::ZERO, - children, + Ok(( + function_name.clone().to_string(), + CostAnalysisNode::new( + CostExprNode::UserFunction(define_type.clone()), + StaticCost::ZERO, + children, + ), )) } @@ -389,12 +424,8 @@ fn build_listlike_cost_analysis_tree( // Build children for all exprs for expr in exprs[1..].iter() { - children.push(build_cost_analysis_tree( - expr, - user_args, - cost_map, - clarity_version, - )?); + let (_, child_tree) = build_cost_analysis_tree(expr, user_args, cost_map, clarity_version)?; + children.push(child_tree); } let function_name = get_function_name(&exprs[0])?; @@ -826,7 +857,7 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let cost_tree = + let (_, cost_tree) = build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) .unwrap(); @@ -873,7 +904,7 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let cost_tree = + let (_, cost_tree) = build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) .unwrap(); @@ -906,7 +937,7 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let cost_tree = + let (_, cost_tree) = build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) .unwrap(); @@ -929,7 +960,7 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let cost_tree = + let (_, cost_tree) = build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) .unwrap(); diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index aa212b77fa..974db38247 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -14,15 +14,17 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +use std::collections::HashMap; + use rstest::rstest; use stacks_common::types::StacksEpochId; use crate::vm::contexts::OwnedEnvironment; -use crate::vm::costs::analysis::static_cost; +use crate::vm::costs::analysis::{build_cost_analysis_tree, static_cost, UserArgumentsContext}; use crate::vm::costs::ExecutionCost; use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator}; use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; -use crate::vm::ClarityVersion; +use crate::vm::{ast, ClarityVersion}; const SIMPLE_TRAIT_SRC: &str = r#"(define-trait mytrait ( (somefunc (uint uint) (response uint uint)) @@ -176,3 +178,39 @@ fn test_complex_trait_implementation_costs( } } } + +#[test] +fn test_build_cost_analysis_tree_function_definition() { + let source = r#"(define-public (somefunc (a uint)) + (ok (+ a 1)) +)"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let ast = ast::parse( + &contract_id, + source, + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .expect("Failed to parse source code"); + + let expr = &ast[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); + + let clarity_version = ClarityVersion::Clarity3; + let result = build_cost_analysis_tree(expr, &user_args, &cost_map, &clarity_version); + + match result { + Ok((function_name, node)) => { + assert_eq!(function_name, Some("somefunc".to_string())); + assert!(matches!( + node.expr, + crate::vm::costs::analysis::CostExprNode::UserFunction(_) + )); + } + Err(e) => { + panic!("Expected Ok result, got error: {}", e); + } + } +} From 84d1789b054070fe0224f7aaaa36e9652db59380 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Wed, 22 Oct 2025 16:46:56 -0700 Subject: [PATCH 08/23] 2-pass for dependent function costs --- clarity/src/vm/costs/analysis.rs | 93 +++++++++++++++++++++++--------- clarity/src/vm/tests/analysis.rs | 28 ++++++++-- 2 files changed, 94 insertions(+), 27 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index cb8a60cf3d..1e3616ad31 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -207,7 +207,7 @@ fn make_ast( /// somewhat of a passthrough since we don't have to build the whole context we can jsut return the cost of the single expression fn static_cost_native( source: &str, - cost_map: &HashMap, + cost_map: &HashMap>, clarity_version: &ClarityVersion, ) -> Result { let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version @@ -234,26 +234,53 @@ pub fn static_cost( } let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); - let mut costs = HashMap::new(); + let mut costs: HashMap> = HashMap::new(); + + // First pass registers all function definitions + for expr in exprs { + if let Some(function_name) = extract_function_name(expr) { + costs.insert(function_name, None); + } + } + + // Second pass computes costs for expr in exprs { - let (_, cost_analysis_tree) = - build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?; - - let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); - costs.insert( - expr.match_atom() - .map(|name| name.to_string()) - .unwrap_or_default(), - summing_cost.into(), - ); - } - Ok(costs) + if let Some(function_name) = extract_function_name(expr) { + let (_, cost_analysis_tree) = + build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?; + + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); + costs.insert(function_name, Some(summing_cost.into())); + } + } + + Ok(costs + .into_iter() + .filter_map(|(name, cost)| cost.map(|c| (name, c))) + .collect()) +} + +/// Extract function name from a symbolic expression +fn extract_function_name(expr: &SymbolicExpression) -> Option { + if let Some(list) = expr.match_list() { + if let Some(first_atom) = list.first().and_then(|first| first.match_atom()) { + if is_function_definition(first_atom.as_str()) { + if let Some(signature) = list.get(1).and_then(|sig| sig.match_list()) { + return signature + .first() + .and_then(|name| name.match_atom()) + .map(|name| name.to_string()); + } + } + } + } + None } pub fn build_cost_analysis_tree( expr: &SymbolicExpression, user_args: &UserArgumentsContext, - cost_map: &HashMap, + cost_map: &HashMap>, clarity_version: &ClarityVersion, ) -> Result<(Option, CostAnalysisNode), String> { match &expr.expr { @@ -339,7 +366,7 @@ fn parse_atom_expression( fn build_function_definition_cost_analysis_tree( list: &[SymbolicExpression], _user_args: &UserArgumentsContext, - cost_map: &HashMap, + cost_map: &HashMap>, clarity_version: &ClarityVersion, ) -> Result<(String, CostAnalysisNode), String> { let define_type = list[0] @@ -417,7 +444,7 @@ fn get_function_name(expr: &SymbolicExpression) -> Result { fn build_listlike_cost_analysis_tree( exprs: &[SymbolicExpression], user_args: &UserArgumentsContext, - cost_map: &HashMap, + cost_map: &HashMap>, clarity_version: &ClarityVersion, ) -> Result { let mut children = Vec::new(); @@ -442,8 +469,9 @@ fn build_listlike_cost_analysis_tree( (CostExprNode::NativeFunction(native_function), cost) } else { // If not a native function, treat as user-defined function and look it up + println!("in user-defined function"); let expr_node = CostExprNode::UserFunction(function_name.clone()); - let cost = calculate_function_cost(function_name.to_string(), cost_map)?; + let cost = calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?; (expr_node, cost) }; @@ -459,14 +487,31 @@ fn build_listlike_cost_analysis_tree( Ok(CostAnalysisNode::new(expr_node, cost, children)) } -// this is a bit tricky, we need to ensure the previously defined function is -// within the cost_map already or we need to find it and compute the cost first +// Calculate function cost with lazy evaluation support fn calculate_function_cost( function_name: String, - cost_map: &HashMap, + cost_map: &HashMap>, + _clarity_version: &ClarityVersion, ) -> Result { - let cost = cost_map.get(&function_name).unwrap_or(&StaticCost::ZERO); - Ok(cost.clone()) + match cost_map.get(&function_name) { + Some(Some(cost)) => { + // Cost already computed + Ok(cost.clone()) + } + Some(None) => { + // Function exists but cost not yet computed - this indicates a circular dependency + // For now, return zero cost to avoid infinite recursion + println!( + "Circular dependency detected for function: {}", + function_name + ); + Ok(StaticCost::ZERO) + } + None => { + // Function not found + Ok(StaticCost::ZERO) + } + } } /// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version /// directly in build_listlike_cost_analysis_tree @@ -760,7 +805,7 @@ mod tests { source: &str, clarity_version: &ClarityVersion, ) -> Result { - let cost_map = HashMap::new(); + let cost_map: HashMap> = HashMap::new(); static_cost_native(source, &cost_map, clarity_version) } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index 974db38247..3d8d8e42b5 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -181,18 +181,18 @@ fn test_complex_trait_implementation_costs( #[test] fn test_build_cost_analysis_tree_function_definition() { - let source = r#"(define-public (somefunc (a uint)) + let src = r#"(define-public (somefunc (a uint)) (ok (+ a 1)) )"#; let contract_id = QualifiedContractIdentifier::transient(); let ast = ast::parse( &contract_id, - source, + src, ClarityVersion::Clarity3, StacksEpochId::Epoch32, ) - .expect("Failed to parse source code"); + .expect("Failed to parse"); let expr = &ast[0]; let user_args = UserArgumentsContext::new(); @@ -214,3 +214,25 @@ fn test_build_cost_analysis_tree_function_definition() { } } } + +#[test] +fn test_dependent_function_calls() { + let src = r#"(define-public (add-one (a uint)) + (begin + (print "somefunc") + (somefunc a) + ) +) +(define-private (somefunc (a uint)) + (ok (+ a 1)) +)"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let function_map = static_cost(src, &ClarityVersion::Clarity3).unwrap(); + + let add_one_cost = function_map.get("add-one").unwrap(); + let somefunc_cost = function_map.get("somefunc").unwrap(); + + assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime); + assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime); +} From 53bbf6cea80e112dff00e524822615c39febee1b Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 28 Oct 2025 11:30:59 -0700 Subject: [PATCH 09/23] use Environment for static_cost --- clarity/src/vm/costs/analysis.rs | 98 ++++++++++++++++-- clarity/src/vm/tests/analysis.rs | 168 ++++++++++++++++--------------- 2 files changed, 176 insertions(+), 90 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 1e3616ad31..7059b28af7 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -6,6 +6,7 @@ use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; use stacks_common::types::StacksEpochId; use crate::vm::ast::build_ast; +use crate::vm::contexts::Environment; use crate::vm::costs::cost_functions::{linear, CostValues}; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; @@ -221,18 +222,17 @@ fn static_cost_native( let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) } -/// Parse Clarity source code and calculate its static execution cost for the specified function -pub fn static_cost( - source: &str, + +pub fn static_cost_from_ast( + contract_ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, ) -> Result, String> { - let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version - let ast = make_ast(source, epoch, clarity_version)?; + let exprs = &contract_ast.expressions; - if ast.expressions.is_empty() { - return Err("No expressions found".to_string()); + if exprs.is_empty() { + return Err("No expressions found in contract AST".to_string()); } - let exprs = &ast.expressions; + let user_args = UserArgumentsContext::new(); let mut costs: HashMap> = HashMap::new(); @@ -260,6 +260,40 @@ pub fn static_cost( .collect()) } +/// Calculate static execution cost for functions using Environment context +/// This replaces the old source-string based approach with Environment integration +pub fn static_cost( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result, String> { + // Get the contract source from the environment's database + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| "Contract source not found in database".to_string())?; + + // Get the contract's clarity version from the environment + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; + + let clarity_version = contract.contract_context.get_clarity_version(); + + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; + + static_cost_from_ast(&ast, clarity_version) +} + +// pub fn static_cost_tree( +// source: &str, +// clarity_version: &ClarityVersion, +// ) -> Result, String> { +// } + /// Extract function name from a symbolic expression fn extract_function_name(expr: &SymbolicExpression) -> Option { if let Some(list) = expr.match_list() { @@ -469,7 +503,6 @@ fn build_listlike_cost_analysis_tree( (CostExprNode::NativeFunction(native_function), cost) } else { // If not a native function, treat as user-defined function and look it up - println!("in user-defined function"); let expr_node = CostExprNode::UserFunction(function_name.clone()); let cost = calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?; (expr_node, cost) @@ -499,6 +532,7 @@ fn calculate_function_cost( Ok(cost.clone()) } Some(None) => { + // Should be impossible but alas.. // Function exists but cost not yet computed - this indicates a circular dependency // For now, return zero cost to avoid infinite recursion println!( @@ -809,6 +843,15 @@ mod tests { static_cost_native(source, &cost_map, clarity_version) } + fn static_cost_test( + source: &str, + clarity_version: &ClarityVersion, + ) -> Result, String> { + let epoch = StacksEpochId::latest(); + let ast = make_ast(source, epoch, clarity_version)?; + static_cost_from_ast(&ast, clarity_version) + } + fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST { let contract_identifier = QualifiedContractIdentifier::transient(); let mut cost_tracker = (); @@ -1051,4 +1094,41 @@ mod tests { assert_eq!(arg_type.as_str(), "uint"); } } + + #[test] + fn test_static_cost_simple_addition() { + let source = "(define-public (add (a uint) (b uint)) (+ a b))"; + let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); + + // Should have one function + assert_eq!(ast_cost.len(), 1); + assert!(ast_cost.contains_key("add")); + + // Check that the cost is reasonable (non-zero for addition) + let add_cost = ast_cost.get("add").unwrap(); + assert!(add_cost.min.runtime > 0); + assert!(add_cost.max.runtime > 0); + } + + #[test] + fn test_static_cost_multiple_functions() { + let source = r#" + (define-public (func1 (x uint)) (+ x 1)) + (define-private (func2 (y uint)) (* y 2)) + "#; + let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); + + // Should have 2 functions + assert_eq!(ast_cost.len(), 2); + + // Check that both functions are present + assert!(ast_cost.contains_key("func1")); + assert!(ast_cost.contains_key("func2")); + + // Check that costs are reasonable + let func1_cost = ast_cost.get("func1").unwrap(); + let func2_cost = ast_cost.get("func2").unwrap(); + assert!(func1_cost.min.runtime > 0); + assert!(func2_cost.min.runtime > 0); + } } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index 3d8d8e42b5..2e815470f0 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -1,26 +1,12 @@ -// Copyright (C) 2013-2020 Blockstack PBC, a public benefit corporation -// Copyright (C) 2020 Stacks Open Internet Foundation -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . - use std::collections::HashMap; use rstest::rstest; use stacks_common::types::StacksEpochId; use crate::vm::contexts::OwnedEnvironment; -use crate::vm::costs::analysis::{build_cost_analysis_tree, static_cost, UserArgumentsContext}; +use crate::vm::costs::analysis::{ + build_cost_analysis_tree, static_cost_from_ast, UserArgumentsContext, +}; use crate::vm::costs::ExecutionCost; use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator}; use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; @@ -38,17 +24,23 @@ fn test_simple_trait_implementation_costs( #[case] epoch: StacksEpochId, mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, ) { - // Simple trait implementation - very brief function that basically does nothing let simple_impl = r#"(impl-trait .mytrait.mytrait) (define-public (somefunc (a uint) (b uint)) (ok (+ a b)) )"#; - // Set up environment with cost tracking - use regular environment but try to get actual costs let mut owned_env = tl_env_factory.get_env(epoch); - // Get static cost analysis - let static_cost = static_cost(simple_impl, &version).unwrap(); + let epoch = StacksEpochId::Epoch21; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + simple_impl, + &mut (), + version, + epoch, + ) + .unwrap(); + let static_cost = static_cost_from_ast(&ast, &version).unwrap(); // Deploy and execute the contract to get dynamic costs let contract_id = QualifiedContractIdentifier::local("simple-impl").unwrap(); owned_env @@ -71,63 +63,6 @@ fn test_simple_trait_implementation_costs( assert!(dynamic_cost.runtime <= cost.max.runtime); } -/// Helper function to execute a contract function and return the execution cost -fn execute_contract_function_and_get_cost( - env: &mut OwnedEnvironment, - contract_id: &QualifiedContractIdentifier, - function_name: &str, - args: &[u64], - version: ClarityVersion, -) -> ExecutionCost { - // Start with a fresh cost tracker - let initial_cost = env.get_cost_total(); - - // Create a dummy sender - let sender = PrincipalData::parse_qualified_contract_principal( - "ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender", - ) - .unwrap(); - - // Build function call string - let arg_str = args - .iter() - .map(|a| format!("u{}", a)) - .collect::>() - .join(" "); - let function_call = format!("({} {})", function_name, arg_str); - - // Parse the function call into a symbolic expression - let ast = crate::vm::ast::parse( - &QualifiedContractIdentifier::transient(), - &function_call, - version, - StacksEpochId::Epoch21, - ) - .expect("Failed to parse function call"); - - if !ast.is_empty() { - let _result = env.execute_transaction( - sender, - None, - contract_id.clone(), - &function_call, - &ast[0..1], - ); - } - - // Get the cost after execution - let final_cost = env.get_cost_total(); - - // Return the difference - ExecutionCost { - write_length: final_cost.write_length - initial_cost.write_length, - write_count: final_cost.write_count - initial_cost.write_count, - read_length: final_cost.read_length - initial_cost.read_length, - read_count: final_cost.read_count - initial_cost.read_count, - runtime: final_cost.runtime - initial_cost.runtime, - } -} - #[rstest] #[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] fn test_complex_trait_implementation_costs( @@ -135,7 +70,6 @@ fn test_complex_trait_implementation_costs( #[case] epoch: StacksEpochId, mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, ) { - // Complex trait implementation with expensive operations but no external calls let complex_impl = r#"(define-public (somefunc (a uint) (b uint)) (begin ;; do something expensive @@ -152,7 +86,16 @@ fn test_complex_trait_implementation_costs( let mut owned_env = tl_env_factory.get_env(epoch); - let static_cost_result = static_cost(complex_impl, &version); + let epoch = StacksEpochId::Epoch21; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + complex_impl, + &mut (), + version, + epoch, + ) + .unwrap(); + let static_cost_result = static_cost_from_ast(&ast, &version); match static_cost_result { Ok(static_cost) => { let contract_id = QualifiedContractIdentifier::local("complex-impl").unwrap(); @@ -228,7 +171,16 @@ fn test_dependent_function_calls() { )"#; let contract_id = QualifiedContractIdentifier::transient(); - let function_map = static_cost(src, &ClarityVersion::Clarity3).unwrap(); + let epoch = StacksEpochId::Epoch32; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + src, + &mut (), + ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); let add_one_cost = function_map.get("add-one").unwrap(); let somefunc_cost = function_map.get("somefunc").unwrap(); @@ -236,3 +188,57 @@ fn test_dependent_function_calls() { assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime); assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime); } + +/// Helper function to execute a contract function and return the execution cost +fn execute_contract_function_and_get_cost( + env: &mut OwnedEnvironment, + contract_id: &QualifiedContractIdentifier, + function_name: &str, + args: &[u64], + version: ClarityVersion, +) -> ExecutionCost { + // Start with a fresh cost tracker + let initial_cost = env.get_cost_total(); + + // Create a dummy sender + let sender = PrincipalData::parse_qualified_contract_principal( + "ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender", + ) + .unwrap(); + + let arg_str = args + .iter() + .map(|a| format!("u{}", a)) + .collect::>() + .join(" "); + let function_call = format!("({} {})", function_name, arg_str); + + let ast = crate::vm::ast::parse( + &QualifiedContractIdentifier::transient(), + &function_call, + version, + StacksEpochId::Epoch21, + ) + .expect("Failed to parse function call"); + + if !ast.is_empty() { + let _result = env.execute_transaction( + sender, + None, + contract_id.clone(), + &function_call, + &ast[0..1], + ); + } + + // Get the cost after execution + let final_cost = env.get_cost_total(); + + ExecutionCost { + write_length: final_cost.write_length - initial_cost.write_length, + write_count: final_cost.write_count - initial_cost.write_count, + read_length: final_cost.read_length - initial_cost.read_length, + read_count: final_cost.read_count - initial_cost.read_count, + runtime: final_cost.runtime - initial_cost.runtime, + } +} From 825d4ace36558826f18074cc487a824e8e1fcf68 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 4 Nov 2025 11:01:11 -0800 Subject: [PATCH 10/23] return the cost analysis node root in static_cost_tree --- clarity/src/vm/costs/analysis.rs | 88 +++++++++++++++++++------------- clarity/src/vm/tests/analysis.rs | 1 + 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 7059b28af7..86ac445a1b 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -21,8 +21,6 @@ use crate::vm::{ClarityVersion, Value}; // type-checking // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) -// possibly use ContractContext? Enviornment? we need to use this somehow to -// provide full view of a contract, rather than passing in source const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; @@ -227,33 +225,37 @@ pub fn static_cost_from_ast( contract_ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, ) -> Result, String> { - let exprs = &contract_ast.expressions; + let cost_trees = static_cost_tree_from_ast(contract_ast, clarity_version)?; - if exprs.is_empty() { - return Err("No expressions found in contract AST".to_string()); - } + Ok(cost_trees + .into_iter() + .map(|(name, cost_analysis_node)| { + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node); + (name, summing_cost.into()) + }) + .collect()) +} +fn static_cost_tree_from_ast( + ast: &crate::vm::ast::ContractAST, + clarity_version: &ClarityVersion, +) -> Result, String> { + let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); - let mut costs: HashMap> = HashMap::new(); - - // First pass registers all function definitions + let costs_map: HashMap> = HashMap::new(); + let mut costs: HashMap> = HashMap::new(); for expr in exprs { if let Some(function_name) = extract_function_name(expr) { costs.insert(function_name, None); } } - - // Second pass computes costs for expr in exprs { if let Some(function_name) = extract_function_name(expr) { let (_, cost_analysis_tree) = - build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?; - - let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); - costs.insert(function_name, Some(summing_cost.into())); + build_cost_analysis_tree(expr, &user_args, &costs_map, clarity_version)?; + costs.insert(function_name, Some(cost_analysis_tree)); } } - Ok(costs .into_iter() .filter_map(|(name, cost)| cost.map(|c| (name, c))) @@ -261,19 +263,20 @@ pub fn static_cost_from_ast( } /// Calculate static execution cost for functions using Environment context -/// This replaces the old source-string based approach with Environment integration +/// returns the top level cost for specific functions +/// function_name -> cost pub fn static_cost( env: &mut Environment, contract_identifier: &QualifiedContractIdentifier, ) -> Result, String> { - // Get the contract source from the environment's database + // Get contract source from the environment's database let contract_source = env .global_context .database .get_contract_src(contract_identifier) .ok_or_else(|| "Contract source not found in database".to_string())?; - // Get the contract's clarity version from the environment + // Get clarity version from the environment let contract = env .global_context .database @@ -288,11 +291,32 @@ pub fn static_cost( static_cost_from_ast(&ast, clarity_version) } -// pub fn static_cost_tree( -// source: &str, -// clarity_version: &ClarityVersion, -// ) -> Result, String> { -// } +/// same idea as `static_cost` but returns the root of the cost analysis tree for each function +pub fn static_cost_tree( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result, String> { + // Get contract source from the environment's database + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| "Contract source not found in database".to_string())?; + + // Get clarity version from the environment + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; + + let clarity_version = contract.contract_context.get_clarity_version(); + + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; + + static_cost_tree_from_ast(&ast, clarity_version) +} /// Extract function name from a symbolic expression fn extract_function_name(expr: &SymbolicExpression) -> Option { @@ -686,15 +710,13 @@ fn get_cost_function_for_native( ContractOf => Some(Costs3::cost_contract_of), PrincipalOf => Some(Costs3::cost_principal_of), AtBlock => Some(Costs3::cost_at_block), - // CreateMap => Some(Costs3::cost_create_map), - // CreateVar => Some(Costs3::cost_create_var), - // CreateNonFungibleToken => Some(Costs3::cost_create_nft), - // CreateFungibleToken => Some(Costs3::cost_create_ft), + // => Some(Costs3::cost_create_map), + // => Some(Costs3::cost_create_var), + // ContractStorage => Some(Costs3::cost_contract_storage), FetchEntry => Some(Costs3::cost_fetch_entry), SetEntry => Some(Costs3::cost_set_entry), FetchVar => Some(Costs3::cost_fetch_var), SetVar => Some(Costs3::cost_set_var), - // ContractStorage => Some(Costs3::cost_contract_storage), GetBlockInfo => Some(Costs3::cost_block_info), GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), GetStxBalance => Some(Costs3::cost_stx_balance), @@ -702,11 +724,11 @@ fn get_cost_function_for_native( StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), StxGetAccount => Some(Costs3::cost_stx_account), MintToken => Some(Costs3::cost_ft_mint), + MintAsset => Some(Costs3::cost_nft_mint), TransferToken => Some(Costs3::cost_ft_transfer), GetTokenBalance => Some(Costs3::cost_ft_balance), GetTokenSupply => Some(Costs3::cost_ft_get_supply), BurnToken => Some(Costs3::cost_ft_burn), - MintAsset => Some(Costs3::cost_nft_mint), TransferAsset => Some(Costs3::cost_nft_transfer), GetAssetOwner => Some(Costs3::cost_nft_owner), BurnAsset => Some(Costs3::cost_nft_burn), @@ -1052,7 +1074,6 @@ mod tests { build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) .unwrap(); - // Should have 3 children: UserArgument for (x uint), UserArgument for (y uint), and the body (+ x y) assert_eq!(cost_tree.children.len(), 3); // First child should be UserArgument for (x uint) @@ -1100,11 +1121,9 @@ mod tests { let source = "(define-public (add (a uint) (b uint)) (+ a b))"; let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); - // Should have one function assert_eq!(ast_cost.len(), 1); assert!(ast_cost.contains_key("add")); - // Check that the cost is reasonable (non-zero for addition) let add_cost = ast_cost.get("add").unwrap(); assert!(add_cost.min.runtime > 0); assert!(add_cost.max.runtime > 0); @@ -1118,14 +1137,11 @@ mod tests { "#; let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); - // Should have 2 functions assert_eq!(ast_cost.len(), 2); - // Check that both functions are present assert!(ast_cost.contains_key("func1")); assert!(ast_cost.contains_key("func2")); - // Check that costs are reasonable let func1_cost = ast_cost.get("func1").unwrap(); let func2_cost = ast_cost.get("func2").unwrap(); assert!(func1_cost.min.runtime > 0); diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index 2e815470f0..9169eb6238 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -1,3 +1,4 @@ +// TODO: This needs work to get the dynamic vs static testing working use std::collections::HashMap; use rstest::rstest; From 71e49be80e4b488230a77f750e2afaaffb9b61b1 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:50:49 -0800 Subject: [PATCH 11/23] attempt at trait counting walk --- clarity/src/vm/costs/analysis.rs | 213 +++++++++++++++++++++++++++---- clarity/src/vm/tests/analysis.rs | 56 +++++++- 2 files changed, 238 insertions(+), 31 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 86ac445a1b..81497d4d9f 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -1,6 +1,6 @@ -// Static cost analysis for Clarity expressions +// Static cost analysis for Clarity contracts -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; use stacks_common::types::StacksEpochId; @@ -48,7 +48,7 @@ pub enum CostExprNode { FieldIdentifier(TraitIdentifier), TraitReference(ClarityName), // User function arguments - UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) + UserArgument(ClarityName, SymbolicExpressionType), // (argument_name, argument_type) // User-defined functions UserFunction(ClarityName), } @@ -94,7 +94,7 @@ impl StaticCost { #[derive(Debug, Clone)] pub struct UserArgumentsContext { /// Map from argument name to argument type - pub arguments: HashMap, + pub arguments: HashMap, } impl UserArgumentsContext { @@ -104,7 +104,7 @@ impl UserArgumentsContext { } } - pub fn add_argument(&mut self, name: ClarityName, arg_type: ClarityName) { + pub fn add_argument(&mut self, name: ClarityName, arg_type: SymbolicExpressionType) { self.arguments.insert(name, arg_type); } @@ -112,7 +112,7 @@ impl UserArgumentsContext { self.arguments.contains_key(name) } - pub fn get_argument_type(&self, name: &ClarityName) -> Option<&ClarityName> { + pub fn get_argument_type(&self, name: &ClarityName) -> Option<&SymbolicExpressionType> { self.arguments.get(name) } } @@ -203,7 +203,8 @@ fn make_ast( Ok(ast) } -/// somewhat of a passthrough since we don't have to build the whole context we can jsut return the cost of the single expression +/// somewhat of a passthrough since we don't have to build the whole context we +/// can jsut return the cost of the single expression fn static_cost_native( source: &str, cost_map: &HashMap>, @@ -221,18 +222,159 @@ fn static_cost_native( Ok(summing_cost.into()) } +type MinMaxTraitCount = (u64, u64); +type TraitCount = HashMap; + +// // "" -> "trait-name" +// // ClarityName can't contain +// fn strip_trait_surrounding_brackets(name: &ClarityName) -> ClarityName { +// let stripped = name +// .as_str() +// .strip_prefix("<") +// .and_then(|name| name.strip_suffix(">")); +// if let Some(name) = stripped { +// ClarityName::from(name) +// } else { +// name.clone() +// } +// } +fn get_trait_count(costs: &HashMap) -> Option { + let mut trait_counts = HashMap::new(); + let mut trait_names = HashMap::new(); + // walk tree + for (name, cost_analysis_node) in costs.iter() { + get_trait_count_from_node( + cost_analysis_node, + &mut trait_counts, + &mut trait_names, + name.clone(), + 1, + ); + // trait_counts.extend(counts); + } + Some(trait_counts) +} +fn get_trait_count_from_node( + cost_analysis_node: &CostAnalysisNode, + mut trait_counts: &mut TraitCount, + mut trait_names: &mut HashMap, + containing_fn_name: String, + multiplier: u64, +) -> TraitCount { + match &cost_analysis_node.expr { + CostExprNode::UserArgument(arg_name, arg_type) => match arg_type { + SymbolicExpressionType::TraitReference(name, _) => { + trait_names.insert(arg_name.clone(), name.clone().to_string()); + trait_counts.entry(name.to_string()).or_insert((0, 0)); + } + _ => {} + }, + CostExprNode::NativeFunction(native_function) => { + println!("native function: {:?}", native_function); + match native_function { + // if map, filter, or fold, we need to check if traits are called + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + println!("map: {:?}", cost_analysis_node.children); + let list_to_traverse = cost_analysis_node.children[1].clone(); + let multiplier = match list_to_traverse.expr { + CostExprNode::UserArgument(_, arg_type) => match arg_type { + SymbolicExpressionType::List(list) => { + if list[0].match_atom().unwrap().as_str() == "list" { + match list[1].clone().expr { + SymbolicExpressionType::LiteralValue(value) => { + match value { + Value::Int(value) => value as u64, + _ => 1, + } + } + _ => 1, + } + } else { + 1 + } + } + _ => 1, + }, + _ => 1, + }; + println!("multiplier: {:?}", multiplier); + cost_analysis_node.children.iter().for_each(|child| { + get_trait_count_from_node( + child, + &mut trait_counts, + &mut trait_names, + containing_fn_name.clone(), + multiplier, + ); + }); + } + _ => {} + } + } + CostExprNode::AtomValue(atom_value) => { + println!("atom value: {:?}", atom_value); + // do nothing + } + CostExprNode::Atom(atom) => { + println!("atom: {:?}", atom); + if trait_names.get(atom).is_some() { + trait_counts + .entry(containing_fn_name.clone()) + .and_modify(|(min, max)| { + *min += 1; + *max += multiplier; + }) + .or_insert((1, multiplier)); + } + // do nothing + } + CostExprNode::FieldIdentifier(field_identifier) => { + println!("field identifier: {:?}", field_identifier); + // do nothing + } + CostExprNode::TraitReference(trait_name) => { + println!("trait_name: {:?}", trait_name); + trait_counts + .entry(trait_name.to_string()) + .and_modify(|(min, max)| { + *min += 1; + *max += multiplier; + }) + .or_insert((1, multiplier)); + } + CostExprNode::UserFunction(user_function) => { + println!("user function: {:?}", user_function); + cost_analysis_node.children.iter().for_each(|child| { + get_trait_count_from_node( + child, + &mut trait_counts, + &mut trait_names, + containing_fn_name.clone(), + multiplier, + ); + }); + } + } + trait_counts.clone() +} + pub fn static_cost_from_ast( contract_ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, -) -> Result, String> { +) -> Result)>, String> { let cost_trees = static_cost_tree_from_ast(contract_ast, clarity_version)?; - Ok(cost_trees + let trait_count = get_trait_count(&cost_trees); + let costs: HashMap = cost_trees .into_iter() .map(|(name, cost_analysis_node)| { let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node); (name, summing_cost.into()) }) + .collect(); + Ok(costs + .into_iter() + .map(|(name, cost)| (name, (cost, trait_count.clone()))) .collect()) } @@ -288,7 +430,11 @@ pub fn static_cost( let epoch = env.global_context.epoch_id; let ast = make_ast(&contract_source, epoch, clarity_version)?; - static_cost_from_ast(&ast, clarity_version) + let costs = static_cost_from_ast(&ast, clarity_version)?; + Ok(costs + .into_iter() + .map(|(name, (cost, _trait_count))| (name, cost)) + .collect()) } /// same idea as `static_cost` but returns the root of the cost analysis tree for each function @@ -447,23 +593,31 @@ fn build_function_definition_cost_analysis_tree( .match_atom() .ok_or("Expected atom for argument name")?; - let arg_type = match &arg_list[1].expr { - SymbolicExpressionType::Atom(type_name) => type_name.clone(), - SymbolicExpressionType::AtomValue(value) => { - ClarityName::from(value.to_string().as_str()) - } - SymbolicExpressionType::LiteralValue(value) => { - ClarityName::from(value.to_string().as_str()) - } - _ => return Err("Argument type must be an atom or atom value".to_string()), - }; + let arg_type = arg_list[1].clone(); + // let arg_type = match &arg_list[1].expr { + // SymbolicExpressionType::Atom(type_name) => type_name.clone(), + // SymbolicExpressionType::AtomValue(value) => { + // ClarityName::from(value.to_string().as_str()) + // } + // SymbolicExpressionType::LiteralValue(value) => { + // ClarityName::from(value.to_string().as_str()) + // } + // SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => { + // trait_name.clone() + // } + // SymbolicExpressionType::List(_) => ClarityName::from("list"), + // _ => { + // println!("arg: {:?}", arg_list[1].expr); + // return Err("Argument type must be an atom or atom value".to_string()); + // } + // }; // Add to function's user arguments context - function_user_args.add_argument(arg_name.clone(), arg_type.clone()); + function_user_args.add_argument(arg_name.clone(), arg_type.clone().expr); // Create UserArgument node children.push(CostAnalysisNode::leaf( - CostExprNode::UserArgument(arg_name.clone(), arg_type), + CostExprNode::UserArgument(arg_name.clone(), arg_type.clone().expr), StaticCost::ZERO, )); } @@ -757,6 +911,7 @@ fn get_cost_function_for_native( InsertEntry => Some(Costs3::cost_set_entry), DeleteEntry => Some(Costs3::cost_set_entry), StxBurn => Some(Costs3::cost_stx_transfer), + Secp256r1Verify => Some(Costs3::cost_secp256r1verify), RestrictAssets => None, // TODO: add cost function AllowanceWithStx => None, // TODO: add cost function AllowanceWithFt => None, // TODO: add cost function @@ -871,7 +1026,11 @@ mod tests { ) -> Result, String> { let epoch = StacksEpochId::latest(); let ast = make_ast(source, epoch, clarity_version)?; - static_cost_from_ast(&ast, clarity_version) + let costs = static_cost_from_ast(&ast, clarity_version)?; + Ok(costs + .into_iter() + .map(|(name, (cost, _trait_count))| (name, cost)) + .collect()) } fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST { @@ -1081,7 +1240,7 @@ mod tests { assert!(matches!(user_arg_x.expr, CostExprNode::UserArgument(_, _))); if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { assert_eq!(arg_name.as_str(), "x"); - assert_eq!(arg_type.as_str(), "uint"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); } // Second child should be UserArgument for (y u64) @@ -1089,7 +1248,7 @@ mod tests { assert!(matches!(user_arg_y.expr, CostExprNode::UserArgument(_, _))); if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { assert_eq!(arg_name.as_str(), "y"); - assert_eq!(arg_type.as_str(), "uint"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); } // Third child should be the function body (+ x y) @@ -1108,11 +1267,11 @@ mod tests { if let CostExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { assert_eq!(name.as_str(), "x"); - assert_eq!(arg_type.as_str(), "uint"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); } if let CostExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { assert_eq!(name.as_str(), "y"); - assert_eq!(arg_type.as_str(), "uint"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); } } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index 9169eb6238..bf3312e703 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -59,11 +59,55 @@ fn test_simple_trait_implementation_costs( println!("static_cost: {:?}", static_cost); let key = static_cost.keys().nth(1).unwrap(); - let cost = static_cost.get(key).unwrap(); + let (cost, _trait_count) = static_cost.get(key).unwrap(); assert!(dynamic_cost.runtime >= cost.min.runtime); assert!(dynamic_cost.runtime <= cost.max.runtime); } +#[rstest] +fn test_trait_counting() { + // map, fold, filter over traits counting + let src = r#"(define-trait trait-name ( + (send (uint principal) (response uint uint)) +)) +(define-public (something (trait ) (addresses (list 10 principal))) + (map (send u500 trait) addresses) +) +(define-private (send (amount uint) (trait ) (addr principal)) (trait true)) +"#; + let contract_id = QualifiedContractIdentifier::local("trait-counting").unwrap(); + let ast = crate::vm::ast::build_ast( + &contract_id, + src, + &mut (), + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .unwrap(); + let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3) + .unwrap() + .clone(); + // trait count for 'something' function should be minimum 1 maximum 10 + println!("static_cost: {:?}", static_cost); + //trait count for send should be 1 + println!("trait_count: {:?}", static_cost.get("something").unwrap()); + println!("trait_count: {:?}", static_cost.get("send").unwrap()); + assert_eq!( + static_cost + .get("send") + .unwrap() + .1 + .clone() + .unwrap() + .get("trait-name") + .unwrap() + .0, + 1 + ); + // assert_eq!(trait_count.get("trait-name").unwrap().0, 1); + // assert_eq!(trait_count.get("trait-name").unwrap().1, 1); +} + #[rstest] #[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] fn test_complex_trait_implementation_costs( @@ -113,7 +157,9 @@ fn test_complex_trait_implementation_costs( ); let key = static_cost.keys().nth(1).unwrap(); - let cost = static_cost.get(key).unwrap(); + let (cost, _trait_count) = static_cost.get(key).unwrap(); + println!("dynamic_cost: {:?}", dynamic_cost); + println!("cost: {:?}", cost); assert!(dynamic_cost.runtime >= cost.min.runtime); assert!(dynamic_cost.runtime <= cost.max.runtime); } @@ -183,9 +229,11 @@ fn test_dependent_function_calls() { .unwrap(); let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); - let add_one_cost = function_map.get("add-one").unwrap(); - let somefunc_cost = function_map.get("somefunc").unwrap(); + let (add_one_cost, _) = function_map.get("add-one").unwrap(); + let (somefunc_cost, _) = function_map.get("somefunc").unwrap(); + println!("add_one_cost: {:?}", add_one_cost); + println!("add_one_cost: {:?}", somefunc_cost); assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime); assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime); } From c0650f6ca27b6f91a1fd1c497c5ed7dd5c134d72 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:03:15 -0800 Subject: [PATCH 12/23] visitor pattern cleanup --- clarity/src/vm/costs/analysis.rs | 584 +++++++++++++++++++++++-------- clarity/src/vm/tests/analysis.rs | 127 ++++--- 2 files changed, 525 insertions(+), 186 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 81497d4d9f..530b959b6d 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -1,6 +1,6 @@ // Static cost analysis for Clarity contracts -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; use stacks_common::types::StacksEpochId; @@ -29,11 +29,9 @@ const STRING_COST_MULTIPLIER: u64 = 3; /// cost includes their processing const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; -/// Function definition keywords in Clarity const FUNCTION_DEFINITION_KEYWORDS: &[&str] = &["define-public", "define-private", "define-read-only"]; -/// Check if a function name is a function definition keyword fn is_function_definition(function_name: &str) -> bool { FUNCTION_DEFINITION_KEYWORDS.contains(&function_name) } @@ -225,137 +223,428 @@ fn static_cost_native( type MinMaxTraitCount = (u64, u64); type TraitCount = HashMap; -// // "" -> "trait-name" -// // ClarityName can't contain -// fn strip_trait_surrounding_brackets(name: &ClarityName) -> ClarityName { -// let stripped = name -// .as_str() -// .strip_prefix("<") -// .and_then(|name| name.strip_suffix(">")); -// if let Some(name) = stripped { -// ClarityName::from(name) -// } else { -// name.clone() -// } -// } -fn get_trait_count(costs: &HashMap) -> Option { - let mut trait_counts = HashMap::new(); - let mut trait_names = HashMap::new(); - // walk tree - for (name, cost_analysis_node) in costs.iter() { - get_trait_count_from_node( - cost_analysis_node, - &mut trait_counts, - &mut trait_names, - name.clone(), - 1, - ); - // trait_counts.extend(counts); +/// Context passed to visitors during trait count analysis +struct TraitCountContext { + containing_fn_name: String, + multiplier: u64, +} + +impl TraitCountContext { + fn new(containing_fn_name: String, multiplier: u64) -> Self { + Self { + containing_fn_name, + multiplier, + } + } + + fn with_multiplier(&self, multiplier: u64) -> Self { + Self { + containing_fn_name: self.containing_fn_name.clone(), + multiplier, + } + } + + fn with_fn_name(&self, fn_name: String) -> Self { + Self { + containing_fn_name: fn_name, + multiplier: self.multiplier, + } } - Some(trait_counts) } -fn get_trait_count_from_node( - cost_analysis_node: &CostAnalysisNode, - mut trait_counts: &mut TraitCount, - mut trait_names: &mut HashMap, - containing_fn_name: String, + +/// Extract the list size multiplier from a list expression (for map/filter/fold operations) +/// Expects a list in the form `(list )` where size is an integer literal +fn extract_list_multiplier(list: &[SymbolicExpression]) -> u64 { + if list.is_empty() { + return 1; + } + + let is_list_atom = list[0] + .match_atom() + .map(|a| a.as_str() == "list") + .unwrap_or(false); + if !is_list_atom || list.len() < 2 { + return 1; + } + + match &list[1].expr { + SymbolicExpressionType::LiteralValue(Value::Int(value)) => *value as u64, + _ => 1, + } +} + +/// Increment trait count for a function +fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: u64) { + trait_counts + .entry(fn_name.to_string()) + .and_modify(|(min, max)| { + *min += 1; + *max += multiplier; + }) + .or_insert((1, multiplier)); +} + +/// Propagate trait count from one function to another with a multiplier +fn propagate_trait_count( + trait_counts: &mut TraitCount, + from_fn: &str, + to_fn: &str, multiplier: u64, -) -> TraitCount { - match &cost_analysis_node.expr { - CostExprNode::UserArgument(arg_name, arg_type) => match arg_type { - SymbolicExpressionType::TraitReference(name, _) => { - trait_names.insert(arg_name.clone(), name.clone().to_string()); - trait_counts.entry(name.to_string()).or_insert((0, 0)); +) { + if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() { + trait_counts + .entry(to_fn.to_string()) + .and_modify(|(min, max)| { + *min += called_trait_count.0; + *max += called_trait_count.1 * multiplier; + }) + .or_insert((called_trait_count.0, called_trait_count.1 * multiplier)); + } +} + +/// Visitor trait for traversing cost analysis nodes and collecting/propagating trait counts +trait TraitCountVisitor { + fn visit_user_argument( + &mut self, + node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + context: &TraitCountContext, + ); + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ); + fn visit_atom_value(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_atom( + &mut self, + node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ); + fn visit_field_identifier(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_trait_reference( + &mut self, + node: &CostAnalysisNode, + trait_name: &ClarityName, + context: &TraitCountContext, + ); + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ); + + fn visit(&mut self, node: &CostAnalysisNode, context: &TraitCountContext) { + match &node.expr { + CostExprNode::UserArgument(arg_name, arg_type) => { + self.visit_user_argument(node, arg_name, arg_type, context); } - _ => {} - }, - CostExprNode::NativeFunction(native_function) => { - println!("native function: {:?}", native_function); - match native_function { - // if map, filter, or fold, we need to check if traits are called - NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { - println!("map: {:?}", cost_analysis_node.children); - let list_to_traverse = cost_analysis_node.children[1].clone(); - let multiplier = match list_to_traverse.expr { - CostExprNode::UserArgument(_, arg_type) => match arg_type { - SymbolicExpressionType::List(list) => { - if list[0].match_atom().unwrap().as_str() == "list" { - match list[1].clone().expr { - SymbolicExpressionType::LiteralValue(value) => { - match value { - Value::Int(value) => value as u64, - _ => 1, - } - } - _ => 1, - } - } else { - 1 - } - } - _ => 1, - }, - _ => 1, - }; - println!("multiplier: {:?}", multiplier); - cost_analysis_node.children.iter().for_each(|child| { - get_trait_count_from_node( - child, - &mut trait_counts, - &mut trait_names, - containing_fn_name.clone(), - multiplier, - ); - }); - } - _ => {} + CostExprNode::NativeFunction(native_function) => { + self.visit_native_function(node, native_function, context); + } + CostExprNode::AtomValue(_atom_value) => { + self.visit_atom_value(node, context); + } + CostExprNode::Atom(atom) => { + self.visit_atom(node, atom, context); + } + CostExprNode::FieldIdentifier(_field_identifier) => { + self.visit_field_identifier(node, context); + } + CostExprNode::TraitReference(trait_name) => { + self.visit_trait_reference(node, trait_name, context); + } + CostExprNode::UserFunction(user_function) => { + self.visit_user_function(node, user_function, context); } } - CostExprNode::AtomValue(atom_value) => { - println!("atom value: {:?}", atom_value); - // do nothing + } +} + +struct TraitCountCollector { + trait_counts: TraitCount, + trait_names: HashMap, +} + +impl TraitCountCollector { + fn new() -> Self { + Self { + trait_counts: HashMap::new(), + trait_names: HashMap::new(), + } + } +} + +impl TraitCountVisitor for TraitCountCollector { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + if let SymbolicExpressionType::TraitReference(name, _) = arg_type { + self.trait_names + .insert(arg_name.clone(), name.clone().to_string()); } - CostExprNode::Atom(atom) => { - println!("atom: {:?}", atom); - if trait_names.get(atom).is_some() { - trait_counts - .entry(containing_fn_name.clone()) - .and_modify(|(min, max)| { - *min += 1; - *max += multiplier; - }) - .or_insert((1, multiplier)); + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + 1 + }; + let new_context = context.with_multiplier(multiplier); + for child in &node.children { + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } } - // do nothing } - CostExprNode::FieldIdentifier(field_identifier) => { - println!("field identifier: {:?}", field_identifier); - // do nothing + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for atom values + } + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ) { + if self.trait_names.contains_key(atom) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); } - CostExprNode::TraitReference(trait_name) => { - println!("trait_name: {:?}", trait_name); - trait_counts - .entry(trait_name.to_string()) - .and_modify(|(min, max)| { - *min += 1; - *max += multiplier; - }) - .or_insert((1, multiplier)); + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for field identifiers + } + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + context: &TraitCountContext, + ) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + // Check if this is a trait call (the function name is a trait argument) + if self.trait_names.contains_key(user_function) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); } - CostExprNode::UserFunction(user_function) => { - println!("user function: {:?}", user_function); - cost_analysis_node.children.iter().for_each(|child| { - get_trait_count_from_node( - child, - &mut trait_counts, - &mut trait_names, - containing_fn_name.clone(), - multiplier, - ); - }); + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); } } - trait_counts.clone() +} + +/// Second pass visitor: propagates trait counts through function calls +struct TraitCountPropagator<'a> { + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, +} + +impl<'a> TraitCountPropagator<'a> { + fn new( + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, + ) -> Self { + Self { + trait_counts, + trait_names, + } + } +} + +impl<'a> TraitCountVisitor for TraitCountPropagator<'a> { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + _arg_name: &ClarityName, + _arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + // No propagation needed for arguments + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + 1 + }; + + // Process the function being called in map/filter/fold + let mut skip_first_child = false; + if let Some(function_node) = node.children.get(0) { + if let CostExprNode::UserFunction(function_name) = &function_node.expr { + if !self.trait_names.contains_key(function_name) { + // This is a regular function call, not a trait call + propagate_trait_count( + self.trait_counts, + &function_name.to_string(), + &context.containing_fn_name, + multiplier, + ); + skip_first_child = true; + } + } + } + + // Continue traversing children, but skip the function node if we already propagated it + for (idx, child) in node.children.iter().enumerate() { + if idx == 0 && skip_first_child { + continue; + } + let new_context = context.with_multiplier(multiplier); + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } + } + } + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + _atom: &ClarityName, + _context: &TraitCountContext, + ) { + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + _context: &TraitCountContext, + ) { + // No propagation needed for trait references (already counted in first pass) + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + if !is_function_definition(user_function.as_str()) + && !self.trait_names.contains_key(user_function) + { + // This is a regular function call, not a trait call or function definition + propagate_trait_count( + self.trait_counts, + &user_function.to_string(), + &context.containing_fn_name, + context.multiplier, + ); + } + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); + } + } +} + +pub(crate) fn get_trait_count(costs: &HashMap) -> Option { + // First pass: collect trait counts and trait names + let mut collector = TraitCountCollector::new(); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), 1); + collector.visit(cost_analysis_node, &context); + } + + // Second pass: propagate trait counts through function calls + // If function A calls function B and uses a map, filter, or fold with + // traits, the maximum will reflect that in A's trait call counts + let mut propagator = + TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), 1); + propagator.visit(cost_analysis_node, &context); + } + + Some(collector.trait_counts) } pub fn static_cost_from_ast( @@ -378,7 +667,7 @@ pub fn static_cost_from_ast( .collect()) } -fn static_cost_tree_from_ast( +pub(crate) fn static_cost_tree_from_ast( ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, ) -> Result, String> { @@ -404,21 +693,19 @@ fn static_cost_tree_from_ast( .collect()) } -/// Calculate static execution cost for functions using Environment context +/// STatic execution cost for functions within Environment /// returns the top level cost for specific functions -/// function_name -> cost +/// {function_name: cost} pub fn static_cost( env: &mut Environment, contract_identifier: &QualifiedContractIdentifier, ) -> Result, String> { - // Get contract source from the environment's database let contract_source = env .global_context .database .get_contract_src(contract_identifier) .ok_or_else(|| "Contract source not found in database".to_string())?; - // Get clarity version from the environment let contract = env .global_context .database @@ -438,18 +725,17 @@ pub fn static_cost( } /// same idea as `static_cost` but returns the root of the cost analysis tree for each function +/// Useful if you need to analyze specific nodes in the cost tree pub fn static_cost_tree( env: &mut Environment, contract_identifier: &QualifiedContractIdentifier, ) -> Result, String> { - // Get contract source from the environment's database let contract_source = env .global_context .database .get_contract_src(contract_identifier) .ok_or_else(|| "Contract source not found in database".to_string())?; - // Get clarity version from the environment let contract = env .global_context .database @@ -466,19 +752,16 @@ pub fn static_cost_tree( /// Extract function name from a symbolic expression fn extract_function_name(expr: &SymbolicExpression) -> Option { - if let Some(list) = expr.match_list() { - if let Some(first_atom) = list.first().and_then(|first| first.match_atom()) { - if is_function_definition(first_atom.as_str()) { - if let Some(signature) = list.get(1).and_then(|sig| sig.match_list()) { - return signature - .first() - .and_then(|name| name.match_atom()) - .map(|name| name.to_string()); - } - } - } - } - None + expr.match_list().and_then(|list| { + list.first() + .and_then(|first| first.match_atom()) + .filter(|atom| is_function_definition(atom.as_str())) + .and_then(|_| list.get(1)) + .and_then(|sig| sig.match_list()) + .and_then(|signature| signature.first()) + .and_then(|name| name.match_atom()) + .map(|name| name.to_string()) + }) } pub fn build_cost_analysis_tree( @@ -1306,4 +1589,23 @@ mod tests { assert!(func1_cost.min.runtime > 0); assert!(func2_cost.min.runtime > 0); } + + #[test] + fn test_extract_function_name_define_public() { + let src = "(define-public (my-func (x uint)) (ok x))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let result = extract_function_name(expr); + assert_eq!(result, Some("my-func".to_string())); + } + + #[test] + fn test_extract_function_name_function_call_not_definition() { + // function call (not a definition) should return None + let src = "(my-func arg1 arg2)"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let result = extract_function_name(expr); + assert_eq!(result, None); + } } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index bf3312e703..a3ac8ea016 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -6,7 +6,8 @@ use stacks_common::types::StacksEpochId; use crate::vm::contexts::OwnedEnvironment; use crate::vm::costs::analysis::{ - build_cost_analysis_tree, static_cost_from_ast, UserArgumentsContext, + build_cost_analysis_tree, get_trait_count, static_cost_from_ast, static_cost_tree_from_ast, + UserArgumentsContext, }; use crate::vm::costs::ExecutionCost; use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator}; @@ -64,50 +65,6 @@ fn test_simple_trait_implementation_costs( assert!(dynamic_cost.runtime <= cost.max.runtime); } -#[rstest] -fn test_trait_counting() { - // map, fold, filter over traits counting - let src = r#"(define-trait trait-name ( - (send (uint principal) (response uint uint)) -)) -(define-public (something (trait ) (addresses (list 10 principal))) - (map (send u500 trait) addresses) -) -(define-private (send (amount uint) (trait ) (addr principal)) (trait true)) -"#; - let contract_id = QualifiedContractIdentifier::local("trait-counting").unwrap(); - let ast = crate::vm::ast::build_ast( - &contract_id, - src, - &mut (), - ClarityVersion::Clarity3, - StacksEpochId::Epoch32, - ) - .unwrap(); - let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3) - .unwrap() - .clone(); - // trait count for 'something' function should be minimum 1 maximum 10 - println!("static_cost: {:?}", static_cost); - //trait count for send should be 1 - println!("trait_count: {:?}", static_cost.get("something").unwrap()); - println!("trait_count: {:?}", static_cost.get("send").unwrap()); - assert_eq!( - static_cost - .get("send") - .unwrap() - .1 - .clone() - .unwrap() - .get("trait-name") - .unwrap() - .0, - 1 - ); - // assert_eq!(trait_count.get("trait-name").unwrap().0, 1); - // assert_eq!(trait_count.get("trait-name").unwrap().1, 1); -} - #[rstest] #[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] fn test_complex_trait_implementation_costs( @@ -238,6 +195,86 @@ fn test_dependent_function_calls() { assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime); } +#[test] +fn test_get_trait_count_direct() { + let src = r#"(define-trait trait-name ( + (send (uint principal) (response uint uint)) +)) +(define-public (something (trait ) (addresses (list 10 principal))) + (map (send u500 trait) addresses) +) +(define-private (send (trait ) (addr principal)) (trait addr)) +"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let ast = crate::vm::ast::build_ast( + &contract_id, + src, + &mut (), + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .unwrap(); + + // Build the cost analysis tree + let costs = static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); + + // Call get_trait_count directly + let trait_count = get_trait_count(&costs); + + // Expected result: {something: (1,10), send: (1,1)} + let expected = { + let mut map = HashMap::new(); + map.insert("something".to_string(), (1, 10)); + map.insert("send".to_string(), (1, 1)); + Some(map) + }; + + assert_eq!(trait_count, expected); +} + +#[rstest] +fn test_trait_counting() { + // map, fold, filter over traits counting + let src = r#"(define-trait trait-name ( + (send (uint principal) (response uint uint)) +)) +(define-public (something (trait ) (addresses (list 10 principal))) + (map (send u500 trait) addresses) +) +(define-private (send (trait ) (addr principal)) (trait addr)) +"#; + let contract_id = QualifiedContractIdentifier::local("trait-counting").unwrap(); + let ast = crate::vm::ast::build_ast( + &contract_id, + src, + &mut (), + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .unwrap(); + let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3) + .unwrap() + .clone(); + // trait count for 'something' function should be minimum 1 maximum 10 + println!("static_cost: {:?}", static_cost); + //trait count for send should be 1 + println!("trait_count: {:?}", static_cost.get("something").unwrap()); + println!("trait_count: {:?}", static_cost.get("send").unwrap()); + // Trait counts are now keyed by function name, not trait name + // Check that "send" function has trait count of (1, 1) + let send_trait_count_map = static_cost.get("send").unwrap().1.clone().unwrap(); + let send_trait_count = send_trait_count_map.get("send").unwrap(); + assert_eq!(send_trait_count.0, 1); + assert_eq!(send_trait_count.1, 1); + + // Check that "something" function has trait count of (1, 10) + let something_trait_count_map = static_cost.get("something").unwrap().1.clone().unwrap(); + let something_trait_count = something_trait_count_map.get("something").unwrap(); + assert_eq!(something_trait_count.0, 1); + assert_eq!(something_trait_count.1, 10); +} + /// Helper function to execute a contract function and return the execution cost fn execute_contract_function_and_get_cost( env: &mut OwnedEnvironment, From 4ada29566f084f1b03f2c2870f2458e9e0e2d3c6 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Fri, 21 Nov 2025 10:04:20 -0800 Subject: [PATCH 13/23] trait counting on map/filter/fold minimum to zero --- clarity/src/vm/costs/analysis.rs | 43 +++++++++++++++++--------------- clarity/src/vm/tests/analysis.rs | 18 ++----------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 530b959b6d..a6b32bc872 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -226,18 +226,18 @@ type TraitCount = HashMap; /// Context passed to visitors during trait count analysis struct TraitCountContext { containing_fn_name: String, - multiplier: u64, + multiplier: (u64, u64), } impl TraitCountContext { - fn new(containing_fn_name: String, multiplier: u64) -> Self { + fn new(containing_fn_name: String, multiplier: (u64, u64)) -> Self { Self { containing_fn_name, multiplier, } } - fn with_multiplier(&self, multiplier: u64) -> Self { + fn with_multiplier(&self, multiplier: (u64, u64)) -> Self { Self { containing_fn_name: self.containing_fn_name.clone(), multiplier, @@ -254,9 +254,9 @@ impl TraitCountContext { /// Extract the list size multiplier from a list expression (for map/filter/fold operations) /// Expects a list in the form `(list )` where size is an integer literal -fn extract_list_multiplier(list: &[SymbolicExpression]) -> u64 { +fn extract_list_multiplier(list: &[SymbolicExpression]) -> (u64, u64) { if list.is_empty() { - return 1; + return (1, 1); } let is_list_atom = list[0] @@ -264,24 +264,24 @@ fn extract_list_multiplier(list: &[SymbolicExpression]) -> u64 { .map(|a| a.as_str() == "list") .unwrap_or(false); if !is_list_atom || list.len() < 2 { - return 1; + return (1, 1); } match &list[1].expr { - SymbolicExpressionType::LiteralValue(Value::Int(value)) => *value as u64, - _ => 1, + SymbolicExpressionType::LiteralValue(Value::Int(value)) => (0, *value as u64), + _ => (1, 1), } } /// Increment trait count for a function -fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: u64) { +fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: (u64, u64)) { trait_counts .entry(fn_name.to_string()) .and_modify(|(min, max)| { - *min += 1; - *max += multiplier; + *min += multiplier.0; + *max += multiplier.1; }) - .or_insert((1, multiplier)); + .or_insert(multiplier); } /// Propagate trait count from one function to another with a multiplier @@ -289,16 +289,19 @@ fn propagate_trait_count( trait_counts: &mut TraitCount, from_fn: &str, to_fn: &str, - multiplier: u64, + multiplier: (u64, u64), ) { if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() { trait_counts .entry(to_fn.to_string()) .and_modify(|(min, max)| { - *min += called_trait_count.0; - *max += called_trait_count.1 * multiplier; + *min += called_trait_count.0 * multiplier.0; + *max += called_trait_count.1 * multiplier.1; }) - .or_insert((called_trait_count.0, called_trait_count.1 * multiplier)); + .or_insert(( + called_trait_count.0 * multiplier.0, + called_trait_count.1 * multiplier.1, + )); } } @@ -409,7 +412,7 @@ impl TraitCountVisitor for TraitCountCollector { { extract_list_multiplier(list) } else { - 1 + (1, 1) }; let new_context = context.with_multiplier(multiplier); for child in &node.children { @@ -535,7 +538,7 @@ impl<'a> TraitCountVisitor for TraitCountPropagator<'a> { { extract_list_multiplier(list) } else { - 1 + (1, 1) }; // Process the function being called in map/filter/fold @@ -630,7 +633,7 @@ pub(crate) fn get_trait_count(costs: &HashMap) -> Opti // First pass: collect trait counts and trait names let mut collector = TraitCountCollector::new(); for (name, cost_analysis_node) in costs.iter() { - let context = TraitCountContext::new(name.clone(), 1); + let context = TraitCountContext::new(name.clone(), (1, 1)); collector.visit(cost_analysis_node, &context); } @@ -640,7 +643,7 @@ pub(crate) fn get_trait_count(costs: &HashMap) -> Opti let mut propagator = TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names); for (name, cost_analysis_node) in costs.iter() { - let context = TraitCountContext::new(name.clone(), 1); + let context = TraitCountContext::new(name.clone(), (1, 1)); propagator.visit(cost_analysis_node, &context); } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index a3ac8ea016..eb71e00a57 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -216,16 +216,13 @@ fn test_get_trait_count_direct() { ) .unwrap(); - // Build the cost analysis tree let costs = static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); - // Call get_trait_count directly let trait_count = get_trait_count(&costs); - // Expected result: {something: (1,10), send: (1,1)} let expected = { let mut map = HashMap::new(); - map.insert("something".to_string(), (1, 10)); + map.insert("something".to_string(), (0, 10)); map.insert("send".to_string(), (1, 1)); Some(map) }; @@ -256,22 +253,14 @@ fn test_trait_counting() { let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3) .unwrap() .clone(); - // trait count for 'something' function should be minimum 1 maximum 10 - println!("static_cost: {:?}", static_cost); - //trait count for send should be 1 - println!("trait_count: {:?}", static_cost.get("something").unwrap()); - println!("trait_count: {:?}", static_cost.get("send").unwrap()); - // Trait counts are now keyed by function name, not trait name - // Check that "send" function has trait count of (1, 1) let send_trait_count_map = static_cost.get("send").unwrap().1.clone().unwrap(); let send_trait_count = send_trait_count_map.get("send").unwrap(); assert_eq!(send_trait_count.0, 1); assert_eq!(send_trait_count.1, 1); - // Check that "something" function has trait count of (1, 10) let something_trait_count_map = static_cost.get("something").unwrap().1.clone().unwrap(); let something_trait_count = something_trait_count_map.get("something").unwrap(); - assert_eq!(something_trait_count.0, 1); + assert_eq!(something_trait_count.0, 0); assert_eq!(something_trait_count.1, 10); } @@ -283,10 +272,8 @@ fn execute_contract_function_and_get_cost( args: &[u64], version: ClarityVersion, ) -> ExecutionCost { - // Start with a fresh cost tracker let initial_cost = env.get_cost_total(); - // Create a dummy sender let sender = PrincipalData::parse_qualified_contract_principal( "ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender", ) @@ -317,7 +304,6 @@ fn execute_contract_function_and_get_cost( ); } - // Get the cost after execution let final_cost = env.get_cost_total(); ExecutionCost { From 50b007b8996972289dc85d4f3873bcacaab6d5fe Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Fri, 21 Nov 2025 10:10:56 -0800 Subject: [PATCH 14/23] replace deprecated InterpreterResult usages --- clarity/src/vm/costs/analysis.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index a6b32bc872..123d904f86 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -10,7 +10,7 @@ use crate::vm::contexts::Environment; use crate::vm::costs::cost_functions::{linear, CostValues}; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; -use crate::vm::errors::InterpreterResult; +use crate::vm::errors::VmExecutionError; use crate::vm::functions::NativeFunctions; use crate::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; use crate::vm::types::QualifiedContractIdentifier; @@ -1084,7 +1084,7 @@ fn calculate_function_cost_from_native_function( fn get_cost_function_for_native( function: NativeFunctions, _clarity_version: &ClarityVersion, -) -> Option InterpreterResult> { +) -> Option Result> { use crate::vm::functions::NativeFunctions::*; // Map NativeFunctions enum variants to their cost functions @@ -1286,7 +1286,7 @@ impl From for StaticCost { /// Helper: calculate min & max costs for a given cost function /// This is likely tooo simplistic but for now it'll do fn get_costs( - cost_fn: fn(u64) -> InterpreterResult, + cost_fn: fn(u64) -> Result, arg_count: u64, ) -> Result { let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; From ea0c605b3e61a7c65f6ef2c9aa0065cecc3844d4 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Sat, 22 Nov 2025 13:33:42 -0800 Subject: [PATCH 15/23] test larger contract analysis with pox-4 --- clarity/src/vm/costs/analysis.rs | 93 ++++++++++++++++++++++++-------- clarity/src/vm/tests/analysis.rs | 53 ++++++++++++++++++ 2 files changed, 125 insertions(+), 21 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 123d904f86..4fedd97b9f 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -953,30 +953,81 @@ fn build_listlike_cost_analysis_tree( children.push(child_tree); } - let function_name = get_function_name(&exprs[0])?; - // Try to lookup the function as a native function first - let (expr_node, cost) = if let Some(native_function) = - NativeFunctions::lookup_by_name_at_version(function_name.as_str(), clarity_version) - { - CostExprNode::NativeFunction(native_function); - let cost = calculate_function_cost_from_native_function( - native_function, - children.len() as u64, - clarity_version, - )?; - (CostExprNode::NativeFunction(native_function), cost) - } else { - // If not a native function, treat as user-defined function and look it up - let expr_node = CostExprNode::UserFunction(function_name.clone()); - let cost = calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?; - (expr_node, cost) + // Try to get function name from first element + let (expr_node, cost, function_name_opt) = match get_function_name(&exprs[0]) { + Ok(function_name) => { + // Try to lookup the function as a native function first + if let Some(native_function) = + NativeFunctions::lookup_by_name_at_version(function_name.as_str(), clarity_version) + { + let cost = calculate_function_cost_from_native_function( + native_function, + children.len() as u64, + clarity_version, + )?; + ( + CostExprNode::NativeFunction(native_function), + cost, + Some(function_name), + ) + } else { + // If not a native function, treat as user-defined function and look it up + let expr_node = CostExprNode::UserFunction(function_name.clone()); + let cost = + calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?; + (expr_node, cost, Some(function_name)) + } + } + Err(_) => { + // First element is not an atom - it might be a List that needs to be recursively analyzed + match &exprs[0].expr { + SymbolicExpressionType::List(_) => { + // Recursively analyze the nested list structure + let (_, nested_tree) = + build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version)?; + // Add the nested tree as a child (its cost will be included when summing children) + children.insert(0, nested_tree); + // The root cost is zero - the actual cost comes from the nested expression + let expr_node = CostExprNode::Atom(ClarityName::from("nested-expression")); + (expr_node, StaticCost::ZERO, None) + } + SymbolicExpressionType::Atom(name) => { + // It's an atom but not a function name - treat as atom with zero cost + (CostExprNode::Atom(name.clone()), StaticCost::ZERO, None) + } + SymbolicExpressionType::AtomValue(value) => { + // It's an atom value - calculate its cost + let cost = calculate_value_cost(value)?; + (CostExprNode::AtomValue(value.clone()), cost, None) + } + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => ( + CostExprNode::TraitReference(trait_name.clone()), + StaticCost::ZERO, + None, + ), + SymbolicExpressionType::Field(field_identifier) => ( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + None, + ), + SymbolicExpressionType::LiteralValue(value) => { + let cost = calculate_value_cost(value)?; + // TODO not sure if LiteralValue is needed in the CostExprNode types + (CostExprNode::AtomValue(value.clone()), cost, None) + } + } + } }; // Handle special cases for string arguments to functions that include their processing cost - if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { - for child in &mut children { - if let CostExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child.expr { - child.cost = StaticCost::ZERO; + if let Some(function_name) = &function_name_opt { + if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { + for child in &mut children { + if let CostExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = + &child.expr + { + child.cost = StaticCost::ZERO; + } } } } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index eb71e00a57..a788c92fdb 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -1,5 +1,6 @@ // TODO: This needs work to get the dynamic vs static testing working use std::collections::HashMap; +use std::path::Path; use rstest::rstest; use stacks_common::types::StacksEpochId; @@ -314,3 +315,55 @@ fn execute_contract_function_and_get_cost( runtime: final_cost.runtime - initial_cost.runtime, } } + +#[test] +fn test_pox_4_costs() { + let workspace_root = Path::new(env!("CARGO_MANIFEST_DIR")).parent().unwrap(); + let pox_4_path = workspace_root + .join("contrib") + .join("boot-contracts-unit-tests") + .join("boot_contracts") + .join("pox-4.clar"); + let contract_source = std::fs::read_to_string(&pox_4_path) + .unwrap_or_else(|e| panic!("Failed to read pox-4.clar file at {:?}: {}", pox_4_path, e)); + + let contract_id = QualifiedContractIdentifier::transient(); + let epoch = StacksEpochId::Epoch32; + let clarity_version = ClarityVersion::Clarity3; + + let ast = crate::vm::ast::build_ast( + &contract_id, + &contract_source, + &mut (), + clarity_version, + epoch, + ) + .expect("Failed to build AST from pox-4.clar"); + + let cost_map = static_cost_from_ast(&ast, &clarity_version) + .expect("Failed to perform static cost analysis on pox-4.clar"); + + // Check some functions in the cost map + let key_functions = vec![ + "stack-stx", + "delegate-stx", + "get-stacker-info", + "current-pox-reward-cycle", + "stack-aggregation-commit", + "stack-increase", + "stack-extend", + ]; + + for function_name in key_functions { + assert!( + cost_map.contains_key(function_name), + "Expected function '{}' to be present in cost map", + function_name + ); + + let (_cost, _trait_count) = cost_map.get(function_name).expect(&format!( + "Failed to get cost for function '{}'", + function_name + )); + } +} From 2df940b33a89da19e235f82aaefef50200aba49e Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Mon, 24 Nov 2025 12:40:44 -0800 Subject: [PATCH 16/23] clean up modules and simplify listlike builder --- clarity/src/vm/ast/mod.rs | 1 + clarity/src/vm/ast/static_cost/mod.rs | 329 ++++++ .../src/vm/ast/static_cost/trait_counter.rs | 417 ++++++++ clarity/src/vm/costs/analysis.rs | 980 ++---------------- clarity/src/vm/tests/analysis.rs | 115 +- 5 files changed, 863 insertions(+), 979 deletions(-) create mode 100644 clarity/src/vm/ast/static_cost/mod.rs create mode 100644 clarity/src/vm/ast/static_cost/trait_counter.rs diff --git a/clarity/src/vm/ast/mod.rs b/clarity/src/vm/ast/mod.rs index b09de172e6..c778354b36 100644 --- a/clarity/src/vm/ast/mod.rs +++ b/clarity/src/vm/ast/mod.rs @@ -17,6 +17,7 @@ pub mod definition_sorter; pub mod expression_identifier; pub mod parser; +pub mod static_cost; pub mod traits_resolver; pub mod errors; diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs new file mode 100644 index 0000000000..bcf20c1110 --- /dev/null +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -0,0 +1,329 @@ +mod trait_counter; +use std::collections::HashMap; + +use clarity_types::types::{CharType, SequenceData}; +pub use trait_counter::{ + TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, +}; + +// Import types from analysis.rs +use crate::vm::costs::analysis::{ + CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, +}; +use crate::vm::costs::cost_functions::{linear, CostValues}; +use crate::vm::costs::costs_3::Costs3; +use crate::vm::costs::ExecutionCost; +use crate::vm::errors::VmExecutionError; +use crate::vm::functions::NativeFunctions; +use crate::vm::representations::ClarityName; +use crate::vm::{ClarityVersion, Value}; + +const STRING_COST_BASE: u64 = 36; +const STRING_COST_MULTIPLIER: u64 = 3; + +/// Convert a NativeFunctions enum variant to its corresponding cost function +/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in +pub(crate) fn get_cost_function_for_native( + function: NativeFunctions, + _clarity_version: &ClarityVersion, +) -> Option Result> { + use crate::vm::functions::NativeFunctions::*; + + // Map NativeFunctions enum variants to their cost functions + match function { + Add => Some(Costs3::cost_add), + Subtract => Some(Costs3::cost_sub), + Multiply => Some(Costs3::cost_mul), + Divide => Some(Costs3::cost_div), + Modulo => Some(Costs3::cost_mod), + Power => Some(Costs3::cost_pow), + Sqrti => Some(Costs3::cost_sqrti), + Log2 => Some(Costs3::cost_log2), + ToInt | ToUInt => Some(Costs3::cost_int_cast), + Equals => Some(Costs3::cost_eq), + CmpGeq => Some(Costs3::cost_geq), + CmpLeq => Some(Costs3::cost_leq), + CmpGreater => Some(Costs3::cost_ge), + CmpLess => Some(Costs3::cost_le), + BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor), + Not | BitwiseNot => Some(Costs3::cost_not), + And | BitwiseAnd => Some(Costs3::cost_and), + Or | BitwiseOr => Some(Costs3::cost_or), + Concat => Some(Costs3::cost_concat), + Len => Some(Costs3::cost_len), + AsMaxLen => Some(Costs3::cost_as_max_len), + ListCons => Some(Costs3::cost_list_cons), + ElementAt | ElementAtAlias => Some(Costs3::cost_element_at), + IndexOf | IndexOfAlias => Some(Costs3::cost_index_of), + Fold => Some(Costs3::cost_fold), + Map => Some(Costs3::cost_map), + Filter => Some(Costs3::cost_filter), + Append => Some(Costs3::cost_append), + TupleGet => Some(Costs3::cost_tuple_get), + TupleMerge => Some(Costs3::cost_tuple_merge), + TupleCons => Some(Costs3::cost_tuple_cons), + ConsSome => Some(Costs3::cost_some_cons), + ConsOkay => Some(Costs3::cost_ok_cons), + ConsError => Some(Costs3::cost_err_cons), + DefaultTo => Some(Costs3::cost_default_to), + UnwrapRet => Some(Costs3::cost_unwrap_ret), + UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret), + IsOkay => Some(Costs3::cost_is_okay), + IsNone => Some(Costs3::cost_is_none), + IsErr => Some(Costs3::cost_is_err), + IsSome => Some(Costs3::cost_is_some), + Unwrap => Some(Costs3::cost_unwrap), + UnwrapErr => Some(Costs3::cost_unwrap_err), + TryRet => Some(Costs3::cost_try_ret), + If => Some(Costs3::cost_if), + Match => Some(Costs3::cost_match), + Begin => Some(Costs3::cost_begin), + Let => Some(Costs3::cost_let), + Asserts => Some(Costs3::cost_asserts), + Hash160 => Some(Costs3::cost_hash160), + Sha256 => Some(Costs3::cost_sha256), + Sha512 => Some(Costs3::cost_sha512), + Sha512Trunc256 => Some(Costs3::cost_sha512t256), + Keccak256 => Some(Costs3::cost_keccak256), + Secp256k1Recover => Some(Costs3::cost_secp256k1recover), + Secp256k1Verify => Some(Costs3::cost_secp256k1verify), + Print => Some(Costs3::cost_print), + ContractCall => Some(Costs3::cost_contract_call), + ContractOf => Some(Costs3::cost_contract_of), + PrincipalOf => Some(Costs3::cost_principal_of), + AtBlock => Some(Costs3::cost_at_block), + // => Some(Costs3::cost_create_map), + // => Some(Costs3::cost_create_var), + // ContractStorage => Some(Costs3::cost_contract_storage), + FetchEntry => Some(Costs3::cost_fetch_entry), + SetEntry => Some(Costs3::cost_set_entry), + FetchVar => Some(Costs3::cost_fetch_var), + SetVar => Some(Costs3::cost_set_var), + GetBlockInfo => Some(Costs3::cost_block_info), + GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), + GetStxBalance => Some(Costs3::cost_stx_balance), + StxTransfer => Some(Costs3::cost_stx_transfer), + StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), + StxGetAccount => Some(Costs3::cost_stx_account), + MintToken => Some(Costs3::cost_ft_mint), + MintAsset => Some(Costs3::cost_nft_mint), + TransferToken => Some(Costs3::cost_ft_transfer), + GetTokenBalance => Some(Costs3::cost_ft_balance), + GetTokenSupply => Some(Costs3::cost_ft_get_supply), + BurnToken => Some(Costs3::cost_ft_burn), + TransferAsset => Some(Costs3::cost_nft_transfer), + GetAssetOwner => Some(Costs3::cost_nft_owner), + BurnAsset => Some(Costs3::cost_nft_burn), + BuffToIntLe => Some(Costs3::cost_buff_to_int_le), + BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le), + BuffToIntBe => Some(Costs3::cost_buff_to_int_be), + BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be), + ToConsensusBuff => Some(Costs3::cost_to_consensus_buff), + FromConsensusBuff => Some(Costs3::cost_from_consensus_buff), + IsStandard => Some(Costs3::cost_is_standard), + PrincipalDestruct => Some(Costs3::cost_principal_destruct), + PrincipalConstruct => Some(Costs3::cost_principal_construct), + AsContract | AsContractSafe => Some(Costs3::cost_as_contract), + StringToInt => Some(Costs3::cost_string_to_int), + StringToUInt => Some(Costs3::cost_string_to_uint), + IntToAscii => Some(Costs3::cost_int_to_ascii), + IntToUtf8 => Some(Costs3::cost_int_to_utf8), + BitwiseLShift => Some(Costs3::cost_bitwise_left_shift), + BitwiseRShift => Some(Costs3::cost_bitwise_right_shift), + Slice => Some(Costs3::cost_slice), + ReplaceAt => Some(Costs3::cost_replace_at), + GetStacksBlockInfo => Some(Costs3::cost_block_info), + GetTenureInfo => Some(Costs3::cost_block_info), + ContractHash => Some(Costs3::cost_contract_hash), + ToAscii => Some(Costs3::cost_to_ascii), + InsertEntry => Some(Costs3::cost_set_entry), + DeleteEntry => Some(Costs3::cost_set_entry), + StxBurn => Some(Costs3::cost_stx_transfer), + Secp256r1Verify => Some(Costs3::cost_secp256r1verify), + RestrictAssets => None, // TODO: add cost function + AllowanceWithStx => None, // TODO: add cost function + AllowanceWithFt => None, // TODO: add cost function + AllowanceWithNft => None, // TODO: add cost function + AllowanceWithStacking => None, // TODO: add cost function + AllowanceAll => None, // TODO: add cost function + } +} + +// Calculate function cost with lazy evaluation support +pub(crate) fn calculate_function_cost( + function_name: String, + cost_map: &HashMap>, + _clarity_version: &ClarityVersion, +) -> Result { + match cost_map.get(&function_name) { + Some(Some(cost)) => { + // Cost already computed + Ok(cost.clone()) + } + Some(None) => { + // Should be impossible but alas.. + // Function exists but cost not yet computed - this indicates a circular dependency + // For now, return zero cost to avoid infinite recursion + println!( + "Circular dependency detected for function: {}", + function_name + ); + Ok(StaticCost::ZERO) + } + None => { + // Function not found + Ok(StaticCost::ZERO) + } + } +} + +/// Determine if a function name represents a branching function +pub(crate) fn is_branching_function(function_name: &ClarityName) -> bool { + match function_name.as_str() { + "if" | "match" => true, + "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and + // unwrap-err traverse both branches regardless of result, so until this is + // fixed in clarity we'll set this to false + _ => false, + } +} + +/// Helper function to determine if a node represents a branching operation +/// This is used in tests and cost calculation +pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool { + match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => true, + CostExprNode::UserFunction(name) => is_branching_function(name), + _ => false, + } +} + +/// Calculate the cost for a string based on its length +fn string_cost(length: usize) -> StaticCost { + let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); + let execution_cost = ExecutionCost::runtime(cost); + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + } +} + +/// Calculate cost for a value (used for literal values) +pub(crate) fn calculate_value_cost(value: &Value) -> Result { + match value { + Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { + Ok(string_cost(data.data.len())) + } + Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { + Ok(string_cost(data.data.len())) + } + _ => Ok(StaticCost::ZERO), + } +} + +pub(crate) fn calculate_function_cost_from_native_function( + native_function: NativeFunctions, + arg_count: u64, + clarity_version: &ClarityVersion, +) -> Result { + let cost_function = match get_cost_function_for_native(native_function, clarity_version) { + Some(cost_fn) => cost_fn, + None => { + // TODO: zero cost for now + return Ok(StaticCost::ZERO); + } + }; + + let cost = get_costs(cost_function, arg_count)?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) +} + +/// Calculate total cost using SummingExecutionCost to handle branching properly +pub(crate) fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); + + for child in &node.children { + let child_summing = calculate_total_cost_with_summing(child); + summing_cost.add_summing(&child_summing); + } + + summing_cost +} + +pub(crate) fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::new(); + + // Check if this is a branching function by examining the node's expression + let is_branching = is_node_branching(node); + + if is_branching { + match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => { + // TODO match? + if node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&node.children[0]); + let condition_total = condition_cost.add_all(); + + // Add the root cost + condition cost to each branch + let mut root_and_condition = node.cost.min.clone(); + let _ = root_and_condition.add(&condition_total); + + for child_cost_node in node.children.iter().skip(1) { + let branch_cost = calculate_total_cost_with_summing(child_cost_node); + let branch_total = branch_cost.add_all(); + + let mut path_cost = root_and_condition.clone(); + let _ = path_cost.add(&branch_total); + + summing_cost.add_cost(path_cost); + } + } + } + _ => { + // For other branching functions, fall back to sequential processing + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + } + } else { + // For non-branching, add all costs sequentially + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + + summing_cost +} + +impl From for StaticCost { + fn from(summing: SummingExecutionCost) -> Self { + StaticCost { + min: summing.min(), + max: summing.max(), + } + } +} + +/// Helper: calculate min & max costs for a given cost function +/// This is likely tooo simplistic but for now it'll do +fn get_costs( + cost_fn: fn(u64) -> Result, + arg_count: u64, +) -> Result { + let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(cost) +} diff --git a/clarity/src/vm/ast/static_cost/trait_counter.rs b/clarity/src/vm/ast/static_cost/trait_counter.rs new file mode 100644 index 0000000000..10a67afd8f --- /dev/null +++ b/clarity/src/vm/ast/static_cost/trait_counter.rs @@ -0,0 +1,417 @@ +use std::collections::HashMap; + +use clarity_types::representations::ClarityName; +use clarity_types::types::Value; + +use crate::vm::ast::static_cost::{CostAnalysisNode, CostExprNode}; +use crate::vm::costs::analysis::is_function_definition; +use crate::vm::functions::NativeFunctions; +use crate::vm::representations::{SymbolicExpression, SymbolicExpressionType}; +type MinMaxTraitCount = (u64, u64); +pub type TraitCount = HashMap; + +/// Context passed to visitors during trait count analysis +pub struct TraitCountContext { + containing_fn_name: String, + multiplier: (u64, u64), +} + +impl TraitCountContext { + pub fn new(containing_fn_name: String, multiplier: (u64, u64)) -> Self { + Self { + containing_fn_name, + multiplier, + } + } + + fn with_multiplier(&self, multiplier: (u64, u64)) -> Self { + Self { + containing_fn_name: self.containing_fn_name.clone(), + multiplier, + } + } + + fn with_fn_name(&self, fn_name: String) -> Self { + Self { + containing_fn_name: fn_name, + multiplier: self.multiplier, + } + } +} + +/// Extract the list size multiplier from a list expression (for map/filter/fold operations) +/// Expects a list in the form `(list )` where size is an integer literal +fn extract_list_multiplier(list: &[SymbolicExpression]) -> (u64, u64) { + if list.is_empty() { + return (1, 1); + } + + let is_list_atom = list[0] + .match_atom() + .map(|a| a.as_str() == "list") + .unwrap_or(false); + if !is_list_atom || list.len() < 2 { + return (1, 1); + } + + match &list[1].expr { + SymbolicExpressionType::LiteralValue(Value::Int(value)) => (0, *value as u64), + _ => (1, 1), + } +} + +/// Increment trait count for a function +fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: (u64, u64)) { + trait_counts + .entry(fn_name.to_string()) + .and_modify(|(min, max)| { + *min += multiplier.0; + *max += multiplier.1; + }) + .or_insert(multiplier); +} + +/// Propagate trait count from one function to another with a multiplier +fn propagate_trait_count( + trait_counts: &mut TraitCount, + from_fn: &str, + to_fn: &str, + multiplier: (u64, u64), +) { + if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() { + trait_counts + .entry(to_fn.to_string()) + .and_modify(|(min, max)| { + *min += called_trait_count.0 * multiplier.0; + *max += called_trait_count.1 * multiplier.1; + }) + .or_insert(( + called_trait_count.0 * multiplier.0, + called_trait_count.1 * multiplier.1, + )); + } +} + +/// Visitor trait for traversing cost analysis nodes and collecting/propagating trait counts +pub trait TraitCountVisitor { + fn visit_user_argument( + &mut self, + node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + context: &TraitCountContext, + ); + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ); + fn visit_atom_value(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_atom( + &mut self, + node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ); + fn visit_field_identifier(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_trait_reference( + &mut self, + node: &CostAnalysisNode, + trait_name: &ClarityName, + context: &TraitCountContext, + ); + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ); + + fn visit(&mut self, node: &CostAnalysisNode, context: &TraitCountContext) { + match &node.expr { + CostExprNode::UserArgument(arg_name, arg_type) => { + self.visit_user_argument(node, arg_name, arg_type, context); + } + CostExprNode::NativeFunction(native_function) => { + self.visit_native_function(node, native_function, context); + } + CostExprNode::AtomValue(_atom_value) => { + self.visit_atom_value(node, context); + } + CostExprNode::Atom(atom) => { + self.visit_atom(node, atom, context); + } + CostExprNode::FieldIdentifier(_field_identifier) => { + self.visit_field_identifier(node, context); + } + CostExprNode::TraitReference(trait_name) => { + self.visit_trait_reference(node, trait_name, context); + } + CostExprNode::UserFunction(user_function) => { + self.visit_user_function(node, user_function, context); + } + } + } +} + +pub struct TraitCountCollector { + pub trait_counts: TraitCount, + pub trait_names: HashMap, +} + +impl TraitCountCollector { + pub fn new() -> Self { + Self { + trait_counts: HashMap::new(), + trait_names: HashMap::new(), + } + } +} + +impl TraitCountVisitor for TraitCountCollector { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + if let SymbolicExpressionType::TraitReference(name, _) = arg_type { + self.trait_names + .insert(arg_name.clone(), name.clone().to_string()); + } + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + (1, 1) + }; + let new_context = context.with_multiplier(multiplier); + for child in &node.children { + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } + } + } + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for atom values + } + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ) { + if self.trait_names.contains_key(atom) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for field identifiers + } + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + context: &TraitCountContext, + ) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + // Check if this is a trait call (the function name is a trait argument) + if self.trait_names.contains_key(user_function) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); + } + } +} + +/// Second pass visitor: propagates trait counts through function calls +pub struct TraitCountPropagator<'a> { + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, +} + +impl<'a> TraitCountPropagator<'a> { + pub fn new( + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, + ) -> Self { + Self { + trait_counts, + trait_names, + } + } +} + +impl<'a> TraitCountVisitor for TraitCountPropagator<'a> { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + _arg_name: &ClarityName, + _arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + // No propagation needed for arguments + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + (1, 1) + }; + + // Process the function being called in map/filter/fold + let mut skip_first_child = false; + if let Some(function_node) = node.children.get(0) { + if let CostExprNode::UserFunction(function_name) = &function_node.expr { + if !self.trait_names.contains_key(function_name) { + // This is a regular function call, not a trait call + propagate_trait_count( + self.trait_counts, + &function_name.to_string(), + &context.containing_fn_name, + multiplier, + ); + skip_first_child = true; + } + } + } + + // Continue traversing children, but skip the function node if we already propagated it + for (idx, child) in node.children.iter().enumerate() { + if idx == 0 && skip_first_child { + continue; + } + let new_context = context.with_multiplier(multiplier); + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } + } + } + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + _atom: &ClarityName, + _context: &TraitCountContext, + ) { + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + _context: &TraitCountContext, + ) { + // No propagation needed for trait references (already counted in first pass) + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + if !is_function_definition(user_function.as_str()) + && !self.trait_names.contains_key(user_function) + { + // This is a regular function call, not a trait call or function definition + propagate_trait_count( + self.trait_counts, + &user_function.to_string(), + &context.containing_fn_name, + context.multiplier, + ); + } + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); + } + } +} diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 4fedd97b9f..cca8aa5414 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -2,25 +2,29 @@ use std::collections::HashMap; -use clarity_types::types::{CharType, SequenceData, TraitIdentifier}; +use clarity_types::types::TraitIdentifier; use stacks_common::types::StacksEpochId; use crate::vm::ast::build_ast; +#[cfg(test)] +use crate::vm::ast::static_cost::is_node_branching; +use crate::vm::ast::static_cost::{ + calculate_function_cost, calculate_function_cost_from_native_function, + calculate_total_cost_with_branching, calculate_value_cost, TraitCount, TraitCountCollector, + TraitCountContext, TraitCountPropagator, TraitCountVisitor, +}; use crate::vm::contexts::Environment; -use crate::vm::costs::cost_functions::{linear, CostValues}; -use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; -use crate::vm::errors::VmExecutionError; use crate::vm::functions::NativeFunctions; use crate::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; use crate::vm::types::QualifiedContractIdentifier; use crate::vm::{ClarityVersion, Value}; - // TODO: // contract-call? - get source from database // type-checking // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) +// split up trait counting and expr node tree impl into separate module? const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; @@ -32,7 +36,7 @@ const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; const FUNCTION_DEFINITION_KEYWORDS: &[&str] = &["define-public", "define-private", "define-read-only"]; -fn is_function_definition(function_name: &str) -> bool { +pub(crate) fn is_function_definition(function_name: &str) -> bool { FUNCTION_DEFINITION_KEYWORDS.contains(&function_name) } @@ -220,434 +224,71 @@ fn static_cost_native( Ok(summing_cost.into()) } -type MinMaxTraitCount = (u64, u64); -type TraitCount = HashMap; - -/// Context passed to visitors during trait count analysis -struct TraitCountContext { - containing_fn_name: String, - multiplier: (u64, u64), -} - -impl TraitCountContext { - fn new(containing_fn_name: String, multiplier: (u64, u64)) -> Self { - Self { - containing_fn_name, - multiplier, - } - } - - fn with_multiplier(&self, multiplier: (u64, u64)) -> Self { - Self { - containing_fn_name: self.containing_fn_name.clone(), - multiplier, - } - } - - fn with_fn_name(&self, fn_name: String) -> Self { - Self { - containing_fn_name: fn_name, - multiplier: self.multiplier, - } - } -} - -/// Extract the list size multiplier from a list expression (for map/filter/fold operations) -/// Expects a list in the form `(list )` where size is an integer literal -fn extract_list_multiplier(list: &[SymbolicExpression]) -> (u64, u64) { - if list.is_empty() { - return (1, 1); - } - - let is_list_atom = list[0] - .match_atom() - .map(|a| a.as_str() == "list") - .unwrap_or(false); - if !is_list_atom || list.len() < 2 { - return (1, 1); - } - - match &list[1].expr { - SymbolicExpressionType::LiteralValue(Value::Int(value)) => (0, *value as u64), - _ => (1, 1), - } -} - -/// Increment trait count for a function -fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: (u64, u64)) { - trait_counts - .entry(fn_name.to_string()) - .and_modify(|(min, max)| { - *min += multiplier.0; - *max += multiplier.1; - }) - .or_insert(multiplier); -} - -/// Propagate trait count from one function to another with a multiplier -fn propagate_trait_count( - trait_counts: &mut TraitCount, - from_fn: &str, - to_fn: &str, - multiplier: (u64, u64), -) { - if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() { - trait_counts - .entry(to_fn.to_string()) - .and_modify(|(min, max)| { - *min += called_trait_count.0 * multiplier.0; - *max += called_trait_count.1 * multiplier.1; - }) - .or_insert(( - called_trait_count.0 * multiplier.0, - called_trait_count.1 * multiplier.1, - )); - } -} - -/// Visitor trait for traversing cost analysis nodes and collecting/propagating trait counts -trait TraitCountVisitor { - fn visit_user_argument( - &mut self, - node: &CostAnalysisNode, - arg_name: &ClarityName, - arg_type: &SymbolicExpressionType, - context: &TraitCountContext, - ); - fn visit_native_function( - &mut self, - node: &CostAnalysisNode, - native_function: &NativeFunctions, - context: &TraitCountContext, - ); - fn visit_atom_value(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); - fn visit_atom( - &mut self, - node: &CostAnalysisNode, - atom: &ClarityName, - context: &TraitCountContext, - ); - fn visit_field_identifier(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); - fn visit_trait_reference( - &mut self, - node: &CostAnalysisNode, - trait_name: &ClarityName, - context: &TraitCountContext, - ); - fn visit_user_function( - &mut self, - node: &CostAnalysisNode, - user_function: &ClarityName, - context: &TraitCountContext, - ); - - fn visit(&mut self, node: &CostAnalysisNode, context: &TraitCountContext) { - match &node.expr { - CostExprNode::UserArgument(arg_name, arg_type) => { - self.visit_user_argument(node, arg_name, arg_type, context); - } - CostExprNode::NativeFunction(native_function) => { - self.visit_native_function(node, native_function, context); - } - CostExprNode::AtomValue(_atom_value) => { - self.visit_atom_value(node, context); - } - CostExprNode::Atom(atom) => { - self.visit_atom(node, atom, context); - } - CostExprNode::FieldIdentifier(_field_identifier) => { - self.visit_field_identifier(node, context); - } - CostExprNode::TraitReference(trait_name) => { - self.visit_trait_reference(node, trait_name, context); - } - CostExprNode::UserFunction(user_function) => { - self.visit_user_function(node, user_function, context); - } - } - } -} - -struct TraitCountCollector { - trait_counts: TraitCount, - trait_names: HashMap, -} - -impl TraitCountCollector { - fn new() -> Self { - Self { - trait_counts: HashMap::new(), - trait_names: HashMap::new(), - } - } -} - -impl TraitCountVisitor for TraitCountCollector { - fn visit_user_argument( - &mut self, - _node: &CostAnalysisNode, - arg_name: &ClarityName, - arg_type: &SymbolicExpressionType, - _context: &TraitCountContext, - ) { - if let SymbolicExpressionType::TraitReference(name, _) = arg_type { - self.trait_names - .insert(arg_name.clone(), name.clone().to_string()); - } - } - - fn visit_native_function( - &mut self, - node: &CostAnalysisNode, - native_function: &NativeFunctions, - context: &TraitCountContext, - ) { - match native_function { - NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { - if node.children.len() > 1 { - let list_node = &node.children[1]; - let multiplier = - if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = - &list_node.expr - { - extract_list_multiplier(list) - } else { - (1, 1) - }; - let new_context = context.with_multiplier(multiplier); - for child in &node.children { - self.visit(child, &new_context); - } - } - } - _ => { - for child in &node.children { - self.visit(child, context); - } - } - } - } - - fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { - // No action needed for atom values - } - - fn visit_atom( - &mut self, - _node: &CostAnalysisNode, - atom: &ClarityName, - context: &TraitCountContext, - ) { - if self.trait_names.contains_key(atom) { - increment_trait_count( - &mut self.trait_counts, - &context.containing_fn_name, - context.multiplier, - ); - } - } - - fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { - // No action needed for field identifiers - } - - fn visit_trait_reference( - &mut self, - _node: &CostAnalysisNode, - _trait_name: &ClarityName, - context: &TraitCountContext, - ) { - increment_trait_count( - &mut self.trait_counts, - &context.containing_fn_name, - context.multiplier, - ); - } - - fn visit_user_function( - &mut self, - node: &CostAnalysisNode, - user_function: &ClarityName, - context: &TraitCountContext, - ) { - // Check if this is a trait call (the function name is a trait argument) - if self.trait_names.contains_key(user_function) { - increment_trait_count( - &mut self.trait_counts, - &context.containing_fn_name, - context.multiplier, - ); - } +/// STatic execution cost for functions within Environment +/// returns the top level cost for specific functions +/// {function_name: cost} +pub fn static_cost( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result, String> { + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| { + format!( + "Contract source ({:?}) not found in database", + contract_identifier.to_string(), + ) + })?; - // Determine the containing function name for children - let fn_name = if is_function_definition(user_function.as_str()) { - context.containing_fn_name.clone() - } else { - user_function.to_string() - }; - let child_context = context.with_fn_name(fn_name); + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; - for child in &node.children { - self.visit(child, &child_context); - } - } -} + let clarity_version = contract.contract_context.get_clarity_version(); -/// Second pass visitor: propagates trait counts through function calls -struct TraitCountPropagator<'a> { - trait_counts: &'a mut TraitCount, - trait_names: &'a HashMap, -} + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; -impl<'a> TraitCountPropagator<'a> { - fn new( - trait_counts: &'a mut TraitCount, - trait_names: &'a HashMap, - ) -> Self { - Self { - trait_counts, - trait_names, - } - } + let costs = static_cost_from_ast(&ast, clarity_version)?; + Ok(costs + .into_iter() + .map(|(name, (cost, _trait_count))| (name, cost)) + .collect()) } -impl<'a> TraitCountVisitor for TraitCountPropagator<'a> { - fn visit_user_argument( - &mut self, - _node: &CostAnalysisNode, - _arg_name: &ClarityName, - _arg_type: &SymbolicExpressionType, - _context: &TraitCountContext, - ) { - // No propagation needed for arguments - } - - fn visit_native_function( - &mut self, - node: &CostAnalysisNode, - native_function: &NativeFunctions, - context: &TraitCountContext, - ) { - match native_function { - NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { - if node.children.len() > 1 { - let list_node = &node.children[1]; - let multiplier = - if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = - &list_node.expr - { - extract_list_multiplier(list) - } else { - (1, 1) - }; - - // Process the function being called in map/filter/fold - let mut skip_first_child = false; - if let Some(function_node) = node.children.get(0) { - if let CostExprNode::UserFunction(function_name) = &function_node.expr { - if !self.trait_names.contains_key(function_name) { - // This is a regular function call, not a trait call - propagate_trait_count( - self.trait_counts, - &function_name.to_string(), - &context.containing_fn_name, - multiplier, - ); - skip_first_child = true; - } - } - } - - // Continue traversing children, but skip the function node if we already propagated it - for (idx, child) in node.children.iter().enumerate() { - if idx == 0 && skip_first_child { - continue; - } - let new_context = context.with_multiplier(multiplier); - self.visit(child, &new_context); - } - } - } - _ => { - for child in &node.children { - self.visit(child, context); - } - } - } - } - - fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} - - fn visit_atom( - &mut self, - _node: &CostAnalysisNode, - _atom: &ClarityName, - _context: &TraitCountContext, - ) { - } - - fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} - - fn visit_trait_reference( - &mut self, - _node: &CostAnalysisNode, - _trait_name: &ClarityName, - _context: &TraitCountContext, - ) { - // No propagation needed for trait references (already counted in first pass) - } - - fn visit_user_function( - &mut self, - node: &CostAnalysisNode, - user_function: &ClarityName, - context: &TraitCountContext, - ) { - if !is_function_definition(user_function.as_str()) - && !self.trait_names.contains_key(user_function) - { - // This is a regular function call, not a trait call or function definition - propagate_trait_count( - self.trait_counts, - &user_function.to_string(), - &context.containing_fn_name, - context.multiplier, - ); - } - - // Determine the containing function name for children - let fn_name = if is_function_definition(user_function.as_str()) { - context.containing_fn_name.clone() - } else { - user_function.to_string() - }; - let child_context = context.with_fn_name(fn_name); +/// same idea as `static_cost` but returns the root of the cost analysis tree for each function +/// Useful if you need to analyze specific nodes in the cost tree +pub fn static_cost_tree( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result, String> { + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| { + format!( + "Contract source ({:?}) not found in database", + contract_identifier.to_string(), + ) + })?; - for child in &node.children { - self.visit(child, &child_context); - } - } -} + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; -pub(crate) fn get_trait_count(costs: &HashMap) -> Option { - // First pass: collect trait counts and trait names - let mut collector = TraitCountCollector::new(); - for (name, cost_analysis_node) in costs.iter() { - let context = TraitCountContext::new(name.clone(), (1, 1)); - collector.visit(cost_analysis_node, &context); - } + let clarity_version = contract.contract_context.get_clarity_version(); - // Second pass: propagate trait counts through function calls - // If function A calls function B and uses a map, filter, or fold with - // traits, the maximum will reflect that in A's trait call counts - let mut propagator = - TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names); - for (name, cost_analysis_node) in costs.iter() { - let context = TraitCountContext::new(name.clone(), (1, 1)); - propagator.visit(cost_analysis_node, &context); - } + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; - Some(collector.trait_counts) + static_cost_tree_from_ast(&ast, clarity_version) } pub fn static_cost_from_ast( @@ -696,63 +337,6 @@ pub(crate) fn static_cost_tree_from_ast( .collect()) } -/// STatic execution cost for functions within Environment -/// returns the top level cost for specific functions -/// {function_name: cost} -pub fn static_cost( - env: &mut Environment, - contract_identifier: &QualifiedContractIdentifier, -) -> Result, String> { - let contract_source = env - .global_context - .database - .get_contract_src(contract_identifier) - .ok_or_else(|| "Contract source not found in database".to_string())?; - - let contract = env - .global_context - .database - .get_contract(contract_identifier) - .map_err(|e| format!("Failed to get contract: {:?}", e))?; - - let clarity_version = contract.contract_context.get_clarity_version(); - - let epoch = env.global_context.epoch_id; - let ast = make_ast(&contract_source, epoch, clarity_version)?; - - let costs = static_cost_from_ast(&ast, clarity_version)?; - Ok(costs - .into_iter() - .map(|(name, (cost, _trait_count))| (name, cost)) - .collect()) -} - -/// same idea as `static_cost` but returns the root of the cost analysis tree for each function -/// Useful if you need to analyze specific nodes in the cost tree -pub fn static_cost_tree( - env: &mut Environment, - contract_identifier: &QualifiedContractIdentifier, -) -> Result, String> { - let contract_source = env - .global_context - .database - .get_contract_src(contract_identifier) - .ok_or_else(|| "Contract source not found in database".to_string())?; - - let contract = env - .global_context - .database - .get_contract(contract_identifier) - .map_err(|e| format!("Failed to get contract: {:?}", e))?; - - let clarity_version = contract.contract_context.get_clarity_version(); - - let epoch = env.global_context.epoch_id; - let ast = make_ast(&contract_source, epoch, clarity_version)?; - - static_cost_tree_from_ast(&ast, clarity_version) -} - /// Extract function name from a symbolic expression fn extract_function_name(expr: &SymbolicExpression) -> Option { expr.match_list().and_then(|list| { @@ -880,23 +464,6 @@ fn build_function_definition_cost_analysis_tree( .ok_or("Expected atom for argument name")?; let arg_type = arg_list[1].clone(); - // let arg_type = match &arg_list[1].expr { - // SymbolicExpressionType::Atom(type_name) => type_name.clone(), - // SymbolicExpressionType::AtomValue(value) => { - // ClarityName::from(value.to_string().as_str()) - // } - // SymbolicExpressionType::LiteralValue(value) => { - // ClarityName::from(value.to_string().as_str()) - // } - // SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => { - // trait_name.clone() - // } - // SymbolicExpressionType::List(_) => ClarityName::from("list"), - // _ => { - // println!("arg: {:?}", arg_list[1].expr); - // return Err("Argument type must be an atom or atom value".to_string()); - // } - // }; // Add to function's user arguments context function_user_args.add_argument(arg_name.clone(), arg_type.clone().expr); @@ -953,395 +520,78 @@ fn build_listlike_cost_analysis_tree( children.push(child_tree); } - // Try to get function name from first element - let (expr_node, cost, function_name_opt) = match get_function_name(&exprs[0]) { - Ok(function_name) => { + let (expr_node, cost) = match &exprs[0].expr { + SymbolicExpressionType::List(_) => { + // Recursively analyze the nested list structure + let (_, nested_tree) = + build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version)?; + // Add the nested tree as a child (its cost will be included when summing children) + children.insert(0, nested_tree); + // The root cost is zero - the actual cost comes from the nested expression + let expr_node = CostExprNode::Atom(ClarityName::from("nested-expression")); + (expr_node, StaticCost::ZERO) + } + SymbolicExpressionType::Atom(name) => { + // Try to get function name from first element // Try to lookup the function as a native function first if let Some(native_function) = - NativeFunctions::lookup_by_name_at_version(function_name.as_str(), clarity_version) + NativeFunctions::lookup_by_name_at_version(name.as_str(), clarity_version) { let cost = calculate_function_cost_from_native_function( native_function, children.len() as u64, clarity_version, )?; - ( - CostExprNode::NativeFunction(native_function), - cost, - Some(function_name), - ) + (CostExprNode::NativeFunction(native_function), cost) } else { // If not a native function, treat as user-defined function and look it up - let expr_node = CostExprNode::UserFunction(function_name.clone()); - let cost = - calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?; - (expr_node, cost, Some(function_name)) + let expr_node = CostExprNode::UserFunction(name.clone()); + let cost = calculate_function_cost(name.to_string(), cost_map, clarity_version)?; + (expr_node, cost) } } - Err(_) => { - // First element is not an atom - it might be a List that needs to be recursively analyzed - match &exprs[0].expr { - SymbolicExpressionType::List(_) => { - // Recursively analyze the nested list structure - let (_, nested_tree) = - build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version)?; - // Add the nested tree as a child (its cost will be included when summing children) - children.insert(0, nested_tree); - // The root cost is zero - the actual cost comes from the nested expression - let expr_node = CostExprNode::Atom(ClarityName::from("nested-expression")); - (expr_node, StaticCost::ZERO, None) - } - SymbolicExpressionType::Atom(name) => { - // It's an atom but not a function name - treat as atom with zero cost - (CostExprNode::Atom(name.clone()), StaticCost::ZERO, None) - } - SymbolicExpressionType::AtomValue(value) => { - // It's an atom value - calculate its cost - let cost = calculate_value_cost(value)?; - (CostExprNode::AtomValue(value.clone()), cost, None) - } - SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => ( - CostExprNode::TraitReference(trait_name.clone()), - StaticCost::ZERO, - None, - ), - SymbolicExpressionType::Field(field_identifier) => ( - CostExprNode::FieldIdentifier(field_identifier.clone()), - StaticCost::ZERO, - None, - ), - SymbolicExpressionType::LiteralValue(value) => { - let cost = calculate_value_cost(value)?; - // TODO not sure if LiteralValue is needed in the CostExprNode types - (CostExprNode::AtomValue(value.clone()), cost, None) - } - } - } - }; - - // Handle special cases for string arguments to functions that include their processing cost - if let Some(function_name) = &function_name_opt { - if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { - for child in &mut children { - if let CostExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = - &child.expr - { - child.cost = StaticCost::ZERO; - } - } - } - } - - Ok(CostAnalysisNode::new(expr_node, cost, children)) -} - -// Calculate function cost with lazy evaluation support -fn calculate_function_cost( - function_name: String, - cost_map: &HashMap>, - _clarity_version: &ClarityVersion, -) -> Result { - match cost_map.get(&function_name) { - Some(Some(cost)) => { - // Cost already computed - Ok(cost.clone()) - } - Some(None) => { - // Should be impossible but alas.. - // Function exists but cost not yet computed - this indicates a circular dependency - // For now, return zero cost to avoid infinite recursion - println!( - "Circular dependency detected for function: {}", - function_name - ); - Ok(StaticCost::ZERO) - } - None => { - // Function not found - Ok(StaticCost::ZERO) - } - } -} -/// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version -/// directly in build_listlike_cost_analysis_tree - -/// Determine if a function name represents a branching function -fn is_branching_function(function_name: &ClarityName) -> bool { - match function_name.as_str() { - "if" | "match" => true, - "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and - // unwrap-err traverse both branches regardless of result, so until this is - // fixed in clarity we'll set this to false - _ => false, - } -} - -/// Helper function to determine if a node represents a branching operation -/// This is used in tests and cost calculation -fn is_node_branching(node: &CostAnalysisNode) -> bool { - match &node.expr { - CostExprNode::NativeFunction(NativeFunctions::If) - | CostExprNode::NativeFunction(NativeFunctions::Match) => true, - CostExprNode::UserFunction(name) => is_branching_function(name), - _ => false, - } -} - -/// Calculate the cost for a string based on its length -fn string_cost(length: usize) -> StaticCost { - let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); - let execution_cost = ExecutionCost::runtime(cost); - StaticCost { - min: execution_cost.clone(), - max: execution_cost, - } -} - -/// Calculate cost for a value (used for literal values) -fn calculate_value_cost(value: &Value) -> Result { - match value { - Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { - Ok(string_cost(data.data.len())) - } - Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { - Ok(string_cost(data.data.len())) + SymbolicExpressionType::AtomValue(value) => { + // It's an atom value - calculate its cost + let cost = calculate_value_cost(value)?; + (CostExprNode::AtomValue(value.clone()), cost) } - _ => Ok(StaticCost::ZERO), - } -} - -fn calculate_function_cost_from_native_function( - native_function: NativeFunctions, - arg_count: u64, - clarity_version: &ClarityVersion, -) -> Result { - let cost_function = match get_cost_function_for_native(native_function, clarity_version) { - Some(cost_fn) => cost_fn, - None => { - // TODO: zero cost for now - return Ok(StaticCost::ZERO); + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => ( + CostExprNode::TraitReference(trait_name.clone()), + StaticCost::ZERO, + ), + SymbolicExpressionType::Field(field_identifier) => ( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + ), + SymbolicExpressionType::LiteralValue(value) => { + let cost = calculate_value_cost(value)?; + // TODO not sure if LiteralValue is needed in the CostExprNode types + (CostExprNode::AtomValue(value.clone()), cost) } }; - let cost = get_costs(cost_function, arg_count)?; - Ok(StaticCost { - min: cost.clone(), - max: cost, - }) -} - -/// Convert a NativeFunctions enum variant to its corresponding cost function -/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in -fn get_cost_function_for_native( - function: NativeFunctions, - _clarity_version: &ClarityVersion, -) -> Option Result> { - use crate::vm::functions::NativeFunctions::*; - - // Map NativeFunctions enum variants to their cost functions - match function { - Add => Some(Costs3::cost_add), - Subtract => Some(Costs3::cost_sub), - Multiply => Some(Costs3::cost_mul), - Divide => Some(Costs3::cost_div), - Modulo => Some(Costs3::cost_mod), - Power => Some(Costs3::cost_pow), - Sqrti => Some(Costs3::cost_sqrti), - Log2 => Some(Costs3::cost_log2), - ToInt | ToUInt => Some(Costs3::cost_int_cast), - Equals => Some(Costs3::cost_eq), - CmpGeq => Some(Costs3::cost_geq), - CmpLeq => Some(Costs3::cost_leq), - CmpGreater => Some(Costs3::cost_ge), - CmpLess => Some(Costs3::cost_le), - BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor), - Not | BitwiseNot => Some(Costs3::cost_not), - And | BitwiseAnd => Some(Costs3::cost_and), - Or | BitwiseOr => Some(Costs3::cost_or), - Concat => Some(Costs3::cost_concat), - Len => Some(Costs3::cost_len), - AsMaxLen => Some(Costs3::cost_as_max_len), - ListCons => Some(Costs3::cost_list_cons), - ElementAt | ElementAtAlias => Some(Costs3::cost_element_at), - IndexOf | IndexOfAlias => Some(Costs3::cost_index_of), - Fold => Some(Costs3::cost_fold), - Map => Some(Costs3::cost_map), - Filter => Some(Costs3::cost_filter), - Append => Some(Costs3::cost_append), - TupleGet => Some(Costs3::cost_tuple_get), - TupleMerge => Some(Costs3::cost_tuple_merge), - TupleCons => Some(Costs3::cost_tuple_cons), - ConsSome => Some(Costs3::cost_some_cons), - ConsOkay => Some(Costs3::cost_ok_cons), - ConsError => Some(Costs3::cost_err_cons), - DefaultTo => Some(Costs3::cost_default_to), - UnwrapRet => Some(Costs3::cost_unwrap_ret), - UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret), - IsOkay => Some(Costs3::cost_is_okay), - IsNone => Some(Costs3::cost_is_none), - IsErr => Some(Costs3::cost_is_err), - IsSome => Some(Costs3::cost_is_some), - Unwrap => Some(Costs3::cost_unwrap), - UnwrapErr => Some(Costs3::cost_unwrap_err), - TryRet => Some(Costs3::cost_try_ret), - If => Some(Costs3::cost_if), - Match => Some(Costs3::cost_match), - Begin => Some(Costs3::cost_begin), - Let => Some(Costs3::cost_let), - Asserts => Some(Costs3::cost_asserts), - Hash160 => Some(Costs3::cost_hash160), - Sha256 => Some(Costs3::cost_sha256), - Sha512 => Some(Costs3::cost_sha512), - Sha512Trunc256 => Some(Costs3::cost_sha512t256), - Keccak256 => Some(Costs3::cost_keccak256), - Secp256k1Recover => Some(Costs3::cost_secp256k1recover), - Secp256k1Verify => Some(Costs3::cost_secp256k1verify), - Print => Some(Costs3::cost_print), - ContractCall => Some(Costs3::cost_contract_call), - ContractOf => Some(Costs3::cost_contract_of), - PrincipalOf => Some(Costs3::cost_principal_of), - AtBlock => Some(Costs3::cost_at_block), - // => Some(Costs3::cost_create_map), - // => Some(Costs3::cost_create_var), - // ContractStorage => Some(Costs3::cost_contract_storage), - FetchEntry => Some(Costs3::cost_fetch_entry), - SetEntry => Some(Costs3::cost_set_entry), - FetchVar => Some(Costs3::cost_fetch_var), - SetVar => Some(Costs3::cost_set_var), - GetBlockInfo => Some(Costs3::cost_block_info), - GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), - GetStxBalance => Some(Costs3::cost_stx_balance), - StxTransfer => Some(Costs3::cost_stx_transfer), - StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), - StxGetAccount => Some(Costs3::cost_stx_account), - MintToken => Some(Costs3::cost_ft_mint), - MintAsset => Some(Costs3::cost_nft_mint), - TransferToken => Some(Costs3::cost_ft_transfer), - GetTokenBalance => Some(Costs3::cost_ft_balance), - GetTokenSupply => Some(Costs3::cost_ft_get_supply), - BurnToken => Some(Costs3::cost_ft_burn), - TransferAsset => Some(Costs3::cost_nft_transfer), - GetAssetOwner => Some(Costs3::cost_nft_owner), - BurnAsset => Some(Costs3::cost_nft_burn), - BuffToIntLe => Some(Costs3::cost_buff_to_int_le), - BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le), - BuffToIntBe => Some(Costs3::cost_buff_to_int_be), - BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be), - ToConsensusBuff => Some(Costs3::cost_to_consensus_buff), - FromConsensusBuff => Some(Costs3::cost_from_consensus_buff), - IsStandard => Some(Costs3::cost_is_standard), - PrincipalDestruct => Some(Costs3::cost_principal_destruct), - PrincipalConstruct => Some(Costs3::cost_principal_construct), - AsContract | AsContractSafe => Some(Costs3::cost_as_contract), - StringToInt => Some(Costs3::cost_string_to_int), - StringToUInt => Some(Costs3::cost_string_to_uint), - IntToAscii => Some(Costs3::cost_int_to_ascii), - IntToUtf8 => Some(Costs3::cost_int_to_utf8), - BitwiseLShift => Some(Costs3::cost_bitwise_left_shift), - BitwiseRShift => Some(Costs3::cost_bitwise_right_shift), - Slice => Some(Costs3::cost_slice), - ReplaceAt => Some(Costs3::cost_replace_at), - GetStacksBlockInfo => Some(Costs3::cost_block_info), - GetTenureInfo => Some(Costs3::cost_block_info), - ContractHash => Some(Costs3::cost_contract_hash), - ToAscii => Some(Costs3::cost_to_ascii), - InsertEntry => Some(Costs3::cost_set_entry), - DeleteEntry => Some(Costs3::cost_set_entry), - StxBurn => Some(Costs3::cost_stx_transfer), - Secp256r1Verify => Some(Costs3::cost_secp256r1verify), - RestrictAssets => None, // TODO: add cost function - AllowanceWithStx => None, // TODO: add cost function - AllowanceWithFt => None, // TODO: add cost function - AllowanceWithNft => None, // TODO: add cost function - AllowanceWithStacking => None, // TODO: add cost function - AllowanceAll => None, // TODO: add cost function - } -} - -/// Calculate total cost using SummingExecutionCost to handle branching properly -fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { - let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); - - for child in &node.children { - let child_summing = calculate_total_cost_with_summing(child); - summing_cost.add_summing(&child_summing); - } - - summing_cost + Ok(CostAnalysisNode::new(expr_node, cost, children)) } -fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { - let mut summing_cost = SummingExecutionCost::new(); - - // Check if this is a branching function by examining the node's expression - let is_branching = is_node_branching(node); - - if is_branching { - match &node.expr { - CostExprNode::NativeFunction(NativeFunctions::If) - | CostExprNode::NativeFunction(NativeFunctions::Match) => { - // TODO match? - if node.children.len() >= 2 { - let condition_cost = calculate_total_cost_with_summing(&node.children[0]); - let condition_total = condition_cost.add_all(); - - // Add the root cost + condition cost to each branch - let mut root_and_condition = node.cost.min.clone(); - let _ = root_and_condition.add(&condition_total); - - for child_cost_node in node.children.iter().skip(1) { - let branch_cost = calculate_total_cost_with_summing(child_cost_node); - let branch_total = branch_cost.add_all(); - - let mut path_cost = root_and_condition.clone(); - let _ = path_cost.add(&branch_total); - - summing_cost.add_cost(path_cost); - } - } - } - _ => { - // For other branching functions, fall back to sequential processing - let mut total_cost = node.cost.min.clone(); - for child_cost_node in &node.children { - let child_summing = calculate_total_cost_with_summing(child_cost_node); - let combined_cost = child_summing.add_all(); - let _ = total_cost.add(&combined_cost); - } - summing_cost.add_cost(total_cost); - } - } - } else { - // For non-branching, add all costs sequentially - let mut total_cost = node.cost.min.clone(); - for child_cost_node in &node.children { - let child_summing = calculate_total_cost_with_summing(child_cost_node); - let combined_cost = child_summing.add_all(); - let _ = total_cost.add(&combined_cost); - } - summing_cost.add_cost(total_cost); +pub(crate) fn get_trait_count(costs: &HashMap) -> Option { + // First pass: collect trait counts and trait names + let mut collector = TraitCountCollector::new(); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), (1, 1)); + collector.visit(cost_analysis_node, &context); } - summing_cost -} - -impl From for StaticCost { - fn from(summing: SummingExecutionCost) -> Self { - StaticCost { - min: summing.min(), - max: summing.max(), - } + // Second pass: propagate trait counts through function calls + // If function A calls function B and uses a map, filter, or fold with + // traits, the maximum will reflect that in A's trait call counts + let mut propagator = + TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), (1, 1)); + propagator.visit(cost_analysis_node, &context); } -} -/// Helper: calculate min & max costs for a given cost function -/// This is likely tooo simplistic but for now it'll do -fn get_costs( - cost_fn: fn(u64) -> Result, - arg_count: u64, -) -> Result { - let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; - Ok(cost) + Some(collector.trait_counts) } #[cfg(test)] diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index a788c92fdb..a1f36a7083 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -11,122 +11,9 @@ use crate::vm::costs::analysis::{ UserArgumentsContext, }; use crate::vm::costs::ExecutionCost; -use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator}; use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; use crate::vm::{ast, ClarityVersion}; -const SIMPLE_TRAIT_SRC: &str = r#"(define-trait mytrait ( - (somefunc (uint uint) (response uint uint)) -)) -"#; - -#[rstest] -#[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] -fn test_simple_trait_implementation_costs( - #[case] version: ClarityVersion, - #[case] epoch: StacksEpochId, - mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, -) { - let simple_impl = r#"(impl-trait .mytrait.mytrait) - (define-public (somefunc (a uint) (b uint)) - (ok (+ a b)) - )"#; - - let mut owned_env = tl_env_factory.get_env(epoch); - - let epoch = StacksEpochId::Epoch21; - let ast = crate::vm::ast::build_ast( - &QualifiedContractIdentifier::transient(), - simple_impl, - &mut (), - version, - epoch, - ) - .unwrap(); - let static_cost = static_cost_from_ast(&ast, &version).unwrap(); - // Deploy and execute the contract to get dynamic costs - let contract_id = QualifiedContractIdentifier::local("simple-impl").unwrap(); - owned_env - .initialize_versioned_contract(contract_id.clone(), version, simple_impl, None) - .unwrap(); - - let dynamic_cost = execute_contract_function_and_get_cost( - &mut owned_env, - &contract_id, - "somefunc", - &[4, 5], - version, - ); - println!("dynamic_cost: {:?}", dynamic_cost); - println!("static_cost: {:?}", static_cost); - - let key = static_cost.keys().nth(1).unwrap(); - let (cost, _trait_count) = static_cost.get(key).unwrap(); - assert!(dynamic_cost.runtime >= cost.min.runtime); - assert!(dynamic_cost.runtime <= cost.max.runtime); -} - -#[rstest] -#[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)] -fn test_complex_trait_implementation_costs( - #[case] version: ClarityVersion, - #[case] epoch: StacksEpochId, - mut tl_env_factory: TopLevelMemoryEnvironmentGenerator, -) { - let complex_impl = r#"(define-public (somefunc (a uint) (b uint)) - (begin - ;; do something expensive - ;; emit events - (print a) - (print b) - (print "doing complex calculation") - (let ((result (* a b))) - (print result) - (ok (+ result (/ (+ a b) u2))) - ) - ) -)"#; - - let mut owned_env = tl_env_factory.get_env(epoch); - - let epoch = StacksEpochId::Epoch21; - let ast = crate::vm::ast::build_ast( - &QualifiedContractIdentifier::transient(), - complex_impl, - &mut (), - version, - epoch, - ) - .unwrap(); - let static_cost_result = static_cost_from_ast(&ast, &version); - match static_cost_result { - Ok(static_cost) => { - let contract_id = QualifiedContractIdentifier::local("complex-impl").unwrap(); - owned_env - .initialize_versioned_contract(contract_id.clone(), version, complex_impl, None) - .unwrap(); - - let dynamic_cost = execute_contract_function_and_get_cost( - &mut owned_env, - &contract_id, - "somefunc", - &[7, 8], - version, - ); - - let key = static_cost.keys().nth(1).unwrap(); - let (cost, _trait_count) = static_cost.get(key).unwrap(); - println!("dynamic_cost: {:?}", dynamic_cost); - println!("cost: {:?}", cost); - assert!(dynamic_cost.runtime >= cost.min.runtime); - assert!(dynamic_cost.runtime <= cost.max.runtime); - } - Err(e) => { - println!("Static cost analysis failed: {}", e); - } - } -} - #[test] fn test_build_cost_analysis_tree_function_definition() { let src = r#"(define-public (somefunc (a uint)) @@ -341,7 +228,7 @@ fn test_pox_4_costs() { .expect("Failed to build AST from pox-4.clar"); let cost_map = static_cost_from_ast(&ast, &clarity_version) - .expect("Failed to perform static cost analysis on pox-4.clar"); + .expect("Failed to get static cost analysis for pox-4.clar"); // Check some functions in the cost map let key_functions = vec![ From 8e499d6b04c279422d951f0c5e20fd9deb8d6721 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:25:01 -0800 Subject: [PATCH 17/23] move static_cost_native to test only --- clarity/src/vm/costs/analysis.rs | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index cca8aa5414..0c594f4d57 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -205,25 +205,6 @@ fn make_ast( Ok(ast) } -/// somewhat of a passthrough since we don't have to build the whole context we -/// can jsut return the cost of the single expression -fn static_cost_native( - source: &str, - cost_map: &HashMap>, - clarity_version: &ClarityVersion, -) -> Result { - let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version - let ast = make_ast(source, epoch, clarity_version)?; - let exprs = &ast.expressions; - let user_args = UserArgumentsContext::new(); - let expr = &exprs[0]; - let (_, cost_analysis_tree) = - build_cost_analysis_tree(&expr, &user_args, cost_map, clarity_version)?; - - let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); - Ok(summing_cost.into()) -} - /// STatic execution cost for functions within Environment /// returns the top level cost for specific functions /// {function_name: cost} @@ -604,7 +585,17 @@ mod tests { clarity_version: &ClarityVersion, ) -> Result { let cost_map: HashMap> = HashMap::new(); - static_cost_native(source, &cost_map, clarity_version) + + let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version + let ast = make_ast(source, epoch, clarity_version)?; + let exprs = &ast.expressions; + let user_args = UserArgumentsContext::new(); + let expr = &exprs[0]; + let (_, cost_analysis_tree) = + build_cost_analysis_tree(&expr, &user_args, &cost_map, clarity_version)?; + + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); + Ok(summing_cost.into()) } fn static_cost_test( From 4dffab25f6994d656490dabcb71cecf60c1a5453 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:35:42 -0800 Subject: [PATCH 18/23] make static_cost functions return trait counts also --- clarity/src/vm/costs/analysis.rs | 44 ++++++++++++++++++++++---------- clarity/src/vm/tests/analysis.rs | 9 ++++--- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 0c594f4d57..7c0045b002 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -207,11 +207,11 @@ fn make_ast( /// STatic execution cost for functions within Environment /// returns the top level cost for specific functions -/// {function_name: cost} +/// {some-function-name: (CostAnalysisNode, Some({some-function-name: (1,1)}))} pub fn static_cost( env: &mut Environment, contract_identifier: &QualifiedContractIdentifier, -) -> Result, String> { +) -> Result)>, String> { let contract_source = env .global_context .database @@ -234,11 +234,7 @@ pub fn static_cost( let epoch = env.global_context.epoch_id; let ast = make_ast(&contract_source, epoch, clarity_version)?; - let costs = static_cost_from_ast(&ast, clarity_version)?; - Ok(costs - .into_iter() - .map(|(name, (cost, _trait_count))| (name, cost)) - .collect()) + static_cost_tree_from_ast(&ast, clarity_version) } /// same idea as `static_cost` but returns the root of the cost analysis tree for each function @@ -246,7 +242,7 @@ pub fn static_cost( pub fn static_cost_tree( env: &mut Environment, contract_identifier: &QualifiedContractIdentifier, -) -> Result, String> { +) -> Result)>, String> { let contract_source = env .global_context .database @@ -276,16 +272,23 @@ pub fn static_cost_from_ast( contract_ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, ) -> Result)>, String> { - let cost_trees = static_cost_tree_from_ast(contract_ast, clarity_version)?; + let cost_trees_with_traits = static_cost_tree_from_ast(contract_ast, clarity_version)?; - let trait_count = get_trait_count(&cost_trees); - let costs: HashMap = cost_trees + // Extract trait_count from the first entry (all entries have the same trait_count) + let trait_count = cost_trees_with_traits + .values() + .next() + .and_then(|(_, trait_count)| trait_count.clone()); + + // Convert CostAnalysisNode to StaticCost + let costs: HashMap = cost_trees_with_traits .into_iter() - .map(|(name, cost_analysis_node)| { + .map(|(name, (cost_analysis_node, _))| { let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node); (name, summing_cost.into()) }) .collect(); + Ok(costs .into_iter() .map(|(name, cost)| (name, (cost, trait_count.clone()))) @@ -295,16 +298,18 @@ pub fn static_cost_from_ast( pub(crate) fn static_cost_tree_from_ast( ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, -) -> Result, String> { +) -> Result)>, String> { let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); let costs_map: HashMap> = HashMap::new(); let mut costs: HashMap> = HashMap::new(); + // first pass extracts the function names for expr in exprs { if let Some(function_name) = extract_function_name(expr) { costs.insert(function_name, None); } } + // second pass computes the cost for expr in exprs { if let Some(function_name) = extract_function_name(expr) { let (_, cost_analysis_tree) = @@ -312,9 +317,20 @@ pub(crate) fn static_cost_tree_from_ast( costs.insert(function_name, Some(cost_analysis_tree)); } } - Ok(costs + + // Build the final map with cost analysis nodes + let cost_trees: HashMap = costs .into_iter() .filter_map(|(name, cost)| cost.map(|c| (name, c))) + .collect(); + + // Compute trait_count while creating the root CostAnalysisNode + let trait_count = get_trait_count(&cost_trees); + + // Return each node with its trait_count + Ok(cost_trees + .into_iter() + .map(|(name, node)| (name, (node, trait_count.clone()))) .collect()) } diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index a1f36a7083..d4750210f2 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -7,8 +7,7 @@ use stacks_common::types::StacksEpochId; use crate::vm::contexts::OwnedEnvironment; use crate::vm::costs::analysis::{ - build_cost_analysis_tree, get_trait_count, static_cost_from_ast, static_cost_tree_from_ast, - UserArgumentsContext, + build_cost_analysis_tree, static_cost_from_ast, static_cost_tree_from_ast, UserArgumentsContext, }; use crate::vm::costs::ExecutionCost; use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; @@ -106,7 +105,11 @@ fn test_get_trait_count_direct() { let costs = static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); - let trait_count = get_trait_count(&costs); + // Extract trait_count from the result (all entries have the same trait_count) + let trait_count = costs + .values() + .next() + .and_then(|(_, trait_count)| trait_count.clone()); let expected = { let mut map = HashMap::new(); From 9e194d6f838f11e57986a0fa45bd4315b5b8f7ab Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:11:22 -0800 Subject: [PATCH 19/23] delegate to correct cost modules --- clarity/src/vm/ast/static_cost/mod.rs | 260 +++++++++++++------------- 1 file changed, 135 insertions(+), 125 deletions(-) diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs index bcf20c1110..3f1fe9569f 100644 --- a/clarity/src/vm/ast/static_cost/mod.rs +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -11,7 +11,10 @@ use crate::vm::costs::analysis::{ CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, }; use crate::vm::costs::cost_functions::{linear, CostValues}; +use crate::vm::costs::costs_1::Costs1; +use crate::vm::costs::costs_2::Costs2; use crate::vm::costs::costs_3::Costs3; +use crate::vm::costs::costs_4::Costs4; use crate::vm::costs::ExecutionCost; use crate::vm::errors::VmExecutionError; use crate::vm::functions::NativeFunctions; @@ -21,126 +24,138 @@ use crate::vm::{ClarityVersion, Value}; const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; -/// Convert a NativeFunctions enum variant to its corresponding cost function -/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in +/// Macro to dispatch to the correct Costs module based on clarity version +/// returns a function pointer to the appropriate cost function +macro_rules! dispatch_cost { + ($version:expr, $cost_fn:ident) => { + match $version { + ClarityVersion::Clarity1 => Costs1::$cost_fn, + ClarityVersion::Clarity2 => Costs2::$cost_fn, + ClarityVersion::Clarity3 => Costs3::$cost_fn, + ClarityVersion::Clarity4 => Costs4::$cost_fn, + } + }; +} + +/// NativeFunctions -> cost via appropriate cost fn pub(crate) fn get_cost_function_for_native( function: NativeFunctions, - _clarity_version: &ClarityVersion, + clarity_version: &ClarityVersion, ) -> Option Result> { use crate::vm::functions::NativeFunctions::*; - // Map NativeFunctions enum variants to their cost functions + // Map NativeFunctions enum to cost functions match function { - Add => Some(Costs3::cost_add), - Subtract => Some(Costs3::cost_sub), - Multiply => Some(Costs3::cost_mul), - Divide => Some(Costs3::cost_div), - Modulo => Some(Costs3::cost_mod), - Power => Some(Costs3::cost_pow), - Sqrti => Some(Costs3::cost_sqrti), - Log2 => Some(Costs3::cost_log2), - ToInt | ToUInt => Some(Costs3::cost_int_cast), - Equals => Some(Costs3::cost_eq), - CmpGeq => Some(Costs3::cost_geq), - CmpLeq => Some(Costs3::cost_leq), - CmpGreater => Some(Costs3::cost_ge), - CmpLess => Some(Costs3::cost_le), - BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor), - Not | BitwiseNot => Some(Costs3::cost_not), - And | BitwiseAnd => Some(Costs3::cost_and), - Or | BitwiseOr => Some(Costs3::cost_or), - Concat => Some(Costs3::cost_concat), - Len => Some(Costs3::cost_len), - AsMaxLen => Some(Costs3::cost_as_max_len), - ListCons => Some(Costs3::cost_list_cons), - ElementAt | ElementAtAlias => Some(Costs3::cost_element_at), - IndexOf | IndexOfAlias => Some(Costs3::cost_index_of), - Fold => Some(Costs3::cost_fold), - Map => Some(Costs3::cost_map), - Filter => Some(Costs3::cost_filter), - Append => Some(Costs3::cost_append), - TupleGet => Some(Costs3::cost_tuple_get), - TupleMerge => Some(Costs3::cost_tuple_merge), - TupleCons => Some(Costs3::cost_tuple_cons), - ConsSome => Some(Costs3::cost_some_cons), - ConsOkay => Some(Costs3::cost_ok_cons), - ConsError => Some(Costs3::cost_err_cons), - DefaultTo => Some(Costs3::cost_default_to), - UnwrapRet => Some(Costs3::cost_unwrap_ret), - UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret), - IsOkay => Some(Costs3::cost_is_okay), - IsNone => Some(Costs3::cost_is_none), - IsErr => Some(Costs3::cost_is_err), - IsSome => Some(Costs3::cost_is_some), - Unwrap => Some(Costs3::cost_unwrap), - UnwrapErr => Some(Costs3::cost_unwrap_err), - TryRet => Some(Costs3::cost_try_ret), - If => Some(Costs3::cost_if), - Match => Some(Costs3::cost_match), - Begin => Some(Costs3::cost_begin), - Let => Some(Costs3::cost_let), - Asserts => Some(Costs3::cost_asserts), - Hash160 => Some(Costs3::cost_hash160), - Sha256 => Some(Costs3::cost_sha256), - Sha512 => Some(Costs3::cost_sha512), - Sha512Trunc256 => Some(Costs3::cost_sha512t256), - Keccak256 => Some(Costs3::cost_keccak256), - Secp256k1Recover => Some(Costs3::cost_secp256k1recover), - Secp256k1Verify => Some(Costs3::cost_secp256k1verify), - Print => Some(Costs3::cost_print), - ContractCall => Some(Costs3::cost_contract_call), - ContractOf => Some(Costs3::cost_contract_of), - PrincipalOf => Some(Costs3::cost_principal_of), - AtBlock => Some(Costs3::cost_at_block), - // => Some(Costs3::cost_create_map), - // => Some(Costs3::cost_create_var), - // ContractStorage => Some(Costs3::cost_contract_storage), - FetchEntry => Some(Costs3::cost_fetch_entry), - SetEntry => Some(Costs3::cost_set_entry), - FetchVar => Some(Costs3::cost_fetch_var), - SetVar => Some(Costs3::cost_set_var), - GetBlockInfo => Some(Costs3::cost_block_info), - GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), - GetStxBalance => Some(Costs3::cost_stx_balance), - StxTransfer => Some(Costs3::cost_stx_transfer), - StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), - StxGetAccount => Some(Costs3::cost_stx_account), - MintToken => Some(Costs3::cost_ft_mint), - MintAsset => Some(Costs3::cost_nft_mint), - TransferToken => Some(Costs3::cost_ft_transfer), - GetTokenBalance => Some(Costs3::cost_ft_balance), - GetTokenSupply => Some(Costs3::cost_ft_get_supply), - BurnToken => Some(Costs3::cost_ft_burn), - TransferAsset => Some(Costs3::cost_nft_transfer), - GetAssetOwner => Some(Costs3::cost_nft_owner), - BurnAsset => Some(Costs3::cost_nft_burn), - BuffToIntLe => Some(Costs3::cost_buff_to_int_le), - BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le), - BuffToIntBe => Some(Costs3::cost_buff_to_int_be), - BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be), - ToConsensusBuff => Some(Costs3::cost_to_consensus_buff), - FromConsensusBuff => Some(Costs3::cost_from_consensus_buff), - IsStandard => Some(Costs3::cost_is_standard), - PrincipalDestruct => Some(Costs3::cost_principal_destruct), - PrincipalConstruct => Some(Costs3::cost_principal_construct), - AsContract | AsContractSafe => Some(Costs3::cost_as_contract), - StringToInt => Some(Costs3::cost_string_to_int), - StringToUInt => Some(Costs3::cost_string_to_uint), - IntToAscii => Some(Costs3::cost_int_to_ascii), - IntToUtf8 => Some(Costs3::cost_int_to_utf8), - BitwiseLShift => Some(Costs3::cost_bitwise_left_shift), - BitwiseRShift => Some(Costs3::cost_bitwise_right_shift), - Slice => Some(Costs3::cost_slice), - ReplaceAt => Some(Costs3::cost_replace_at), - GetStacksBlockInfo => Some(Costs3::cost_block_info), - GetTenureInfo => Some(Costs3::cost_block_info), - ContractHash => Some(Costs3::cost_contract_hash), - ToAscii => Some(Costs3::cost_to_ascii), - InsertEntry => Some(Costs3::cost_set_entry), - DeleteEntry => Some(Costs3::cost_set_entry), - StxBurn => Some(Costs3::cost_stx_transfer), - Secp256r1Verify => Some(Costs3::cost_secp256r1verify), - RestrictAssets => None, // TODO: add cost function + Add => Some(dispatch_cost!(clarity_version, cost_add)), + Subtract => Some(dispatch_cost!(clarity_version, cost_sub)), + Multiply => Some(dispatch_cost!(clarity_version, cost_mul)), + Divide => Some(dispatch_cost!(clarity_version, cost_div)), + Modulo => Some(dispatch_cost!(clarity_version, cost_mod)), + Power => Some(dispatch_cost!(clarity_version, cost_pow)), + Sqrti => Some(dispatch_cost!(clarity_version, cost_sqrti)), + Log2 => Some(dispatch_cost!(clarity_version, cost_log2)), + ToInt | ToUInt => Some(dispatch_cost!(clarity_version, cost_int_cast)), + Equals => Some(dispatch_cost!(clarity_version, cost_eq)), + CmpGeq => Some(dispatch_cost!(clarity_version, cost_geq)), + CmpLeq => Some(dispatch_cost!(clarity_version, cost_leq)), + CmpGreater => Some(dispatch_cost!(clarity_version, cost_ge)), + CmpLess => Some(dispatch_cost!(clarity_version, cost_le)), + BitwiseXor | BitwiseXor2 => Some(dispatch_cost!(clarity_version, cost_xor)), + Not | BitwiseNot => Some(dispatch_cost!(clarity_version, cost_not)), + And | BitwiseAnd => Some(dispatch_cost!(clarity_version, cost_and)), + Or | BitwiseOr => Some(dispatch_cost!(clarity_version, cost_or)), + Concat => Some(dispatch_cost!(clarity_version, cost_concat)), + Len => Some(dispatch_cost!(clarity_version, cost_len)), + AsMaxLen => Some(dispatch_cost!(clarity_version, cost_as_max_len)), + ListCons => Some(dispatch_cost!(clarity_version, cost_list_cons)), + ElementAt | ElementAtAlias => Some(dispatch_cost!(clarity_version, cost_element_at)), + IndexOf | IndexOfAlias => Some(dispatch_cost!(clarity_version, cost_index_of)), + Fold => Some(dispatch_cost!(clarity_version, cost_fold)), + Map => Some(dispatch_cost!(clarity_version, cost_map)), + Filter => Some(dispatch_cost!(clarity_version, cost_filter)), + Append => Some(dispatch_cost!(clarity_version, cost_append)), + TupleGet => Some(dispatch_cost!(clarity_version, cost_tuple_get)), + TupleMerge => Some(dispatch_cost!(clarity_version, cost_tuple_merge)), + TupleCons => Some(dispatch_cost!(clarity_version, cost_tuple_cons)), + ConsSome => Some(dispatch_cost!(clarity_version, cost_some_cons)), + ConsOkay => Some(dispatch_cost!(clarity_version, cost_ok_cons)), + ConsError => Some(dispatch_cost!(clarity_version, cost_err_cons)), + DefaultTo => Some(dispatch_cost!(clarity_version, cost_default_to)), + UnwrapRet => Some(dispatch_cost!(clarity_version, cost_unwrap_ret)), + UnwrapErrRet => Some(dispatch_cost!(clarity_version, cost_unwrap_err_or_ret)), + IsOkay => Some(dispatch_cost!(clarity_version, cost_is_okay)), + IsNone => Some(dispatch_cost!(clarity_version, cost_is_none)), + IsErr => Some(dispatch_cost!(clarity_version, cost_is_err)), + IsSome => Some(dispatch_cost!(clarity_version, cost_is_some)), + Unwrap => Some(dispatch_cost!(clarity_version, cost_unwrap)), + UnwrapErr => Some(dispatch_cost!(clarity_version, cost_unwrap_err)), + TryRet => Some(dispatch_cost!(clarity_version, cost_try_ret)), + If => Some(dispatch_cost!(clarity_version, cost_if)), + Match => Some(dispatch_cost!(clarity_version, cost_match)), + Begin => Some(dispatch_cost!(clarity_version, cost_begin)), + Let => Some(dispatch_cost!(clarity_version, cost_let)), + Asserts => Some(dispatch_cost!(clarity_version, cost_asserts)), + Hash160 => Some(dispatch_cost!(clarity_version, cost_hash160)), + Sha256 => Some(dispatch_cost!(clarity_version, cost_sha256)), + Sha512 => Some(dispatch_cost!(clarity_version, cost_sha512)), + Sha512Trunc256 => Some(dispatch_cost!(clarity_version, cost_sha512t256)), + Keccak256 => Some(dispatch_cost!(clarity_version, cost_keccak256)), + Secp256k1Recover => Some(dispatch_cost!(clarity_version, cost_secp256k1recover)), + Secp256k1Verify => Some(dispatch_cost!(clarity_version, cost_secp256k1verify)), + Print => Some(dispatch_cost!(clarity_version, cost_print)), + ContractCall => Some(dispatch_cost!(clarity_version, cost_contract_call)), + ContractOf => Some(dispatch_cost!(clarity_version, cost_contract_of)), + PrincipalOf => Some(dispatch_cost!(clarity_version, cost_principal_of)), + AtBlock => Some(dispatch_cost!(clarity_version, cost_at_block)), + // => Some(dispatch_cost!(clarity_version, cost_create_map)), + // => Some(dispatch_cost!(clarity_version, cost_create_var)), + // ContractStorage => Some(dispatch_cost!(clarity_version, cost_contract_storage)), + FetchEntry => Some(dispatch_cost!(clarity_version, cost_fetch_entry)), + SetEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), + FetchVar => Some(dispatch_cost!(clarity_version, cost_fetch_var)), + SetVar => Some(dispatch_cost!(clarity_version, cost_set_var)), + GetBlockInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), + GetBurnBlockInfo => Some(dispatch_cost!(clarity_version, cost_burn_block_info)), + GetStxBalance => Some(dispatch_cost!(clarity_version, cost_stx_balance)), + StxTransfer => Some(dispatch_cost!(clarity_version, cost_stx_transfer)), + StxTransferMemo => Some(dispatch_cost!(clarity_version, cost_stx_transfer_memo)), + StxGetAccount => Some(dispatch_cost!(clarity_version, cost_stx_account)), + MintToken => Some(dispatch_cost!(clarity_version, cost_ft_mint)), + MintAsset => Some(dispatch_cost!(clarity_version, cost_nft_mint)), + TransferToken => Some(dispatch_cost!(clarity_version, cost_ft_transfer)), + GetTokenBalance => Some(dispatch_cost!(clarity_version, cost_ft_balance)), + GetTokenSupply => Some(dispatch_cost!(clarity_version, cost_ft_get_supply)), + BurnToken => Some(dispatch_cost!(clarity_version, cost_ft_burn)), + TransferAsset => Some(dispatch_cost!(clarity_version, cost_nft_transfer)), + GetAssetOwner => Some(dispatch_cost!(clarity_version, cost_nft_owner)), + BurnAsset => Some(dispatch_cost!(clarity_version, cost_nft_burn)), + BuffToIntLe => Some(dispatch_cost!(clarity_version, cost_buff_to_int_le)), + BuffToUIntLe => Some(dispatch_cost!(clarity_version, cost_buff_to_uint_le)), + BuffToIntBe => Some(dispatch_cost!(clarity_version, cost_buff_to_int_be)), + BuffToUIntBe => Some(dispatch_cost!(clarity_version, cost_buff_to_uint_be)), + ToConsensusBuff => Some(dispatch_cost!(clarity_version, cost_to_consensus_buff)), + FromConsensusBuff => Some(dispatch_cost!(clarity_version, cost_from_consensus_buff)), + IsStandard => Some(dispatch_cost!(clarity_version, cost_is_standard)), + PrincipalDestruct => Some(dispatch_cost!(clarity_version, cost_principal_destruct)), + PrincipalConstruct => Some(dispatch_cost!(clarity_version, cost_principal_construct)), + AsContract | AsContractSafe => Some(dispatch_cost!(clarity_version, cost_as_contract)), + StringToInt => Some(dispatch_cost!(clarity_version, cost_string_to_int)), + StringToUInt => Some(dispatch_cost!(clarity_version, cost_string_to_uint)), + IntToAscii => Some(dispatch_cost!(clarity_version, cost_int_to_ascii)), + IntToUtf8 => Some(dispatch_cost!(clarity_version, cost_int_to_utf8)), + BitwiseLShift => Some(dispatch_cost!(clarity_version, cost_bitwise_left_shift)), + BitwiseRShift => Some(dispatch_cost!(clarity_version, cost_bitwise_right_shift)), + Slice => Some(dispatch_cost!(clarity_version, cost_slice)), + ReplaceAt => Some(dispatch_cost!(clarity_version, cost_replace_at)), + GetStacksBlockInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), + GetTenureInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), + ContractHash => Some(dispatch_cost!(clarity_version, cost_contract_hash)), + ToAscii => Some(dispatch_cost!(clarity_version, cost_to_ascii)), + InsertEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), + DeleteEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), + StxBurn => Some(dispatch_cost!(clarity_version, cost_stx_transfer)), + Secp256r1Verify => Some(dispatch_cost!(clarity_version, cost_secp256r1verify)), + RestrictAssets => Some(dispatch_cost!(clarity_version, cost_restrict_assets)), AllowanceWithStx => None, // TODO: add cost function AllowanceWithFt => None, // TODO: add cost function AllowanceWithNft => None, // TODO: add cost function @@ -149,7 +164,6 @@ pub(crate) fn get_cost_function_for_native( } } -// Calculate function cost with lazy evaluation support pub(crate) fn calculate_function_cost( function_name: String, cost_map: &HashMap>, @@ -161,8 +175,8 @@ pub(crate) fn calculate_function_cost( Ok(cost.clone()) } Some(None) => { - // Should be impossible but alas.. - // Function exists but cost not yet computed - this indicates a circular dependency + // Should be impossible.. + // Function exists but cost not yet computed, circular dependency? // For now, return zero cost to avoid infinite recursion println!( "Circular dependency detected for function: {}", @@ -188,8 +202,6 @@ pub(crate) fn is_branching_function(function_name: &ClarityName) -> bool { } } -/// Helper function to determine if a node represents a branching operation -/// This is used in tests and cost calculation pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool { match &node.expr { CostExprNode::NativeFunction(NativeFunctions::If) @@ -199,7 +211,7 @@ pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool { } } -/// Calculate the cost for a string based on its length +/// string cost based on length fn string_cost(length: usize) -> StaticCost { let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); let execution_cost = ExecutionCost::runtime(cost); @@ -209,7 +221,7 @@ fn string_cost(length: usize) -> StaticCost { } } -/// Calculate cost for a value (used for literal values) +/// Strings are the only Value's with costs associated pub(crate) fn calculate_value_cost(value: &Value) -> Result { match value { Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { @@ -230,7 +242,6 @@ pub(crate) fn calculate_function_cost_from_native_function( let cost_function = match get_cost_function_for_native(native_function, clarity_version) { Some(cost_fn) => cost_fn, None => { - // TODO: zero cost for now return Ok(StaticCost::ZERO); } }; @@ -242,7 +253,7 @@ pub(crate) fn calculate_function_cost_from_native_function( }) } -/// Calculate total cost using SummingExecutionCost to handle branching properly +/// total cost handling branching pub(crate) fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); @@ -318,8 +329,7 @@ impl From for StaticCost { } } -/// Helper: calculate min & max costs for a given cost function -/// This is likely tooo simplistic but for now it'll do +/// get min & max costs for a given cost function fn get_costs( cost_fn: fn(u64) -> Result, arg_count: u64, From d2cc5dec057896f3aafbefc68555c1716f6b97a1 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:41:47 -0800 Subject: [PATCH 20/23] gate static costs behind developer-mode --- clarity/src/vm/ast/mod.rs | 1 + clarity/src/vm/ast/static_cost/mod.rs | 1 - clarity/src/vm/ast/static_cost/trait_counter.rs | 1 + clarity/src/vm/costs/mod.rs | 1 + clarity/src/vm/tests/mod.rs | 2 +- 5 files changed, 4 insertions(+), 2 deletions(-) diff --git a/clarity/src/vm/ast/mod.rs b/clarity/src/vm/ast/mod.rs index c778354b36..371adedbcc 100644 --- a/clarity/src/vm/ast/mod.rs +++ b/clarity/src/vm/ast/mod.rs @@ -17,6 +17,7 @@ pub mod definition_sorter; pub mod expression_identifier; pub mod parser; +#[cfg(feature = "developer-mode")] pub mod static_cost; pub mod traits_resolver; diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs index 3f1fe9569f..8280c3d168 100644 --- a/clarity/src/vm/ast/static_cost/mod.rs +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -6,7 +6,6 @@ pub use trait_counter::{ TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, }; -// Import types from analysis.rs use crate::vm::costs::analysis::{ CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, }; diff --git a/clarity/src/vm/ast/static_cost/trait_counter.rs b/clarity/src/vm/ast/static_cost/trait_counter.rs index 10a67afd8f..7cd615fd76 100644 --- a/clarity/src/vm/ast/static_cost/trait_counter.rs +++ b/clarity/src/vm/ast/static_cost/trait_counter.rs @@ -7,6 +7,7 @@ use crate::vm::ast::static_cost::{CostAnalysisNode, CostExprNode}; use crate::vm::costs::analysis::is_function_definition; use crate::vm::functions::NativeFunctions; use crate::vm::representations::{SymbolicExpression, SymbolicExpressionType}; + type MinMaxTraitCount = (u64, u64); pub type TraitCount = HashMap; diff --git a/clarity/src/vm/costs/mod.rs b/clarity/src/vm/costs/mod.rs index ea055ddea5..dadf218462 100644 --- a/clarity/src/vm/costs/mod.rs +++ b/clarity/src/vm/costs/mod.rs @@ -42,6 +42,7 @@ use crate::vm::types::{ FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, }; use crate::vm::{CallStack, ClarityName, Environment, LocalContext, SymbolicExpression, Value}; +#[cfg(feature = "developer-mode")] pub mod analysis; pub mod constants; pub mod cost_functions; diff --git a/clarity/src/vm/tests/mod.rs b/clarity/src/vm/tests/mod.rs index 8a30cc13de..bd2353273e 100644 --- a/clarity/src/vm/tests/mod.rs +++ b/clarity/src/vm/tests/mod.rs @@ -24,7 +24,7 @@ use crate::vm::contexts::OwnedEnvironment; pub use crate::vm::database::BurnStateDB; use crate::vm::database::MemoryBackingStore; -#[cfg(test)] +#[cfg(all(test, feature = "developer-mode"))] mod analysis; mod assets; mod contracts; From 2c3dca4705b4ef865025e3519d0283b288412236 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:58:45 -0800 Subject: [PATCH 21/23] use lookup_reserved_functions --- clarity/src/vm/ast/static_cost/mod.rs | 169 +++----------------------- clarity/src/vm/costs/analysis.rs | 25 ++-- 2 files changed, 29 insertions(+), 165 deletions(-) diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs index 8280c3d168..d5470dbba8 100644 --- a/clarity/src/vm/ast/static_cost/mod.rs +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -6,163 +6,24 @@ pub use trait_counter::{ TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, }; +use crate::vm::callables::CallableType; use crate::vm::costs::analysis::{ CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, }; -use crate::vm::costs::cost_functions::{linear, CostValues}; +use crate::vm::costs::cost_functions::linear; use crate::vm::costs::costs_1::Costs1; use crate::vm::costs::costs_2::Costs2; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::costs_4::Costs4; use crate::vm::costs::ExecutionCost; use crate::vm::errors::VmExecutionError; -use crate::vm::functions::NativeFunctions; +use crate::vm::functions::{lookup_reserved_functions, NativeFunctions}; use crate::vm::representations::ClarityName; use crate::vm::{ClarityVersion, Value}; const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; -/// Macro to dispatch to the correct Costs module based on clarity version -/// returns a function pointer to the appropriate cost function -macro_rules! dispatch_cost { - ($version:expr, $cost_fn:ident) => { - match $version { - ClarityVersion::Clarity1 => Costs1::$cost_fn, - ClarityVersion::Clarity2 => Costs2::$cost_fn, - ClarityVersion::Clarity3 => Costs3::$cost_fn, - ClarityVersion::Clarity4 => Costs4::$cost_fn, - } - }; -} - -/// NativeFunctions -> cost via appropriate cost fn -pub(crate) fn get_cost_function_for_native( - function: NativeFunctions, - clarity_version: &ClarityVersion, -) -> Option Result> { - use crate::vm::functions::NativeFunctions::*; - - // Map NativeFunctions enum to cost functions - match function { - Add => Some(dispatch_cost!(clarity_version, cost_add)), - Subtract => Some(dispatch_cost!(clarity_version, cost_sub)), - Multiply => Some(dispatch_cost!(clarity_version, cost_mul)), - Divide => Some(dispatch_cost!(clarity_version, cost_div)), - Modulo => Some(dispatch_cost!(clarity_version, cost_mod)), - Power => Some(dispatch_cost!(clarity_version, cost_pow)), - Sqrti => Some(dispatch_cost!(clarity_version, cost_sqrti)), - Log2 => Some(dispatch_cost!(clarity_version, cost_log2)), - ToInt | ToUInt => Some(dispatch_cost!(clarity_version, cost_int_cast)), - Equals => Some(dispatch_cost!(clarity_version, cost_eq)), - CmpGeq => Some(dispatch_cost!(clarity_version, cost_geq)), - CmpLeq => Some(dispatch_cost!(clarity_version, cost_leq)), - CmpGreater => Some(dispatch_cost!(clarity_version, cost_ge)), - CmpLess => Some(dispatch_cost!(clarity_version, cost_le)), - BitwiseXor | BitwiseXor2 => Some(dispatch_cost!(clarity_version, cost_xor)), - Not | BitwiseNot => Some(dispatch_cost!(clarity_version, cost_not)), - And | BitwiseAnd => Some(dispatch_cost!(clarity_version, cost_and)), - Or | BitwiseOr => Some(dispatch_cost!(clarity_version, cost_or)), - Concat => Some(dispatch_cost!(clarity_version, cost_concat)), - Len => Some(dispatch_cost!(clarity_version, cost_len)), - AsMaxLen => Some(dispatch_cost!(clarity_version, cost_as_max_len)), - ListCons => Some(dispatch_cost!(clarity_version, cost_list_cons)), - ElementAt | ElementAtAlias => Some(dispatch_cost!(clarity_version, cost_element_at)), - IndexOf | IndexOfAlias => Some(dispatch_cost!(clarity_version, cost_index_of)), - Fold => Some(dispatch_cost!(clarity_version, cost_fold)), - Map => Some(dispatch_cost!(clarity_version, cost_map)), - Filter => Some(dispatch_cost!(clarity_version, cost_filter)), - Append => Some(dispatch_cost!(clarity_version, cost_append)), - TupleGet => Some(dispatch_cost!(clarity_version, cost_tuple_get)), - TupleMerge => Some(dispatch_cost!(clarity_version, cost_tuple_merge)), - TupleCons => Some(dispatch_cost!(clarity_version, cost_tuple_cons)), - ConsSome => Some(dispatch_cost!(clarity_version, cost_some_cons)), - ConsOkay => Some(dispatch_cost!(clarity_version, cost_ok_cons)), - ConsError => Some(dispatch_cost!(clarity_version, cost_err_cons)), - DefaultTo => Some(dispatch_cost!(clarity_version, cost_default_to)), - UnwrapRet => Some(dispatch_cost!(clarity_version, cost_unwrap_ret)), - UnwrapErrRet => Some(dispatch_cost!(clarity_version, cost_unwrap_err_or_ret)), - IsOkay => Some(dispatch_cost!(clarity_version, cost_is_okay)), - IsNone => Some(dispatch_cost!(clarity_version, cost_is_none)), - IsErr => Some(dispatch_cost!(clarity_version, cost_is_err)), - IsSome => Some(dispatch_cost!(clarity_version, cost_is_some)), - Unwrap => Some(dispatch_cost!(clarity_version, cost_unwrap)), - UnwrapErr => Some(dispatch_cost!(clarity_version, cost_unwrap_err)), - TryRet => Some(dispatch_cost!(clarity_version, cost_try_ret)), - If => Some(dispatch_cost!(clarity_version, cost_if)), - Match => Some(dispatch_cost!(clarity_version, cost_match)), - Begin => Some(dispatch_cost!(clarity_version, cost_begin)), - Let => Some(dispatch_cost!(clarity_version, cost_let)), - Asserts => Some(dispatch_cost!(clarity_version, cost_asserts)), - Hash160 => Some(dispatch_cost!(clarity_version, cost_hash160)), - Sha256 => Some(dispatch_cost!(clarity_version, cost_sha256)), - Sha512 => Some(dispatch_cost!(clarity_version, cost_sha512)), - Sha512Trunc256 => Some(dispatch_cost!(clarity_version, cost_sha512t256)), - Keccak256 => Some(dispatch_cost!(clarity_version, cost_keccak256)), - Secp256k1Recover => Some(dispatch_cost!(clarity_version, cost_secp256k1recover)), - Secp256k1Verify => Some(dispatch_cost!(clarity_version, cost_secp256k1verify)), - Print => Some(dispatch_cost!(clarity_version, cost_print)), - ContractCall => Some(dispatch_cost!(clarity_version, cost_contract_call)), - ContractOf => Some(dispatch_cost!(clarity_version, cost_contract_of)), - PrincipalOf => Some(dispatch_cost!(clarity_version, cost_principal_of)), - AtBlock => Some(dispatch_cost!(clarity_version, cost_at_block)), - // => Some(dispatch_cost!(clarity_version, cost_create_map)), - // => Some(dispatch_cost!(clarity_version, cost_create_var)), - // ContractStorage => Some(dispatch_cost!(clarity_version, cost_contract_storage)), - FetchEntry => Some(dispatch_cost!(clarity_version, cost_fetch_entry)), - SetEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), - FetchVar => Some(dispatch_cost!(clarity_version, cost_fetch_var)), - SetVar => Some(dispatch_cost!(clarity_version, cost_set_var)), - GetBlockInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), - GetBurnBlockInfo => Some(dispatch_cost!(clarity_version, cost_burn_block_info)), - GetStxBalance => Some(dispatch_cost!(clarity_version, cost_stx_balance)), - StxTransfer => Some(dispatch_cost!(clarity_version, cost_stx_transfer)), - StxTransferMemo => Some(dispatch_cost!(clarity_version, cost_stx_transfer_memo)), - StxGetAccount => Some(dispatch_cost!(clarity_version, cost_stx_account)), - MintToken => Some(dispatch_cost!(clarity_version, cost_ft_mint)), - MintAsset => Some(dispatch_cost!(clarity_version, cost_nft_mint)), - TransferToken => Some(dispatch_cost!(clarity_version, cost_ft_transfer)), - GetTokenBalance => Some(dispatch_cost!(clarity_version, cost_ft_balance)), - GetTokenSupply => Some(dispatch_cost!(clarity_version, cost_ft_get_supply)), - BurnToken => Some(dispatch_cost!(clarity_version, cost_ft_burn)), - TransferAsset => Some(dispatch_cost!(clarity_version, cost_nft_transfer)), - GetAssetOwner => Some(dispatch_cost!(clarity_version, cost_nft_owner)), - BurnAsset => Some(dispatch_cost!(clarity_version, cost_nft_burn)), - BuffToIntLe => Some(dispatch_cost!(clarity_version, cost_buff_to_int_le)), - BuffToUIntLe => Some(dispatch_cost!(clarity_version, cost_buff_to_uint_le)), - BuffToIntBe => Some(dispatch_cost!(clarity_version, cost_buff_to_int_be)), - BuffToUIntBe => Some(dispatch_cost!(clarity_version, cost_buff_to_uint_be)), - ToConsensusBuff => Some(dispatch_cost!(clarity_version, cost_to_consensus_buff)), - FromConsensusBuff => Some(dispatch_cost!(clarity_version, cost_from_consensus_buff)), - IsStandard => Some(dispatch_cost!(clarity_version, cost_is_standard)), - PrincipalDestruct => Some(dispatch_cost!(clarity_version, cost_principal_destruct)), - PrincipalConstruct => Some(dispatch_cost!(clarity_version, cost_principal_construct)), - AsContract | AsContractSafe => Some(dispatch_cost!(clarity_version, cost_as_contract)), - StringToInt => Some(dispatch_cost!(clarity_version, cost_string_to_int)), - StringToUInt => Some(dispatch_cost!(clarity_version, cost_string_to_uint)), - IntToAscii => Some(dispatch_cost!(clarity_version, cost_int_to_ascii)), - IntToUtf8 => Some(dispatch_cost!(clarity_version, cost_int_to_utf8)), - BitwiseLShift => Some(dispatch_cost!(clarity_version, cost_bitwise_left_shift)), - BitwiseRShift => Some(dispatch_cost!(clarity_version, cost_bitwise_right_shift)), - Slice => Some(dispatch_cost!(clarity_version, cost_slice)), - ReplaceAt => Some(dispatch_cost!(clarity_version, cost_replace_at)), - GetStacksBlockInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), - GetTenureInfo => Some(dispatch_cost!(clarity_version, cost_block_info)), - ContractHash => Some(dispatch_cost!(clarity_version, cost_contract_hash)), - ToAscii => Some(dispatch_cost!(clarity_version, cost_to_ascii)), - InsertEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), - DeleteEntry => Some(dispatch_cost!(clarity_version, cost_set_entry)), - StxBurn => Some(dispatch_cost!(clarity_version, cost_stx_transfer)), - Secp256r1Verify => Some(dispatch_cost!(clarity_version, cost_secp256r1verify)), - RestrictAssets => Some(dispatch_cost!(clarity_version, cost_restrict_assets)), - AllowanceWithStx => None, // TODO: add cost function - AllowanceWithFt => None, // TODO: add cost function - AllowanceWithNft => None, // TODO: add cost function - AllowanceWithStacking => None, // TODO: add cost function - AllowanceAll => None, // TODO: add cost function - } -} - pub(crate) fn calculate_function_cost( function_name: String, cost_map: &HashMap>, @@ -238,14 +99,24 @@ pub(crate) fn calculate_function_cost_from_native_function( arg_count: u64, clarity_version: &ClarityVersion, ) -> Result { - let cost_function = match get_cost_function_for_native(native_function, clarity_version) { - Some(cost_fn) => cost_fn, - None => { - return Ok(StaticCost::ZERO); - } - }; + let cost_function = + match lookup_reserved_functions(native_function.to_string().as_str(), clarity_version) { + Some(CallableType::NativeFunction(_, _, cost_fn)) => cost_fn, + Some(CallableType::NativeFunction205(_, _, cost_fn, _)) => cost_fn, + Some(CallableType::SpecialFunction(_, _)) => return Ok(StaticCost::ZERO), + Some(CallableType::UserFunction(_)) => return Ok(StaticCost::ZERO), // TODO ? + None => { + return Ok(StaticCost::ZERO); + } + }; - let cost = get_costs(cost_function, arg_count)?; + let cost = match clarity_version { + ClarityVersion::Clarity1 => cost_function.eval::(arg_count), + ClarityVersion::Clarity2 => cost_function.eval::(arg_count), + ClarityVersion::Clarity3 => cost_function.eval::(arg_count), + ClarityVersion::Clarity4 => cost_function.eval::(arg_count), + } + .map_err(|e| format!("Cost calculation error: {:?}", e))?; Ok(StaticCost { min: cost.clone(), max: cost, diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 7c0045b002..70815ad5e9 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -681,10 +681,8 @@ mod tests { let source = r#"(concat "hello" "world")"#; let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); - // For concat with 2 arguments: - // linear(2, 37, 220) = 37*2 + 220 = 294 - assert_eq!(cost.min.runtime, 294); - assert_eq!(cost.max.runtime, 294); + assert_eq!(cost.min.runtime, 366); + assert_eq!(cost.max.runtime, 366); } #[test] @@ -692,24 +690,19 @@ mod tests { let source = r#"(len "hello")"#; let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); - // cost: 429 (constant) - len doesn't depend on string size - assert_eq!(cost.min.runtime, 429); - assert_eq!(cost.max.runtime, 429); + assert_eq!(cost.min.runtime, 612); + assert_eq!(cost.max.runtime, 612); } #[test] fn test_branching() { let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); - // min: 147 raw string - // max: 294 (concat) - - // ok = 199 - // if = 168 - // ge = (linear(n, 7, 128))) - let base_cost = 168 + ((2 * 7) + 128) + 199; - assert_eq!(cost.min.runtime, base_cost + 147); - assert_eq!(cost.max.runtime, base_cost + 294); + // min: raw string + // max: concat + + assert_eq!(cost.min.runtime, 346); + assert_eq!(cost.max.runtime, 565); } // ---- ExprTreee building specific tests From df8e4adab9f45a9cdaf44c9068c17d7a1b825a09 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Wed, 26 Nov 2025 09:19:27 -0800 Subject: [PATCH 22/23] pass epoch along for cost eval --- clarity/src/vm/ast/static_cost/mod.rs | 28 ++++++--- clarity/src/vm/costs/analysis.rs | 85 +++++++++++++++++++-------- clarity/src/vm/tests/analysis.rs | 24 ++++---- 3 files changed, 93 insertions(+), 44 deletions(-) diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs index d5470dbba8..47d30752c0 100644 --- a/clarity/src/vm/ast/static_cost/mod.rs +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -2,6 +2,7 @@ mod trait_counter; use std::collections::HashMap; use clarity_types::types::{CharType, SequenceData}; +use stacks_common::types::StacksEpochId; pub use trait_counter::{ TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, }; @@ -97,10 +98,12 @@ pub(crate) fn calculate_value_cost(value: &Value) -> Result pub(crate) fn calculate_function_cost_from_native_function( native_function: NativeFunctions, arg_count: u64, - clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result { + // Derive clarity_version from epoch for lookup_reserved_functions + let clarity_version = ClarityVersion::default_for_epoch(epoch); let cost_function = - match lookup_reserved_functions(native_function.to_string().as_str(), clarity_version) { + match lookup_reserved_functions(native_function.to_string().as_str(), &clarity_version) { Some(CallableType::NativeFunction(_, _, cost_fn)) => cost_fn, Some(CallableType::NativeFunction205(_, _, cost_fn, _)) => cost_fn, Some(CallableType::SpecialFunction(_, _)) => return Ok(StaticCost::ZERO), @@ -110,11 +113,22 @@ pub(crate) fn calculate_function_cost_from_native_function( } }; - let cost = match clarity_version { - ClarityVersion::Clarity1 => cost_function.eval::(arg_count), - ClarityVersion::Clarity2 => cost_function.eval::(arg_count), - ClarityVersion::Clarity3 => cost_function.eval::(arg_count), - ClarityVersion::Clarity4 => cost_function.eval::(arg_count), + let cost = match epoch { + StacksEpochId::Epoch20 => cost_function.eval::(arg_count), + StacksEpochId::Epoch2_05 => cost_function.eval::(arg_count), + StacksEpochId::Epoch21 + | StacksEpochId::Epoch22 + | StacksEpochId::Epoch23 + | StacksEpochId::Epoch24 + | StacksEpochId::Epoch25 + | StacksEpochId::Epoch30 + | StacksEpochId::Epoch31 + | StacksEpochId::Epoch32 => cost_function.eval::(arg_count), + StacksEpochId::Epoch33 => cost_function.eval::(arg_count), + StacksEpochId::Epoch10 => { + // fallback to costs 1 since epoch 1 doesn't have direct cost mapping + cost_function.eval::(arg_count) + } } .map_err(|e| format!("Cost calculation error: {:?}", e))?; Ok(StaticCost { diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 70815ad5e9..57fb36110f 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -234,7 +234,7 @@ pub fn static_cost( let epoch = env.global_context.epoch_id; let ast = make_ast(&contract_source, epoch, clarity_version)?; - static_cost_tree_from_ast(&ast, clarity_version) + static_cost_tree_from_ast(&ast, clarity_version, epoch) } /// same idea as `static_cost` but returns the root of the cost analysis tree for each function @@ -265,14 +265,15 @@ pub fn static_cost_tree( let epoch = env.global_context.epoch_id; let ast = make_ast(&contract_source, epoch, clarity_version)?; - static_cost_tree_from_ast(&ast, clarity_version) + static_cost_tree_from_ast(&ast, clarity_version, epoch) } pub fn static_cost_from_ast( contract_ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result)>, String> { - let cost_trees_with_traits = static_cost_tree_from_ast(contract_ast, clarity_version)?; + let cost_trees_with_traits = static_cost_tree_from_ast(contract_ast, clarity_version, epoch)?; // Extract trait_count from the first entry (all entries have the same trait_count) let trait_count = cost_trees_with_traits @@ -298,6 +299,7 @@ pub fn static_cost_from_ast( pub(crate) fn static_cost_tree_from_ast( ast: &crate::vm::ast::ContractAST, clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result)>, String> { let exprs = &ast.expressions; let user_args = UserArgumentsContext::new(); @@ -313,7 +315,7 @@ pub(crate) fn static_cost_tree_from_ast( for expr in exprs { if let Some(function_name) = extract_function_name(expr) { let (_, cost_analysis_tree) = - build_cost_analysis_tree(expr, &user_args, &costs_map, clarity_version)?; + build_cost_analysis_tree(expr, &user_args, &costs_map, clarity_version, epoch)?; costs.insert(function_name, Some(cost_analysis_tree)); } } @@ -353,6 +355,7 @@ pub fn build_cost_analysis_tree( user_args: &UserArgumentsContext, cost_map: &HashMap>, clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result<(Option, CostAnalysisNode), String> { match &expr.expr { SymbolicExpressionType::List(list) => { @@ -364,6 +367,7 @@ pub fn build_cost_analysis_tree( user_args, cost_map, clarity_version, + epoch, )?; Ok((Some(returned_function_name), cost_analysis_tree)) } else { @@ -372,12 +376,18 @@ pub fn build_cost_analysis_tree( user_args, cost_map, clarity_version, + epoch, )?; Ok((None, cost_analysis_tree)) } } else { - let cost_analysis_tree = - build_listlike_cost_analysis_tree(list, user_args, cost_map, clarity_version)?; + let cost_analysis_tree = build_listlike_cost_analysis_tree( + list, + user_args, + cost_map, + clarity_version, + epoch, + )?; Ok((None, cost_analysis_tree)) } } @@ -439,6 +449,7 @@ fn build_function_definition_cost_analysis_tree( _user_args: &UserArgumentsContext, cost_map: &HashMap>, clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result<(String, CostAnalysisNode), String> { let define_type = list[0] .match_atom() @@ -476,7 +487,7 @@ fn build_function_definition_cost_analysis_tree( // Process the function body with the function's user arguments context let (_, body_tree) = - build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version)?; + build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version, epoch)?; children.push(body_tree); // Get the function name from the signature @@ -508,12 +519,14 @@ fn build_listlike_cost_analysis_tree( user_args: &UserArgumentsContext, cost_map: &HashMap>, clarity_version: &ClarityVersion, + epoch: StacksEpochId, ) -> Result { let mut children = Vec::new(); // Build children for all exprs for expr in exprs[1..].iter() { - let (_, child_tree) = build_cost_analysis_tree(expr, user_args, cost_map, clarity_version)?; + let (_, child_tree) = + build_cost_analysis_tree(expr, user_args, cost_map, clarity_version, epoch)?; children.push(child_tree); } @@ -521,7 +534,7 @@ fn build_listlike_cost_analysis_tree( SymbolicExpressionType::List(_) => { // Recursively analyze the nested list structure let (_, nested_tree) = - build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version)?; + build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version, epoch)?; // Add the nested tree as a child (its cost will be included when summing children) children.insert(0, nested_tree); // The root cost is zero - the actual cost comes from the nested expression @@ -537,7 +550,7 @@ fn build_listlike_cost_analysis_tree( let cost = calculate_function_cost_from_native_function( native_function, children.len() as u64, - clarity_version, + epoch, )?; (CostExprNode::NativeFunction(native_function), cost) } else { @@ -608,7 +621,7 @@ mod tests { let user_args = UserArgumentsContext::new(); let expr = &exprs[0]; let (_, cost_analysis_tree) = - build_cost_analysis_tree(&expr, &user_args, &cost_map, clarity_version)?; + build_cost_analysis_tree(&expr, &user_args, &cost_map, clarity_version, epoch)?; let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) @@ -620,7 +633,7 @@ mod tests { ) -> Result, String> { let epoch = StacksEpochId::latest(); let ast = make_ast(source, epoch, clarity_version)?; - let costs = static_cost_from_ast(&ast, clarity_version)?; + let costs = static_cost_from_ast(&ast, clarity_version, epoch)?; Ok(costs .into_iter() .map(|(name, (cost, _trait_count))| (name, cost)) @@ -713,9 +726,15 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let (_, cost_tree) = - build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) - .unwrap(); + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); // Root should be an If node assert!(matches!( @@ -760,9 +779,15 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let (_, cost_tree) = - build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) - .unwrap(); + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); assert!(matches!( cost_tree.expr, @@ -793,9 +818,15 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let (_, cost_tree) = - build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) - .unwrap(); + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); assert!(matches!( cost_tree.expr, @@ -816,9 +847,15 @@ mod tests { let expr = &ast.expressions[0]; let user_args = UserArgumentsContext::new(); let cost_map = HashMap::new(); // Empty cost map for tests - let (_, cost_tree) = - build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3) - .unwrap(); + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); assert_eq!(cost_tree.children.len(), 3); diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index d4750210f2..8b89dd7e60 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -33,7 +33,8 @@ fn test_build_cost_analysis_tree_function_definition() { let cost_map = HashMap::new(); let clarity_version = ClarityVersion::Clarity3; - let result = build_cost_analysis_tree(expr, &user_args, &cost_map, &clarity_version); + let epoch = StacksEpochId::Epoch32; + let result = build_cost_analysis_tree(expr, &user_args, &cost_map, &clarity_version, epoch); match result { Ok((function_name, node)) => { @@ -71,7 +72,7 @@ fn test_dependent_function_calls() { epoch, ) .unwrap(); - let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); + let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch).unwrap(); let (add_one_cost, _) = function_map.get("add-one").unwrap(); let (somefunc_cost, _) = function_map.get("somefunc").unwrap(); @@ -103,7 +104,8 @@ fn test_get_trait_count_direct() { ) .unwrap(); - let costs = static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3).unwrap(); + let costs = + static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3, StacksEpochId::Epoch32).unwrap(); // Extract trait_count from the result (all entries have the same trait_count) let trait_count = costs @@ -133,15 +135,11 @@ fn test_trait_counting() { (define-private (send (trait ) (addr principal)) (trait addr)) "#; let contract_id = QualifiedContractIdentifier::local("trait-counting").unwrap(); - let ast = crate::vm::ast::build_ast( - &contract_id, - src, - &mut (), - ClarityVersion::Clarity3, - StacksEpochId::Epoch32, - ) - .unwrap(); - let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3) + let epoch = StacksEpochId::Epoch32; + let ast = + crate::vm::ast::build_ast(&contract_id, src, &mut (), ClarityVersion::Clarity3, epoch) + .unwrap(); + let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch) .unwrap() .clone(); let send_trait_count_map = static_cost.get("send").unwrap().1.clone().unwrap(); @@ -230,7 +228,7 @@ fn test_pox_4_costs() { ) .expect("Failed to build AST from pox-4.clar"); - let cost_map = static_cost_from_ast(&ast, &clarity_version) + let cost_map = static_cost_from_ast(&ast, &clarity_version, epoch) .expect("Failed to get static cost analysis for pox-4.clar"); // Check some functions in the cost map From 80c781581fccda9893d61694bae724aecf287e88 Mon Sep 17 00:00:00 2001 From: "brady.ouren" <233826532+brady-stacks@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:41:37 -0800 Subject: [PATCH 23/23] stub in special function calculation --- clarity/src/vm/ast/static_cost/mod.rs | 65 +++++++++++------------ clarity/src/vm/costs/analysis.rs | 20 ++++--- clarity/src/vm/costs/cost_functions.rs | 30 +++++++++++ clarity/src/vm/functions/mod.rs | 1 + clarity/src/vm/functions/special_costs.rs | 25 +++++++++ clarity/src/vm/tests/analysis.rs | 22 ++++++++ 6 files changed, 120 insertions(+), 43 deletions(-) create mode 100644 clarity/src/vm/functions/special_costs.rs diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs index 47d30752c0..25449360d6 100644 --- a/clarity/src/vm/ast/static_cost/mod.rs +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -1,6 +1,7 @@ mod trait_counter; use std::collections::HashMap; +use clarity_types::representations::SymbolicExpression; use clarity_types::types::{CharType, SequenceData}; use stacks_common::types::StacksEpochId; pub use trait_counter::{ @@ -12,15 +13,12 @@ use crate::vm::costs::analysis::{ CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, }; use crate::vm::costs::cost_functions::linear; -use crate::vm::costs::costs_1::Costs1; -use crate::vm::costs::costs_2::Costs2; -use crate::vm::costs::costs_3::Costs3; -use crate::vm::costs::costs_4::Costs4; use crate::vm::costs::ExecutionCost; use crate::vm::errors::VmExecutionError; use crate::vm::functions::{lookup_reserved_functions, NativeFunctions}; use crate::vm::representations::ClarityName; use crate::vm::{ClarityVersion, Value}; +use crate::vm::functions::special_costs; const STRING_COST_BASE: u64 = 36; const STRING_COST_MULTIPLIER: u64 = 3; @@ -98,43 +96,40 @@ pub(crate) fn calculate_value_cost(value: &Value) -> Result pub(crate) fn calculate_function_cost_from_native_function( native_function: NativeFunctions, arg_count: u64, + args: &[SymbolicExpression], epoch: StacksEpochId, ) -> Result { // Derive clarity_version from epoch for lookup_reserved_functions let clarity_version = ClarityVersion::default_for_epoch(epoch); - let cost_function = - match lookup_reserved_functions(native_function.to_string().as_str(), &clarity_version) { - Some(CallableType::NativeFunction(_, _, cost_fn)) => cost_fn, - Some(CallableType::NativeFunction205(_, _, cost_fn, _)) => cost_fn, - Some(CallableType::SpecialFunction(_, _)) => return Ok(StaticCost::ZERO), - Some(CallableType::UserFunction(_)) => return Ok(StaticCost::ZERO), // TODO ? - None => { - return Ok(StaticCost::ZERO); - } - }; - - let cost = match epoch { - StacksEpochId::Epoch20 => cost_function.eval::(arg_count), - StacksEpochId::Epoch2_05 => cost_function.eval::(arg_count), - StacksEpochId::Epoch21 - | StacksEpochId::Epoch22 - | StacksEpochId::Epoch23 - | StacksEpochId::Epoch24 - | StacksEpochId::Epoch25 - | StacksEpochId::Epoch30 - | StacksEpochId::Epoch31 - | StacksEpochId::Epoch32 => cost_function.eval::(arg_count), - StacksEpochId::Epoch33 => cost_function.eval::(arg_count), - StacksEpochId::Epoch10 => { - // fallback to costs 1 since epoch 1 doesn't have direct cost mapping - cost_function.eval::(arg_count) + match lookup_reserved_functions(native_function.to_string().as_str(), &clarity_version) { + Some(CallableType::NativeFunction(_, _, cost_fn)) => { + let cost = cost_fn + .eval_for_epoch(arg_count, epoch) + .map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) + } + Some(CallableType::NativeFunction205(_, _, cost_fn, _)) => { + let cost = cost_fn + .eval_for_epoch(arg_count, epoch) + .map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) + } + Some(CallableType::SpecialFunction(_, _)) => { + let cost = special_costs::get_cost_for_special_function(native_function, args, epoch); + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) } + Some(CallableType::UserFunction(_)) => Ok(StaticCost::ZERO), // TODO ? + None => Ok(StaticCost::ZERO), } - .map_err(|e| format!("Cost calculation error: {:?}", e))?; - Ok(StaticCost { - min: cost.clone(), - max: cost, - }) } /// total cost handling branching diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 57fb36110f..06b04c12d1 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -6,8 +6,7 @@ use clarity_types::types::TraitIdentifier; use stacks_common::types::StacksEpochId; use crate::vm::ast::build_ast; -#[cfg(test)] -use crate::vm::ast::static_cost::is_node_branching; +// #[cfg(feature = "developer-mode")] use crate::vm::ast::static_cost::{ calculate_function_cost, calculate_function_cost_from_native_function, calculate_total_cost_with_branching, calculate_value_cost, TraitCount, TraitCountCollector, @@ -543,15 +542,19 @@ fn build_listlike_cost_analysis_tree( } SymbolicExpressionType::Atom(name) => { // Try to get function name from first element - // Try to lookup the function as a native function first + // lookup the function as a native function first + // special functions + // - let, etc use bindings lengths not argument lengths if let Some(native_function) = NativeFunctions::lookup_by_name_at_version(name.as_str(), clarity_version) { - let cost = calculate_function_cost_from_native_function( - native_function, - children.len() as u64, - epoch, - )?; + let cost = calculate_function_cost_from_native_function( + native_function, + children.len() as u64, + &exprs[1..], + epoch, + )?; + (CostExprNode::NativeFunction(native_function), cost) } else { // If not a native function, treat as user-defined function and look it up @@ -608,6 +611,7 @@ pub(crate) fn get_trait_count(costs: &HashMap) -> Opti mod tests { use super::*; + use crate::vm::ast::static_cost::is_node_branching; fn static_cost_native_test( source: &str, diff --git a/clarity/src/vm/costs/cost_functions.rs b/clarity/src/vm/costs/cost_functions.rs index 9621d9cc8b..c6137ac598 100644 --- a/clarity/src/vm/costs/cost_functions.rs +++ b/clarity/src/vm/costs/cost_functions.rs @@ -14,6 +14,11 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . use super::ExecutionCost; +use super::costs_1::Costs1; +use super::costs_2::Costs2; +use super::costs_3::Costs3; +use super::costs_4::Costs4; +use stacks_common::types::StacksEpochId; use crate::vm::errors::{RuntimeError, VmExecutionError}; define_named_enum!(ClarityCostFunction { @@ -342,6 +347,31 @@ pub trait CostValues { } impl ClarityCostFunction { + /// shortcut to eval() + pub fn eval_for_epoch( + &self, + n: u64, + epoch: StacksEpochId, + ) -> Result { + match epoch { + StacksEpochId::Epoch20 => self.eval::(n), + StacksEpochId::Epoch2_05 => self.eval::(n), + StacksEpochId::Epoch21 + | StacksEpochId::Epoch22 + | StacksEpochId::Epoch23 + | StacksEpochId::Epoch24 + | StacksEpochId::Epoch25 + | StacksEpochId::Epoch30 + | StacksEpochId::Epoch31 + | StacksEpochId::Epoch32 => self.eval::(n), + StacksEpochId::Epoch33 => self.eval::(n), + StacksEpochId::Epoch10 => { + // fallback to costs 1 since epoch 1 doesn't have direct cost mapping + self.eval::(n) + } + } + } + pub fn eval(&self, n: u64) -> Result { match self { ClarityCostFunction::AnalysisTypeAnnotate => C::cost_analysis_type_annotate(n), diff --git a/clarity/src/vm/functions/mod.rs b/clarity/src/vm/functions/mod.rs index bcffec2b2c..5e7681dda8 100644 --- a/clarity/src/vm/functions/mod.rs +++ b/clarity/src/vm/functions/mod.rs @@ -79,6 +79,7 @@ mod options; mod post_conditions; pub mod principals; mod sequences; +pub mod special_costs; pub mod tuples; define_versioned_named_enum_with_max!(NativeFunctions(ClarityVersion) { diff --git a/clarity/src/vm/functions/special_costs.rs b/clarity/src/vm/functions/special_costs.rs new file mode 100644 index 0000000000..4b6f95c26a --- /dev/null +++ b/clarity/src/vm/functions/special_costs.rs @@ -0,0 +1,25 @@ +use clarity_types::execution_cost::ExecutionCost; +use clarity_types::representations::SymbolicExpression; +use stacks_common::types::StacksEpochId; +use crate::vm::{costs::cost_functions::ClarityCostFunction, functions::NativeFunctions}; + +pub fn get_cost_for_special_function(native_function: NativeFunctions, args: &[SymbolicExpression], epoch: StacksEpochId) -> ExecutionCost { + match native_function { + NativeFunctions::Let => cost_binding_list_len(args, epoch), + NativeFunctions::If => cost_binding_list_len(args, epoch), + NativeFunctions::TupleCons => cost_binding_list_len(args, epoch), + _ => ExecutionCost::ZERO, + } +} + +pub fn cost_binding_list_len(args: &[SymbolicExpression], epoch: StacksEpochId) -> ExecutionCost { + let binding_len = args.get(1).and_then(|e| e.match_list()).map(|binding_list| binding_list.len() as u64).unwrap_or(0); + ClarityCostFunction::Let.eval_for_epoch(binding_len, epoch).unwrap_or_else(|_| { + ExecutionCost::ZERO + }) +} + +pub fn noop(_args: &[SymbolicExpression], _epoch: StacksEpochId) -> ExecutionCost { + ExecutionCost::ZERO +} + diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs index 8b89dd7e60..e2d927f65c 100644 --- a/clarity/src/vm/tests/analysis.rs +++ b/clarity/src/vm/tests/analysis.rs @@ -50,6 +50,28 @@ fn test_build_cost_analysis_tree_function_definition() { } } +#[test] +fn test_let_cost() { + let src = "(let ((a 1) (b 2)) (+ a b))"; + let src2 = "(let ((a 1) (b 2) (c 3)) (+ a b))"; // should compute for 3 bindings not 2 + + let contract_id = QualifiedContractIdentifier::transient(); + let epoch = StacksEpochId::Epoch32; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + src, + &mut (), + ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch).unwrap(); + let (let_cost, _) = function_map.get("let").unwrap(); + let (let2_cost, _) = function_map.get("let2").unwrap(); + assert_ne!(let2_cost.min.runtime, let_cost.min.runtime); +} + + #[test] fn test_dependent_function_calls() { let src = r#"(define-public (add-one (a uint))