Skip to content

Commit 7daec93

Browse files
committed
clean up modules and simplify listlike builder
1 parent 38aaeaa commit 7daec93

File tree

5 files changed

+863
-979
lines changed

5 files changed

+863
-979
lines changed

clarity/src/vm/ast/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pub mod definition_sorter;
1818
pub mod expression_identifier;
1919
pub mod parser;
20+
pub mod static_cost;
2021
pub mod traits_resolver;
2122

2223
pub mod errors;
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
mod trait_counter;
2+
use std::collections::HashMap;
3+
4+
use clarity_types::types::{CharType, SequenceData};
5+
pub use trait_counter::{
6+
TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor,
7+
};
8+
9+
// Import types from analysis.rs
10+
use crate::vm::costs::analysis::{
11+
CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost,
12+
};
13+
use crate::vm::costs::cost_functions::{linear, CostValues};
14+
use crate::vm::costs::costs_3::Costs3;
15+
use crate::vm::costs::ExecutionCost;
16+
use crate::vm::errors::VmExecutionError;
17+
use crate::vm::functions::NativeFunctions;
18+
use crate::vm::representations::ClarityName;
19+
use crate::vm::{ClarityVersion, Value};
20+
21+
const STRING_COST_BASE: u64 = 36;
22+
const STRING_COST_MULTIPLIER: u64 = 3;
23+
24+
/// Convert a NativeFunctions enum variant to its corresponding cost function
25+
/// TODO: This assumes Costs3 but should find a way to use the clarity version passed in
26+
pub(crate) fn get_cost_function_for_native(
27+
function: NativeFunctions,
28+
_clarity_version: &ClarityVersion,
29+
) -> Option<fn(u64) -> Result<ExecutionCost, VmExecutionError>> {
30+
use crate::vm::functions::NativeFunctions::*;
31+
32+
// Map NativeFunctions enum variants to their cost functions
33+
match function {
34+
Add => Some(Costs3::cost_add),
35+
Subtract => Some(Costs3::cost_sub),
36+
Multiply => Some(Costs3::cost_mul),
37+
Divide => Some(Costs3::cost_div),
38+
Modulo => Some(Costs3::cost_mod),
39+
Power => Some(Costs3::cost_pow),
40+
Sqrti => Some(Costs3::cost_sqrti),
41+
Log2 => Some(Costs3::cost_log2),
42+
ToInt | ToUInt => Some(Costs3::cost_int_cast),
43+
Equals => Some(Costs3::cost_eq),
44+
CmpGeq => Some(Costs3::cost_geq),
45+
CmpLeq => Some(Costs3::cost_leq),
46+
CmpGreater => Some(Costs3::cost_ge),
47+
CmpLess => Some(Costs3::cost_le),
48+
BitwiseXor | BitwiseXor2 => Some(Costs3::cost_xor),
49+
Not | BitwiseNot => Some(Costs3::cost_not),
50+
And | BitwiseAnd => Some(Costs3::cost_and),
51+
Or | BitwiseOr => Some(Costs3::cost_or),
52+
Concat => Some(Costs3::cost_concat),
53+
Len => Some(Costs3::cost_len),
54+
AsMaxLen => Some(Costs3::cost_as_max_len),
55+
ListCons => Some(Costs3::cost_list_cons),
56+
ElementAt | ElementAtAlias => Some(Costs3::cost_element_at),
57+
IndexOf | IndexOfAlias => Some(Costs3::cost_index_of),
58+
Fold => Some(Costs3::cost_fold),
59+
Map => Some(Costs3::cost_map),
60+
Filter => Some(Costs3::cost_filter),
61+
Append => Some(Costs3::cost_append),
62+
TupleGet => Some(Costs3::cost_tuple_get),
63+
TupleMerge => Some(Costs3::cost_tuple_merge),
64+
TupleCons => Some(Costs3::cost_tuple_cons),
65+
ConsSome => Some(Costs3::cost_some_cons),
66+
ConsOkay => Some(Costs3::cost_ok_cons),
67+
ConsError => Some(Costs3::cost_err_cons),
68+
DefaultTo => Some(Costs3::cost_default_to),
69+
UnwrapRet => Some(Costs3::cost_unwrap_ret),
70+
UnwrapErrRet => Some(Costs3::cost_unwrap_err_or_ret),
71+
IsOkay => Some(Costs3::cost_is_okay),
72+
IsNone => Some(Costs3::cost_is_none),
73+
IsErr => Some(Costs3::cost_is_err),
74+
IsSome => Some(Costs3::cost_is_some),
75+
Unwrap => Some(Costs3::cost_unwrap),
76+
UnwrapErr => Some(Costs3::cost_unwrap_err),
77+
TryRet => Some(Costs3::cost_try_ret),
78+
If => Some(Costs3::cost_if),
79+
Match => Some(Costs3::cost_match),
80+
Begin => Some(Costs3::cost_begin),
81+
Let => Some(Costs3::cost_let),
82+
Asserts => Some(Costs3::cost_asserts),
83+
Hash160 => Some(Costs3::cost_hash160),
84+
Sha256 => Some(Costs3::cost_sha256),
85+
Sha512 => Some(Costs3::cost_sha512),
86+
Sha512Trunc256 => Some(Costs3::cost_sha512t256),
87+
Keccak256 => Some(Costs3::cost_keccak256),
88+
Secp256k1Recover => Some(Costs3::cost_secp256k1recover),
89+
Secp256k1Verify => Some(Costs3::cost_secp256k1verify),
90+
Print => Some(Costs3::cost_print),
91+
ContractCall => Some(Costs3::cost_contract_call),
92+
ContractOf => Some(Costs3::cost_contract_of),
93+
PrincipalOf => Some(Costs3::cost_principal_of),
94+
AtBlock => Some(Costs3::cost_at_block),
95+
// => Some(Costs3::cost_create_map),
96+
// => Some(Costs3::cost_create_var),
97+
// ContractStorage => Some(Costs3::cost_contract_storage),
98+
FetchEntry => Some(Costs3::cost_fetch_entry),
99+
SetEntry => Some(Costs3::cost_set_entry),
100+
FetchVar => Some(Costs3::cost_fetch_var),
101+
SetVar => Some(Costs3::cost_set_var),
102+
GetBlockInfo => Some(Costs3::cost_block_info),
103+
GetBurnBlockInfo => Some(Costs3::cost_burn_block_info),
104+
GetStxBalance => Some(Costs3::cost_stx_balance),
105+
StxTransfer => Some(Costs3::cost_stx_transfer),
106+
StxTransferMemo => Some(Costs3::cost_stx_transfer_memo),
107+
StxGetAccount => Some(Costs3::cost_stx_account),
108+
MintToken => Some(Costs3::cost_ft_mint),
109+
MintAsset => Some(Costs3::cost_nft_mint),
110+
TransferToken => Some(Costs3::cost_ft_transfer),
111+
GetTokenBalance => Some(Costs3::cost_ft_balance),
112+
GetTokenSupply => Some(Costs3::cost_ft_get_supply),
113+
BurnToken => Some(Costs3::cost_ft_burn),
114+
TransferAsset => Some(Costs3::cost_nft_transfer),
115+
GetAssetOwner => Some(Costs3::cost_nft_owner),
116+
BurnAsset => Some(Costs3::cost_nft_burn),
117+
BuffToIntLe => Some(Costs3::cost_buff_to_int_le),
118+
BuffToUIntLe => Some(Costs3::cost_buff_to_uint_le),
119+
BuffToIntBe => Some(Costs3::cost_buff_to_int_be),
120+
BuffToUIntBe => Some(Costs3::cost_buff_to_uint_be),
121+
ToConsensusBuff => Some(Costs3::cost_to_consensus_buff),
122+
FromConsensusBuff => Some(Costs3::cost_from_consensus_buff),
123+
IsStandard => Some(Costs3::cost_is_standard),
124+
PrincipalDestruct => Some(Costs3::cost_principal_destruct),
125+
PrincipalConstruct => Some(Costs3::cost_principal_construct),
126+
AsContract | AsContractSafe => Some(Costs3::cost_as_contract),
127+
StringToInt => Some(Costs3::cost_string_to_int),
128+
StringToUInt => Some(Costs3::cost_string_to_uint),
129+
IntToAscii => Some(Costs3::cost_int_to_ascii),
130+
IntToUtf8 => Some(Costs3::cost_int_to_utf8),
131+
BitwiseLShift => Some(Costs3::cost_bitwise_left_shift),
132+
BitwiseRShift => Some(Costs3::cost_bitwise_right_shift),
133+
Slice => Some(Costs3::cost_slice),
134+
ReplaceAt => Some(Costs3::cost_replace_at),
135+
GetStacksBlockInfo => Some(Costs3::cost_block_info),
136+
GetTenureInfo => Some(Costs3::cost_block_info),
137+
ContractHash => Some(Costs3::cost_contract_hash),
138+
ToAscii => Some(Costs3::cost_to_ascii),
139+
InsertEntry => Some(Costs3::cost_set_entry),
140+
DeleteEntry => Some(Costs3::cost_set_entry),
141+
StxBurn => Some(Costs3::cost_stx_transfer),
142+
Secp256r1Verify => Some(Costs3::cost_secp256r1verify),
143+
RestrictAssets => None, // TODO: add cost function
144+
AllowanceWithStx => None, // TODO: add cost function
145+
AllowanceWithFt => None, // TODO: add cost function
146+
AllowanceWithNft => None, // TODO: add cost function
147+
AllowanceWithStacking => None, // TODO: add cost function
148+
AllowanceAll => None, // TODO: add cost function
149+
}
150+
}
151+
152+
// Calculate function cost with lazy evaluation support
153+
pub(crate) fn calculate_function_cost(
154+
function_name: String,
155+
cost_map: &HashMap<String, Option<StaticCost>>,
156+
_clarity_version: &ClarityVersion,
157+
) -> Result<StaticCost, String> {
158+
match cost_map.get(&function_name) {
159+
Some(Some(cost)) => {
160+
// Cost already computed
161+
Ok(cost.clone())
162+
}
163+
Some(None) => {
164+
// Should be impossible but alas..
165+
// Function exists but cost not yet computed - this indicates a circular dependency
166+
// For now, return zero cost to avoid infinite recursion
167+
println!(
168+
"Circular dependency detected for function: {}",
169+
function_name
170+
);
171+
Ok(StaticCost::ZERO)
172+
}
173+
None => {
174+
// Function not found
175+
Ok(StaticCost::ZERO)
176+
}
177+
}
178+
}
179+
180+
/// Determine if a function name represents a branching function
181+
pub(crate) fn is_branching_function(function_name: &ClarityName) -> bool {
182+
match function_name.as_str() {
183+
"if" | "match" => true,
184+
"unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and
185+
// unwrap-err traverse both branches regardless of result, so until this is
186+
// fixed in clarity we'll set this to false
187+
_ => false,
188+
}
189+
}
190+
191+
/// Helper function to determine if a node represents a branching operation
192+
/// This is used in tests and cost calculation
193+
pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool {
194+
match &node.expr {
195+
CostExprNode::NativeFunction(NativeFunctions::If)
196+
| CostExprNode::NativeFunction(NativeFunctions::Match) => true,
197+
CostExprNode::UserFunction(name) => is_branching_function(name),
198+
_ => false,
199+
}
200+
}
201+
202+
/// Calculate the cost for a string based on its length
203+
fn string_cost(length: usize) -> StaticCost {
204+
let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER);
205+
let execution_cost = ExecutionCost::runtime(cost);
206+
StaticCost {
207+
min: execution_cost.clone(),
208+
max: execution_cost,
209+
}
210+
}
211+
212+
/// Calculate cost for a value (used for literal values)
213+
pub(crate) fn calculate_value_cost(value: &Value) -> Result<StaticCost, String> {
214+
match value {
215+
Value::Sequence(SequenceData::String(CharType::UTF8(data))) => {
216+
Ok(string_cost(data.data.len()))
217+
}
218+
Value::Sequence(SequenceData::String(CharType::ASCII(data))) => {
219+
Ok(string_cost(data.data.len()))
220+
}
221+
_ => Ok(StaticCost::ZERO),
222+
}
223+
}
224+
225+
pub(crate) fn calculate_function_cost_from_native_function(
226+
native_function: NativeFunctions,
227+
arg_count: u64,
228+
clarity_version: &ClarityVersion,
229+
) -> Result<StaticCost, String> {
230+
let cost_function = match get_cost_function_for_native(native_function, clarity_version) {
231+
Some(cost_fn) => cost_fn,
232+
None => {
233+
// TODO: zero cost for now
234+
return Ok(StaticCost::ZERO);
235+
}
236+
};
237+
238+
let cost = get_costs(cost_function, arg_count)?;
239+
Ok(StaticCost {
240+
min: cost.clone(),
241+
max: cost,
242+
})
243+
}
244+
245+
/// Calculate total cost using SummingExecutionCost to handle branching properly
246+
pub(crate) fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost {
247+
let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone());
248+
249+
for child in &node.children {
250+
let child_summing = calculate_total_cost_with_summing(child);
251+
summing_cost.add_summing(&child_summing);
252+
}
253+
254+
summing_cost
255+
}
256+
257+
pub(crate) fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost {
258+
let mut summing_cost = SummingExecutionCost::new();
259+
260+
// Check if this is a branching function by examining the node's expression
261+
let is_branching = is_node_branching(node);
262+
263+
if is_branching {
264+
match &node.expr {
265+
CostExprNode::NativeFunction(NativeFunctions::If)
266+
| CostExprNode::NativeFunction(NativeFunctions::Match) => {
267+
// TODO match?
268+
if node.children.len() >= 2 {
269+
let condition_cost = calculate_total_cost_with_summing(&node.children[0]);
270+
let condition_total = condition_cost.add_all();
271+
272+
// Add the root cost + condition cost to each branch
273+
let mut root_and_condition = node.cost.min.clone();
274+
let _ = root_and_condition.add(&condition_total);
275+
276+
for child_cost_node in node.children.iter().skip(1) {
277+
let branch_cost = calculate_total_cost_with_summing(child_cost_node);
278+
let branch_total = branch_cost.add_all();
279+
280+
let mut path_cost = root_and_condition.clone();
281+
let _ = path_cost.add(&branch_total);
282+
283+
summing_cost.add_cost(path_cost);
284+
}
285+
}
286+
}
287+
_ => {
288+
// For other branching functions, fall back to sequential processing
289+
let mut total_cost = node.cost.min.clone();
290+
for child_cost_node in &node.children {
291+
let child_summing = calculate_total_cost_with_summing(child_cost_node);
292+
let combined_cost = child_summing.add_all();
293+
let _ = total_cost.add(&combined_cost);
294+
}
295+
summing_cost.add_cost(total_cost);
296+
}
297+
}
298+
} else {
299+
// For non-branching, add all costs sequentially
300+
let mut total_cost = node.cost.min.clone();
301+
for child_cost_node in &node.children {
302+
let child_summing = calculate_total_cost_with_summing(child_cost_node);
303+
let combined_cost = child_summing.add_all();
304+
let _ = total_cost.add(&combined_cost);
305+
}
306+
summing_cost.add_cost(total_cost);
307+
}
308+
309+
summing_cost
310+
}
311+
312+
impl From<SummingExecutionCost> for StaticCost {
313+
fn from(summing: SummingExecutionCost) -> Self {
314+
StaticCost {
315+
min: summing.min(),
316+
max: summing.max(),
317+
}
318+
}
319+
}
320+
321+
/// Helper: calculate min & max costs for a given cost function
322+
/// This is likely tooo simplistic but for now it'll do
323+
fn get_costs(
324+
cost_fn: fn(u64) -> Result<ExecutionCost, VmExecutionError>,
325+
arg_count: u64,
326+
) -> Result<ExecutionCost, String> {
327+
let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?;
328+
Ok(cost)
329+
}

0 commit comments

Comments
 (0)