diff --git a/QKV-new.py b/QKV-new.py new file mode 100644 index 0000000..5103d6e --- /dev/null +++ b/QKV-new.py @@ -0,0 +1,136 @@ +"""QKV Accelerator ISA definition, new exercise""" + +from taidl import Accelerator + +qkv = Accelerator("QKV") + +#data models + +#d1 is I/O but only 64/64 because we need to be able to transpose it into d3 and we don't want d3 to be 128 by 128 +qkv.add_data_model("d1", [128], [64], "bf16") +#d2 is intermediate value buffer. we're leaving it as 64 by 64 because that's all we need +qkv.add_data_model("d2", [64], [64], "bf16") +#d3 has to be 128 rows so that we can gemm (matrix multiplication) from just d3 into d2 +qkv.add_data_model("d3", [128], [64], "bf16") + +#instruction semantics +#notes: @c means computational attributes, @a means addressing attributes. +#d0 is implicit off-chip HBM/DRAM memory. Its elements are stored in a flat byte-addressed array, whereas all the scratch pads actually have rows and columns + + +instr = qkv.add_instruction("load1_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes]) +instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total. +instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in +instr.add_semantics(""" +ENTRY load1_rm { + %In1 = u8[`@c.n * 128`] parameter(0); + %a = u8[`@c.n`,64,2] reshape(%In1); + ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a); +} +""") + +instr = qkv.add_instruction("load3_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes]) +instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total. +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in +instr.add_semantics(""" +ENTRY load3_rm { + %In1 = u8[`@c.n * 128`] parameter(0); + %a = u8[`@c.n`,64,2] reshape(%In1); + ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a); +} +""") + +instr = qkv.add_instruction("store1_rm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]]) +instr.add_semantics(""" +ENTRY store1_rm { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = u8[`@c.n`,64,2] bitcast_convert(%In1); + ROOT %Out0 = u8[`@c.n*128`] reshape(%a); +} +""") + +instr = qkv.add_instruction("store3_rm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d3", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]]) +instr.add_semantics(""" +ENTRY store3_rm { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = u8[`@c.n`,64,2] bitcast_convert(%In1); + ROOT %Out0 = u8[`@c.n*128`] reshape(%a); +} +""") + +instr = qkv.add_instruction("mov1", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov1 { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[`@c.n`,64] copy(%In1); +} +""") + +instr = qkv.add_instruction("mov3", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov3 { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[`@c.n`,64] copy(%In1); +} +""") + +instr = qkv.add_instruction("mov_cm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov_cm { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[64, `@c.n`] transpose(%In1), dimensions={1,0}; +} +""") + +#the other 2 errors are in gem13 (forgot = between ..._dims={1} and ..._dims{0} in both this and gemm33 + +instr = qkv.add_instruction("gemm13", [], ["addr_1", "addr_2", "addr_out"]) +instr.set_inputs([ ["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ]) +instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ]) +instr.add_semantics(""" +ENTRY gemm13 { + %In1 = bf16[64, 64] parameter(0); + %In2 = bf16[64, 64] parameter(1); + ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0}; +} +""") + +#2 of the errors are in gemm33 + +instr = qkv.add_instruction("gemm33", [], ["addr_1", "addr_2", "addr_out"]) +instr.set_inputs([ ["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ]) +instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ]) +instr.add_semantics(""" +ENTRY gemm33 { + %In1 = bf16[64, 64] parameter(0); + %In2 = bf16[64, 64] parameter(1); + ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0}; +} +""") + +instr = qkv.add_instruction("softmax", ["n"], ["addr"]) +instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]]) +instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]]) +instr.add_semantics(""" +ENTRY softmax { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = bf16[`@c.n`,64] exponential(%In1); + %reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1}; + %b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0}; + ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b); +} +""") + +qkv.generate_oracle() + +qkv.generate_backend() \ No newline at end of file diff --git a/asm/attention_new.py b/asm/attention_new.py new file mode 100644 index 0000000..fa87852 --- /dev/null +++ b/asm/attention_new.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0] + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64]. + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file diff --git a/asm/attention_new2.py b/asm/attention_new2.py new file mode 100644 index 0000000..a56a4a1 --- /dev/null +++ b/asm/attention_new2.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load1_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=0) #transpose K from d1[64] to d3[0] + api.gemm13(addr_1=0, addr_2=0, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64]. + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file diff --git a/asm/attention_new3.py b/asm/attention_new3.py new file mode 100644 index 0000000..816e15f --- /dev/null +++ b/asm/attention_new3.py @@ -0,0 +1,33 @@ +# Input file: log.hlo +# Kernel name: qkv +# PII number: 0 +# Do not edit! + +import jax.numpy as jnp + + +def qkv(kernel, api): + @kernel(hbm=32768, + input=[ + {'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + {'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + {'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + ], + constant=[], + output=[ + {'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + ] + ) + def qkv_(): + api.load1_rm(n = 64, addr_in = 0, addr_out = 0) + api.load1_rm(n = 64, addr_in = 8192, addr_out = 64) + api.mov_cm(n = 64, addr_in = 64, addr_out = 0) + api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0) + api.softmax(n = 64, addr = 0) + api.mov1(n = 64, addr_in = 0, addr_out = 0) + api.load3_rm(n = 64, addr_in = 16384, addr_out = 0) + api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0) + api.mov1(n = 64, addr_in = 0, addr_out = 0) + api.store1_rm(n = 64, addr_in = 0, addr_out = 24576) + + return qkv_ diff --git a/asm/attention_new4.py b/asm/attention_new4.py new file mode 100644 index 0000000..e840a40 --- /dev/null +++ b/asm/attention_new4.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0] + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov1(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=0) #load in v to d3[64]. + api.gemm13(addr_1=0, addr_2=0, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file diff --git a/test_qkv.py b/test_qkv.py new file mode 100644 index 0000000..f8e7604 --- /dev/null +++ b/test_qkv.py @@ -0,0 +1,84 @@ +try: + import sys + sys.path.insert(0, '/workspace/targets/QKV') + from oracle.decorator import kernel + import oracle.api as api + + from oracle.decorator import set_simulation_backend + # Run the simulations on CPU + set_simulation_backend('CPU') +except Exception as e: + print(f"Error setting path for QKV Oracle: {e}") + print("Make sure you generated the Oracle using the generator in Hands-on Exercise 1.") + exit(1) + +try: + import sys + sys.path.insert(0, '/workspace/asm') + from attention_new4 import qkv #this is the line to modify if you want to test a different kernel +except Exception as e: + print(f"Error importing attention kernel: {e}") + print("Make sure you compiled the correct attention hlo module.") + exit(1) + + +import os +import numpy as np +import jax +import jax.numpy as jnp + + +DATA_DIR = "/workspace/data/" + + +def load_bf16_matrix(path, shape): + """ + Load a raw 8-bit file and reinterpret as jax bfloat16 matrix with given shape. + """ + np_uint8 = np.fromfile(path, dtype=np.uint8) + if np_uint8.size != (shape[0] * shape[1] * 2): + raise ValueError(f"Data in {path} has size {np_uint8.size}, expected {shape[0]*shape[1]}") + np_uint8 = np_uint8.reshape(shape[0], shape[1], 2) + j_uint8 = jnp.array(np_uint8, dtype=jnp.uint8) + mat = jax.lax.bitcast_convert_type(j_uint8, jnp.bfloat16) + return mat + + +if __name__ == "__main__": + # Compile the simulation for the attention kernel + qkv = qkv(kernel, api) + inputs, compile_time = qkv('fsim-compile')() + print(f"Simulation ready in {compile_time}ms") + + # Load input data + Q = load_bf16_matrix(os.path.join(DATA_DIR, "Q.dat"), (64, 64)) + K = load_bf16_matrix(os.path.join(DATA_DIR, "K.dat"), (64, 64)) + V = load_bf16_matrix(os.path.join(DATA_DIR, "V.dat"), (64, 64)) + print("Loaded data/Q.dat, data/K.dat, data/V.dat (raw bfloat16 bits)") + + # Run the simulation + outputs, elapsed = qkv('fsim')(Q, K, V) + qkv_output = outputs[0] + print(f"Simulation ran in {elapsed}ms") + + # Print input and output shapes and dtypes + print(f"Inputs:") + print(f" Q: {Q.shape}, {Q.dtype}") + print(f" K: {K.shape}, {K.dtype}") + print(f" V: {V.shape}, {V.dtype}") + print(f"Outputs:") + print(f" Output: {qkv_output.shape}, {qkv_output.dtype}") + + # Load golden output (from FPGA implementation of QKV accelerator) + golden = load_bf16_matrix(os.path.join(DATA_DIR, "attention.dat"), (64, 64)) + print("Loaded data/attention.dat (raw bfloat16 bits) as golden output") + + # Compare simulation output of attention kernel with golden + max_diff = jnp.max(jnp.abs(qkv_output - golden)) + print(f"Max absolute difference between simulation and golden: {max_diff}") + if max_diff == 0: + print("Output matches golden exactly!") + print("Great! The compiled attention kernel is correct.") + else: + print("Output does not match golden.") + print("Oh no! There might be a bug in the compiled attention kernel.")