Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci-platform-snitch-tiled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ jobs:
{"name":"Kernels/Integer/Softmax/Large","L1":[5000,10000]},

{"name":"Kernels/FP32/Softmax/Regular","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/RMSNorm_fused","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/MatMul","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/Add/Regular","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/Hardswish","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/Div","L1":[2000,5000,10000]},

{"name":"Kernels/FP32/GEMM/Regular","L1":[2000,5000,10000]},
{"name":"Kernels/FP32/GEMM/TransB","L1":[2000,5000,10000]},
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/ci-platform-snitch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ jobs:
docker-image: ${{ needs.select-env.outputs.image }}
test-names: |
Kernels/FP32/Softmax/Regular
Kernels/FP32/RMSNorm_fused
Kernels/FP32/MatMul
Kernels/FP32/Add/Regular
Kernels/FP32/Hardswish
Kernels/FP32/Div

Kernels/Integer/Add/Large
Kernels/Integer/Add/Regular
Expand Down
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ if(TOOLCHAIN STREQUAL GCC)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()

set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, Generic, Snitch)")
set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open Generic Snitch)
set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, Generic, Snitch, Snitch_tiled)")
set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open Generic Snitch Snitch_tiled)

if(platform STREQUAL MemPool)
message(STATUS "Building for platform 'MemPool'")
Expand All @@ -36,6 +36,8 @@ elseif(platform STREQUAL Generic)
message(STATUS "Building for platform 'Generic'")
elseif(platform STREQUAL Snitch)
message(STATUS "Building for platform 'Snitch'")
elseif(platform STREQUAL Snitch_tiled)
message(STATUS "Building for platform 'Snitch_tiled'")
elseif(platform STREQUAL SoftHier)
message(STATUS "Building for platform 'SoftHier'")
elseif(platform STREQUAL Chimera)
Expand Down Expand Up @@ -211,7 +213,7 @@ if(platform STREQUAL Siracusa OR platform STREQUAL Siracusa_w_neureka OR platfor

endif()

if(platform STREQUAL Snitch)
if(platform STREQUAL Snitch OR platform STREQUAL Snitch_tiled)

if(TOOLCHAIN STREQUAL LLVM)
set(CMAKE_TOOLCHAIN_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/snitch/toolchain_llvm.cmake)
Expand Down
3 changes: 3 additions & 0 deletions Deeploy/Targets/Generic/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@
BasicConcatBindings = [
NodeBinding(ConcatChecker([PointerClass(type), PointerClass(type)], [PointerClass(type)]),
ConcatTemplate.referenceTemplate, BasicTransformer) for type in IntegerDataTypes
] + [
NodeBinding(ConcatChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
ConcatTemplate.referenceTemplate, BasicTransformer)
]

BasicQuantBindings = [
Expand Down
28 changes: 28 additions & 0 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,31 @@ def computeOps(self):
numPx = opRep['dim_im_out_x']

return numPx * opsPerPx


class RMSNormLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)

def computeOps(self):
# RMSNorm: square, mean, sqrt, div, mul
size = self.mapper.parser.operatorRepresentation['size']
lastDimLength = self.mapper.parser.operatorRepresentation['lastDimLength']
batch_size = size // lastDimLength

# square + sum + mean + eps + sqrt + div + mul
ops = size + batch_size * lastDimLength + batch_size * 4 + size * 2
return ops


class HardSwishLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)

def computeOps(self):
# HardSwish(x) = x * clip(x/6 + 0.5, 0, 1)
# Operations: div + add + clip + mul
size = self.mapper.parser.operatorRepresentation['size']
return size * 4
47 changes: 43 additions & 4 deletions Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,23 +467,62 @@ def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:

ret = all([len(node.inputs) == 2, len(node.outputs) == 1])

return ret

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

data_in_1 = ctxt.lookup(node.inputs[0].name)
data_in_2 = ctxt.lookup(node.inputs[1].name)
data_out = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['data_in_1'] = data_in_1.name
self.operatorRepresentation['data_in_2'] = data_in_2.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = np.prod(data_in_1.shape)
self.operatorRepresentation['size'] = np.prod(data_out.shape)

# Check if broadcasting is needed
shape1 = list(data_in_1.shape)
shape2 = list(data_in_2.shape)
out_shape = list(data_out.shape)

