|
| 1 | +mod trait_counter; |
| 2 | +use std::collections::HashMap; |
| 3 | + |
| 4 | +use clarity_types::types::{CharType, SequenceData}; |
| 5 | +pub use trait_counter::{ |
| 6 | + TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, |
| 7 | +}; |
| 8 | + |
| 9 | +// Import types from analysis.rs |
| 10 | +use crate::vm::costs::analysis::{ |
| 11 | + CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, |
| 12 | +}; |
| 13 | +use crate::vm::costs::cost_functions::{linear, CostValues}; |
| 14 | +use crate::vm::costs::costs_3::Costs3; |
| 15 | +use crate::vm::costs::ExecutionCost; |
| 16 | +use crate::vm::errors::VmExecutionError; |
| 17 | +use crate::vm::functions::NativeFunctions; |
| 18 | +use crate::vm::representations::ClarityName; |
| 19 | +use crate::vm::{ClarityVersion, Value}; |
| 20 | + |
| 21 | +const STRING_COST_BASE: u64 = 36; |
| 22 | +const STRING_COST_MULTIPLIER: u64 = 3; |
| 23 | + |
| 24 | +/// Convert a NativeFunctions enum variant to its corresponding cost function |
| 25 | +/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in |
| 26 | +pub(crate) fn get_cost_function_for_native( |
| 27 | + function: NativeFunctions, |
| 28 | + _clarity_version: &ClarityVersion, |
| 29 | +) -> Option<fn(u64) -> Result<ExecutionCost, VmExecutionError>> { |
| 30 | + use crate::vm::functions::NativeFunctions::*; |
| 31 | + |
| 32 | + // Map NativeFunctions enum variants to their cost functions |
| 33 | + match function { |
| 34 | + Add => Some(Costs3::cost_add), |
| 35 | + Subtract => Some(Costs3::cost_sub), |
| 36 | + Multiply => Some(Costs3::cost_mul), |
| 37 | + Divide => Some(Costs3::cost_div), |
| 38 | + Modulo => Some(Costs3::cost_mod), |
| 39 | + Power => Some(Costs3::cost_pow), |
| 40 | + Sqrti => Some(Costs3::cost_sqrti), |
| 41 | + Log2 => Some(Costs3::cost_log2), |
| 42 | + ToInt | ToUInt => Some(Costs3::cost_int_cast), |
| 43 | + Equals => Some(Costs3::cost_eq), |
| 44 | + CmpGeq => Some(Costs3::cost_geq), |
| 45 | + CmpLeq => Some(Costs3::cost_leq), |
| 46 | + CmpGreater => Some(Costs3::cost_ge), |
| 47 | + CmpLess => Some(Costs3::cost_le), |
| 48 | + BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor), |
| 49 | + Not | BitwiseNot => Some(Costs3::cost_not), |
| 50 | + And | BitwiseAnd => Some(Costs3::cost_and), |
| 51 | + Or | BitwiseOr => Some(Costs3::cost_or), |
| 52 | + Concat => Some(Costs3::cost_concat), |
| 53 | + Len => Some(Costs3::cost_len), |
| 54 | + AsMaxLen => Some(Costs3::cost_as_max_len), |
| 55 | + ListCons => Some(Costs3::cost_list_cons), |
| 56 | + ElementAt | ElementAtAlias => Some(Costs3::cost_element_at), |
| 57 | + IndexOf | IndexOfAlias => Some(Costs3::cost_index_of), |
| 58 | + Fold => Some(Costs3::cost_fold), |
| 59 | + Map => Some(Costs3::cost_map), |
| 60 | + Filter => Some(Costs3::cost_filter), |
| 61 | + Append => Some(Costs3::cost_append), |
| 62 | + TupleGet => Some(Costs3::cost_tuple_get), |
| 63 | + TupleMerge => Some(Costs3::cost_tuple_merge), |
| 64 | + TupleCons => Some(Costs3::cost_tuple_cons), |
| 65 | + ConsSome => Some(Costs3::cost_some_cons), |
| 66 | + ConsOkay => Some(Costs3::cost_ok_cons), |
| 67 | + ConsError => Some(Costs3::cost_err_cons), |
| 68 | + DefaultTo => Some(Costs3::cost_default_to), |
| 69 | + UnwrapRet => Some(Costs3::cost_unwrap_ret), |
| 70 | + UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret), |
| 71 | + IsOkay => Some(Costs3::cost_is_okay), |
| 72 | + IsNone => Some(Costs3::cost_is_none), |
| 73 | + IsErr => Some(Costs3::cost_is_err), |
| 74 | + IsSome => Some(Costs3::cost_is_some), |
| 75 | + Unwrap => Some(Costs3::cost_unwrap), |
| 76 | + UnwrapErr => Some(Costs3::cost_unwrap_err), |
| 77 | + TryRet => Some(Costs3::cost_try_ret), |
| 78 | + If => Some(Costs3::cost_if), |
| 79 | + Match => Some(Costs3::cost_match), |
| 80 | + Begin => Some(Costs3::cost_begin), |
| 81 | + Let => Some(Costs3::cost_let), |
| 82 | + Asserts => Some(Costs3::cost_asserts), |
| 83 | + Hash160 => Some(Costs3::cost_hash160), |
| 84 | + Sha256 => Some(Costs3::cost_sha256), |
| 85 | + Sha512 => Some(Costs3::cost_sha512), |
| 86 | + Sha512Trunc256 => Some(Costs3::cost_sha512t256), |
| 87 | + Keccak256 => Some(Costs3::cost_keccak256), |
| 88 | + Secp256k1Recover => Some(Costs3::cost_secp256k1recover), |
| 89 | + Secp256k1Verify => Some(Costs3::cost_secp256k1verify), |
| 90 | + Print => Some(Costs3::cost_print), |
| 91 | + ContractCall => Some(Costs3::cost_contract_call), |
| 92 | + ContractOf => Some(Costs3::cost_contract_of), |
| 93 | + PrincipalOf => Some(Costs3::cost_principal_of), |
| 94 | + AtBlock => Some(Costs3::cost_at_block), |
| 95 | + // => Some(Costs3::cost_create_map), |
| 96 | + // => Some(Costs3::cost_create_var), |
| 97 | + // ContractStorage => Some(Costs3::cost_contract_storage), |
| 98 | + FetchEntry => Some(Costs3::cost_fetch_entry), |
| 99 | + SetEntry => Some(Costs3::cost_set_entry), |
| 100 | + FetchVar => Some(Costs3::cost_fetch_var), |
| 101 | + SetVar => Some(Costs3::cost_set_var), |
| 102 | + GetBlockInfo => Some(Costs3::cost_block_info), |
| 103 | + GetBurnBlockInfo => Some(Costs3::cost_burn_block_info), |
| 104 | + GetStxBalance => Some(Costs3::cost_stx_balance), |
| 105 | + StxTransfer => Some(Costs3::cost_stx_transfer), |
| 106 | + StxTransferMemo => Some(Costs3::cost_stx_transfer_memo), |
| 107 | + StxGetAccount => Some(Costs3::cost_stx_account), |
| 108 | + MintToken => Some(Costs3::cost_ft_mint), |
| 109 | + MintAsset => Some(Costs3::cost_nft_mint), |
| 110 | + TransferToken => Some(Costs3::cost_ft_transfer), |
| 111 | + GetTokenBalance => Some(Costs3::cost_ft_balance), |
| 112 | + GetTokenSupply => Some(Costs3::cost_ft_get_supply), |
| 113 | + BurnToken => Some(Costs3::cost_ft_burn), |
| 114 | + TransferAsset => Some(Costs3::cost_nft_transfer), |
| 115 | + GetAssetOwner => Some(Costs3::cost_nft_owner), |
| 116 | + BurnAsset => Some(Costs3::cost_nft_burn), |
| 117 | + BuffToIntLe => Some(Costs3::cost_buff_to_int_le), |
| 118 | + BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le), |
| 119 | + BuffToIntBe => Some(Costs3::cost_buff_to_int_be), |
| 120 | + BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be), |
| 121 | + ToConsensusBuff => Some(Costs3::cost_to_consensus_buff), |
| 122 | + FromConsensusBuff => Some(Costs3::cost_from_consensus_buff), |
| 123 | + IsStandard => Some(Costs3::cost_is_standard), |
| 124 | + PrincipalDestruct => Some(Costs3::cost_principal_destruct), |
| 125 | + PrincipalConstruct => Some(Costs3::cost_principal_construct), |
| 126 | + AsContract | AsContractSafe => Some(Costs3::cost_as_contract), |
| 127 | + StringToInt => Some(Costs3::cost_string_to_int), |
| 128 | + StringToUInt => Some(Costs3::cost_string_to_uint), |
| 129 | + IntToAscii => Some(Costs3::cost_int_to_ascii), |
| 130 | + IntToUtf8 => Some(Costs3::cost_int_to_utf8), |
| 131 | + BitwiseLShift => Some(Costs3::cost_bitwise_left_shift), |
| 132 | + BitwiseRShift => Some(Costs3::cost_bitwise_right_shift), |
| 133 | + Slice => Some(Costs3::cost_slice), |
| 134 | + ReplaceAt => Some(Costs3::cost_replace_at), |
| 135 | + GetStacksBlockInfo => Some(Costs3::cost_block_info), |
| 136 | + GetTenureInfo => Some(Costs3::cost_block_info), |
| 137 | + ContractHash => Some(Costs3::cost_contract_hash), |
| 138 | + ToAscii => Some(Costs3::cost_to_ascii), |
| 139 | + InsertEntry => Some(Costs3::cost_set_entry), |
| 140 | + DeleteEntry => Some(Costs3::cost_set_entry), |
| 141 | + StxBurn => Some(Costs3::cost_stx_transfer), |
| 142 | + Secp256r1Verify => Some(Costs3::cost_secp256r1verify), |
| 143 | + RestrictAssets => None, // TODO: add cost function |
| 144 | + AllowanceWithStx => None, // TODO: add cost function |
| 145 | + AllowanceWithFt => None, // TODO: add cost function |
| 146 | + AllowanceWithNft => None, // TODO: add cost function |
| 147 | + AllowanceWithStacking => None, // TODO: add cost function |
| 148 | + AllowanceAll => None, // TODO: add cost function |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +// Calculate function cost with lazy evaluation support |
| 153 | +pub(crate) fn calculate_function_cost( |
| 154 | + function_name: String, |
| 155 | + cost_map: &HashMap<String, Option<StaticCost>>, |
| 156 | + _clarity_version: &ClarityVersion, |
| 157 | +) -> Result<StaticCost, String> { |
| 158 | + match cost_map.get(&function_name) { |
| 159 | + Some(Some(cost)) => { |
| 160 | + // Cost already computed |
| 161 | + Ok(cost.clone()) |
| 162 | + } |
| 163 | + Some(None) => { |
| 164 | + // Should be impossible but alas.. |
| 165 | + // Function exists but cost not yet computed - this indicates a circular dependency |
| 166 | + // For now, return zero cost to avoid infinite recursion |
| 167 | + println!( |
| 168 | + "Circular dependency detected for function: {}", |
| 169 | + function_name |
| 170 | + ); |
| 171 | + Ok(StaticCost::ZERO) |
| 172 | + } |
| 173 | + None => { |
| 174 | + // Function not found |
| 175 | + Ok(StaticCost::ZERO) |
| 176 | + } |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +/// Determine if a function name represents a branching function |
| 181 | +pub(crate) fn is_branching_function(function_name: &ClarityName) -> bool { |
| 182 | + match function_name.as_str() { |
| 183 | + "if" | "match" => true, |
| 184 | + "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and |
| 185 | + // unwrap-err traverse both branches regardless of result, so until this is |
| 186 | + // fixed in clarity we'll set this to false |
| 187 | + _ => false, |
| 188 | + } |
| 189 | +} |
| 190 | + |
| 191 | +/// Helper function to determine if a node represents a branching operation |
| 192 | +/// This is used in tests and cost calculation |
| 193 | +pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool { |
| 194 | + match &node.expr { |
| 195 | + CostExprNode::NativeFunction(NativeFunctions::If) |
| 196 | + | CostExprNode::NativeFunction(NativeFunctions::Match) => true, |
| 197 | + CostExprNode::UserFunction(name) => is_branching_function(name), |
| 198 | + _ => false, |
| 199 | + } |
| 200 | +} |
| 201 | + |
| 202 | +/// Calculate the cost for a string based on its length |
| 203 | +fn string_cost(length: usize) -> StaticCost { |
| 204 | + let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); |
| 205 | + let execution_cost = ExecutionCost::runtime(cost); |
| 206 | + StaticCost { |
| 207 | + min: execution_cost.clone(), |
| 208 | + max: execution_cost, |
| 209 | + } |
| 210 | +} |
| 211 | + |
| 212 | +/// Calculate cost for a value (used for literal values) |
| 213 | +pub(crate) fn calculate_value_cost(value: &Value) -> Result<StaticCost, String> { |
| 214 | + match value { |
| 215 | + Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { |
| 216 | + Ok(string_cost(data.data.len())) |
| 217 | + } |
| 218 | + Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { |
| 219 | + Ok(string_cost(data.data.len())) |
| 220 | + } |
| 221 | + _ => Ok(StaticCost::ZERO), |
| 222 | + } |
| 223 | +} |
| 224 | + |
| 225 | +pub(crate) fn calculate_function_cost_from_native_function( |
| 226 | + native_function: NativeFunctions, |
| 227 | + arg_count: u64, |
| 228 | + clarity_version: &ClarityVersion, |
| 229 | +) -> Result<StaticCost, String> { |
| 230 | + let cost_function = match get_cost_function_for_native(native_function, clarity_version) { |
| 231 | + Some(cost_fn) => cost_fn, |
| 232 | + None => { |
| 233 | + // TODO: zero cost for now |
| 234 | + return Ok(StaticCost::ZERO); |
| 235 | + } |
| 236 | + }; |
| 237 | + |
| 238 | + let cost = get_costs(cost_function, arg_count)?; |
| 239 | + Ok(StaticCost { |
| 240 | + min: cost.clone(), |
| 241 | + max: cost, |
| 242 | + }) |
| 243 | +} |
| 244 | + |
| 245 | +/// Calculate total cost using SummingExecutionCost to handle branching properly |
| 246 | +pub(crate) fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { |
| 247 | + let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); |
| 248 | + |
| 249 | + for child in &node.children { |
| 250 | + let child_summing = calculate_total_cost_with_summing(child); |
| 251 | + summing_cost.add_summing(&child_summing); |
| 252 | + } |
| 253 | + |
| 254 | + summing_cost |
| 255 | +} |
| 256 | + |
| 257 | +pub(crate) fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { |
| 258 | + let mut summing_cost = SummingExecutionCost::new(); |
| 259 | + |
| 260 | + // Check if this is a branching function by examining the node's expression |
| 261 | + let is_branching = is_node_branching(node); |
| 262 | + |
| 263 | + if is_branching { |
| 264 | + match &node.expr { |
| 265 | + CostExprNode::NativeFunction(NativeFunctions::If) |
| 266 | + | CostExprNode::NativeFunction(NativeFunctions::Match) => { |
| 267 | + // TODO match? |
| 268 | + if node.children.len() >= 2 { |
| 269 | + let condition_cost = calculate_total_cost_with_summing(&node.children[0]); |
| 270 | + let condition_total = condition_cost.add_all(); |
| 271 | + |
| 272 | + // Add the root cost + condition cost to each branch |
| 273 | + let mut root_and_condition = node.cost.min.clone(); |
| 274 | + let _ = root_and_condition.add(&condition_total); |
| 275 | + |
| 276 | + for child_cost_node in node.children.iter().skip(1) { |
| 277 | + let branch_cost = calculate_total_cost_with_summing(child_cost_node); |
| 278 | + let branch_total = branch_cost.add_all(); |
| 279 | + |
| 280 | + let mut path_cost = root_and_condition.clone(); |
| 281 | + let _ = path_cost.add(&branch_total); |
| 282 | + |
| 283 | + summing_cost.add_cost(path_cost); |
| 284 | + } |
| 285 | + } |
| 286 | + } |
| 287 | + _ => { |
| 288 | + // For other branching functions, fall back to sequential processing |
| 289 | + let mut total_cost = node.cost.min.clone(); |
| 290 | + for child_cost_node in &node.children { |
| 291 | + let child_summing = calculate_total_cost_with_summing(child_cost_node); |
| 292 | + let combined_cost = child_summing.add_all(); |
| 293 | + let _ = total_cost.add(&combined_cost); |
| 294 | + } |
| 295 | + summing_cost.add_cost(total_cost); |
| 296 | + } |
| 297 | + } |
| 298 | + } else { |
| 299 | + // For non-branching, add all costs sequentially |
| 300 | + let mut total_cost = node.cost.min.clone(); |
| 301 | + for child_cost_node in &node.children { |
| 302 | + let child_summing = calculate_total_cost_with_summing(child_cost_node); |
| 303 | + let combined_cost = child_summing.add_all(); |
| 304 | + let _ = total_cost.add(&combined_cost); |
| 305 | + } |
| 306 | + summing_cost.add_cost(total_cost); |
| 307 | + } |
| 308 | + |
| 309 | + summing_cost |
| 310 | +} |
| 311 | + |
| 312 | +impl From<SummingExecutionCost> for StaticCost { |
| 313 | + fn from(summing: SummingExecutionCost) -> Self { |
| 314 | + StaticCost { |
| 315 | + min: summing.min(), |
| 316 | + max: summing.max(), |
| 317 | + } |
| 318 | + } |
| 319 | +} |
| 320 | + |
| 321 | +/// Helper: calculate min & max costs for a given cost function |
| 322 | +/// This is likely tooo simplistic but for now it'll do |
| 323 | +fn get_costs( |
| 324 | + cost_fn: fn(u64) -> Result<ExecutionCost, VmExecutionError>, |
| 325 | + arg_count: u64, |
| 326 | +) -> Result<ExecutionCost, String> { |
| 327 | + let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; |
| 328 | + Ok(cost) |
| 329 | +} |
0 commit comments