diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index 16c034702de..3cbeca98a76 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -123,7 +123,7 @@ bvt boolbvt::convert_bitvector(const exprt &expr) else if(expr.id() == ID_update_bit) return convert_update_bit(to_update_bit_expr(expr)); else if(expr.id()==ID_case) - return convert_case(expr); + return convert_case(to_case_expr(expr)); else if(expr.id()==ID_cond) return convert_cond(to_cond_expr(expr)); else if(expr.id()==ID_if) @@ -390,7 +390,7 @@ literalt boolbvt::convert_rest(const exprt &expr) } else if(expr.id()==ID_case) { - bvt bv=convert_case(expr); + bvt bv = convert_case(to_case_expr(expr)); CHECK_RETURN(bv.size() == 1); return bv[0]; } diff --git a/src/solvers/flattening/boolbv.h b/src/solvers/flattening/boolbv.h index c305b555ee0..8a452fe9405 100644 --- a/src/solvers/flattening/boolbv.h +++ b/src/solvers/flattening/boolbv.h @@ -185,7 +185,7 @@ class boolbvt:public arrayst virtual bvt convert_update(const update_exprt &); virtual bvt convert_update_bit(const update_bit_exprt &); virtual bvt convert_update_bits(const update_bits_exprt &); - virtual bvt convert_case(const exprt &expr); + virtual bvt convert_case(const case_exprt &); virtual bvt convert_cond(const cond_exprt &); virtual bvt convert_shift(const binary_exprt &expr); virtual bvt convert_bitwise(const exprt &expr); diff --git a/src/solvers/flattening/boolbv_case.cpp b/src/solvers/flattening/boolbv_case.cpp index 54eb819fd38..2192cdfa6f4 100644 --- a/src/solvers/flattening/boolbv_case.cpp +++ b/src/solvers/flattening/boolbv_case.cpp @@ -6,14 +6,13 @@ Author: Daniel Kroening, kroening@kroening.com \*******************************************************************/ -#include "boolbv.h" - #include +#include -bvt boolbvt::convert_case(const exprt &expr) -{ - PRECONDITION(expr.id() == ID_case); +#include "boolbv.h" +bvt boolbvt::convert_case(const case_exprt &expr) +{ const std::vector &operands=expr.operands(); std::size_t width=boolbv_width(expr.type()); diff --git a/src/util/std_expr.h b/src/util/std_expr.h index 4cf659779f8..c076d6fcc2c 100644 --- a/src/util/std_expr.h +++ b/src/util/std_expr.h @@ -3578,6 +3578,124 @@ inline cond_exprt &to_cond_expr(exprt &expr) return ret; } +/// \brief Case expression: evaluates to the value corresponding to the first +/// matching case. The first operand is the value to compare against. Subsequent +/// operands alternate between compare values and result values. The syntax is: +/// case(select_value, case1_value, result1, case2_value, result2, ...) +class case_exprt : public multi_ary_exprt +{ +public: + case_exprt(operandst _operands, typet _type) + : multi_ary_exprt(ID_case, std::move(_operands), std::move(_type)) + { + } + + /// Constructor with select value + case_exprt(exprt _select_value, typet _type) + : multi_ary_exprt(ID_case, {std::move(_select_value)}, std::move(_type)) + { + } + + /// Get the value that is being compared against + const exprt &select_value() const + { + PRECONDITION(!operands().empty()); + return operands()[0]; + } + + /// Get the value that is being compared against + exprt &select_value() + { + PRECONDITION(!operands().empty()); + return operands()[0]; + } + + /// Add a case: value to compare and corresponding result + /// \param case_value: the value to compare against select_value + /// \param result_value: the value to return if case_value matches + /// select_value + void add_case(const exprt &case_value, const exprt &result_value) + { + operands().reserve(operands().size() + 2); + operands().push_back(case_value); + operands().push_back(result_value); + } + + /// Get the number of cases (excluding the select value) + std::size_t number_of_cases() const + { + PRECONDITION(operands().size() >= 1); + return (operands().size() - 1) / 2; + } + + /// Get the case value for the i-th case + const exprt &case_value(std::size_t i) const + { + PRECONDITION(i < number_of_cases()); + return operands()[1 + 2 * i]; + } + + /// Get the case value for the i-th case + exprt &case_value(std::size_t i) + { + PRECONDITION(i < number_of_cases()); + return operands()[1 + 2 * i]; + } + + /// Get the result value for the i-th case + const exprt &result_value(std::size_t i) const + { + PRECONDITION(i < number_of_cases()); + return operands()[1 + 2 * i + 1]; + } + + /// Get the result value for the i-th case + exprt &result_value(std::size_t i) + { + PRECONDITION(i < number_of_cases()); + return operands()[1 + 2 * i + 1]; + } + + static void validate_expr(const case_exprt &value) + { + DATA_INVARIANT( + value.operands().size() >= 1, + "case expression must have at least one operand"); + DATA_INVARIANT( + value.operands().size() % 2 == 1, + "case expression must have odd number of operands"); + } +}; + +template <> +inline bool can_cast_expr(const exprt &base) +{ + return base.id() == ID_case; +} + +/// \brief Cast an exprt to a \ref case_exprt +/// +/// \a expr must be known to be \ref case_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref case_exprt +inline const case_exprt &to_case_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_case); + const case_exprt &ret = static_cast(expr); + case_exprt::validate_expr(ret); + return ret; +} + +/// \copydoc to_case_expr(const exprt &) +inline case_exprt &to_case_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_case); + case_exprt &ret = static_cast(expr); + case_exprt::validate_expr(ret); + return ret; +} + /// \brief Expression to define a mapping from an argument (index) to elements. /// This enables constructing an array via an anonymous function. /// Not all kinds of array comprehension can be expressed, only those of the diff --git a/unit/Makefile b/unit/Makefile index a708a86b214..978bc6fec14 100644 --- a/unit/Makefile +++ b/unit/Makefile @@ -142,6 +142,7 @@ SRC += analyses/ai/ai.cpp \ solvers/strings/string_refinement/substitute_array_list.cpp \ solvers/strings/string_refinement/union_find_replace.cpp \ util/bitvector_expr.cpp \ + util/case_expr.cpp \ util/cmdline.cpp \ util/dense_integer_map.cpp \ util/edit_distance.cpp \ diff --git a/unit/util/case_expr.cpp b/unit/util/case_expr.cpp new file mode 100644 index 00000000000..15352c7e530 --- /dev/null +++ b/unit/util/case_expr.cpp @@ -0,0 +1,97 @@ +/*******************************************************************\ + +Module: Unit tests for case_exprt + +Author: Unit test + +\*******************************************************************/ + +#include +#include +#include + +#include + +TEST_CASE("case_exprt construction and access", "[core][util][case_expr]") +{ + const signedbv_typet int_type(32); + const symbol_exprt select_value("x", int_type); + + SECTION("Basic construction") + { + case_exprt case_expr(select_value, int_type); + + REQUIRE(case_expr.id() == ID_case); + REQUIRE(case_expr.select_value() == select_value); + REQUIRE(case_expr.number_of_cases() == 0); + } + + SECTION("Adding cases") + { + case_exprt case_expr(select_value, int_type); + + const constant_exprt case1_value = from_integer(1, int_type); + const constant_exprt result1_value = from_integer(10, int_type); + + const constant_exprt case2_value = from_integer(2, int_type); + const constant_exprt result2_value = from_integer(20, int_type); + + case_expr.add_case(case1_value, result1_value); + REQUIRE(case_expr.number_of_cases() == 1); + REQUIRE(case_expr.case_value(0) == case1_value); + REQUIRE(case_expr.result_value(0) == result1_value); + + case_expr.add_case(case2_value, result2_value); + REQUIRE(case_expr.number_of_cases() == 2); + REQUIRE(case_expr.case_value(1) == case2_value); + REQUIRE(case_expr.result_value(1) == result2_value); + + // Verify operands structure: 1 select + 2*2 case/result pairs = 5 + REQUIRE(case_expr.operands().size() == 5); + // Verify odd number of operands + REQUIRE(case_expr.operands().size() % 2 == 1); + } + + SECTION("to_case_expr conversion") + { + case_exprt case_expr(select_value, int_type); + const constant_exprt case_value = from_integer(1, int_type); + const constant_exprt result_value = from_integer(10, int_type); + case_expr.add_case(case_value, result_value); + + exprt &base = case_expr; + case_exprt &converted = to_case_expr(base); + + REQUIRE(&converted == &case_expr); + REQUIRE(converted.number_of_cases() == 1); + REQUIRE(converted.case_value(0) == case_value); + } + + SECTION("can_cast_expr") + { + case_exprt case_expr(select_value, int_type); + exprt &base = case_expr; + + REQUIRE(can_cast_expr(base)); + REQUIRE_FALSE(can_cast_expr(base)); + } + + SECTION("Construction with operands") + { + const constant_exprt case_value = from_integer(1, int_type); + const constant_exprt result_value = from_integer(10, int_type); + + case_exprt::operandst ops; + ops.push_back(select_value); + ops.push_back(case_value); + ops.push_back(result_value); + + case_exprt case_expr(std::move(ops), int_type); + + REQUIRE(case_expr.id() == ID_case); + REQUIRE(case_expr.number_of_cases() == 1); + REQUIRE(case_expr.select_value() == select_value); + REQUIRE(case_expr.case_value(0) == case_value); + REQUIRE(case_expr.result_value(0) == result_value); + } +}