Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions QKV-new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""QKV Accelerator ISA definition, new exercise"""

from taidl import Accelerator

qkv = Accelerator("QKV")

#data models

#d1 is I/O but only 64/64 because we need to be able to transpose it into d3 and we don't want d3 to be 128 by 128
qkv.add_data_model("d1", [128], [64], "bf16")
#d2 is intermediate value buffer. we're leaving it as 64 by 64 because that's all we need
qkv.add_data_model("d2", [64], [64], "bf16")
#d3 has to be 128 rows so that we can gemm (matrix multiplication) from just d3 into d2
qkv.add_data_model("d3", [128], [64], "bf16")

#instruction semantics
#notes: @c means computational attributes, @a means addressing attributes.
#d0 is implicit off-chip HBM/DRAM memory. Its elements are stored in a flat byte-addressed array, whereas all the scratch pads actually have rows and columns


instr = qkv.add_instruction("load1_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes])
instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total.
instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in
instr.add_semantics("""
ENTRY load1_rm {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")

instr = qkv.add_instruction("load3_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes])
instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total.
instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in
instr.add_semantics("""
ENTRY load3_rm {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")

instr = qkv.add_instruction("store1_rm", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]])
instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]])
instr.add_semantics("""
ENTRY store1_rm {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")

instr = qkv.add_instruction("store3_rm", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([[ "d3", ["@a.addr_in"], ["@c.n"] ]])
instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]])
instr.add_semantics("""
ENTRY store3_rm {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")

instr = qkv.add_instruction("mov1", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]])
instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]])
instr.add_semantics("""
ENTRY mov1 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

instr = qkv.add_instruction("mov3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]])
instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]])
instr.add_semantics("""
ENTRY mov3 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

instr = qkv.add_instruction("mov_cm", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]])
instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]])
instr.add_semantics("""
ENTRY mov_cm {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[64, `@c.n`] transpose(%In1), dimensions={1,0};
}
""")

#the other 2 errors are in gem13 (forgot = between ..._dims={1} and ..._dims{0} in both this and gemm33

instr = qkv.add_instruction("gemm13", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([ ["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ])
instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ])
instr.add_semantics("""
ENTRY gemm13 {
%In1 = bf16[64, 64] parameter(0);
%In2 = bf16[64, 64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")

#2 of the errors are in gemm33

instr = qkv.add_instruction("gemm33", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([ ["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ])
instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ])
instr.add_semantics("""
ENTRY gemm33 {
%In1 = bf16[64, 64] parameter(0);
%In2 = bf16[64, 64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")

instr = qkv.add_instruction("softmax", ["n"], ["addr"])
instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.add_semantics("""
ENTRY softmax {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = bf16[`@c.n`,64] exponential(%In1);
%reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1};
%b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0};
ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b);
}
""")

qkv.generate_oracle()

qkv.generate_backend()
28 changes: 28 additions & 0 deletions asm/attention_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768,
input=[
{'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q
{'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K
{'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V
],
constant=[],
output=[
{'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16},
]
)

def qkv_():
api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0.
api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later
api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0]
api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0]
api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place
api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T
api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64].
api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0]
api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it
api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention
return qkv_
28 changes: 28 additions & 0 deletions asm/attention_new2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768,
input=[
{'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q
{'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K
{'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V
],
constant=[],
output=[
{'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16},
]
)

def qkv_():
api.load1_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0.
api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later
api.mov_cm(n=64, addr_in=64, addr_out=0) #transpose K from d1[64] to d3[0]
api.gemm13(addr_1=0, addr_2=0, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0]
api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place
api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T
api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64].
api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0]
api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it
api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention
return qkv_
33 changes: 33 additions & 0 deletions asm/attention_new3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Input file: log.hlo
# Kernel name: qkv
# PII number: 0
# Do not edit!

import jax.numpy as jnp


def qkv(kernel, api):
@kernel(hbm=32768,
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16},
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16},
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16},
],
constant=[],
output=[
{'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def qkv_():
api.load1_rm(n = 64, addr_in = 0, addr_out = 0)
api.load1_rm(n = 64, addr_in = 8192, addr_out = 64)
api.mov_cm(n = 64, addr_in = 64, addr_out = 0)
api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0)
api.softmax(n = 64, addr = 0)
api.mov1(n = 64, addr_in = 0, addr_out = 0)
api.load3_rm(n = 64, addr_in = 16384, addr_out = 0)
api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0)
api.mov1(n = 64, addr_in = 0, addr_out = 0)
api.store1_rm(n = 64, addr_in = 0, addr_out = 24576)

return qkv_
28 changes: 28 additions & 0 deletions asm/attention_new4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768,
input=[
{'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q
{'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K
{'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V
],
constant=[],
output=[
{'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16},
]
)

def qkv_():
api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0.
api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later
api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0]
api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0]
api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place
api.mov1(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T
api.load3_rm(n=64, addr_in=16384, addr_out=0) #load in v to d3[64].
api.gemm13(addr_1=0, addr_2=0, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0]
api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it
api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention
return qkv_
84 changes: 84 additions & 0 deletions test_qkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
try:
import sys
sys.path.insert(0, '/workspace/targets/QKV')
from oracle.decorator import kernel
import oracle.api as api

from oracle.decorator import set_simulation_backend
# Run the simulations on CPU
set_simulation_backend('CPU')
except Exception as e:
print(f"Error setting path for QKV Oracle: {e}")
print("Make sure you generated the Oracle using the generator in Hands-on Exercise 1.")
exit(1)

try:
import sys
sys.path.insert(0, '/workspace/asm')
from attention_new4 import qkv #this is the line to modify if you want to test a different kernel
except Exception as e:
print(f"Error importing attention kernel: {e}")
print("Make sure you compiled the correct attention hlo module.")
exit(1)


import os
import numpy as np
import jax
import jax.numpy as jnp


DATA_DIR = "/workspace/data/"


def load_bf16_matrix(path, shape):
"""
Load a raw 8-bit file and reinterpret as jax bfloat16 matrix with given shape.
"""
np_uint8 = np.fromfile(path, dtype=np.uint8)
if np_uint8.size != (shape[0] * shape[1] * 2):
raise ValueError(f"Data in {path} has size {np_uint8.size}, expected {shape[0]*shape[1]}")
np_uint8 = np_uint8.reshape(shape[0], shape[1], 2)
j_uint8 = jnp.array(np_uint8, dtype=jnp.uint8)
mat = jax.lax.bitcast_convert_type(j_uint8, jnp.bfloat16)
return mat


if __name__ == "__main__":
# Compile the simulation for the attention kernel
qkv = qkv(kernel, api)
inputs, compile_time = qkv('fsim-compile')()
print(f"Simulation ready in {compile_time}ms")

# Load input data
Q = load_bf16_matrix(os.path.join(DATA_DIR, "Q.dat"), (64, 64))
K = load_bf16_matrix(os.path.join(DATA_DIR, "K.dat"), (64, 64))
V = load_bf16_matrix(os.path.join(DATA_DIR, "V.dat"), (64, 64))
print("Loaded data/Q.dat, data/K.dat, data/V.dat (raw bfloat16 bits)")

# Run the simulation
outputs, elapsed = qkv('fsim')(Q, K, V)
qkv_output = outputs[0]
print(f"Simulation ran in {elapsed}ms")

# Print input and output shapes and dtypes
print(f"Inputs:")
print(f" Q: {Q.shape}, {Q.dtype}")
print(f" K: {K.shape}, {K.dtype}")
print(f" V: {V.shape}, {V.dtype}")
print(f"Outputs:")
print(f" Output: {qkv_output.shape}, {qkv_output.dtype}")

# Load golden output (from FPGA implementation of QKV accelerator)
golden = load_bf16_matrix(os.path.join(DATA_DIR, "attention.dat"), (64, 64))
print("Loaded data/attention.dat (raw bfloat16 bits) as golden output")

# Compare simulation output of attention kernel with golden
max_diff = jnp.max(jnp.abs(qkv_output - golden))
print(f"Max absolute difference between simulation and golden: {max_diff}")
if max_diff == 0:
print("Output matches golden exactly!")
print("Great! The compiled attention kernel is correct.")
else:
print("Output does not match golden.")
print("Oh no! There might be a bug in the compiled attention kernel.")