need_broadcast = (shape1 != out_shape) or (shape2 != out_shape)
self.operatorRepresentation['need_broadcast'] = need_broadcast

if need_broadcast:
# Calculate strides for broadcasting
ndim = len(out_shape)

# Compute strides for input 1
strides1 = [1] * ndim
for i in range(ndim - 1, -1, -1):
if i < len(shape1) and shape1[i] == out_shape[i]:
if i == ndim - 1:
strides1[i] = 1
else:
strides1[i] = strides1[i + 1] * shape1[i + 1] if (
i + 1 < len(shape1) and shape1[i + 1] == out_shape[i + 1]) else strides1[i + 1]
else:
strides1[i] = 0 # Broadcast dimension

# Compute strides for input 2
strides2 = [1] * ndim
for i in range(ndim - 1, -1, -1):
if i < len(shape2) and shape2[i] == out_shape[i]:
if i == ndim - 1:
strides2[i] = 1
else:
strides2[i] = strides2[i + 1] * shape2[i + 1] if (
i + 1 < len(shape2) and shape2[i + 1] == out_shape[i + 1]) else strides2[i + 1]
else:
strides2[i] = 0 # Broadcast dimension

self.operatorRepresentation['ndim'] = ndim
self.operatorRepresentation['strides1'] = strides1
self.operatorRepresentation['strides2'] = strides2
self.operatorRepresentation['out_shape'] = out_shape
Comment on lines +494 to +525
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's read the actual code in the file
wc -l Deeploy/Targets/Generic/Parsers.py

Repository: pulp-platform/Deeploy

Length of output: 104


🏁 Script executed:

#!/bin/bash
# Read the relevant section of the code
sed -n '480,530p' Deeploy/Targets/Generic/Parsers.py

Repository: pulp-platform/Deeploy

Length of output: 2170


🏁 Script executed:

#!/bin/bash
# Search for test files that test AddParser with broadcasting
find . -type f -name "*.py" -path "*test*" | head -20

Repository: pulp-platform/Deeploy

Length of output: 832


🏁 Script executed:

#!/bin/bash
# Look for AddParser tests specifically
rg -l "AddParser" --type=py | grep -i test

Repository: pulp-platform/Deeploy

Length of output: 47


🏁 Script executed:

#!/bin/bash
# Search for test files that might test AddParser
rg "AddParser" --type=py DeeployTest/

Repository: pulp-platform/Deeploy

Length of output: 47


🏁 Script executed:

#!/bin/bash
# Search for Add operator tests in test files
rg -i "Add\|broadcast" --type=py DeeployTest/ -A3 -B3 | head -100

Repository: pulp-platform/Deeploy

Length of output: 47


🏁 Script executed:

#!/bin/bash
# Check the AddParser class definition more carefully
sed -n '450,530p' Deeploy/Targets/Generic/Parsers.py

Repository: pulp-platform/Deeploy

Length of output: 3564


🏁 Script executed:

#!/bin/bash
# Look for test ONNX models that might use Add with broadcasting
find . -name "*.onnx" -o -name "*.pb" | head -20

Repository: pulp-platform/Deeploy

Length of output: 1241


🏁 Script executed:

