diff --git a/.github/workflows/ci-platform-siracusa-tiled.yml b/.github/workflows/ci-platform-siracusa-tiled.yml index dc52f6ad7f..e0006ae150 100644 --- a/.github/workflows/ci-platform-siracusa-tiled.yml +++ b/.github/workflows/ci-platform-siracusa-tiled.yml @@ -135,9 +135,7 @@ jobs: - name: "MLPerf/AnomalyDetection" L1: [64000] - name: "CCT/CCT_1_16_16_8" - L1: [2000, 64000] - - name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8" - L1: [4000, 64000] + L1: [64000] - name: "testFloatDemoTinyViT" L1: [4000] num-cores: [8] @@ -168,9 +166,9 @@ jobs: - name: "microLlama/microLlama1" L1: [60000, 10000, 5000] - name: "CCT/CCT_2_32_32_128" - L1: [64000, 128000] - - name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128" - L1: [32000, 64000] + L1: [128000] + - name: "testTrainCCT/CCT2_FT2" + L1: [128000] - name: "testFloatDemoTinyViT" L1: [4000] num-cores: [8] @@ -208,9 +206,9 @@ jobs: - name: "microLlama/microLlama8_parallel" L1: [60000, 20000, 10000] - name: "CCT/CCT_2_32_32_128" - L1: [64000, 128000] - - name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128" - L1: [8000, 64000] + L1: [128000] + - name: "testTrainCCT/CCT2_FT2" + L1: [128000] - name: "testFloatDemoTinyViT" L1: [4000] num-cores: [8] diff --git a/.github/workflows/ci-platform-siracusa.yml b/.github/workflows/ci-platform-siracusa.yml index f59f7fa884..de5dab7f6b 100644 --- a/.github/workflows/ci-platform-siracusa.yml +++ b/.github/workflows/ci-platform-siracusa.yml @@ -95,6 +95,5 @@ jobs: MLPerf/AnomalyDetection CCT/CCT_1_16_16_8 CCT/CCT_2_32_32_128_Opset20 - testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8 testFloatDemoTinyViT num-cores: 8 diff --git a/Deeploy/Targets/Generic/Layers.py b/Deeploy/Targets/Generic/Layers.py index c924895c13..ea86221108 100644 --- a/Deeploy/Targets/Generic/Layers.py +++ b/Deeploy/Targets/Generic/Layers.py @@ -58,6 +58,18 @@ def computeOps(self): return mul1 + neg + exp + add + div + mul2 +class GELUGradLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + def computeOps(self): + size = self.mapper.parser.operatorRepresentation['size'] + ops_per_element = 9 + gelu_grad_ops = size * ops_per_element + return gelu_grad_ops + + class iHardswishLayer(ONNXLayer): def __init__(self, maps: List[NodeMapper]): @@ -438,6 +450,12 @@ def computeOps(self): return compAverage + compNormalize + compSqr + compSum + compSqrt + compDiv +class LayerNormGradLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + class TransposeLayer(ONNXLayer): def __init__(self, maps: List[NodeMapper]): diff --git a/Deeploy/Targets/Generic/Parsers.py b/Deeploy/Targets/Generic/Parsers.py index f63bb5411d..4c602367e9 100644 --- a/Deeploy/Targets/Generic/Parsers.py +++ b/Deeploy/Targets/Generic/Parsers.py @@ -770,6 +770,33 @@ def parseNodeCtxt(self, return ctxt, True +class GELUGradParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + + ret = all([len(node.inputs) == 2, len(node.outputs) == 1]) + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + upstream_grad = ctxt.lookup(node.inputs[0].name) + gelu_input = ctxt.lookup(node.inputs[1].name) + gelu_grad = ctxt.lookup(node.outputs[0].name) + + self.operatorRepresentation['grad_in'] = upstream_grad.name + self.operatorRepresentation['data_in'] = gelu_input.name + self.operatorRepresentation['grad_out'] = gelu_grad.name + self.operatorRepresentation['size'] = np.prod(upstream_grad.shape) + + return ctxt, True + + class RQSiGELUParser(GELUParser): def __init__(self): @@ -1647,6 +1674,36 @@ def parseNodeCtxt(self, return ctxt, True +class LayerNormGradParser(iLayerNormParser): + + def parseNode(self, node: gs.Node) -> (bool): + + ret = all(['epsilon' in node.attrs, len(node.inputs) == 4, len(node.outputs) == 1]) + + if ret: + self.operatorRepresentation['epsilon'] = node.attrs['epsilon'] + + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + inputs = ['grad_in', 'data_in', 'weight', 'bias'] + outputs = ['grad_out'] + + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + for idx, outputNode in enumerate(node.outputs): + self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name + + self.operatorRepresentation['size'] = np.prod(ctxt.lookup(node.inputs[0].name).shape) + self.operatorRepresentation['lastDimLength'] = ctxt.lookup(node.inputs[0].name).shape[-1] + + return ctxt, True + + class MatMulParser(NodeParser): def __init__(self, noBiasHoisting = True): diff --git a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py index b881529f7e..146bcf699e 100644 --- a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py +++ b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py @@ -676,8 +676,8 @@ def _split_transposes_fun(graph: gs.Graph, match: Match, name: str): inputNode.outputs = [postSplitOutput] for node in originalNode.outputs.copy(): - nodeName = node.name + f"_transpose_in" - varName = node.name + f"_transpose_in_var" + nodeName = f"{t1.name}_{node.name}_transpose_in" + varName = f"{t1.name}_{node.name}_transpose_in_var" newOutput = gs.Variable(name = varName, dtype = np.float32, shape = t1.outputs[0].shape) transposeNode = gs.Node(name = nodeName, diff --git a/Deeploy/Targets/PULPOpen/Bindings.py b/Deeploy/Targets/PULPOpen/Bindings.py index 35e7230fb8..d9a2dc254e 100644 --- a/Deeploy/Targets/PULPOpen/Bindings.py +++ b/Deeploy/Targets/PULPOpen/Bindings.py @@ -415,10 +415,22 @@ PointerClass(float32_t)], [PointerClass(float32_t)]), FloatLayernormTemplate.referenceTemplate, ForkTransformer) +PULPLayernormGradBinding = NodeBinding( + LayerNormChecker( + [PointerClass(float32_t), + PointerClass(float32_t), + PointerClass(float32_t), + PointerClass(float32_t)], [PointerClass(float32_t)]), FloatLayernormTemplate.referenceGradTemplate, + ForkTransformer) + PULPFloatGELUBinding = NodeBinding( GELUChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), FloatGELUTemplate.referenceTemplate, ForkTransformer) +PULPFloatGELUGradBinding = NodeBinding( + GELUChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), + FloatGELUTemplate.referenceGradTemplate, ForkTransformer) + PULPGatherBindings = [ NodeBinding(GatherChecker([PointerClass(float32_t), PointerClass(type)], [PointerClass(float32_t)]), GatherTemplate.referenceTemplate, ForkTransformer) for type in IntegerDataTypes diff --git a/Deeploy/Targets/PULPOpen/Platform.py b/Deeploy/Targets/PULPOpen/Platform.py index 133670da02..b822b4f41d 100644 --- a/Deeploy/Targets/PULPOpen/Platform.py +++ b/Deeploy/Targets/PULPOpen/Platform.py @@ -13,17 +13,18 @@ from Deeploy.MemoryLevelExtension.NetworkDeployers.MemoryLevelDeployer import MemoryPlatform, MemoryPlatformWrapper from Deeploy.Targets.Generic.Bindings import BasicGEMMBindings, BasicPad1DBindings, BasicPad2DBindings, \ BasicRQIntegerDivBinding -from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELULayer, GEMMLayer, \ - LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \ - ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, RQSiHardswishLayer, SGDLayer, \ - SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, SoftmaxGradLayer, SoftmaxLayer, \ - TransposeLayer, iHardswishLayer, iRMSNormLayer +from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELUGradLayer, GELULayer, \ + GEMMLayer, LayerNormGradLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, \ + ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \ + RQSiHardswishLayer, SGDLayer, SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, \ + SoftmaxGradLayer, SoftmaxLayer, TransposeLayer, iHardswishLayer, iRMSNormLayer from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, DequantParser, FlattenParser, GatherParser, \ - GELUParser, GEMMParser, LayerNormParser, MatMulParser, MaxPool2DParser, MulParser, Pad1DParser, Pad2DParser, \ - QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQAddParser, \ - RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, SGDParser, SliceParser, \ - SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, SoftmaxParser, \ - TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser + GELUGradParser, GELUParser, GEMMParser, LayerNormGradParser, LayerNormParser, MatMulParser, MaxPool2DParser, \ + MulParser, Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \ + RequantShiftParser, ReshapeParser, RQAddParser, RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, \ + SGDParser, SliceParser, SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, \ + SoftmaxParser, TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, \ + iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, IntegerDivRequantMergePass, \ MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, QuantPatternPass, RQSSplitPass, \ @@ -37,14 +38,15 @@ from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \ PULPConv2DTilingReadyBindings, PULPDWConv2DTilingReadyBindings, PULPFlattenTilingReadyBindings, \ - PULPFPGELUTilingReadyBindings, PULPFPGEMMTilingReadyBindings, PULPGatherTilingReadyBindings, \ - PULPiHardswishTilingReadyBindings, PULPiRMSNormTilingReadyBindings, PULPiRQSGELUTilingReadyBindings, \ - PULPLayernormTilingReadyBindings, PULPMatMulTilingReadyBindings, PULPMaxPool2DTilingReadyBindings, \ - PULPMulTilingReadyBindings, PULPReduceMeanTilingReadyBindings, PULPReduceSumTilingReadyBindings, \ - PULPReluTilingReadyBindings, PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, \ - PULPRQSDWConv2DTilingReadyBindings, PULPRQSGEMMTilingReadyBindings, PULPRQSiHardswishTilingReadyBindings, \ - PULPRQSMatrixVecTilingReadyBindings, PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, \ - PULPSGDTilingReadyBindings, PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \ + PULPFPGELUGradTilingReadyBindings, PULPFPGELUTilingReadyBindings, PULPFPGEMMTilingReadyBindings, \ + PULPGatherTilingReadyBindings, PULPiHardswishTilingReadyBindings, PULPiRMSNormTilingReadyBindings, \ + PULPiRQSGELUTilingReadyBindings, PULPLayernormGradTilingReadyBindings, PULPLayernormTilingReadyBindings, \ + PULPMatMulTilingReadyBindings, PULPMaxPool2DTilingReadyBindings, PULPMulTilingReadyBindings, \ + PULPReduceMeanTilingReadyBindings, PULPReduceSumTilingReadyBindings, PULPReluTilingReadyBindings, \ + PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, PULPRQSDWConv2DTilingReadyBindings, \ + PULPRQSGEMMTilingReadyBindings, PULPRQSiHardswishTilingReadyBindings, PULPRQSMatrixVecTilingReadyBindings, \ + PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, PULPSGDTilingReadyBindings, \ + PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \ PULPSoftmaxCrossEntropyTilingReadyBindings, PULPSoftmaxGradTilingReadyBindings, PULPSoftmaxTilingReadyBindings, \ PULPTransposeTilingReadyBindings, PULPUniformRQSTilingReadyBindings from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPAddRequantMergePass, \ @@ -54,6 +56,7 @@ AddMapper = NodeMapper(AddParser(), PULPAddTilingReadyBindings) FlattenMapper = NodeMapper(FlattenParser(), PULPFlattenTilingReadyBindings) GELUMapper = NodeMapper(GELUParser(), PULPFPGELUTilingReadyBindings) +GELUGradMapper = NodeMapper(GELUGradParser(), PULPFPGELUGradTilingReadyBindings) GatherMapper = NodeMapper(GatherParser(), PULPGatherTilingReadyBindings) MulMapper = NodeMapper(MulParser(), PULPMulTilingReadyBindings) Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings) @@ -83,6 +86,7 @@ TallGEMMMapper = NodeMapper(PULPTallGEMMParser(), PULPRQSTallGEMMTilingReadyBindings) MaxPool2DMapper = NodeMapper(MaxPool2DParser(), PULPMaxPool2DTilingReadyBindings) LayerNormMapper = NodeMapper(LayerNormParser(), PULPLayernormTilingReadyBindings) +LayerNormGradMapper = NodeMapper(LayerNormGradParser(), PULPLayernormGradTilingReadyBindings) ReluMapper = NodeMapper(ReluParser(), PULPReluTilingReadyBindings) SoftmaxMapper = NodeMapper(SoftmaxParser(), PULPSoftmaxTilingReadyBindings) SoftmaxGradMapper = NodeMapper(SoftmaxGradParser(), PULPSoftmaxGradTilingReadyBindings) @@ -111,7 +115,9 @@ 'RequantizedGemm': PULPRQSGEMMLayer([MatrixVecMapper, TallGEMMMapper, GEMMMapper]), 'Gemm': GEMMLayer([FloatGEMMMapper, GEMMDequantMapper]), 'Gelu': GELULayer([GELUMapper]), + 'GeluGrad': GELUGradLayer([GELUGradMapper]), 'LayerNormalization': LayerNormLayer([LayerNormMapper]), + 'LayerNormalizationGrad': LayerNormGradLayer([LayerNormGradMapper]), 'MaxPool': MaxPoolLayer([MaxPool2DMapper]), 'RequantizediGELU': RQSiGELULayer([RQGELU_int8_Mapper]), 'RQIntegerDiv': RQIntegerDivLayer([RQIntegerDivMapper]), diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatGELUTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatGELUTemplate.py index df2a178662..701d102590 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatGELUTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatGELUTemplate.py @@ -7,4 +7,14 @@ referenceTemplate = NodeTemplate(""" // GELU (Name: ${nodeName}, Op: ${nodeOp}) PULP_GELU_fp${data_in_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}(${data_in}, ${data_out}, ${size}); +""") + +referenceGradTemplate = NodeTemplate(""" +// GELU Parallel (Name: ${nodeName}, Op: ${nodeOp}) +int8_t ${nodeName}_core_id = pi_core_id(); +int8_t ${nodeName}_log2Core = log2(NUM_CORES); +int16_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0); +int16_t ${nodeName}_chunk_start = MIN(${nodeName}_chunk*${nodeName}_core_id, ${size}); +int16_t ${nodeName}_chunk_stop = MIN(${nodeName}_chunk_start + ${nodeName}_chunk, ${size}); +GELU_fp${data_in_type.referencedType.typeWidth}_fp${grad_out_type.referencedType.typeWidth}_sigmoid_grad_chunk(${grad_in}, ${data_in}, ${grad_out}, ${nodeName}_chunk_start, ${nodeName}_chunk_stop); """) \ No newline at end of file diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py index d007e60df0..59499706e5 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py @@ -2,16 +2,42 @@ # # SPDX-License-Identifier: Apache-2.0 -from Deeploy.DeeployTypes import NodeTemplate +from typing import Dict, List, Tuple -referenceTemplate = NodeTemplate(""" +from Deeploy.AbstractDataTypes import float32_tPtr +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class PULPFloatGEMMTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + + if 'C' not in operatorRepresentation or operatorRepresentation['C'] is None: + # No bias case - set C to NULL and provide a default type + operatorRepresentation['C'] = None + operatorRepresentation['C_type'] = float32_tPtr # Default to fp32 type + operatorRepresentation['C_batched'] = False + + return ctxt, operatorRepresentation, [] + + +referenceTemplate = PULPFloatGEMMTemplate(""" // GEMM (Name: ${nodeName}, Op: ${nodeOp}) ${A_type.typeName} ref_${data_out}_${A} = ${A}; ${B_type.typeName} ref_${data_out}_${B} = ${B}; +% if C is not None: ${C_type.typeName} ref_${data_out}_${C} = ${C}; +% else: +${C_type.typeName} ref_${data_out}_C = NULL; +% endif ${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out}; for(uint32_t i=0; i<${batch}; i++){ + % if C is not None: PULP_Gemm_fp${A_type.referencedType.typeWidth}_fp${B_type.referencedType.typeWidth}_fp${C_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}( ref_${data_out}_${A}, ref_${data_out}_${B}, @@ -23,7 +49,19 @@ ${transA}, ${transB} ); - + % else: + PULP_Gemm_fp${A_type.referencedType.typeWidth}_fp${B_type.referencedType.typeWidth}_fp${C_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}( + ref_${data_out}_${A}, + ref_${data_out}_${B}, + NULL, + ref_${data_out}_${data_out}, + ${M}, + ${N}, + ${O}, + ${transA}, + ${transB} + ); + % endif % if A_batched: ref_${data_out}_${A} += ${M} * ${N}; % endif @@ -32,7 +70,7 @@ ref_${data_out}_${B} += ${N} * ${O}; % endif - % if C_batched: + % if C is not None and C_batched: ref_${data_out}_${C} += ${M} * ${O}; % endif diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatLayernormTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatLayernormTemplate.py index 9d4f60e8fc..315481741e 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatLayernormTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatLayernormTemplate.py @@ -15,4 +15,38 @@ ${size}, ${lastDimLength} ); +""") + +referenceGradTemplate = NodeTemplate(""" +// FloatLayernormGrad Parallel (Name: ${nodeName}, Op: ${nodeOp}) + +int8_t ${nodeName}_core_id = pi_core_id(); +int8_t ${nodeName}_log2Core = log2(NUM_CORES); + +int32_t ${nodeName}_seq_length = ${size} / ${lastDimLength}; +int32_t ${nodeName}_chunk = (${nodeName}_seq_length >> ${nodeName}_log2Core) + + ((${nodeName}_seq_length & (NUM_CORES-1)) != 0); +int32_t ${nodeName}_start = MIN(${nodeName}_chunk * ${nodeName}_core_id, ${nodeName}_seq_length); +int32_t ${nodeName}_end = MIN(${nodeName}_start + ${nodeName}_chunk, ${nodeName}_seq_length); + +int32_t ${nodeName}_elem_start = ${nodeName}_start * ${lastDimLength}; +int32_t ${nodeName}_elem_end = ${nodeName}_end * ${lastDimLength}; +int32_t ${nodeName}_elem_count = ${nodeName}_elem_end - ${nodeName}_elem_start; + +const float${grad_in_type.referencedType.typeWidth}_t* ${nodeName}_grad_in_ptr = ${grad_in} + ${nodeName}_elem_start; +const float${data_in_type.referencedType.typeWidth}_t* ${nodeName}_data_in_ptr = ${data_in} + ${nodeName}_elem_start; +float${grad_out_type.referencedType.typeWidth}_t* ${nodeName}_grad_out_ptr = ${grad_out} + ${nodeName}_elem_start; + +if (${nodeName}_elem_count > 0) { + LayernormGrad_fp${grad_in_type.referencedType.typeWidth}_fp${grad_out_type.referencedType.typeWidth}( + ${nodeName}_grad_in_ptr, // Upstream gradient (dy) + ${nodeName}_data_in_ptr, // Original input (x) + ${nodeName}_grad_out_ptr, // Output gradient (dx) + ${weight}, // Input Scale parameter + ${bias}, // Input Bias parameter + ${epsilon}, // Epsilon for numerical stability + ${nodeName}_elem_count, // Number of elements to process + ${lastDimLength} // Size of the feature dimension + ); +} """) \ No newline at end of file diff --git a/Deeploy/Targets/PULPOpen/Templates/SGDTemplate.py b/Deeploy/Targets/PULPOpen/Templates/SGDTemplate.py index b209a76653..418b41aadf 100644 --- a/Deeploy/Targets/PULPOpen/Templates/SGDTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/SGDTemplate.py @@ -5,16 +5,45 @@ from Deeploy.DeeployTypes import NodeTemplate referenceTemplate = NodeTemplate(""" -// SGD Weight Update (Name: ${nodeName}, Op: ${nodeOp}) -BEGIN_SINGLE_CORE - ${weight_type.typeName} ref_${weight} = ${weight}; - ${grad_type.typeName} ref_${grad} = ${grad}; - ${weight_type.typeName} ref_${weight_updated} = ${weight_updated}; +// SGD Weight Update with Separated Multiplication and Subtraction Unrolling +// (Name: ${nodeName}, Op: ${nodeOp}) +int8_t ${nodeName}_core_id = pi_core_id(); +int8_t ${nodeName}_log2Core = log2(NUM_CORES); +int32_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0); +int32_t ${nodeName}_chunk_start = MIN(${nodeName}_chunk*${nodeName}_core_id, ${size}); +int32_t ${nodeName}_chunk_stop = MIN(${nodeName}_chunk_start + ${nodeName}_chunk, ${size}); - float32_t learning_rate = ${lr}; +${weight_type.typeName} ref_${weight} = ${weight}; +${grad_type.typeName} ref_${grad} = ${grad}; +${weight_type.typeName} ref_${weight_updated} = ${weight_updated}; - for (uint32_t i=0; i<${size}; ++i) { - ref_${weight_updated}[i] = ref_${weight}[i] - learning_rate * ref_${grad}[i]; - } -END_SINGLE_CORE -""") +float32_t learning_rate = ${lr}; + +// Temporary buffer for multiplication results +float32_t temp_mul[6]; + +uint32_t i = ${nodeName}_chunk_start; +for (; i+5 < ${nodeName}_chunk_stop; i+=6) { + // Unrolled multiplication operations + temp_mul[0] = learning_rate * ref_${grad}[i]; + temp_mul[1] = learning_rate * ref_${grad}[i+1]; + temp_mul[2] = learning_rate * ref_${grad}[i+2]; + temp_mul[3] = learning_rate * ref_${grad}[i+3]; + temp_mul[4] = learning_rate * ref_${grad}[i+4]; + temp_mul[5] = learning_rate * ref_${grad}[i+5]; + + // Unrolled subtraction operations + ref_${weight_updated}[i] = ref_${weight}[i] - temp_mul[0]; + ref_${weight_updated}[i+1] = ref_${weight}[i+1] - temp_mul[1]; + ref_${weight_updated}[i+2] = ref_${weight}[i+2] - temp_mul[2]; + ref_${weight_updated}[i+3] = ref_${weight}[i+3] - temp_mul[3]; + ref_${weight_updated}[i+4] = ref_${weight}[i+4] - temp_mul[4]; + ref_${weight_updated}[i+5] = ref_${weight}[i+5] - temp_mul[5]; +} + +// Handle remaining elements +for (; i < ${nodeName}_chunk_stop; i++) { + float32_t temp_grad = learning_rate * ref_${grad}[i]; + ref_${weight_updated}[i] = ref_${weight}[i] - temp_grad; +} +""") \ No newline at end of file diff --git a/Deeploy/Targets/PULPOpen/TileConstraints/GEMMTileConstraint.py b/Deeploy/Targets/PULPOpen/TileConstraints/GEMMTileConstraint.py index 2f747a4002..f913b13a2e 100644 --- a/Deeploy/Targets/PULPOpen/TileConstraints/GEMMTileConstraint.py +++ b/Deeploy/Targets/PULPOpen/TileConstraints/GEMMTileConstraint.py @@ -196,11 +196,19 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw # Get to-be-tiled tensor's buffers bufferA = ctxt.lookup(name = parseDict['A']) bufferB = ctxt.lookup(name = parseDict['B']) - bufferC = ctxt.lookup(name = parseDict['C']) outputBuffer = ctxt.lookup(name = parseDict['data_out']) # Add I/O dimensions to the model as variables - for bufferName in [bufferA.name, bufferB.name, bufferC.name, outputBuffer.name]: + has_bias = 'C' in parseDict and parseDict['C'] is not None + bufferC = None + if has_bias: + bufferC = ctxt.lookup(name = parseDict['C']) + + buffer_names = [bufferA.name, bufferB.name, outputBuffer.name] + if has_bias: + buffer_names.append(bufferC.name) + + for bufferName in buffer_names: tilerModel.addTensorDimToModel(ctxt, bufferName) dimOffsetA = len(bufferA.shape) - 2 @@ -223,10 +231,13 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw # Add GEMM Geometrical constraints tilerModel.addConstraint(ASecondDimVar == BFirstDimVar) - addDimVar_1 = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = 0) - addDimVar_2 = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = 1) - tilerModel.addConstraint(outputFirstDimVar == addDimVar_1) - tilerModel.addConstraint(outputSecondDimVar == addDimVar_2) + # Add bias constraints only if bias is present + if has_bias: + dimOffsetC = len(bufferC.shape) - 2 + addDimVar_1 = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = dimOffsetC) + addDimVar_2 = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = dimOffsetC + 1) + tilerModel.addConstraint(outputFirstDimVar == addDimVar_1) + tilerModel.addConstraint(outputSecondDimVar == addDimVar_2) return tilerModel @@ -262,23 +273,29 @@ def serializeTilingSolution( cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], targetMemLevel: str, ctxt: NetworkContext, operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: - outputCubes = [cube.rectangle for cube in absoluteOutputCubes] - addrNames = ['A', 'B', 'C', 'data_out'] + outputCubes = [ + HyperRectangle(tuple(cube.rectangle.offset), tuple(cube.rectangle.dims)) for cube in absoluteOutputCubes + ] + + has_bias = 'C' in operatorRepresentation and operatorRepresentation['C'] is not None + + addrNames = ['A', 'B', 'data_out'] + if has_bias: + addrNames.insert(2, 'C') + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, operatorRepresentation, addrNames) transA = operatorRepresentation['transA'] transB = operatorRepresentation['transB'] - buffA = ctxt.lookup(operatorRepresentation['A']) - buffB = ctxt.lookup(operatorRepresentation['B']) - buffC = ctxt.lookup(operatorRepresentation['C']) + varA = operatorRepresentation['A'] if transA == 0: - NSize = buffA.shape[-1] + NSize = ctxt.lookup(varA).shape[-1] else: - NSize = buffA.shape[-2] + NSize = ctxt.lookup(varA).shape[-2] NOffset = 0 @@ -288,65 +305,43 @@ def serializeTilingSolution( replacements = {"M": [], "O": [], "batch": []} - # Every output is constructed by a pair of inputs. Reconstruct this pair. for cube in outputCubes: - MOffset, OOffset = cube.offset[-2:] - MSize, OSize = cube.dims[-2:] - if len(cube.offset) > 2: - BatchSize = math.prod(cube.dims[:-2]) - - if len(cube.offset) > 3: - assert all(off == 0 for off in cube.offset[:-3]), ( - f"Unsupported tiling across leading batch dims: offsets={cube.offset}. " - "Only the last batch dim (besides M/O) may be tiled.") + BSize = 1 + BOffset = 0 + BatchSize = 1 + BatchOffset = 0 + + if len(cube.offset) == 2: + (MOffset, OOffset) = cube.offset + (MSize, OSize) = cube.dims + elif len(cube.offset) == 3: + (BatchOffset, MOffset, OOffset) = cube.offset + (BatchSize, MSize, OSize) = cube.dims else: - BatchSize = 1 + (BatchOffset, BOffset, MOffset, OOffset) = cube.offset + (BatchSize, BSize, MSize, OSize) = cube.dims replacements["M"].append(MSize) replacements["O"].append(OSize) - replacements["batch"].append(BatchSize) + replacements["batch"].append(BSize) if transA == 0: - AMatrixOffsets = (MOffset, NOffset) - AMatrixShape = (MSize, NSize) + ACube = HyperRectangle((BatchOffset, BOffset, MOffset, NOffset), (BatchSize, BSize, MSize, NSize)) else: - AMatrixOffsets = (NOffset, MOffset) - AMatrixShape = (NSize, MSize) - - if len(buffA.shape) > 2: - batchDimCount = len(buffA.shape) - 2 - AMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + AMatrixOffsets - AMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + AMatrixShape - - ACube = HyperRectangle(AMatrixOffsets, AMatrixShape) - inputACubes.append(ACube) + ACube = HyperRectangle((BatchOffset, BOffset, NOffset, MOffset), (BatchSize, BSize, NSize, MSize)) if transB == 0: - BMatrixOffsets = (NOffset, OOffset) - BMatrixShape = (NSize, OSize) + BCube = HyperRectangle((BatchOffset, BOffset, NOffset, OOffset), (BatchSize, BSize, NSize, OSize)) else: - BMatrixOffsets = (OOffset, NOffset) - BMatrixShape = (OSize, NSize) - - if len(buffB.shape) > 2: - batchDimCount = len(buffB.shape) - 2 - BMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + BMatrixOffsets - BMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + BMatrixShape + BCube = HyperRectangle((BatchOffset, BOffset, OOffset, NOffset), (BatchSize, BSize, OSize, NSize)) - BCube = HyperRectangle(BMatrixOffsets, BMatrixShape) + inputACubes.append(ACube) inputBCubes.append(BCube) - CMatrixOffsets = (MOffset, OOffset) - CMatrixShape = (MSize, OSize) - - if len(buffC.shape) > 2: - batchDimCount = len(buffC.shape) - 2 - CMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + CMatrixOffsets - CMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + CMatrixShape - - CCube = HyperRectangle(CMatrixOffsets, CMatrixShape) - inputAddCubes.append(CCube) + if has_bias: + CCube = HyperRectangle(tuple(cube.offset), tuple(cube.dims)) + inputAddCubes.append(CCube) inputLoadSchedule = [] outputLoadSchedule = [] @@ -360,8 +355,12 @@ def serializeTilingSolution( "batch": PointerClass(uint8_t) } - for a, b, c in zip(inputACubes, inputBCubes, inputAddCubes): - inputLoadSchedule.append({"A": a, "B": b, "C": c}) + if has_bias: + for a, b, c in zip(inputACubes, inputBCubes, inputAddCubes): + inputLoadSchedule.append({"A": a, "B": b, "C": c}) + else: + for a, b in zip(inputACubes, inputBCubes): + inputLoadSchedule.append({"A": a, "B": b}) for out in outputCubes: outputLoadSchedule.append({"data_out": out}) diff --git a/Deeploy/Targets/PULPOpen/TileConstraints/GeluTileConstraint.py b/Deeploy/Targets/PULPOpen/TileConstraints/GeluTileConstraint.py new file mode 100644 index 0000000000..3b7b284706 --- /dev/null +++ b/Deeploy/Targets/PULPOpen/TileConstraints/GeluTileConstraint.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.Targets.Generic.TileConstraints.BOPTileConstraint import BOPTileConstraint + + +class GeluGradTileConstraint(BOPTileConstraint): + + dataIn1Name = 'grad_in' + dataIn2Name = 'data_in' + dataOutName = 'grad_out' \ No newline at end of file diff --git a/Deeploy/Targets/PULPOpen/TileConstraints/LayernormTileConstraint.py b/Deeploy/Targets/PULPOpen/TileConstraints/LayernormTileConstraint.py index 5f43ad7534..c3593ee6f0 100644 --- a/Deeploy/Targets/PULPOpen/TileConstraints/LayernormTileConstraint.py +++ b/Deeploy/Targets/PULPOpen/TileConstraints/LayernormTileConstraint.py @@ -78,3 +78,82 @@ def serializeTilingSolution( variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) return variableReplacementSchedule, tilingSchedule + + +class LayernormGradTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + grad_in_buffer_name = parseDict['grad_in'] + data_in_buffer_name = parseDict['data_in'] + weight_buffer_name = parseDict['weight'] + bias_buffer_name = parseDict['bias'] + grad_out_buffer_name = parseDict['grad_out'] + + for buffer_name in [ + grad_in_buffer_name, data_in_buffer_name, weight_buffer_name, bias_buffer_name, grad_out_buffer_name + ]: + tilerModel.addTensorDimToModel(ctxt, buffer_name) + + input_shape = ctxt.lookup(data_in_buffer_name).shape + last_dim_idx = len(input_shape) - 1 + last_dim_len = input_shape[-1] + + tilerModel.addConstraint( + tilerModel.getTensorDimVar(tensorName = data_in_buffer_name, dimIdx = last_dim_idx) == last_dim_len) + + tilerModel.addConstraint( + tilerModel.getTensorDimVar(tensorName = data_in_buffer_name, dimIdx = last_dim_idx) == + tilerModel.getTensorDimVar(tensorName = weight_buffer_name, dimIdx = 0)) + + tilerModel.addConstraint( + tilerModel.getTensorDimVar(tensorName = data_in_buffer_name, dimIdx = last_dim_idx) == + tilerModel.getTensorDimVar(tensorName = bias_buffer_name, dimIdx = 0)) + + for idx, dim in enumerate(input_shape): + tilerModel.addConstraint( + tilerModel.getTensorDimVar(tensorName = data_in_buffer_name, dimIdx = idx) == + tilerModel.getTensorDimVar(tensorName = grad_in_buffer_name, dimIdx = idx)) + + for idx, dim in enumerate(input_shape): + tilerModel.addConstraint( + tilerModel.getTensorDimVar(tensorName = data_in_buffer_name, dimIdx = idx) == + tilerModel.getTensorDimVar(tensorName = grad_out_buffer_name, dimIdx = idx)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + + output_cubes = [cube.rectangle for cube in absoluteOutputCubes] + addr_names = ['grad_in', 'data_in', 'weight', 'bias', 'grad_out'] + input_base_offsets, output_base_offsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addr_names) + + replacements = {"size": []} + replacement_types = {"size": PointerClass(uint16_t)} + + input_load_schedule = [] + output_load_schedule = [] + + for cube in output_cubes: + new_size = np.prod(cube.dims) + replacements["size"].append(new_size) + + feature_size = cube.dims[-1] + + weight_cube = HyperRectangle((0,), (feature_size,)) + bias_cube = HyperRectangle((0,), (feature_size,)) + + input_load_schedule.append({"grad_in": cube, "data_in": cube, "weight": weight_cube, "bias": bias_cube}) + + output_load_schedule.append({"grad_out": cube}) + + tiling_schedule = TilingSchedule(input_base_offsets, output_base_offsets, input_load_schedule, + output_load_schedule) + variable_replacement_schedule = VariableReplacementScheme(replacements, replacement_types) + + return variable_replacement_schedule, tiling_schedule diff --git a/Deeploy/Targets/PULPOpen/TileConstraints/ReduceSumTileConstraint.py b/Deeploy/Targets/PULPOpen/TileConstraints/ReduceSumTileConstraint.py new file mode 100644 index 0000000000..cd404dde73 --- /dev/null +++ b/Deeploy/Targets/PULPOpen/TileConstraints/ReduceSumTileConstraint.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple, Union + +from ortools.constraint_solver.pywrapcp import IntVar + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, TilingSchedule, VariableReplacementScheme + + +class ReduceSumTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + outputBufferName = parseDict['data_out'] + + inputBuffer = ctxt.lookup(inputBufferName) + outputBuffer = ctxt.lookup(outputBufferName) + + inputShapeLen = len(inputBuffer.shape) + outputShapeLen = len(outputBuffer.shape) + + # Add I/O dimensions to the model as variables + for bufferName in [inputBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + # For ReduceSum, we need to handle dimension reduction + # If keepdims=True, all dimensions should match (reduced dims become 1) + # If keepdims=False, reduced dimensions are removed from output + + keepdims = parseDict.get('keepdims', True) # Default to True if not specified + + if keepdims: + # keepdims=True: output has same number of dimensions as input + if inputShapeLen == outputShapeLen: + for idx in range(inputShapeLen): + outputDim = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = idx) + inputDim = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = idx) + + # For reduced dimensions, output should be 1 + if 'axis' in parseDict: + axis = parseDict['axis'] + if isinstance(axis, int): + axis = [axis] + + # Handle negative axis indexing + normalized_axis = [] + for ax in axis: + if ax < 0: + ax = inputShapeLen + ax + normalized_axis.append(ax) + + if idx in normalized_axis: + # This dimension is reduced, output should be 1 + tilerModel.addConstraint(outputDim == 1) + else: + # This dimension is preserved + tilerModel.addConstraint(outputDim == inputDim) + else: + # No axis specified, all dimensions are reduced to 1 + tilerModel.addConstraint(outputDim == 1) + else: + raise ValueError("With keepdims=True, input and output should have same number of dimensions") + + else: + # keepdims=False: reduced dimensions are removed from output + if 'axis' in parseDict: + axis = parseDict['axis'] + if isinstance(axis, int): + axis = [axis] + + # Handle negative axis indexing + normalized_axis = [] + for ax in axis: + if ax < 0: + ax = inputShapeLen + ax + normalized_axis.append(ax) + normalized_axis = sorted(normalized_axis) + + # Expected output shape length + expected_output_len = inputShapeLen - len(normalized_axis) + + if outputShapeLen != expected_output_len: + raise ValueError(f"With keepdims=False and axis={axis}, expected output to have " + f"{expected_output_len} dimensions, but got {outputShapeLen}") + + # Map input dimensions to output dimensions (skipping reduced ones) + output_idx = 0 + for input_idx in range(inputShapeLen): + if input_idx not in normalized_axis: + # This dimension is preserved + outputDim = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = output_idx) + inputDim = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = input_idx) + tilerModel.addConstraint(outputDim == inputDim) + output_idx += 1 + + else: + # No axis specified - global reduction, output should be scalar + # In many frameworks, scalar outputs are represented as 1D tensors with size 1 + # or as 0D tensors (empty shape) + if outputShapeLen == 0: + # True scalar (0D tensor) - nothing to constrain + pass + elif outputShapeLen == 1: + # 1D tensor with size 1 representing scalar + outputDim = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 0) + tilerModel.addConstraint(outputDim == 1) + else: + # Allow other representations but warn about potential issues + # Some frameworks might represent scalars differently + # For now, just ensure all output dimensions are 1 + for idx in range(outputShapeLen): + outputDim = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = idx) + tilerModel.addConstraint(outputDim == 1) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + # No constraints - let the tiler handle dimensions normally + # We'll handle the actual ReduceSum logic in serializeTilingSolution + return tilerModel + + @staticmethod + def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict, + ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]: + + inputBufferName = parseDict['data_in'] + inputBuffer = ctxt.lookup(inputBufferName) + + symbolicParseDict = parseDict.copy() + + # Since we force all dimensions to be full size, we can use the actual shape + # This ensures the template gets the correct dimensions for the single cube + symbolicParseDict['data_in_shape'] = list(inputBuffer.shape) + + # Add axes information (normalized) + if 'axis' in parseDict: + axis = parseDict['axis'] + if isinstance(axis, int): + axes = [axis] + else: + axes = list(axis) + + # Handle negative axis indexing + normalized_axes = [] + for ax in axes: + if ax < 0: + ax = len(inputBuffer.shape) + ax + normalized_axes.append(ax) + + symbolicParseDict['axes'] = normalized_axes + else: + # Global reduction - all axes + symbolicParseDict['axes'] = list(range(len(inputBuffer.shape))) + + # Add keepdims information + symbolicParseDict['keepdims'] = parseDict.get('keepdims', True) + + return symbolicParseDict + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + + # Get original tensor shapes from context + inputBufferName = operatorRepresentation['data_in'] + outputBufferName = operatorRepresentation['data_out'] + inputBuffer = ctxt.lookup(inputBufferName) + outputBuffer = ctxt.lookup(outputBufferName) + + # Use original dimensions for ReduceSum computation + originalInputShape = list(inputBuffer.shape) + originalOutputShape = list(outputBuffer.shape) + + addrNames = ['data_in', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + replacements = {"data_in_shape": [], "axes": [], "keepdims": [], "reduceLength": []} + replacementTypes = { + "data_in_shape": PointerClass(uint32_t), + "axes": PointerClass(uint32_t), + "keepdims": PointerClass(uint32_t), + "reduceLength": PointerClass(uint32_t) + } + + # Get axis and keepdims information from operator representation + # Note: the key might be 'axes' (plural) instead of 'axis' (singular) + axis = operatorRepresentation.get('axis', operatorRepresentation.get('axes', None)) + keepdims = operatorRepresentation.get('keepdims', True) + + # Calculate axes (normalize negative indices) + if axis is not None: + if isinstance(axis, int): + axes = [axis] + else: + axes = list(axis) + + # Handle negative axis indexing + normalized_axes = [] + for ax in axes: + if ax < 0: + ax = len(originalInputShape) + ax + normalized_axes.append(ax) + axes = normalized_axes + else: + # Global reduction - all axes + axes = list(range(len(originalInputShape))) + + # Calculate reduceLength (product of dimensions being reduced) + reduceLength = 1 + for ax in axes: + reduceLength *= originalInputShape[ax] + + # For ReduceSum, we always use the original tensor dimensions + # regardless of how the tiler decides to split them + replacements['data_in_shape'].append(tuple(originalInputShape)) + replacements['axes'].append(tuple(axes)) + replacements['keepdims'].append(1 if keepdims else 0) + replacements['reduceLength'].append(reduceLength) + + # Create scheduling based on original dimensions + inputLoadSchedule = [] + outputLoadSchedule = [] + + # Create HyperRectangles with original dimensions + from Deeploy.TilingExtension.TilingCodegen import HyperRectangle + + inputCube = HyperRectangle(dims = originalInputShape, offset = [0] * len(originalInputShape)) + + outputCube = HyperRectangle(dims = originalOutputShape, offset = [0] * len(originalOutputShape)) + + inputLoadSchedule.append({"data_in": inputCube}) + outputLoadSchedule.append({"data_out": outputCube}) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule diff --git a/Deeploy/Targets/PULPOpen/TileConstraints/iSoftmaxTileConstraint.py b/Deeploy/Targets/PULPOpen/TileConstraints/iSoftmaxTileConstraint.py index 1b5eddb51a..f77834043c 100644 --- a/Deeploy/Targets/PULPOpen/TileConstraints/iSoftmaxTileConstraint.py +++ b/Deeploy/Targets/PULPOpen/TileConstraints/iSoftmaxTileConstraint.py @@ -101,3 +101,88 @@ def serializeTilingSolution( variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) return variableReplacementSchedule, tilingSchedule + + +class SoftmaxGradTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + upstream_grad = parseDict['upstream_grad'] + softmax_output = parseDict['softmax_output'] + softmax_grad = parseDict['softmax_grad'] + + shapeLen = len(ctxt.lookup(upstream_grad).shape) + + for bufferName in [upstream_grad, softmax_output, softmax_grad]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + for idx in range(shapeLen): + upstream_dim = tilerModel.getTensorDimVar(tensorName = upstream_grad, dimIdx = idx) + softmax_out_dim = tilerModel.getTensorDimVar(tensorName = softmax_output, dimIdx = idx) + softmax_grad_dim = tilerModel.getTensorDimVar(tensorName = softmax_grad, dimIdx = idx) + + tilerModel.addConstraint(upstream_dim == softmax_out_dim) + tilerModel.addConstraint(upstream_dim == softmax_grad_dim) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + upstream_grad = parseDict['upstream_grad'] + inputBuffer = ctxt.lookup(upstream_grad) + + lastDimLength = inputBuffer.shape[-1] + lastDimIdx = len(inputBuffer.shape) - 1 + lastDimVar = tilerModel.getTensorDimVar(tensorName = upstream_grad, dimIdx = lastDimIdx) + + tilerModel.addConstraint(lastDimVar == lastDimLength) + + return tilerModel + + @staticmethod + def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict, + ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]: + + upstream_grad = parseDict['upstream_grad'] + inputBuffer = ctxt.lookup(upstream_grad) + + lastDimIdx = len(inputBuffer.shape) - 1 + + symbolicParseDict = parseDict.copy() + symbolicParseDict['lastDimLength'] = tilerModel.getTensorDimVar(upstream_grad, lastDimIdx) + + return symbolicParseDict + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['upstream_grad', 'softmax_output', 'softmax_grad'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + replacements = {"lastDimLength": [], "size": []} + + replacementTypes = {"lastDimLength": PointerClass(uint32_t), "size": PointerClass(uint32_t)} + + for cube in outputCubes: + lastDimLength = cube.dims[-1] + size = np.prod(cube.dims) + + replacements['lastDimLength'].append(lastDimLength) + replacements['size'].append(size) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for out in outputCubes: + inputLoadSchedule.append({"upstream_grad": out, "softmax_output": out}) + outputLoadSchedule.append({"softmax_grad": out}) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule diff --git a/Deeploy/Targets/PULPOpen/Tiler.py b/Deeploy/Targets/PULPOpen/Tiler.py index 6de8ca3000..3d7d11f343 100644 --- a/Deeploy/Targets/PULPOpen/Tiler.py +++ b/Deeploy/Targets/PULPOpen/Tiler.py @@ -14,25 +14,29 @@ from Deeploy.Targets.Generic.TileConstraints.RQSiHardswishTileConstraint import RQSiHardswishTileConstraint from Deeploy.Targets.Generic.TileConstraints.TransposeTileConstraint import TransposeTileConstraint from Deeploy.Targets.Generic.TileConstraints.UnaryTileConstraint import UnaryTileConstraint -from Deeploy.Targets.Generic.TileConstraints.UntiledTileConstraint import UntiledTileConstraint from Deeploy.Targets.PULPOpen.Bindings import PULPAddBindings, PULPConcatBindings, PULPFloatConv2DBindings, \ - PULPFloatDWConv2DBindings, PULPFloatGELUBinding, PULPFloatGEMMBindings, PULPGatherBindings, \ - PULPiHardswishBindings, PULPiRMSNormBindings, PULPiRQSGELUBindings, PULPLayernormBinding, PULPMatMulBindings, \ - PULPMaxPool2DBindings, PULPMulBindings, PULPReduceMeanBindings, PULPReduceSumBindings, PULPReluBinding, \ - PULPReshapeBindings, PULPRQAddBindings, PULPRQSBindings, PULPRQSConv2DBindings, PULPRQSDWConv2DBindings, \ - PULPRQSGEMMBindings, PULPRQSiHardswishBindings, PULPRQSMatrixVecBindings, PULPRQSTallGEMMBindings, \ - PULPSGDBindings, PULPSliceBindings, PULPSoftmaxBindings, PULPSoftmaxCrossEntropyLossBindings, \ - PULPSoftmaxCrossEntropyLossGradBindings, PULPSoftmaxGradBindings, PULPTransposeBindings, PULPUniformRQSBindings + PULPFloatDWConv2DBindings, PULPFloatGELUBinding, PULPFloatGELUGradBinding, PULPFloatGEMMBindings, \ + PULPGatherBindings, PULPiHardswishBindings, PULPiRMSNormBindings, PULPiRQSGELUBindings, PULPLayernormBinding, \ + PULPLayernormGradBinding, PULPMatMulBindings, PULPMaxPool2DBindings, PULPMulBindings, PULPReduceMeanBindings, \ + PULPReduceSumBindings, PULPReluBinding, PULPReshapeBindings, PULPRQAddBindings, PULPRQSBindings, \ + PULPRQSConv2DBindings, PULPRQSDWConv2DBindings, PULPRQSGEMMBindings, PULPRQSiHardswishBindings, \ + PULPRQSMatrixVecBindings, PULPRQSTallGEMMBindings, PULPSGDBindings, PULPSliceBindings, PULPSoftmaxBindings, \ + PULPSoftmaxCrossEntropyLossBindings, PULPSoftmaxCrossEntropyLossGradBindings, PULPSoftmaxGradBindings, \ + PULPTransposeBindings, PULPUniformRQSBindings from Deeploy.Targets.PULPOpen.TileConstraints.ConvTileConstraint import Conv2DTileConstraint, RQConv2DTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.DWConvTileConstraint import DWConv2DTileConstraint, \ RQDWConv2DTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.GatherTileConstraint import GatherTileConstraint +from Deeploy.Targets.PULPOpen.TileConstraints.GeluTileConstraint import GeluGradTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.GEMMTileConstraint import FloatGEMMTileConstraint, GEMMTileConstraint -from Deeploy.Targets.PULPOpen.TileConstraints.iSoftmaxTileConstraint import iSoftmaxTileConstraint -from Deeploy.Targets.PULPOpen.TileConstraints.LayernormTileConstraint import LayernormTileConstraint +from Deeploy.Targets.PULPOpen.TileConstraints.iSoftmaxTileConstraint import SoftmaxGradTileConstraint, \ + iSoftmaxTileConstraint +from Deeploy.Targets.PULPOpen.TileConstraints.LayernormTileConstraint import LayernormGradTileConstraint, \ + LayernormTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.MatMulTileConstraint import MatMulTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.MaxPoolTileConstraint import MaxPoolCTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.ReduceMeanConstraint import ReduceMeanTileConstraint +from Deeploy.Targets.PULPOpen.TileConstraints.ReduceSumTileConstraint import ReduceSumTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.RequantShiftTileConstraint import RequantShiftTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.SGDTileConstraint import SGDTileConstraint from Deeploy.Targets.PULPOpen.TileConstraints.SliceConstraint import SliceTileConstraint @@ -117,9 +121,15 @@ PULPLayernormTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPLayernormBinding], tileConstraint = LayernormTileConstraint()) +PULPLayernormGradTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPLayernormGradBinding], + tileConstraint = LayernormGradTileConstraint()) + PULPFPGELUTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPFloatGELUBinding], tileConstraint = UnaryTileConstraint()) +PULPFPGELUGradTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPFloatGELUGradBinding], + tileConstraint = GeluGradTileConstraint()) + PULPGatherTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = PULPGatherBindings, tileConstraint = GatherTileConstraint()) @@ -130,10 +140,10 @@ nodeBindings = PULPSoftmaxCrossEntropyLossGradBindings, tileConstraint = SoftmaxCrossEntropyGradTileConstraint()) PULPSoftmaxGradTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = PULPSoftmaxGradBindings, - tileConstraint = UntiledTileConstraint()) + tileConstraint = SoftmaxGradTileConstraint()) PULPReduceSumTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = PULPReduceSumBindings, - tileConstraint = UntiledTileConstraint()) + tileConstraint = ReduceSumTileConstraint()) PULPSGDTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = PULPSGDBindings, tileConstraint = SGDTileConstraint()) diff --git a/Deeploy/TilingExtension/TilingCodegen.py b/Deeploy/TilingExtension/TilingCodegen.py index 604ba23c9d..0974fa337b 100644 --- a/Deeploy/TilingExtension/TilingCodegen.py +++ b/Deeploy/TilingExtension/TilingCodegen.py @@ -31,8 +31,8 @@ def __init__(self, offset: Tuple[int, ...], dims: Tuple[int, ...]): assert len(offset) == len( dims), f"HyperRectangle offset and dims for mismatching dimensions {offset} and {dims}" - self.offset = offset - self.dims = dims + self.offset = tuple(offset) if not isinstance(offset, tuple) else offset + self.dims = tuple(dims) if not isinstance(dims, tuple) else dims @dataclass diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/inputs.npz deleted file mode 100644 index 964a5a3551..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/inputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/network.onnx deleted file mode 100644 index 0216957044..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/network.onnx and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/outputs.npz deleted file mode 100644 index ced669e248..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128/outputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/inputs.npz deleted file mode 100644 index 3916dc0ff0..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/inputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/network.onnx deleted file mode 100644 index e2955db62f..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/network.onnx and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/outputs.npz deleted file mode 100644 index 8bc8dc897a..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_16/outputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/inputs.npz deleted file mode 100644 index 13f2cc4c68..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/inputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/network.onnx deleted file mode 100644 index 5f13ba3560..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/network.onnx and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/outputs.npz deleted file mode 100644 index 62e9a7ca96..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_32/outputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/inputs.npz deleted file mode 100644 index f9f1b89c37..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/inputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/network.onnx deleted file mode 100644 index cd0cbb25cc..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/network.onnx and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/outputs.npz deleted file mode 100644 index c850e02db2..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_64/outputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/inputs.npz deleted file mode 100644 index de4d4f8e06..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/inputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/network.onnx deleted file mode 100644 index 3ea0bcc0ba..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/network.onnx and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/outputs.npz deleted file mode 100644 index f5b9700baa..0000000000 Binary files a/DeeployTest/Tests/testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8/outputs.npz and /dev/null differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT1/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/inputs.npz new file mode 100644 index 0000000000..a9018350f2 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/inputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network.onnx new file mode 100644 index 0000000000..7473d7e5c1 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network_infer.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network_infer.onnx new file mode 100644 index 0000000000..11b0ca1f69 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/network_infer.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT1/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/outputs.npz new file mode 100644 index 0000000000..d2ad678b76 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT1/outputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT2/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/inputs.npz new file mode 100644 index 0000000000..7af9629e9b Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/inputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network.onnx new file mode 100644 index 0000000000..ac9569fb58 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network_infer.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network_infer.onnx new file mode 100644 index 0000000000..366a0be89e Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/network_infer.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_FT2/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/outputs.npz new file mode 100644 index 0000000000..c2850ae68a Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_FT2/outputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LP/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LP/inputs.npz new file mode 100644 index 0000000000..c32b8dfd64 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LP/inputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LP/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LP/network.onnx new file mode 100644 index 0000000000..798e35f96b Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LP/network.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LP/network_infer.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LP/network_infer.onnx new file mode 100644 index 0000000000..2eae9e8d7e Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LP/network_infer.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LP/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LP/outputs.npz new file mode 100644 index 0000000000..bb23f3a08a Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LP/outputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/inputs.npz new file mode 100644 index 0000000000..c4296c01c6 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/inputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network.onnx new file mode 100644 index 0000000000..8f183a9e2c Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network_infer.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network_infer.onnx new file mode 100644 index 0000000000..6cc128149a Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/network_infer.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/outputs.npz new file mode 100644 index 0000000000..e34b4860ed Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA1/outputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/inputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/inputs.npz new file mode 100644 index 0000000000..71d400304c Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/inputs.npz differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network.onnx new file mode 100644 index 0000000000..93a262b786 Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network_infer.onnx b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network_infer.onnx new file mode 100644 index 0000000000..9c5a0963db Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/network_infer.onnx differ diff --git a/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/outputs.npz b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/outputs.npz new file mode 100644 index 0000000000..b134b08d6a Binary files /dev/null and b/DeeployTest/Tests/testTrainCCT/CCT2_LoRA2/outputs.npz differ diff --git a/TargetLibraries/Generic/inc/kernel/GELU.h b/TargetLibraries/Generic/inc/kernel/GELU.h index f7e56e3beb..ffa104b617 100644 --- a/TargetLibraries/Generic/inc/kernel/GELU.h +++ b/TargetLibraries/Generic/inc/kernel/GELU.h @@ -25,4 +25,8 @@ void GELU_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t dataSize); void GELU_fp32_fp32_sigmoid(float32_t *data_in, float32_t *data_out, int32_t dataSize); +void GELU_fp32_fp32_sigmoid_grad_chunk(float32_t *grad_in, float32_t *data_in, + float32_t *grad_out, int32_t start_idx, + int32_t end_idx); + #endif //__DEEPLOY_BASIC_MATH_GELU_KERNEL_HEADER_ diff --git a/TargetLibraries/Generic/inc/kernel/Layernorm.h b/TargetLibraries/Generic/inc/kernel/Layernorm.h index 3ca4f1d0c5..381f184dd6 100644 --- a/TargetLibraries/Generic/inc/kernel/Layernorm.h +++ b/TargetLibraries/Generic/inc/kernel/Layernorm.h @@ -25,4 +25,8 @@ void Layernorm_fp32_fp32(float32_t *data_in, float32_t *data_out, float32_t *scale, float32_t *bias, float32_t epsilon, int32_t size, int32_t lastDimLength); +void LayernormGrad_fp32_fp32(float32_t *grad_in, float32_t *data_in, + float32_t *grad_out, float32_t *scale, + float32_t *bias, float32_t epsilon, int32_t size, + int32_t lastDimLength); #endif //__DEEPLOY_BASIC_MATH_LAYERNORM_KERNEL_HEADER_ diff --git a/TargetLibraries/Generic/src/GELU_fp32.c b/TargetLibraries/Generic/src/GELU_fp32.c index 9dbf15c4a3..6cafed1986 100644 --- a/TargetLibraries/Generic/src/GELU_fp32.c +++ b/TargetLibraries/Generic/src/GELU_fp32.c @@ -30,3 +30,24 @@ void GELU_fp32_fp32_sigmoid(float32_t *data_in, float32_t *data_out, data_out[i] = x * sigmoid; } } + +void GELU_fp32_fp32_sigmoid_grad_chunk(float32_t *grad_in, float32_t *data_in, + float32_t *grad_out, int32_t start_idx, + int32_t end_idx) { + // d(Gelu)/dx ≈ sigmoid(1.702 * x) + x * sigmoid(1.702 * x) * (1 - + // sigmoid(1.702 * x)) * 1.702 + const float COEFF = 1.702f; + for (int32_t i = start_idx; i < end_idx; i++) { + float x = data_in[i]; + float upstream_grad = grad_in[i]; + float z = COEFF * x; + float sigmoid_z = 1.0f / (1.0f + expf(-z)); + + // d(Gelu)/dx = sigmoid(1.702*x) + x * sigmoid(1.702*x) * + // (1-sigmoid(1.702*x)) * 1.702 + float sigmoid_derivative = sigmoid_z * (1.0f - sigmoid_z) * COEFF; + float gelu_derivative = sigmoid_z + x * sigmoid_derivative; + + grad_out[i] = upstream_grad * gelu_derivative; + } +} diff --git a/TargetLibraries/Generic/src/Layernorm_fp32.c b/TargetLibraries/Generic/src/Layernorm_fp32.c index f5337f6154..fb68df8dfe 100644 --- a/TargetLibraries/Generic/src/Layernorm_fp32.c +++ b/TargetLibraries/Generic/src/Layernorm_fp32.c @@ -36,3 +36,58 @@ void Layernorm_fp32_fp32(float32_t *data_in, float32_t *data_out, } } } + +void LayernormGrad_fp32_fp32(float32_t *grad_in, float32_t *data_in, + float32_t *grad_out, float32_t *scale, + float32_t *bias, float32_t epsilon, int32_t size, + int32_t lastDimLength) { + float32_t mean, variance, std, inv_std; + float32_t sum_dy, sum_dy_scaled, sum_dy_scaled_centered; + float32_t centered_input; + + for (int i = 0; i < (size / lastDimLength); i++) { + // RW: Step 1: Recompute mean and variance from forward pass + mean = 0.0f; + variance = 0.0f; + + for (int j = 0; j < lastDimLength; j++) { + mean += data_in[j + i * lastDimLength]; + } + mean = mean / lastDimLength; + + for (int j = 0; j < lastDimLength; j++) { + centered_input = data_in[j + i * lastDimLength] - mean; + variance += centered_input * centered_input; + } + variance = variance / lastDimLength; + variance += epsilon; + std = sqrtf(variance); + inv_std = 1.0f / std; + + // RW: Step 2: Compute intermediate values needed for gradient calculation + sum_dy = 0.0f; + sum_dy_scaled_centered = 0.0f; + + // RW: Calculate sum(dy) and sum(dy * scale * (x - mean) / std) + for (int j = 0; j < lastDimLength; j++) { + sum_dy += grad_in[j + i * lastDimLength]; + centered_input = data_in[j + i * lastDimLength] - mean; + sum_dy_scaled_centered += + grad_in[j + i * lastDimLength] * scale[j] * centered_input * inv_std; + } + + // RW: Step 3: Calculate gradients for each element + for (int j = 0; j < lastDimLength; j++) { + centered_input = data_in[j + i * lastDimLength] - mean; + + // Gradient formula: + // dx = (1/std) * scale * (dy - (1/N)*sum(dy) - + // (x-mean)/(N*std^2)*sum(dy*scale*(x-mean)/std)) + grad_out[j + i * lastDimLength] = + inv_std * scale[j] * + (grad_in[j + i * lastDimLength] - (sum_dy / lastDimLength) - + (centered_input * inv_std * inv_std / lastDimLength) * + sum_dy_scaled_centered); + } + } +} diff --git a/TargetLibraries/PULPOpen/src/Gemm.c b/TargetLibraries/PULPOpen/src/Gemm.c index 826960d4f3..a46f8ac6ae 100644 --- a/TargetLibraries/PULPOpen/src/Gemm.c +++ b/TargetLibraries/PULPOpen/src/Gemm.c @@ -26,15 +26,329 @@ void PULP_Gemm_fp32_fp32_fp32_fp32(const float32_t *__restrict__ pSrcA, return; } - for (uint32_t i = M_start; i < M_end; ++i) { - for (uint32_t j = 0; j < O; ++j) { - float32_t sum = 0.0f; - for (uint32_t k = 0; k < N; ++k) { - uint32_t a_idx = transA ? (k * M + i) : (i * N + k); - uint32_t b_idx = transB ? (j * N + k) : (k * O + j); - sum += pSrcA[a_idx] * pSrcB[b_idx]; - } - pDstY[i * O + j] = sum + pDstC[i * O + j]; + const uint32_t has_bias = (pDstC != NULL); + const uint32_t N_unroll = N - (N % 6); + const uint32_t O_unroll = O - (O % 6); + + if (!transA && !transB) { + + for (uint32_t i = M_start; i < M_end; ++i) { + const float32_t *__restrict__ a_row = &pSrcA[i * N]; + float32_t *__restrict__ y_row = &pDstY[i * O]; + const float32_t *__restrict__ c_row = has_bias ? &pDstC[i * O] : NULL; + + uint32_t j = 0; + + for (; j < O_unroll; j += 6) { + float32_t sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, + sum4 = 0.0f, sum5 = 0.0f; + + uint32_t k = 0; + + for (; k < N; ++k) { + const float32_t a_val = a_row[k]; + sum0 += a_val * pSrcB[k * O + j]; + sum1 += a_val * pSrcB[k * O + j + 1]; + sum2 += a_val * pSrcB[k * O + j + 2]; + sum3 += a_val * pSrcB[k * O + j + 3]; + sum4 += a_val * pSrcB[k * O + j + 4]; + sum5 += a_val * pSrcB[k * O + j + 5]; + } + + if (has_bias) { + y_row[j] = sum0 + c_row[j]; + y_row[j + 1] = sum1 + c_row[j + 1]; + y_row[j + 2] = sum2 + c_row[j + 2]; + y_row[j + 3] = sum3 + c_row[j + 3]; + y_row[j + 4] = sum4 + c_row[j + 4]; + y_row[j + 5] = sum5 + c_row[j + 5]; + } else { + y_row[j] = sum0; + y_row[j + 1] = sum1; + y_row[j + 2] = sum2; + y_row[j + 3] = sum3; + y_row[j + 4] = sum4; + y_row[j + 5] = sum5; + } + } + + for (; j < O; ++j) { + float32_t sum = 0.0f; + for (uint32_t k = 0; k < N; ++k) { + sum += a_row[k] * pSrcB[k * O + j]; + } + + y_row[j] = has_bias ? sum + c_row[j] : sum; + } + } + } else if (transA && !transB) { + + for (uint32_t i = M_start; i < M_end; ++i) { + float32_t *__restrict__ y_row = &pDstY[i * O]; + const float32_t *__restrict__ c_row = has_bias ? &pDstC[i * O] : NULL; + + uint32_t j = 0; + for (; j < O_unroll; j += 6) { + float32_t sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, + sum4 = 0.0f, sum5 = 0.0f; + + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + const float32_t a0 = pSrcA[k * M + i]; + const float32_t a1 = pSrcA[(k + 1) * M + i]; + const float32_t a2 = pSrcA[(k + 2) * M + i]; + const float32_t a3 = pSrcA[(k + 3) * M + i]; + const float32_t a4 = pSrcA[(k + 4) * M + i]; + const float32_t a5 = pSrcA[(k + 5) * M + i]; + + sum0 += a0 * pSrcB[k * O + j] + a1 * pSrcB[(k + 1) * O + j] + + a2 * pSrcB[(k + 2) * O + j] + a3 * pSrcB[(k + 3) * O + j] + + a4 * pSrcB[(k + 4) * O + j] + a5 * pSrcB[(k + 5) * O + j]; + sum1 += a0 * pSrcB[k * O + j + 1] + a1 * pSrcB[(k + 1) * O + j + 1] + + a2 * pSrcB[(k + 2) * O + j + 1] + + a3 * pSrcB[(k + 3) * O + j + 1] + + a4 * pSrcB[(k + 4) * O + j + 1] + + a5 * pSrcB[(k + 5) * O + j + 1]; + sum2 += a0 * pSrcB[k * O + j + 2] + a1 * pSrcB[(k + 1) * O + j + 2] + + a2 * pSrcB[(k + 2) * O + j + 2] + + a3 * pSrcB[(k + 3) * O + j + 2] + + a4 * pSrcB[(k + 4) * O + j + 2] + + a5 * pSrcB[(k + 5) * O + j + 2]; + sum3 += a0 * pSrcB[k * O + j + 3] + a1 * pSrcB[(k + 1) * O + j + 3] + + a2 * pSrcB[(k + 2) * O + j + 3] + + a3 * pSrcB[(k + 3) * O + j + 3] + + a4 * pSrcB[(k + 4) * O + j + 3] + + a5 * pSrcB[(k + 5) * O + j + 3]; + sum4 += a0 * pSrcB[k * O + j + 4] + a1 * pSrcB[(k + 1) * O + j + 4] + + a2 * pSrcB[(k + 2) * O + j + 4] + + a3 * pSrcB[(k + 3) * O + j + 4] + + a4 * pSrcB[(k + 4) * O + j + 4] + + a5 * pSrcB[(k + 5) * O + j + 4]; + sum5 += a0 * pSrcB[k * O + j + 5] + a1 * pSrcB[(k + 1) * O + j + 5] + + a2 * pSrcB[(k + 2) * O + j + 5] + + a3 * pSrcB[(k + 3) * O + j + 5] + + a4 * pSrcB[(k + 4) * O + j + 5] + + a5 * pSrcB[(k + 5) * O + j + 5]; + } + + for (; k < N; ++k) { + const float32_t a_val = pSrcA[k * M + i]; + sum0 += a_val * pSrcB[k * O + j]; + sum1 += a_val * pSrcB[k * O + j + 1]; + sum2 += a_val * pSrcB[k * O + j + 2]; + sum3 += a_val * pSrcB[k * O + j + 3]; + sum4 += a_val * pSrcB[k * O + j + 4]; + sum5 += a_val * pSrcB[k * O + j + 5]; + } + + if (has_bias) { + y_row[j] = sum0 + c_row[j]; + y_row[j + 1] = sum1 + c_row[j + 1]; + y_row[j + 2] = sum2 + c_row[j + 2]; + y_row[j + 3] = sum3 + c_row[j + 3]; + y_row[j + 4] = sum4 + c_row[j + 4]; + y_row[j + 5] = sum5 + c_row[j + 5]; + } else { + y_row[j] = sum0; + y_row[j + 1] = sum1; + y_row[j + 2] = sum2; + y_row[j + 3] = sum3; + y_row[j + 4] = sum4; + y_row[j + 5] = sum5; + } + } + + for (; j < O; ++j) { + float32_t sum = 0.0f; + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + sum += pSrcA[k * M + i] * pSrcB[k * O + j] + + pSrcA[(k + 1) * M + i] * pSrcB[(k + 1) * O + j] + + pSrcA[(k + 2) * M + i] * pSrcB[(k + 2) * O + j] + + pSrcA[(k + 3) * M + i] * pSrcB[(k + 3) * O + j] + + pSrcA[(k + 4) * M + i] * pSrcB[(k + 4) * O + j] + + pSrcA[(k + 5) * M + i] * pSrcB[(k + 5) * O + j]; + } + for (; k < N; ++k) { + sum += pSrcA[k * M + i] * pSrcB[k * O + j]; + } + + y_row[j] = has_bias ? sum + c_row[j] : sum; + } + } + } else if (!transA && transB) { + + for (uint32_t i = M_start; i < M_end; ++i) { + const float32_t *__restrict__ a_row = &pSrcA[i * N]; + float32_t *__restrict__ y_row = &pDstY[i * O]; + const float32_t *__restrict__ c_row = has_bias ? &pDstC[i * O] : NULL; + + uint32_t j = 0; + for (; j < O_unroll; j += 6) { + const float32_t *__restrict__ b_row0 = &pSrcB[j * N]; + const float32_t *__restrict__ b_row1 = &pSrcB[(j + 1) * N]; + const float32_t *__restrict__ b_row2 = &pSrcB[(j + 2) * N]; + const float32_t *__restrict__ b_row3 = &pSrcB[(j + 3) * N]; + const float32_t *__restrict__ b_row4 = &pSrcB[(j + 4) * N]; + const float32_t *__restrict__ b_row5 = &pSrcB[(j + 5) * N]; + + float32_t sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, + sum4 = 0.0f, sum5 = 0.0f; + + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + const float32_t a0 = a_row[k]; + const float32_t a1 = a_row[k + 1]; + const float32_t a2 = a_row[k + 2]; + const float32_t a3 = a_row[k + 3]; + const float32_t a4 = a_row[k + 4]; + const float32_t a5 = a_row[k + 5]; + + sum0 += a0 * b_row0[k] + a1 * b_row0[k + 1] + a2 * b_row0[k + 2] + + a3 * b_row0[k + 3] + a4 * b_row0[k + 4] + a5 * b_row0[k + 5]; + sum1 += a0 * b_row1[k] + a1 * b_row1[k + 1] + a2 * b_row1[k + 2] + + a3 * b_row1[k + 3] + a4 * b_row1[k + 4] + a5 * b_row1[k + 5]; + sum2 += a0 * b_row2[k] + a1 * b_row2[k + 1] + a2 * b_row2[k + 2] + + a3 * b_row2[k + 3] + a4 * b_row2[k + 4] + a5 * b_row2[k + 5]; + sum3 += a0 * b_row3[k] + a1 * b_row3[k + 1] + a2 * b_row3[k + 2] + + a3 * b_row3[k + 3] + a4 * b_row3[k + 4] + a5 * b_row3[k + 5]; + sum4 += a0 * b_row4[k] + a1 * b_row4[k + 1] + a2 * b_row4[k + 2] + + a3 * b_row4[k + 3] + a4 * b_row4[k + 4] + a5 * b_row4[k + 5]; + sum5 += a0 * b_row5[k] + a1 * b_row5[k + 1] + a2 * b_row5[k + 2] + + a3 * b_row5[k + 3] + a4 * b_row5[k + 4] + a5 * b_row5[k + 5]; + } + + for (; k < N; ++k) { + const float32_t a_val = a_row[k]; + sum0 += a_val * b_row0[k]; + sum1 += a_val * b_row1[k]; + sum2 += a_val * b_row2[k]; + sum3 += a_val * b_row3[k]; + sum4 += a_val * b_row4[k]; + sum5 += a_val * b_row5[k]; + } + + if (has_bias) { + y_row[j] = sum0 + c_row[j]; + y_row[j + 1] = sum1 + c_row[j + 1]; + y_row[j + 2] = sum2 + c_row[j + 2]; + y_row[j + 3] = sum3 + c_row[j + 3]; + y_row[j + 4] = sum4 + c_row[j + 4]; + y_row[j + 5] = sum5 + c_row[j + 5]; + } else { + y_row[j] = sum0; + y_row[j + 1] = sum1; + y_row[j + 2] = sum2; + y_row[j + 3] = sum3; + y_row[j + 4] = sum4; + y_row[j + 5] = sum5; + } + } + + for (; j < O; ++j) { + const float32_t *__restrict__ b_row = &pSrcB[j * N]; + float32_t sum = 0.0f; + + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + sum += a_row[k] * b_row[k] + a_row[k + 1] * b_row[k + 1] + + a_row[k + 2] * b_row[k + 2] + a_row[k + 3] * b_row[k + 3] + + a_row[k + 4] * b_row[k + 4] + a_row[k + 5] * b_row[k + 5]; + } + for (; k < N; ++k) { + sum += a_row[k] * b_row[k]; + } + + y_row[j] = has_bias ? sum + c_row[j] : sum; + } + } + } else { + + for (uint32_t i = M_start; i < M_end; ++i) { + float32_t *__restrict__ y_row = &pDstY[i * O]; + const float32_t *__restrict__ c_row = has_bias ? &pDstC[i * O] : NULL; + + uint32_t j = 0; + for (; j < O_unroll; j += 6) { + const float32_t *__restrict__ b_row0 = &pSrcB[j * N]; + const float32_t *__restrict__ b_row1 = &pSrcB[(j + 1) * N]; + const float32_t *__restrict__ b_row2 = &pSrcB[(j + 2) * N]; + const float32_t *__restrict__ b_row3 = &pSrcB[(j + 3) * N]; + const float32_t *__restrict__ b_row4 = &pSrcB[(j + 4) * N]; + const float32_t *__restrict__ b_row5 = &pSrcB[(j + 5) * N]; + + float32_t sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, + sum4 = 0.0f, sum5 = 0.0f; + + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + const float32_t a0 = pSrcA[k * M + i]; + const float32_t a1 = pSrcA[(k + 1) * M + i]; + const float32_t a2 = pSrcA[(k + 2) * M + i]; + const float32_t a3 = pSrcA[(k + 3) * M + i]; + const float32_t a4 = pSrcA[(k + 4) * M + i]; + const float32_t a5 = pSrcA[(k + 5) * M + i]; + + sum0 += a0 * b_row0[k] + a1 * b_row0[k + 1] + a2 * b_row0[k + 2] + + a3 * b_row0[k + 3] + a4 * b_row0[k + 4] + a5 * b_row0[k + 5]; + sum1 += a0 * b_row1[k] + a1 * b_row1[k + 1] + a2 * b_row1[k + 2] + + a3 * b_row1[k + 3] + a4 * b_row1[k + 4] + a5 * b_row1[k + 5]; + sum2 += a0 * b_row2[k] + a1 * b_row2[k + 1] + a2 * b_row2[k + 2] + + a3 * b_row2[k + 3] + a4 * b_row2[k + 4] + a5 * b_row2[k + 5]; + sum3 += a0 * b_row3[k] + a1 * b_row3[k + 1] + a2 * b_row3[k + 2] + + a3 * b_row3[k + 3] + a4 * b_row3[k + 4] + a5 * b_row3[k + 5]; + sum4 += a0 * b_row4[k] + a1 * b_row4[k + 1] + a2 * b_row4[k + 2] + + a3 * b_row4[k + 3] + a4 * b_row4[k + 4] + a5 * b_row4[k + 5]; + sum5 += a0 * b_row5[k] + a1 * b_row5[k + 1] + a2 * b_row5[k + 2] + + a3 * b_row5[k + 3] + a4 * b_row5[k + 4] + a5 * b_row5[k + 5]; + } + + for (; k < N; ++k) { + const float32_t a_val = pSrcA[k * M + i]; + sum0 += a_val * b_row0[k]; + sum1 += a_val * b_row1[k]; + sum2 += a_val * b_row2[k]; + sum3 += a_val * b_row3[k]; + sum4 += a_val * b_row4[k]; + sum5 += a_val * b_row5[k]; + } + + if (has_bias) { + y_row[j] = sum0 + c_row[j]; + y_row[j + 1] = sum1 + c_row[j + 1]; + y_row[j + 2] = sum2 + c_row[j + 2]; + y_row[j + 3] = sum3 + c_row[j + 3]; + y_row[j + 4] = sum4 + c_row[j + 4]; + y_row[j + 5] = sum5 + c_row[j + 5]; + } else { + y_row[j] = sum0; + y_row[j + 1] = sum1; + y_row[j + 2] = sum2; + y_row[j + 3] = sum3; + y_row[j + 4] = sum4; + y_row[j + 5] = sum5; + } + } + + for (; j < O; ++j) { + const float32_t *__restrict__ b_row = &pSrcB[j * N]; + float32_t sum = 0.0f; + + uint32_t k = 0; + for (; k < N_unroll; k += 6) { + sum += pSrcA[k * M + i] * b_row[k] + + pSrcA[(k + 1) * M + i] * b_row[k + 1] + + pSrcA[(k + 2) * M + i] * b_row[k + 2] + + pSrcA[(k + 3) * M + i] * b_row[k + 3] + + pSrcA[(k + 4) * M + i] * b_row[k + 4] + + pSrcA[(k + 5) * M + i] * b_row[k + 5]; + } + for (; k < N; ++k) { + sum += pSrcA[k * M + i] * b_row[k]; + } + + y_row[j] = has_bias ? sum + c_row[j] : sum; + } } } } \ No newline at end of file