Skip to content
261 changes: 146 additions & 115 deletions lib/sources/pulp_mhsa_fp16.c

Large diffs are not rendered by default.

179 changes: 92 additions & 87 deletions lib/sources/pulp_mhsa_fp32.c

Large diffs are not rendered by default.

23 changes: 5 additions & 18 deletions lib/sources/pulp_residual_fp16.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,14 @@ void pulp_residualconn_fp16_fw(void *SkipConn_args_fp16) {
return;
}

int dims[] = {out->dim};

struct vect_sum_args_fp16 args_sum;

args_sum.op_1 = skip->data;
args_sum.op_2 = lout->data;
args_sum.dest = out->data;
args_sum.size = out->dim;

args_sum.op_1_dims = dims;
args_sum.op_2_dims = dims;

args_sum.op_1_dims_len = 1;
args_sum.op_2_dims_len = 1;

pi_cl_team_fork(NUM_CORES, array_broadcast_sum_fp16, &args_sum);
pi_cl_team_fork(NUM_CORES, vect_sum_fp16, &args_sum);
}


Expand All @@ -77,21 +71,14 @@ void pulp_sumnode_fp16_bw(void *SkipConn_args_fp16) {
return;
}

int dims[] = {skip->dim};

struct vect_sum_args_fp16 args_sum;

args_sum.op_1 = out->diff;
args_sum.op_2 = skip->diff;
args_sum.dest = skip->diff;
args_sum.size = skip->dim;

args_sum.op_1_dims = dims;
args_sum.op2_dims = dims;

args_sum.op_1_dims_len = 1;
args_sum.op_2_dims_len = 1;

pi_cl_team_fork(NUM_CORES, array_broadcast_sum_fp16, &args_sum);
pi_cl_team_fork(NUM_CORES, vect_sum_fp16, &args_sum);
}
}

Expand Down
11 changes: 6 additions & 5 deletions lib/sources/pulp_rnn_fp32.c
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ void pulp_rnn_fp32_bw_cl(void *Rnn_args) {

// Calculate gradient for State Weights
// Transpose State
int dims[] = {N, M};
int t_axes[] = {1, 0};
dims[0] = N;
dims[1] = M;

struct transp_args transp_args2;

transp_args2.matrix = hiddState;
transp_args2.transp_matrix = temp;
transp_args2.in_matrix = hiddState;
transp_args2.out_matrix = temp;
transp_args2.dim = dims;
transp_args2.transposed_axes = t_axes;
transp_args2.n_dim = 2;
Expand Down Expand Up @@ -301,7 +301,8 @@ void pulp_rnn_fp32_bw_cl(void *Rnn_args) {

// Calculate the Gradient of the Input
// Transpose Input Weights
dims = {K, M};
dims[0] = K;
dims[1] = M;

struct transp_args transp_args3;

Expand Down
31 changes: 30 additions & 1 deletion tests/test_matmul/utils/GM.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,41 @@
B.transpose(0, 1)
else:
C = torch.mm(input=A, mat2=B, out=C)
elif (data_type == 'bf16'):
# Matrices to be multiplied
A = torch.Tensor(in_size, mid_size).to(torch.bfloat16)
if transp == '1':
B = torch.Tensor(out_size, mid_size).to(torch.bfloat16)
else:
B = torch.Tensor(mid_size, out_size).to(torch.bfloat16)
C = torch.Tensor(in_size, out_size).to(torch.bfloat16)

A = torch.div(torch.randn(in_size, mid_size), divider).to(torch.bfloat16)
for i in range(A.shape[0]):
for j in range(A.shape[1]):
A[i][j] += (i+j+0.1)/divider

if transp == '1':
B = torch.zeros(out_size, mid_size).to(torch.bfloat16)
else:
B = torch.zeros(mid_size, out_size).to(torch.bfloat16)

for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = i*j+0.1

if transp == '1':
C = torch.mm(input=A, mat2=B.transpose(0, 1), out=C)
B.transpose(0, 1)
else:
C = torch.mm(input=A, mat2=B, out=C)

else : # Error message
print('Invalid data type selection!!')
exit()


if data_type == 'bf16':
data_type = 'fp16'

# Print data and create data header file
f = open('net_args.h', "w")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_mhsa_fp16/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
APP = mhsa_fp16

# User settings
IN_H?=196 # Sequence Length
IN_W?=160 # Token Size
N_HEADS?=5
ATT_DIM?=160 #Hidden dimension
IN_H?=20 # Sequence Length
IN_W?=40 # Token Size
N_HEADS?=2
ATT_DIM?=40 #Hidden dimension

IN_CH?=1
OUT_CH?=1
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mhsa_fp16/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
#define Tker_l0 (Tin_l0*Tout_l0)

// Tensor checksum definition
#define CHECK_TOLERANCE 0.001
#define ERROR_TOLERANCE 0.001
#define CHECK_TOLERANCE 0x00000021
#define ERROR_TOLERANCE 0x00000001

// PULP DEFINES
#define STACK_SIZE 4096
Expand Down
Loading