#!/bin/bash
# Let's examine one of the Add test cases to understand the shapes used
python3 << 'EOF'
import sys
try:
    import onnx
    model = onnx.load('./DeeployTest/Tests/Kernels/FP32/Add/Large/network.onnx')
    print("Model graph inputs and outputs:")
    for input_tensor in model.graph.input:
        print(f"Input: {input_tensor.name}")
        if input_tensor.type.tensor_type.shape.dim:
            dims = [d.dim_value for d in input_tensor.type.tensor_type.shape.dim]
            print(f"  Shape: {dims}")
    
    for output_tensor in model.graph.output:
        print(f"Output: {output_tensor.name}")
        if output_tensor.type.tensor_type.shape.dim:
            dims = [d.dim_value for d in output_tensor.type.tensor_type.shape.dim]
            print(f"  Shape: {dims}")
    
    for node in model.graph.node:
        print(f"Node: {node.op_type} - {node.name}")
        print(f"  Inputs: {list(node.input)}")
        print(f"  Outputs: {list(node.output)}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    sys.exit(1)
EOF

Repository: pulp-platform/Deeploy

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Check the other Add test case
python3 << 'EOF'
import sys
try:
    import onnx
    model = onnx.load('./DeeployTest/Tests/Kernels/FP32/Add/Regular/network.onnx')
    print("Model graph inputs and outputs:")
    for input_tensor in model.graph.input:
        print(f"Input: {input_tensor.name}")
        if input_tensor.type.tensor_type.shape.dim:
            dims = [d.dim_value for d in input_tensor.type.tensor_type.shape.dim]
            print(f"  Shape: {dims}")
    
    for output_tensor in model.graph.output:
        print(f"Output: {output_tensor.name}")
        if output_tensor.type.tensor_type.shape.dim:
            dims = [d.dim_value for d in output_tensor.type.tensor_type.shape.dim]
            print(f"  Shape: {dims}")
    
    for node in model.graph.node:
        print(f"Node: {node.op_type} - {node.name}")
        print(f"  Inputs: {list(node.input)}")
        print(f"  Outputs: {list(node.output)}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    sys.exit(1)
EOF

Repository: pulp-platform/Deeploy

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for other test files that might test broadcasting scenarios
find DeeployTest/Tests -name "*.onnx" | xargs -I {} sh -c 'echo "=== {} ===" && python3 << EOF 2>&1 | head -20
import onnx
try:
    model = onnx.load("{}")
    for node in model.graph.node:
        if node.op_type == "Add":
            print(f"Found Add node: {node.name}")
except:
    pass
EOF' | grep -A10 "Found Add"

Repository: pulp-platform/Deeploy

Length of output: 47


🏁 Script executed:

#!/bin/bash
# Let's manually trace through the stride calculation logic
# to verify the reviewer's concern about rank differences
python3 << 'EOF'
# Simulate the stride calculation as it appears in the code

def compute_strides_original(shape, out_shape):
    """Original code logic"""
    ndim = len(out_shape)
    strides = [1] * ndim
    for i in range(ndim - 1, -1, -1):
        if i < len(shape) and shape[i] == out_shape[i]:
            if i == ndim - 1:
                strides[i] = 1
            else:
                strides[i] = strides[i + 1] * shape[i + 1] if (
                    i + 1 < len(shape) and shape[i + 1] == out_shape[i + 1]) else strides[i + 1]
        else:
            strides[i] = 0
    return strides

# Test case 1: Same rank (should work)
shape1 = [1, 3, 4]
out_shape = [2, 3, 4]
result = compute_strides_original(shape1, out_shape)
print(f"Test 1 - shape1={shape1}, out_shape={out_shape}")
print(f"  Result: {result}")
print(f"  Expected: [0, 4, 1]")
print()

# Test case 2: Different rank (reviewer's concern)
shape1 = [4]
out_shape = [2, 3, 4]
result = compute_strides_original(shape1, out_shape)
print(f"Test 2 - shape1={shape1}, out_shape={out_shape}")
print(f"  Result: {result}")
print(f"  Expected: [0, 0, 1] (shape1 implicitly [1, 1, 4] per ONNX broadcasting)")
print()

# Test case 3: Another different rank scenario
shape1 = [3, 4]
out_shape = [2, 3, 4]
result = compute_strides_original(shape1, out_shape)
print(f"Test 3 - shape1={shape1}, out_shape={out_shape}")
print(f"  Result: {result}")
print(f"  Expected: [0, 4, 1] (shape1 implicitly [1, 3, 4])")
print()

# Now test the suggested fix
def compute_strides_fixed(shape, out_shape):
    """Fixed version with padding"""
    ndim = len(out_shape)
    # Pad shape from the left to match ndim
    padded_shape = [1] * (ndim - len(shape)) + shape
    
    strides = [1] * ndim
    for i in range(ndim - 1, -1, -1):
        if padded_shape[i] == out_shape[i]:
            if i == ndim - 1:
                strides[i] = 1
            else:
                strides[i] = strides[i + 1] * padded_shape[i + 1] if (
                    padded_shape[i + 1] == out_shape[i + 1]) else strides[i + 1]
        else:
            strides[i] = 0
    return strides

print("=" * 50)
print("WITH FIX (padding shapes):")
print("=" * 50)

# Test case 2 with fix
shape1 = [4]
out_shape = [2, 3, 4]
result = compute_strides_fixed(shape1, out_shape)
print(f"Test 2 - shape1={shape1}, out_shape={out_shape}")
print(f"  Result: {result}")
print(f"  Expected: [0, 0, 1]")
print()

# Test case 3 with fix
shape1 = [3, 4]
out_shape = [2, 3, 4]
result = compute_strides_fixed(shape1, out_shape)
print(f"Test 3 - shape1={shape1}, out_shape={out_shape}")
print(f"  Result: {result}")
print(f"  Expected: [0, 4, 1]")
EOF

Repository: pulp-platform/Deeploy

Length of output: 706


🏁 Script executed:

#!/bin/bash
# Look for documentation or comments about ONNX broadcasting in the codebase
rg -i "onnx.*broadcast|broadcast.*onnx" --type=py -B2 -A2 | head -50

Repository: pulp-platform/Deeploy

Length of output: 412


🏁 Script executed:

#!/bin/bash
# Check if there are any other stride calculations in the codebase for reference
rg "stride" Deeploy/ --type=py -B2 -A2 | head -100

Repository: pulp-platform/Deeploy

Length of output: 8804


🏁 Script executed:

#!/bin/bash
# Let's look at the correct implementation in Snitch/Parsers.py
rg "_compute_broadcast_strides" Deeploy/Targets/Snitch/Parsers.py -A50 | head -80

Repository: pulp-platform/Deeploy

Length of output: 2043


🏁 Script executed:

#!/bin/bash
# Check if AddParser in Generic/Parsers.py should use a similar approach
sed -n '1,100p' Deeploy/Targets/Generic/Parsers.py | grep -i "import\|class"

Repository: pulp-platform/Deeploy

Length of output: 368


🏁 Script executed:

#!/bin/bash
# Let's examine the Snitch AddParser implementation more closely
rg "class.*AddParser" Deeploy/Targets/Snitch/Parsers.py -A100 | head -150

Repository: pulp-platform/Deeploy

Length of output: 3652


🏁 Script executed:

#!/bin/bash
# Check if Generic/Parsers.py imports or reuses any functions from elsewhere
grep -n "class AddParser" Deeploy/Targets/Generic/Parsers.py -A150 | head -200

Repository: pulp-platform/Deeploy

Length of output: 6781


Stride calculation is broken for inputs with different ranks — fix required.

The stride computation in AddParser.parseNodeCtxt (lines 494-525) fails when input shapes have fewer dimensions than the output shape. The code checks i < len(shape1) without padding shapes first, violating ONNX broadcasting semantics which pad shorter shapes with 1s on the left.

For example, shape1=[4] and out_shape=[2,3,4] produces strides1=[0,0,0] instead of [0,0,1].

The correct approach is already implemented in Deeploy/Targets/Snitch/Parsers.py (SnitchAddParser._compute_broadcast_strides). Pad both shapes from the left before computing strides:

         if need_broadcast:
             # Calculate strides for broadcasting
             ndim = len(out_shape)
 
+            # Pad shapes from the left to match ndim (ONNX broadcasts from right)
+            padded_shape1 = [1] * (ndim - len(shape1)) + shape1
+            padded_shape2 = [1] * (ndim - len(shape2)) + shape2
+
             # Compute strides for input 1
             strides1 = [1] * ndim
             for i in range(ndim - 1, -1, -1):
-                if i < len(shape1) and shape1[i] == out_shape[i]:
+                if padded_shape1[i] == out_shape[i]:
                     if i == ndim - 1:
                         strides1[i] = 1
                     else:
-                        strides1[i] = strides1[i + 1] * shape1[i + 1] if (
-                            i + 1 < len(shape1) and shape1[i + 1] == out_shape[i + 1]) else strides1[i + 1]
+                        strides1[i] = strides1[i + 1] * padded_shape1[i + 1] if (
+                            padded_shape1[i + 1] == out_shape[i + 1]) else strides1[i + 1]
                 else:
                     strides1[i] = 0  # Broadcast dimension

Apply the same fix to the strides2 computation below.

🤖 Prompt for AI Agents
In `@Deeploy/Targets/Generic/Parsers.py` around lines 494 - 525,
AddParser.parseNodeCtxt computes broadcasting strides incorrectly when input
ranks differ because it checks i < len(shape1/shape2) instead of left-padding
shapes with 1s per ONNX rules; update the code to left-pad shape1 and shape2 to
length ndim (length of out_shape) with leading 1s and then compute strides1 and
strides2 exactly as in SnitchAddParser._compute_broadcast_strides (treat
dimensions equal to out_shape as non-broadcast and set stride 0 for broadcast
dims, compute cumulative strides from the right otherwise), and apply the same
padding+stride logic to both strides1 and strides2 so examples like shape1=[4],
out_shape=[2,3,4] produce strides1=[0,0,1].


return ctxt, True

Expand Down
51 changes: 51 additions & 0 deletions Deeploy/Targets/Generic/TypeCheckers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def _inferSignedness(self, inputs: List[VariableBuffer],
return [False]


class FloatAddChecker(SignPropTypeChecker):

def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
super().__init__(input_types, output_types)

def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[int]:
return [2**(self.input_types[0].referencedType.typeWidth)]

def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
return [True]


class GatherChecker(SignPropTypeChecker):

def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
Expand Down Expand Up @@ -610,3 +624,40 @@ def _inferNumLevels(self, inputs: List[VariableBuffer],
def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
return [True]


class RMSNormChecker(SignPropTypeChecker):

def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
super().__init__(input_types, output_types)

def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[int]:
# RMSNorm: square, mean, sqrt, reciprocal, multiply
# Output precision similar to input
return [2**(self.input_types[0].referencedType.typeWidth)]

def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
# RMSNorm output can be signed (depending on input signedness)
if inputs[0]._signed:
return [True]
else:
return [False]


class HardSwishChecker(SignPropTypeChecker):

def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
super().__init__(input_types, output_types)

def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[int]:
return [2**(self.input_types[0].referencedType.typeWidth)]

def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
if inputs[0]._signed:
return [True]
else:
return [False]
122 changes: 119 additions & 3 deletions Deeploy/Targets/Snitch/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@
from Deeploy.CommonExtensions.DataTypes import float32_t, int8_t, int32_t, uint8_t
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
from Deeploy.Targets.Generic.Templates import iNoNormTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, GEMMChecker, RQAddChecker, SoftmaxChecker, iNoNormChecker
from Deeploy.Targets.Generic.Templates import ConcatTemplate, iNoNormTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, DivChecker, GatherChecker, GEMMChecker, \
HardSwishChecker, MatMulChecker, MulChecker, ReshapeChecker, RMSNormChecker, RQAddChecker, SoftmaxChecker, \
TransposeChecker, iNoNormChecker
from Deeploy.Targets.Snitch.CodeTransformationPasses import SnitchClusterTiling, SnitchCoreFilterPass, \
SnitchSynchCoresPass
from Deeploy.Targets.Snitch.DMA.SnitchDma import SnitchDma
from Deeploy.Targets.Snitch.Templates import AddTemplate, FloatGemmTemplate, RQAddTemplate, iSoftmaxTemplate
from Deeploy.Targets.Snitch.Templates import AddTemplate, FloatGemmTemplate, FloatMatMulTemplate, GatherTemplate, \
MatMulTemplate, ReshapeTemplate, RQAddTemplate, TransposeTemplate, iSoftmaxTemplate
from Deeploy.Targets.Snitch.Templates.FloatAddTemplate import referenceTemplate as FloatAddTemplate
from Deeploy.Targets.Snitch.Templates.FloatDivTemplate import referenceTemplate as FloatDivTemplate
from Deeploy.Targets.Snitch.Templates.FloatHardSwishTemplate import referenceTemplate as FloatHardSwishTemplate
from Deeploy.Targets.Snitch.Templates.FloatMulTemplate import referenceTemplate as FloatMulTemplate
from Deeploy.Targets.Snitch.Templates.FloatRMSNormTemplate import referenceTemplate as FloatRMSNormTemplate
from Deeploy.Targets.Snitch.Templates.FloatSoftmaxTemplate import FloatSoftmax_Template
from Deeploy.Targets.Snitch.Templates.GemmTemplate import SnitchGemm_Template
from Deeploy.Targets.Snitch.Templates.RqGemmTemplate import SnitchRqGemm_Template
Expand Down Expand Up @@ -45,6 +53,7 @@
ArgumentStructGeneration(),
MemoryManagementGeneration("L1"),
MemoryAwareFunctionCallClosure(writeback = False, generateStruct = True),
MemoryManagementGeneration("L2"),
MemoryManagementGeneration()
])

Expand All @@ -69,7 +78,18 @@
SnitchAddBindings = [
NodeBinding(AddChecker([PointerClass(_type), PointerClass(_type)], [PointerClass(int32_t)]),
AddTemplate.referenceTemplate, TiledTransformer) for _type in [int8_t]
] + [
# fp32 support
NodeBinding(AddChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatAddTemplate, TiledTransformer)
]

# Basic (non-tiled) FP32 Add Bindings
BasicAddBindings = [
NodeBinding(AddChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatAddTemplate, BasicTransformer)
]

SnitchGemmBindings = [
NodeBinding(
GEMMChecker([PointerClass(int8_t), PointerClass(int8_t),
Expand All @@ -90,3 +110,99 @@
PointerClass(int32_t)
], [PointerClass(int8_t)]), SnitchRqGemm_Template, TiledTransformer)
]

# RMSNorm Bindings (Tiled)
SnitchRMSNormBindings = [
NodeBinding(RMSNormChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatRMSNormTemplate, TiledTransformer)
]

# RMSNorm Bindings (Non-tiled)
BasicRMSNormBindings = [
NodeBinding(RMSNormChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatRMSNormTemplate, BasicTransformer)
]

# HardSwish Bindings (Tiled)
SnitchHardSwishBindings = [
NodeBinding(HardSwishChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatHardSwishTemplate,
TiledTransformer)
]

# HardSwish Bindings (Non-tiled)
BasicHardSwishBindings = [
NodeBinding(HardSwishChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatHardSwishTemplate,
BasicTransformer)
]

# Div Bindings (Tiled)
SnitchDivBindings = [
NodeBinding(DivChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatDivTemplate, TiledTransformer)
]

# Div Bindings (Non-tiled)
BasicDivBindings = [
NodeBinding(DivChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatDivTemplate, BasicTransformer)
]

# Mul Bindings (Tiled)
SnitchMulBindings = [
NodeBinding(MulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatMulTemplate, TiledTransformer)
]

# Mul Bindings (Non-tiled)
BasicMulBindings = [
NodeBinding(MulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatMulTemplate, BasicTransformer)
]

# MatMul Bindings (Tiled)
SnitchMatMulBindings = [
NodeBinding(MatMulChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
MatMulTemplate.referenceTemplate, TiledTransformer),
NodeBinding(MatMulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatMatMulTemplate.referenceTemplate, TiledTransformer)
]

# Concat Bindings (Tiled)
SnitchConcatBindings = [
NodeBinding(ConcatChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int8_t)]),
ConcatTemplate.referenceTemplate, TiledTransformer),
NodeBinding(ConcatChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]),
ConcatTemplate.referenceTemplate, TiledTransformer),
NodeBinding(ConcatChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
ConcatTemplate.referenceTemplate, TiledTransformer)
]

