diff --git a/docs/EasyBuilder.md b/docs/EasyBuilder.md new file mode 100644 index 000000000..f9f018977 --- /dev/null +++ b/docs/EasyBuilder.md @@ -0,0 +1,345 @@ +# Easy builder for building IR in C++ + +Sometimes in the transformations of MLIR, developers may need to use C++ to +build complex IR. For example, developers may expand math operators `exp` to +polynormial expressions when mathematical approximation is allowed. This may +result in tens of calls to `builder.create(loc, ...)` in the expansion pass +implementation. Another example is that when a pass expands and tiles a `matmul` +operation, the developer may need to write lines of code to create `scf.for` and +`scf.if`, and manipulate the insertion point for the `OpBuilder` to create IR +with complex control flow to schedule the tile based on thread-id and cache +size. One can imagine that the C++ code to accomplish the above task will be +verbose and hard to read. The easy-builder utilities are introduced to +make it easier to develop C++ code building complex IR. Easy-builder is not +designed to replace the `OpBuilder`. Instead, it is built upon that, and serves +as supplementary IR builder for complex cases, like heavily using `arith` and +`scf` operations. + +An example code to build IR `(x+y-10)/(x-y+1)`, where `x` and `y` are unsigned +16-bit integers: + +```C++ +OpBuilder builder = ...; +Value x = ...; +Value y = ...; +Location loc = ...; + +// start of easy-build +EasyBuilder b{builder, loc}; +EBUnsigned wx = b.wrap(x); +EBUnsigned wy = b.wrap(y); +EBUnsigned result = (wx + wy - uint16_t(10)) / (wx - wy + uint16_t(1)); +``` + +The `result` above can be implicitly converted to `Value` type. + +In contrast, the code using `OpBuilder` to do the same task will have more lines +of code and less readablity: + +```C++ +OpBuilder builder = ...; +Value x = ...; +Value y = ...; +Location loc = ...; + +auto x_p_y = builder.create(loc, x, y); +auto v10 = builder.create(loc, 10, /*width*/16); +auto x_p_y_m10 = builder.create(loc, x_p_y, v10); +auto x_m_y = builder.create(loc, x, y); +auto v1 = builder.create(loc, 1, /*width*/16); +auto x_m_y_p1 = builder.create(loc, x_m_y, v1); +Value result = builder.create(loc, x_p_y_m10, x_m_y_p1); +``` + +[TOC] + +## Overall Design + +There are some observations on the C++ code for IR creation in MLIR transforms: + +1. Consecutive calls to `builder.create(loc, ...)` often use the same + `OpBuilder` and `Location`. Many mutations in the MLIR passes are to expand + an operation into a sequence of operations. So expanded operations will share + the same builder and the same location of the original operation. +2. There is currently no C++ operator overloading on MLIR's `Value` class. + Developers have to break a long arithmatical expression into calls to the + `create<..>` function. +3. It would be helpful if MLIR provides helper functions to convert C++ types + (like `uint16_t` and `OpFoldResult`) into `Value`. +4. C++ RAII and use of macros could be helpful for building + structured-control-flow IR (not just `scf` operations). + +Easy-build is designed to improve the `OpBuilder` in the above aspects. It is +also extendable for different dialects and operations - developers can extend +easy-builder for a new dialect with reasonable efforts. + +Easy-build provides `EBValue` class ("EB" here stands for "easy-build") as a +wrapper for MLIR's `Value` objects. `EBValue` also stores the reference to the +`OpBuilder` and `Location`. An `EBValue` is self-contained for creating new +operations based on it. This enables C++ operator overloading on `EBValue`. Note +that it is hard to implement operator overloading on MLIR's `Value` to build new +operation, because it lacks information of the `OpBuilder` and `Location` of the +new operation to create. + +## Design details + +Most of easy-build's APIs are defined in `mlir::easybuild` namespace. This +section will introduc the key data structures of easy-build. + +### EasyBuildState + +This class holds the "states" of an easy-builder. It contains the a `Location` +object, a reference to the `OpBuilder` and other configurations of an an +easy-builder. A shared-ptr to a `EasyBuildState` is attached to every non-empty +`EBValue`. Further creation of new operations based on a `EBValue` should use +the `OpBuilder` and `Location` in the referenced `EasyBuildState`. The newly +created `EBValue` should hold a shared-ptr to the same `EasyBuildState` of its +operands. + +```c++ +struct EasyBuildState { + OpBuilder &builder; + Location loc; + ... +}; +``` + +### EBValue + +This class is essentially a `mlir::Value` with shared-ptr to `EasyBuildState`. +`EBValue` is a general base class for any values and itself does not enable C++ +operator overloading for IR creation. However, developers can inherit this class +to restrict the `Value` to be held in a `EBValue` and enable some specific +IR-building utilities. In the example at the beginning of this document, a +subclass `EBUnsigned` is used to hold `Value` of "unsigned integer". The line +`EBUnsigned wx = b.wrap(x);` converts `x` of `Value` to +`EBUnsigned`. If `x`'s type is not compatible to `EBUnsigned`, a runtime +assertion failure may occur. Easy-build for `arith` dialects enables operator +overloads for `EBUnsigned` like: + +```c++ +EBUnsigned operator+(EBUnsigned a, EBUnsigned b) { + return EBUnsigned {a.builder, + a.builder->builder.template create(a.builder->loc, + a.v, b.v)}; +} +``` + +The created `EBUnsigned` should share the same `EasyBuildState` pointer of the +operands. + +An `EBValue` object can be implictly be converted to `Value`: + +```c++ +EBValue wrapped = ...; +Value v = wrapped; +``` + +Please refer to sections [Easy-build for arith dialect](#Easy-build-for-arith-dialect) +for the subclasses of `EBValue` for arith operations. See also +[Extending easy-build for dialects](#Extending-easy-build-for-dialects) for +extending `EBValue` for a new dialect. + +### EasyBuilder + +The `EasyBuilder` is a utility class for + +1. creating an initial `EasyBuildState` object +2. wrapping `Value`, C++ numerical values (e.g. `uint32_t`, `float`) or + `OpFoldResult` into `EBValue` or its subclasses, and setting the shared-ptr + `EasyBuildState` of the created values. +3. setting `Location` for the next created operation + +The use of easy-build usually starts from creating an EasyBuilder. The +constructor of it will internally create an `EasyBuildState` object with the +given `OpBuilder` and `Location`. + +#### Wrapping various values to EBValue + +To convert various types of C++ values to `EBValue` or its subclasses, +`EasyBuilder` provides a template function `EasyBuilder::wrap(V)` to convert +C++ type `V` into `T`, which is a subclass of `EBValue`. The result `EBValue` or +its subclasses's should hold a shared-ptr pointing to the `EasyBuildState` of +this `EasyBuilder`. `EasyBuilder::wrap` may introduce runtime type checking for +the input value (implementation provided by the class `T`). When the convertion +fails, an assertion failure may happen at the runtime. An example of such case +is when we try to wrap a `memref` typed `Value` to `EBUnsigned`, the convertion +should fail, because `memref` are not arithmetic values. + +`EasyBuilder` provides a similar function `EasyBuilder::wrapOrFail(V)` which +returns `FailureOr`. It has similar functionality of `wrap()`, except that it +returns `failure()` when the convertion fails, instead of triggering an runtime +abortion. + +```C++ +#include "mlir/Dialect/Arith/Utils/EasyBuild.h" + +EasyBuilder b {...}; +Value v = builder.create(...); +EBFloatPoint u1 = b.wrap(v); // OK +FailureOr u1 = b.wrapOrFail(v); +assert(failed(u1)); // convertion should fail +``` + +`EasyBuilder` overrides `operator()` to provide convenience converter for +general `Value` to `EBValue` and arithmetic C++ values to the corresponding +subclass of `EBValue`: + +```C++ +#include "mlir/Dialect/Arith/Utils/EasyBuild.h" + +EasyBuilder b {...}; +Value v = ...; +EBValue v1 = b(v); // wrap Value to base class EBValue +EBUnsigned u1 = b(uint32_t(2)); // creating arith.constant of i32 +EBFloatPoint u1 = b(2.0f); // creating arith.constant of f32 +``` + +#### Setting source location + +Users can set the `Location` to be used in the `OpBuilder::create` after a call +to `EasyBuilder::setLoc()`. New operations related to a `EasyBuildState` will be +created with the new `Location` set by `EasyBuilder::setLoc()`. The previously +created operations' location before calling `setLoc()` will not be changed. + +#### Creating operations using EBValues as inputs + +Developers can call the template member function `F(...)` of +`EasyBuilder` to create a new operation of type `TOp` and wrap the result to +type `TValue`, which is `EBValue` or its subclasses. This method is used to +generate general operations for `EBValue`s. The operation will be created with +the current `OpBuilder` and `Location` of the `EasyBuilder`. For example, to +create an `mydialect::MyOp` operation with given `EBValue` as operands and get +the single result value of the operation as `EBUnsigned`: + +```c++ +EasyBuilder b {...}; +EBValue v1 = ...; +EBUnsigned v2 = b.F(v1); +``` + +## Typical workflow for using easy-build + +To use easy-build, a developer may first include the easy-build header and +optionally include the header for the subclass of `EBValue` for a dialect. + +```C++ +#include "mlir/IR/EasyBuild.h" +#include "mlir/Dialect/Arith/Utils/EasyBuild.h" +``` + +Then in the code building the IR, create a `EasyBuilder` with an existing +`OpBuilder` and `Location`: + +```c++ +using namespace easybuild; +Operation* originOp = ...; +OpBuilder builder {originOp}; +Location loc = originOp->getLoc(); +EasyBuilder b {builder, loc}; +``` + +Wrap `Value` or other C++ values to `EBValue` or its subclasses: + +```c++ +auto input1 = b.wrap(originOp->getOperand(0)); +auto input2 = b.wrap(originOp->getOperand(1)); +``` + +Generate operations via the wrapped values. The insert point and `Location` is +defined by the `OpBuilder` inside of `EasyBuildState`. The results can be used +as `Value`: + +```c++ +Value result = input1 + input2; +``` + +## Easy-build for arith dialect + +Subclasses of `EBValue` have been defined for `arith` operations, including +`EBUnsigned`, `EBSigned` and `EBFloatPoint`. These subclasses can accept +`EasyBuilder::wrap()` of input values of types of corresponding scalar types, or +their vector or tensor type. `EBUnsigned` accepts scalar, vector or tensor of +unsigned or signless integer-or-index-typed `Value`. `EBUnsigned` accepts +scalar, vector or tensor of signed or signless integer-typed `Value`. +`EBFloatPoint` accepts scalar, vector or tensor of float-point-typed `Value`. + +Developers can also wrap C++ arithmetic types (e.g. `uint32_t`, `float`) to the +corresponding `EBUnsigned`, `EBSigned` or `EBFloatPoint` type, via +`EasyBuilder::operator()`. A call to such function will generate an +`arith.constant` operation at the current insertion point. + +Similarly, `EBUnsigned`, `EBSigned` and `EBFloatPoint` enables +`EasyBuilder::wrap()` to convert from `OpFoldResult`. + +```c++ +OpFoldResult f = ...; +auto input1 = b.wrap(f); +``` + +If `OpFoldResult` contains a `Value`, `wrap` will try to convert the +extracted value to `EBUnsigned`. If it contains a constant as `Attr`, +`wrap` will create an `arith.constant` based on the type of the +`Attr`. + +Some of the C++ operators are enabled for `EBUnsigned`, `EBSigned` and +`EBFloatPoint` classes, include arithmetic `+ - * / % `, logical `& | ^`, +integer-shifting `>> <<` and comparison `> >= < <= == !=`. Using these C++ +operators will create the corresponding `arith` operations at the current +`OpBuilder` in the `EasyBuildState` of the `EBValue`. Signed and Unsigned +integers are distinguished, so that they will emit different operations for the +arith operations that is sensitive to the signess, like `divsi` or `divui`. + +## Easy-build for general structured-control-flow + +TBD + +## Extending easy-build for dialects + +To extend easy-build for new dialects or operations, developers usually need to +create a new class to inherit the `EBValue` class and define utility helper +functions for that new class. The definition of a subclass of `EBValue` should +be operation-centric. That is, the developer of a subclass of `EBValue` should +consider which operations are designed to be applied on it, instead of +considering the data type of the `Value` first. For example, when adding support +for `arith` operations, we find that the operations of it can fall into three +categories: unsigned int, signed int and float point. Thus `EBUnsigned`, +`EBSigned` and `EBFloatPoint` classes are introduced. + +The developer needs to implement the `wrapOrFail` static member function in the +subclass, to enable converting values to it via `EasyBuilder::wrap<>()`. An +example for implementing a hypothesis subclass `EBMyFloatPoint` can be: + +```c++ +struct EBMyFloatPoint : EBValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, + Value v) { + ... + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + ... + } + + using EBValue::EBValue; +}; +``` + +Developers can implement the utility functions to help to build the IR: + +```c++ +inline EBMyFloatPoint sin(EBMyFloatPoint input) { + std::shared_ptr state = input.builder; + OpBuilder& builder = state->builder; + return EBMyFloatPoint{ state, + builder.create(state->loc, input) }; +} + +inline EBMyFloatPoint operator+(EBMyFloatPoint a, EBMyFloatPoint b) { + std::shared_ptr state = a.builder; + OpBuilder& builder = state->builder; + return EBMyFloatPoint{ state, + builder.create(state->loc, a, b) }; +} +``` + diff --git a/include/gc/Dialect/Arith/Utils/EasyBuild.h b/include/gc/Dialect/Arith/Utils/EasyBuild.h new file mode 100644 index 000000000..2a45ffcee --- /dev/null +++ b/include/gc/Dialect/Arith/Utils/EasyBuild.h @@ -0,0 +1,443 @@ +//===- EasyBuild.h - Easy Arith IR Builder utilities ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the easy-build utilities for arith dialects. It +// provides the utility functions, classes and operators to make it easir to +// program arith dialect operations in C++ +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H +#define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H +#include "gc/IR/EasyBuild.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include +#include +#include + +namespace mlir { +namespace easybuild { + +namespace impl { + +template struct ToFloatType {}; + +template <> struct ToFloatType<4> { + using type = Float32Type; +}; +template <> struct ToFloatType<8> { + using type = Float64Type; +}; + +inline Type getElementType(Value v) { + auto type = v.getType(); + if (type.isa() || type.isa()) { + type = type.cast().getElementType(); + } + return type; +} + +} // namespace impl + +struct EBUnsigned; + +struct EBArithValue : public EBValue { + template + static T toIndex(const impl::StatePtr &state, uint64_t v); + + template + static auto wrapOrFail(const impl::StatePtr &state, T &&v); + + template static auto wrap(const impl::StatePtr &state, T &&v) { + auto ret = wrapOrFail(state, std::forward(v)); + if (failed(ret)) { + llvm_unreachable("Bad wrap"); + } + return *ret; + } + +protected: + using EBValue::EBValue; +}; + +struct EBUnsigned : public EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, + Value v) { + auto type = impl::getElementType(v); + if (type.isUnsignedInteger() || type.isSignlessInteger() || + type.isIndex()) { + return EBUnsigned{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) { + if (val.getType().isIndex()) + return EBUnsigned{state, state->builder.create( + state->loc, val.getInt())}; + else + return EBUnsigned{state, state->builder.create( + state->loc, val.getInt(), val.getType())}; + } + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +struct EBSigned : EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { + auto type = impl::getElementType(v); + if (type.isSignedInteger() || type.isSignlessInteger()) { + return EBSigned{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) + return EBSigned{state, state->builder.create( + state->loc, val.getInt(), val.getType())}; + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +struct EBFloatPoint : EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, + Value v) { + auto type = impl::getElementType(v); + if (type.isa()) { + return EBFloatPoint{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) + return EBFloatPoint{state, state->builder.create( + state->loc, val.getValue(), + val.getType().cast())}; + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +template +inline T EBArithValue::toIndex(const impl::StatePtr &state, uint64_t v) { + return EBUnsigned{ + state, state->builder.create(state->loc, v)}; +} + +template +inline auto EBArithValue::wrapOrFail(const impl::StatePtr &state, T &&v) { + using DT = std::decay_t; + static_assert(std::is_arithmetic_v
, "Expecting arithmetic types"); + if constexpr (std::is_same_v) { + if (state->u64AsIndex) { + return FailureOr{toIndex(state, v)}; + } + } + + if constexpr (std::is_same_v) { + return FailureOr{ + EBUnsigned{state, state->builder.create( + state->loc, static_cast(v), 1)}}; + } else if constexpr (std::is_integral_v
) { + if constexpr (!std::is_signed_v
) { + return FailureOr{EBUnsigned{ + state, state->builder.create( + state->loc, static_cast(v), sizeof(T) * 8)}}; + } else { + return FailureOr{EBSigned{ + state, state->builder.create( + state->loc, static_cast(v), sizeof(T) * 8)}}; + } + } else { + using DType = typename impl::ToFloatType::type; + return FailureOr{ + EBFloatPoint{state, state->builder.create( + state->loc, APFloat{v}, + DType::get(state->builder.getContext()))}}; + } +} + +struct OperatorHandlers { + template + static V handleBinary(const V &a, const V &b) { + assert(a.builder == b.builder); + return {a.builder, + a.builder->builder.template create(a.builder->loc, a.v, b.v)}; + } + + template + static V handleBinaryConst(const V &a, const T2 &b) { + return handleBinary(a, EBArithValue::wrap(a.builder, b)); + } + + template + static V handleBinaryConst(const T2 &a, const V &b) { + return handleBinary(EBArithValue::wrap(b.builder, a), b); + } + + template + static EBUnsigned handleCmp(const V &a, const V &b, Pred predicate) { + assert(a.builder == b.builder); + return {a.builder, a.builder->builder.template create( + a.builder->loc, predicate, a.v, b.v)}; + } + + template + static EBUnsigned handleCmpConst(const V &a, const T2 &b, Pred predicate) { + return handleCmp(a, EBArithValue::wrap(a.builder, b), predicate); + } + + template + static EBUnsigned handleCmpConst(const T2 &a, const V &b, Pred predicate) { + return handleCmp(EBArithValue::wrap(b.builder, a), b, predicate); + } + + template + static T create(const impl::StatePtr &state, Args &&...v) { + return {state, + state->builder.create(state->loc, std::forward(v)...)}; + } +}; + +#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, OPCLASS, TYPE) \ + inline TYPE operator OP(const TYPE &a, const TYPE &b) { \ + return OperatorHandlers::handleBinary(a, b); \ + } \ + template inline TYPE operator OP(const TYPE &a, T b) { \ + return OperatorHandlers::handleBinaryConst(a, b); \ + } \ + template inline TYPE operator OP(T a, const TYPE &b) { \ + return OperatorHandlers::handleBinaryConst(a, b); \ + } + +#define DEF_EASYBUILD_BINARY_OPERATOR(OP, SIGNED, UNSIGNED, FLOAT) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, FLOAT, EBFloatPoint) + +DEF_EASYBUILD_BINARY_OPERATOR(+, arith::AddIOp, arith::AddIOp, arith::AddFOp) +DEF_EASYBUILD_BINARY_OPERATOR(-, arith::SubIOp, arith::SubIOp, arith::SubFOp) +DEF_EASYBUILD_BINARY_OPERATOR(*, arith::MulIOp, arith::MulIOp, arith::MulFOp) +DEF_EASYBUILD_BINARY_OPERATOR(/, arith::DivSIOp, arith::DivUIOp, arith::DivFOp) +DEF_EASYBUILD_BINARY_OPERATOR(%, arith::RemSIOp, arith::RemUIOp, arith::RemFOp) + +#undef DEF_EASYBUILD_BINARY_OPERATOR +#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(OP, SIGNED, UNSIGNED) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) + +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(>>, arith::ShRSIOp, arith::ShRUIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(<<, arith::ShLIOp, arith::ShLIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(&, arith::AndIOp, arith::AndIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(|, arith::OrIOp, arith::OrIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(^, arith::XOrIOp, arith::XOrIOp) + +#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT +#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE + +inline EBFloatPoint operator-(const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, a.v); +} + +#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \ + EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \ + return OperatorHandlers::handleCmp(a, b, PRED); \ + } \ + template EBUnsigned operator OP(const TYPE &a, T b) { \ + return OperatorHandlers::handleCmpConst(a, b, PRED); \ + } \ + template EBUnsigned operator OP(T a, const TYPE &b) { \ + return OperatorHandlers::handleCmpConst(a, b, PRED); \ + } + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ult) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ule) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ugt) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::uge) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::eq) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ne) + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::slt) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sle) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sgt) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sge) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::eq) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::ne) + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OLT) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OLE) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OGT) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OGE) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OEQ) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::ONE) + +#undef DEF_EASYBUILD_CMP_OPERATOR + +namespace arithops { +inline EBFloatPoint castIntToFP(Type type, const EBSigned &v) { + return OperatorHandlers::create(v.builder, + type, v); +} + +inline EBFloatPoint castIntToFP(Type type, const EBUnsigned &v) { + return OperatorHandlers::create(v.builder, + type, v); +} + +template inline T castFPToInt(const EBFloatPoint &v) { + if constexpr (std::is_same_v) { + return OperatorHandlers::create(v.builder, v); + } else { + static_assert(std::is_same_v, + "Expecting EBUnsigned or EBSigned"); + return OperatorHandlers::create(v.builder, v); + } +} + +inline EBSigned ceildiv(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBUnsigned ceildiv(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBSigned floordiv(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBSigned extend(Type type, const EBSigned &a) { + return OperatorHandlers::create(a.builder, type, a); +} + +inline EBUnsigned extend(Type type, const EBUnsigned &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBFloatPoint extend(Type type, const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBSigned trunc(Type type, const EBSigned &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBFloatPoint trunc(Type type, const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, + type, a); +} + +template +inline T select(const EBUnsigned &pred, const T &trueValue, + const T &falseValue) { + static_assert(std::is_base_of_v, + "Expecting T to be a subclass of EBArithValue"); + return OperatorHandlers::create(pred.builder, pred, + trueValue, falseValue); +} + +template +inline TyTo bitcast(Type type, const TyFrom &v) { + return OperatorHandlers::create(v.builder, type, v); +} + +inline EBSigned min(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBSigned max(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBUnsigned min(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBUnsigned max(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBFloatPoint minnum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint maxnum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint minimum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint maximum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +} // namespace arithops + +} // namespace easybuild +} // namespace mlir +#endif diff --git a/include/gc/IR/EasyBuild.h b/include/gc/IR/EasyBuild.h new file mode 100644 index 000000000..da3da952d --- /dev/null +++ b/include/gc/IR/EasyBuild.h @@ -0,0 +1,94 @@ +//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the easy-build utilities core data structures for +// building IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EASYBUILD_H +#define MLIR_IR_EASYBUILD_H +#include "mlir/IR/Builders.h" +#include +#include +#include + +namespace mlir { +namespace easybuild { + +namespace impl { +struct EasyBuildState { + OpBuilder &builder; + Location loc; + bool u64AsIndex; + EasyBuildState(OpBuilder &builder, Location loc, bool u64AsIndex) + : builder{builder}, loc{loc}, u64AsIndex{u64AsIndex} {} +}; + +using StatePtr = std::shared_ptr; + +} // namespace impl + +struct EBValue { + std::shared_ptr builder; + Value v; + EBValue() = default; + EBValue(const impl::StatePtr &builder, Value v) : builder{builder}, v{v} {} + Value get() const { return v; } + operator Value() const { return v; } + + static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { + return EBValue{state, v}; + } +}; + +struct EBArithValue; + +struct EasyBuilder { + std::shared_ptr builder; + EasyBuilder(OpBuilder &builder, Location loc, bool u64AsIndex = false) + : builder{ + std::make_shared(builder, loc, u64AsIndex)} {} + EasyBuilder(const std::shared_ptr &builder) + : builder{builder} {} + void setLoc(const Location &l) { builder->loc = l; } + + template auto wrapOrFail(V &&v) { + return W::wrapOrFail(builder, std::forward(v)); + } + + template auto wrap(V &&v) { + auto ret = wrapOrFail(std::forward(v)); + if (failed(ret)) { + llvm_unreachable("wrap failed!"); + } + return *ret; + } + + template auto operator()(V &&v) { + if constexpr (std::is_convertible_v) { + return EBValue{builder, std::forward(v)}; + } else { + return wrap(std::forward(v)); + } + } + + template auto toIndex(uint64_t v) const { + return W::toIndex(builder, v); + } + + template + auto F(Args &&...v) { + return wrap( + builder->builder.create(builder->loc, std::forward(v)...)); + } +}; + +} // namespace easybuild +} // namespace mlir +#endif diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index c93735c63..9c7152441 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -17,5 +17,6 @@ function(add_mlir_unittest test_dirname) add_unittest(GCUnitTests ${test_dirname} ${ARGN}) endfunction() +add_subdirectory(Dialect) add_subdirectory(Example) diff --git a/unittests/Dialect/Arith/CMakeLists.txt b/unittests/Dialect/Arith/CMakeLists.txt new file mode 100644 index 000000000..e9776f11e --- /dev/null +++ b/unittests/Dialect/Arith/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRArithTests + EasyBuildTest.cpp) +target_link_libraries(MLIRArithTests + PRIVATE + MLIRFuncDialect + MLIRArithDialect) diff --git a/unittests/Dialect/Arith/EasyBuildTest.cpp b/unittests/Dialect/Arith/EasyBuildTest.cpp new file mode 100644 index 000000000..07eb31c5a --- /dev/null +++ b/unittests/Dialect/Arith/EasyBuildTest.cpp @@ -0,0 +1,635 @@ +//===- EasyBuildTest.cpp - Tests Arith Op Easy builders -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Arith/Utils/EasyBuild.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::easybuild; + +namespace { +class ArithTest : public ::testing::Test { +protected: + ArithTest() { + context.getOrLoadDialect(); + context.getOrLoadDialect(); + } + + mlir::MLIRContext context; +}; +} // namespace + +TEST_F(ArithTest, EasyBuildConst) { + OpBuilder builder{&context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create(loc, "funcname", + FunctionType::get(&context, {}, {})); + + builder.setInsertionPointToStart(func.addEntryBlock()); + auto i1 = b(true); + i1 = b(false); + auto i8 = b(int8_t(3)); + i8 = b(int8_t(-3)); + auto u8 = b(uint8_t(33)); + + auto i16 = b(int16_t(33)); + i16 = b(int16_t(-33)); + auto u16 = b(uint16_t(33)); + + auto i32 = b(int32_t(33)); + i32 = b(int32_t(-33)); + auto u32 = b(uint32_t(33)); + + auto i64 = b(int64_t(33)); + i64 = b(int64_t(-33)); + auto u64 = b(uint64_t(33)); + + auto idx = b.toIndex(23); + + { + EasyBuilder b2{builder, loc, /*u64AsIndex*/ true}; + auto idx2 = b(uint64_t(33)); + } + builder.create(loc); + std::string out; + llvm::raw_string_ostream os{out}; + os << func; + + const char *expected = + R"mlir(func.func @funcname() { + %true = arith.constant true + %false = arith.constant false + %c3_i8 = arith.constant 3 : i8 + %c-3_i8 = arith.constant -3 : i8 + %c33_i8 = arith.constant 33 : i8 + %c33_i16 = arith.constant 33 : i16 + %c-33_i16 = arith.constant -33 : i16 + %c33_i16_0 = arith.constant 33 : i16 + %c33_i32 = arith.constant 33 : i32 + %c-33_i32 = arith.constant -33 : i32 + %c33_i32_1 = arith.constant 33 : i32 + %c33_i64 = arith.constant 33 : i64 + %c-33_i64 = arith.constant -33 : i64 + %c33_i64_2 = arith.constant 33 : i64 + %c23 = arith.constant 23 : index + %c33_i64_3 = arith.constant 33 : i64 + return +})mlir"; + ASSERT_EQ(out, expected); +} + +#define SKIP_IF_UNEXPECTED_FP_SIZE() \ + if constexpr (sizeof(float) != 4 || sizeof(double) != 8) { \ + GTEST_SKIP(); \ + } + +TEST_F(ArithTest, EasyBuildFloatConst) { + SKIP_IF_UNEXPECTED_FP_SIZE() + OpBuilder builder{&context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create(loc, "funcname", + FunctionType::get(&context, {}, {})); + + builder.setInsertionPointToStart(func.addEntryBlock()); + auto a = b(1.0f); + auto a2 = b(1.0); + builder.create(loc); + std::string out; + llvm::raw_string_ostream os{out}; + os << func; + const char *expected = + R"mlir(func.func @funcname() { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f64 + return +})mlir"; + ASSERT_EQ(out, expected); +} + +template +static std::string composeIR(MLIRContext *context, T1 &&getA, T2 &&getB) { + OpBuilder builder{context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create( + loc, "funcname", FunctionType::get(builder.getContext(), {}, {})); + builder.setInsertionPointToStart(func.addEntryBlock()); + auto A = getA(b); + auto B = getB(b); + auto v1 = A + B; + v1 = A - B; + v1 = A * B; + v1 = A / B; + v1 = A % B; + v1 = A >> B; + v1 = A << B; + v1 = A & B; + v1 = A | B; + v1 = A ^ B; + auto cmp = A < B; + cmp = cmp & (A <= B); + cmp = cmp & (A > B); + cmp = cmp & (A >= B); + cmp = cmp ^ (A == B); + cmp = cmp ^ (A != B); + builder.create(loc); + + std::string out; + llvm::raw_string_ostream os{out}; + os << func; + return out; +} + +// check X+Y, where both X and Y are WrappedValues +TEST_F(ArithTest, EasyBuildSignedOperatorsBothValues) { + auto out = composeIR( + &context, [](EasyBuilder b) { return b(int32_t(33)); }, + [](EasyBuilder b) { return b(int32_t(31)); }); + const char *signedExpected = + R"mlir(func.func @funcname() { + %c33_i32 = arith.constant 33 : i32 + %c31_i32 = arith.constant 31 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %1 = arith.subi %c33_i32, %c31_i32 : i32 + %2 = arith.muli %c33_i32, %c31_i32 : i32 + %3 = arith.divsi %c33_i32, %c31_i32 : i32 + %4 = arith.remsi %c33_i32, %c31_i32 : i32 + %5 = arith.shrsi %c33_i32, %c31_i32 : i32 + %6 = arith.shli %c33_i32, %c31_i32 : i32 + %7 = arith.andi %c33_i32, %c31_i32 : i32 + %8 = arith.ori %c33_i32, %c31_i32 : i32 + %9 = arith.xori %c33_i32, %c31_i32 : i32 + %10 = arith.cmpi slt, %c33_i32, %c31_i32 : i32 + %11 = arith.cmpi sle, %c33_i32, %c31_i32 : i32 + %12 = arith.andi %10, %11 : i1 + %13 = arith.cmpi sgt, %c33_i32, %c31_i32 : i32 + %14 = arith.andi %12, %13 : i1 + %15 = arith.cmpi sge, %c33_i32, %c31_i32 : i32 + %16 = arith.andi %14, %15 : i1 + %17 = arith.cmpi eq, %c33_i32, %c31_i32 : i32 + %18 = arith.xori %16, %17 : i1 + %19 = arith.cmpi ne, %c33_i32, %c31_i32 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, signedExpected); +} + +// check X+Y, where X is compile-time value (like 1) and Y is WrappedValue +TEST_F(ArithTest, EasyBuildSignedOperatorsLHSConst) { + auto out = composeIR( + &context, [](EasyBuilder b) { return int32_t(33); }, + [](EasyBuilder b) { return b(int32_t(31)); }); + const char *signedExpected = + R"mlir(func.func @funcname() { + %c31_i32 = arith.constant 31 : i32 + %c33_i32 = arith.constant 33 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %c33_i32_0 = arith.constant 33 : i32 + %1 = arith.subi %c33_i32_0, %c31_i32 : i32 + %c33_i32_1 = arith.constant 33 : i32 + %2 = arith.muli %c33_i32_1, %c31_i32 : i32 + %c33_i32_2 = arith.constant 33 : i32 + %3 = arith.divsi %c33_i32_2, %c31_i32 : i32 + %c33_i32_3 = arith.constant 33 : i32 + %4 = arith.remsi %c33_i32_3, %c31_i32 : i32 + %c33_i32_4 = arith.constant 33 : i32 + %5 = arith.shrsi %c33_i32_4, %c31_i32 : i32 + %c33_i32_5 = arith.constant 33 : i32 + %6 = arith.shli %c33_i32_5, %c31_i32 : i32 + %c33_i32_6 = arith.constant 33 : i32 + %7 = arith.andi %c33_i32_6, %c31_i32 : i32 + %c33_i32_7 = arith.constant 33 : i32 + %8 = arith.ori %c33_i32_7, %c31_i32 : i32 + %c33_i32_8 = arith.constant 33 : i32 + %9 = arith.xori %c33_i32_8, %c31_i32 : i32 + %c33_i32_9 = arith.constant 33 : i32 + %10 = arith.cmpi slt, %c33_i32_9, %c31_i32 : i32 + %c33_i32_10 = arith.constant 33 : i32 + %11 = arith.cmpi sle, %c33_i32_10, %c31_i32 : i32 + %12 = arith.andi %10, %11 : i1 + %c33_i32_11 = arith.constant 33 : i32 + %13 = arith.cmpi sgt, %c33_i32_11, %c31_i32 : i32 + %14 = arith.andi %12, %13 : i1 + %c33_i32_12 = arith.constant 33 : i32 + %15 = arith.cmpi sge, %c33_i32_12, %c31_i32 : i32 + %16 = arith.andi %14, %15 : i1 + %c33_i32_13 = arith.constant 33 : i32 + %17 = arith.cmpi eq, %c33_i32_13, %c31_i32 : i32 + %18 = arith.xori %16, %17 : i1 + %c33_i32_14 = arith.constant 33 : i32 + %19 = arith.cmpi ne, %c33_i32_14, %c31_i32 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, signedExpected); +} + +// check X+Y, where Y is compile-time value (like 1) and X is WrappedValue +TEST_F(ArithTest, EasyBuildSignedOperatorsRHSConst) { + auto out = composeIR( + &context, [](EasyBuilder b) { return b(int32_t(33)); }, + [](EasyBuilder b) { return int32_t(31); }); + const char *signedExpected = + R"mlir(func.func @funcname() { + %c33_i32 = arith.constant 33 : i32 + %c31_i32 = arith.constant 31 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %c31_i32_0 = arith.constant 31 : i32 + %1 = arith.subi %c33_i32, %c31_i32_0 : i32 + %c31_i32_1 = arith.constant 31 : i32 + %2 = arith.muli %c33_i32, %c31_i32_1 : i32 + %c31_i32_2 = arith.constant 31 : i32 + %3 = arith.divsi %c33_i32, %c31_i32_2 : i32 + %c31_i32_3 = arith.constant 31 : i32 + %4 = arith.remsi %c33_i32, %c31_i32_3 : i32 + %c31_i32_4 = arith.constant 31 : i32 + %5 = arith.shrsi %c33_i32, %c31_i32_4 : i32 + %c31_i32_5 = arith.constant 31 : i32 + %6 = arith.shli %c33_i32, %c31_i32_5 : i32 + %c31_i32_6 = arith.constant 31 : i32 + %7 = arith.andi %c33_i32, %c31_i32_6 : i32 + %c31_i32_7 = arith.constant 31 : i32 + %8 = arith.ori %c33_i32, %c31_i32_7 : i32 + %c31_i32_8 = arith.constant 31 : i32 + %9 = arith.xori %c33_i32, %c31_i32_8 : i32 + %c31_i32_9 = arith.constant 31 : i32 + %10 = arith.cmpi slt, %c33_i32, %c31_i32_9 : i32 + %c31_i32_10 = arith.constant 31 : i32 + %11 = arith.cmpi sle, %c33_i32, %c31_i32_10 : i32 + %12 = arith.andi %10, %11 : i1 + %c31_i32_11 = arith.constant 31 : i32 + %13 = arith.cmpi sgt, %c33_i32, %c31_i32_11 : i32 + %14 = arith.andi %12, %13 : i1 + %c31_i32_12 = arith.constant 31 : i32 + %15 = arith.cmpi sge, %c33_i32, %c31_i32_12 : i32 + %16 = arith.andi %14, %15 : i1 + %c31_i32_13 = arith.constant 31 : i32 + %17 = arith.cmpi eq, %c33_i32, %c31_i32_13 : i32 + %18 = arith.xori %16, %17 : i1 + %c31_i32_14 = arith.constant 31 : i32 + %19 = arith.cmpi ne, %c33_i32, %c31_i32_14 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, signedExpected); +} + +// check X+Y, where both X and Y are WrappedValues +TEST_F(ArithTest, EasyBuildUnsignedOperatorsBothValues) { + auto out = composeIR( + &context, [](EasyBuilder b) { return b(uint32_t(33)); }, + [](EasyBuilder b) { return b(uint32_t(31)); }); + const char *unsignedExpected = + R"mlir(func.func @funcname() { + %c33_i32 = arith.constant 33 : i32 + %c31_i32 = arith.constant 31 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %1 = arith.subi %c33_i32, %c31_i32 : i32 + %2 = arith.muli %c33_i32, %c31_i32 : i32 + %3 = arith.divui %c33_i32, %c31_i32 : i32 + %4 = arith.remui %c33_i32, %c31_i32 : i32 + %5 = arith.shrui %c33_i32, %c31_i32 : i32 + %6 = arith.shli %c33_i32, %c31_i32 : i32 + %7 = arith.andi %c33_i32, %c31_i32 : i32 + %8 = arith.ori %c33_i32, %c31_i32 : i32 + %9 = arith.xori %c33_i32, %c31_i32 : i32 + %10 = arith.cmpi ult, %c33_i32, %c31_i32 : i32 + %11 = arith.cmpi ule, %c33_i32, %c31_i32 : i32 + %12 = arith.andi %10, %11 : i1 + %13 = arith.cmpi ugt, %c33_i32, %c31_i32 : i32 + %14 = arith.andi %12, %13 : i1 + %15 = arith.cmpi uge, %c33_i32, %c31_i32 : i32 + %16 = arith.andi %14, %15 : i1 + %17 = arith.cmpi eq, %c33_i32, %c31_i32 : i32 + %18 = arith.xori %16, %17 : i1 + %19 = arith.cmpi ne, %c33_i32, %c31_i32 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, unsignedExpected); +} + +// check X+Y, where X is compile-time value (like 1) and Y is WrappedValue +TEST_F(ArithTest, EasyBuildUnsignedOperatorsLHSConst) { + auto out = composeIR( + &context, [](EasyBuilder b) { return uint32_t(33); }, + [](EasyBuilder b) { return b(uint32_t(31)); }); + const char *unsignedExpected = + R"mlir(func.func @funcname() { + %c31_i32 = arith.constant 31 : i32 + %c33_i32 = arith.constant 33 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %c33_i32_0 = arith.constant 33 : i32 + %1 = arith.subi %c33_i32_0, %c31_i32 : i32 + %c33_i32_1 = arith.constant 33 : i32 + %2 = arith.muli %c33_i32_1, %c31_i32 : i32 + %c33_i32_2 = arith.constant 33 : i32 + %3 = arith.divui %c33_i32_2, %c31_i32 : i32 + %c33_i32_3 = arith.constant 33 : i32 + %4 = arith.remui %c33_i32_3, %c31_i32 : i32 + %c33_i32_4 = arith.constant 33 : i32 + %5 = arith.shrui %c33_i32_4, %c31_i32 : i32 + %c33_i32_5 = arith.constant 33 : i32 + %6 = arith.shli %c33_i32_5, %c31_i32 : i32 + %c33_i32_6 = arith.constant 33 : i32 + %7 = arith.andi %c33_i32_6, %c31_i32 : i32 + %c33_i32_7 = arith.constant 33 : i32 + %8 = arith.ori %c33_i32_7, %c31_i32 : i32 + %c33_i32_8 = arith.constant 33 : i32 + %9 = arith.xori %c33_i32_8, %c31_i32 : i32 + %c33_i32_9 = arith.constant 33 : i32 + %10 = arith.cmpi ult, %c33_i32_9, %c31_i32 : i32 + %c33_i32_10 = arith.constant 33 : i32 + %11 = arith.cmpi ule, %c33_i32_10, %c31_i32 : i32 + %12 = arith.andi %10, %11 : i1 + %c33_i32_11 = arith.constant 33 : i32 + %13 = arith.cmpi ugt, %c33_i32_11, %c31_i32 : i32 + %14 = arith.andi %12, %13 : i1 + %c33_i32_12 = arith.constant 33 : i32 + %15 = arith.cmpi uge, %c33_i32_12, %c31_i32 : i32 + %16 = arith.andi %14, %15 : i1 + %c33_i32_13 = arith.constant 33 : i32 + %17 = arith.cmpi eq, %c33_i32_13, %c31_i32 : i32 + %18 = arith.xori %16, %17 : i1 + %c33_i32_14 = arith.constant 33 : i32 + %19 = arith.cmpi ne, %c33_i32_14, %c31_i32 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, unsignedExpected); +} + +// check X+Y, where Y is compile-time value (like 1) and X is WrappedValue +TEST_F(ArithTest, EasyBuildUnsignedOperatorsRHSConst) { + auto out = composeIR( + &context, [](EasyBuilder b) { return b(uint32_t(33)); }, + [](EasyBuilder b) { return uint32_t(31); }); + const char *unsignedExpected = + R"mlir(func.func @funcname() { + %c33_i32 = arith.constant 33 : i32 + %c31_i32 = arith.constant 31 : i32 + %0 = arith.addi %c33_i32, %c31_i32 : i32 + %c31_i32_0 = arith.constant 31 : i32 + %1 = arith.subi %c33_i32, %c31_i32_0 : i32 + %c31_i32_1 = arith.constant 31 : i32 + %2 = arith.muli %c33_i32, %c31_i32_1 : i32 + %c31_i32_2 = arith.constant 31 : i32 + %3 = arith.divui %c33_i32, %c31_i32_2 : i32 + %c31_i32_3 = arith.constant 31 : i32 + %4 = arith.remui %c33_i32, %c31_i32_3 : i32 + %c31_i32_4 = arith.constant 31 : i32 + %5 = arith.shrui %c33_i32, %c31_i32_4 : i32 + %c31_i32_5 = arith.constant 31 : i32 + %6 = arith.shli %c33_i32, %c31_i32_5 : i32 + %c31_i32_6 = arith.constant 31 : i32 + %7 = arith.andi %c33_i32, %c31_i32_6 : i32 + %c31_i32_7 = arith.constant 31 : i32 + %8 = arith.ori %c33_i32, %c31_i32_7 : i32 + %c31_i32_8 = arith.constant 31 : i32 + %9 = arith.xori %c33_i32, %c31_i32_8 : i32 + %c31_i32_9 = arith.constant 31 : i32 + %10 = arith.cmpi ult, %c33_i32, %c31_i32_9 : i32 + %c31_i32_10 = arith.constant 31 : i32 + %11 = arith.cmpi ule, %c33_i32, %c31_i32_10 : i32 + %12 = arith.andi %10, %11 : i1 + %c31_i32_11 = arith.constant 31 : i32 + %13 = arith.cmpi ugt, %c33_i32, %c31_i32_11 : i32 + %14 = arith.andi %12, %13 : i1 + %c31_i32_12 = arith.constant 31 : i32 + %15 = arith.cmpi uge, %c33_i32, %c31_i32_12 : i32 + %16 = arith.andi %14, %15 : i1 + %c31_i32_13 = arith.constant 31 : i32 + %17 = arith.cmpi eq, %c33_i32, %c31_i32_13 : i32 + %18 = arith.xori %16, %17 : i1 + %c31_i32_14 = arith.constant 31 : i32 + %19 = arith.cmpi ne, %c33_i32, %c31_i32_14 : i32 + %20 = arith.xori %18, %19 : i1 + return +})mlir"; + ASSERT_EQ(out, unsignedExpected); +} + +template +static std::string composeFPIR(MLIRContext *context, T1 &&getA, T2 &&getB) { + OpBuilder builder{context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create( + loc, "funcname", FunctionType::get(builder.getContext(), {}, {})); + builder.setInsertionPointToStart(func.addEntryBlock()); + auto A = getA(b); + auto B = getB(b); + auto v1 = A + B; + v1 = A - B; + v1 = A * B; + v1 = A / B; + v1 = A % B; + (void)-A; + auto cmp = A < B; + cmp = cmp & (A <= B); + cmp = cmp & (A > B); + cmp = cmp & (A >= B); + cmp = cmp ^ (A == B); + cmp = cmp ^ (A != B); + builder.create(loc); + + std::string out; + llvm::raw_string_ostream os{out}; + os << func; + return out; +} + +// check X+Y, where both X and Y are WrappedValues +TEST_F(ArithTest, EasyBuildFloatOperatorsValues) { + SKIP_IF_UNEXPECTED_FP_SIZE() + auto out = composeFPIR( + &context, [](EasyBuilder b) { return b(33.0f); }, + [](EasyBuilder b) { return b(31.0f); }); + const char *expected = + R"mlir(func.func @funcname() { + %cst = arith.constant 3.300000e+01 : f32 + %cst_0 = arith.constant 3.100000e+01 : f32 + %0 = arith.addf %cst, %cst_0 : f32 + %1 = arith.subf %cst, %cst_0 : f32 + %2 = arith.mulf %cst, %cst_0 : f32 + %3 = arith.divf %cst, %cst_0 : f32 + %4 = arith.remf %cst, %cst_0 : f32 + %5 = arith.negf %cst : f32 + %6 = arith.cmpf olt, %cst, %cst_0 : f32 + %7 = arith.cmpf ole, %cst, %cst_0 : f32 + %8 = arith.andi %6, %7 : i1 + %9 = arith.cmpf ogt, %cst, %cst_0 : f32 + %10 = arith.andi %8, %9 : i1 + %11 = arith.cmpf oge, %cst, %cst_0 : f32 + %12 = arith.andi %10, %11 : i1 + %13 = arith.cmpf oeq, %cst, %cst_0 : f32 + %14 = arith.xori %12, %13 : i1 + %15 = arith.cmpf one, %cst, %cst_0 : f32 + %16 = arith.xori %14, %15 : i1 + return +})mlir"; + ASSERT_EQ(out, expected); +} + +// check X+Y, where X is compile-time value (like 1) and Y is WrappedValue +TEST_F(ArithTest, EasyBuildFloatOperatorsLHSConst) { + SKIP_IF_UNEXPECTED_FP_SIZE() + auto out = composeFPIR( + &context, [](EasyBuilder b) { return 33.0f; }, + [](EasyBuilder b) { return b(31.0f); }); + const char *expected = + R"mlir(func.func @funcname() { + %cst = arith.constant 3.100000e+01 : f32 + %cst_0 = arith.constant 3.300000e+01 : f32 + %0 = arith.addf %cst_0, %cst : f32 + %cst_1 = arith.constant 3.300000e+01 : f32 + %1 = arith.subf %cst_1, %cst : f32 + %cst_2 = arith.constant 3.300000e+01 : f32 + %2 = arith.mulf %cst_2, %cst : f32 + %cst_3 = arith.constant 3.300000e+01 : f32 + %3 = arith.divf %cst_3, %cst : f32 + %cst_4 = arith.constant 3.300000e+01 : f32 + %4 = arith.remf %cst_4, %cst : f32 + %cst_5 = arith.constant 3.300000e+01 : f32 + %5 = arith.cmpf olt, %cst_5, %cst : f32 + %cst_6 = arith.constant 3.300000e+01 : f32 + %6 = arith.cmpf ole, %cst_6, %cst : f32 + %7 = arith.andi %5, %6 : i1 + %cst_7 = arith.constant 3.300000e+01 : f32 + %8 = arith.cmpf ogt, %cst_7, %cst : f32 + %9 = arith.andi %7, %8 : i1 + %cst_8 = arith.constant 3.300000e+01 : f32 + %10 = arith.cmpf oge, %cst_8, %cst : f32 + %11 = arith.andi %9, %10 : i1 + %cst_9 = arith.constant 3.300000e+01 : f32 + %12 = arith.cmpf oeq, %cst_9, %cst : f32 + %13 = arith.xori %11, %12 : i1 + %cst_10 = arith.constant 3.300000e+01 : f32 + %14 = arith.cmpf one, %cst_10, %cst : f32 + %15 = arith.xori %13, %14 : i1 + return +})mlir"; + ASSERT_EQ(out, expected); +} + +// check X+Y, where Y is compile-time value (like 1) and X is WrappedValue +TEST_F(ArithTest, EasyBuildFloatOperatorsRHSConst) { + SKIP_IF_UNEXPECTED_FP_SIZE() + auto out = composeFPIR( + &context, [](EasyBuilder b) { return b(33.0f); }, + [](EasyBuilder b) { return 31.0f; }); + const char *expected = + R"mlir(func.func @funcname() { + %cst = arith.constant 3.300000e+01 : f32 + %cst_0 = arith.constant 3.100000e+01 : f32 + %0 = arith.addf %cst, %cst_0 : f32 + %cst_1 = arith.constant 3.100000e+01 : f32 + %1 = arith.subf %cst, %cst_1 : f32 + %cst_2 = arith.constant 3.100000e+01 : f32 + %2 = arith.mulf %cst, %cst_2 : f32 + %cst_3 = arith.constant 3.100000e+01 : f32 + %3 = arith.divf %cst, %cst_3 : f32 + %cst_4 = arith.constant 3.100000e+01 : f32 + %4 = arith.remf %cst, %cst_4 : f32 + %5 = arith.negf %cst : f32 + %cst_5 = arith.constant 3.100000e+01 : f32 + %6 = arith.cmpf olt, %cst, %cst_5 : f32 + %cst_6 = arith.constant 3.100000e+01 : f32 + %7 = arith.cmpf ole, %cst, %cst_6 : f32 + %8 = arith.andi %6, %7 : i1 + %cst_7 = arith.constant 3.100000e+01 : f32 + %9 = arith.cmpf ogt, %cst, %cst_7 : f32 + %10 = arith.andi %8, %9 : i1 + %cst_8 = arith.constant 3.100000e+01 : f32 + %11 = arith.cmpf oge, %cst, %cst_8 : f32 + %12 = arith.andi %10, %11 : i1 + %cst_9 = arith.constant 3.100000e+01 : f32 + %13 = arith.cmpf oeq, %cst, %cst_9 : f32 + %14 = arith.xori %12, %13 : i1 + %cst_10 = arith.constant 3.100000e+01 : f32 + %15 = arith.cmpf one, %cst, %cst_10 : f32 + %16 = arith.xori %14, %15 : i1 + return +})mlir"; + ASSERT_EQ(out, expected); +} + +// check wrap() +TEST_F(ArithTest, EasyBuildCheckWrap) { + OpBuilder builder{&context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create( + loc, "funcname", + FunctionType::get(&context, + {MemRefType::get({100}, IntegerType::get(&context, 16)), + IntegerType::get(&context, 16)}, + {})); + + builder.setInsertionPointToStart(func.addEntryBlock()); + // arg0 is of memref type + auto arg0 = func.getArgument(0); + EBValue wb = b(arg0); // check that it is ok to wrap generic value to EBValue + auto expectedFail = b.wrapOrFail(arg0); + ASSERT_TRUE(failed(expectedFail)); + + auto arg1 = func.getArgument(1); + auto expectedFail1 = b.wrapOrFail(arg1); + ASSERT_TRUE(failed(expectedFail1)); + auto expectedOK1 = b.wrapOrFail(arg1); + ASSERT_TRUE(succeeded(expectedOK1)); + EBUnsigned u = b.wrap(arg1); + + OpFoldResult foldresult = arg1; + expectedFail1 = b.wrapOrFail(foldresult); + ASSERT_TRUE(failed(expectedFail1)); + expectedOK1 = b.wrapOrFail(foldresult); + ASSERT_TRUE(succeeded(expectedOK1)); + u = b.wrap(foldresult); + + foldresult = builder.getIndexAttr(123); + expectedFail1 = b.wrapOrFail(foldresult); + ASSERT_TRUE(failed(expectedFail1)); + expectedOK1 = b.wrapOrFail(foldresult); + ASSERT_TRUE(succeeded(expectedOK1)); + u = b.wrap(foldresult); +} + +TEST_F(ArithTest, EasyBuildCheckOpCall) { + OpBuilder builder{&context}; + auto loc = builder.getUnknownLoc(); + EasyBuilder b{builder, loc}; + auto func = builder.create( + loc, "funcname", + FunctionType::get(&context, {IntegerType::get(&context, 16)}, {})); + + builder.setInsertionPointToStart(func.addEntryBlock()); + + auto arg0 = func.getArgument(0); + auto v = b.wrap(arg0); + auto v2 = b.F(v, b(uint16_t(1))); + v2 = b.F(v2, b(uint16_t(100))); + builder.create(loc); + + const char *expected = + R"mlir(func.func @funcname(%arg0: i16) { + %c1_i16 = arith.constant 1 : i16 + %0 = arith.minui %arg0, %c1_i16 : i16 + %c100_i16 = arith.constant 100 : i16 + %1 = arith.maxui %0, %c100_i16 : i16 + return +})mlir"; + std::string out; + llvm::raw_string_ostream os{out}; + os << func; + ASSERT_EQ(out, expected); +} diff --git a/unittests/Dialect/CMakeLists.txt b/unittests/Dialect/CMakeLists.txt new file mode 100644 index 000000000..42445996f --- /dev/null +++ b/unittests/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Arith) \ No newline at end of file