# Transpose Bindings (Tiled)
SnitchTransposeBindings = [
NodeBinding(TransposeChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), TransposeTemplate.referenceTemplate,
TiledTransformer),
NodeBinding(TransposeChecker([PointerClass(int32_t)], [PointerClass(int32_t)]), TransposeTemplate.referenceTemplate,
TiledTransformer),
NodeBinding(TransposeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
TransposeTemplate.referenceTemplate, TiledTransformer)
]

# Reshape Bindings (Tiled)
SnitchReshapeBindings = [
NodeBinding(ReshapeChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), ReshapeTemplate.referenceTemplate,
TiledTransformer),
NodeBinding(ReshapeChecker([PointerClass(int32_t)], [PointerClass(int32_t)]), ReshapeTemplate.referenceTemplate,
TiledTransformer),
NodeBinding(ReshapeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), ReshapeTemplate.referenceTemplate,
TiledTransformer)
]

# Gather Bindings (Tiled)
SnitchGatherBindings = [
NodeBinding(GatherChecker([PointerClass(int8_t), PointerClass(int32_t)], [PointerClass(int8_t)]),
GatherTemplate.referenceTemplate, TiledTransformer),
NodeBinding(GatherChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]),
GatherTemplate.referenceTemplate, TiledTransformer),
NodeBinding(GatherChecker([PointerClass(float32_t), PointerClass(int32_t)], [PointerClass(float32_t)]),
GatherTemplate.referenceTemplate, TiledTransformer)
]
Loading