From 90c20d3d66185f8859865675abdd8ca0653c2dd0 Mon Sep 17 00:00:00 2001 From: giberish4040404 Date: Fri, 2 Jan 2026 20:50:14 -0600 Subject: [PATCH 1/4] 4 different attention kernels --- asm/attention_new2.py | 28 ++++++++++++++++++++++++++++ asm/attention_new3.py | 33 +++++++++++++++++++++++++++++++++ asm/attention_new4.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 asm/attention_new2.py create mode 100644 asm/attention_new3.py create mode 100644 asm/attention_new4.py diff --git a/asm/attention_new2.py b/asm/attention_new2.py new file mode 100644 index 0000000..a56a4a1 --- /dev/null +++ b/asm/attention_new2.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load1_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=0) #transpose K from d1[64] to d3[0] + api.gemm13(addr_1=0, addr_2=0, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64]. + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file diff --git a/asm/attention_new3.py b/asm/attention_new3.py new file mode 100644 index 0000000..816e15f --- /dev/null +++ b/asm/attention_new3.py @@ -0,0 +1,33 @@ +# Input file: log.hlo +# Kernel name: qkv +# PII number: 0 +# Do not edit! + +import jax.numpy as jnp + + +def qkv(kernel, api): + @kernel(hbm=32768, + input=[ + {'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + {'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + {'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + ], + constant=[], + output=[ + {'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16}, + ] + ) + def qkv_(): + api.load1_rm(n = 64, addr_in = 0, addr_out = 0) + api.load1_rm(n = 64, addr_in = 8192, addr_out = 64) + api.mov_cm(n = 64, addr_in = 64, addr_out = 0) + api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0) + api.softmax(n = 64, addr = 0) + api.mov1(n = 64, addr_in = 0, addr_out = 0) + api.load3_rm(n = 64, addr_in = 16384, addr_out = 0) + api.gemm13(addr_1 = 0, addr_2 = 0, addr_out = 0) + api.mov1(n = 64, addr_in = 0, addr_out = 0) + api.store1_rm(n = 64, addr_in = 0, addr_out = 24576) + + return qkv_ diff --git a/asm/attention_new4.py b/asm/attention_new4.py new file mode 100644 index 0000000..e840a40 --- /dev/null +++ b/asm/attention_new4.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0] + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov1(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=0) #load in v to d3[64]. + api.gemm13(addr_1=0, addr_2=0, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file From 66767d50909ef2e5273b0df82bcb902388374d6f Mon Sep 17 00:00:00 2001 From: giberish4040404 Date: Fri, 2 Jan 2026 20:53:09 -0600 Subject: [PATCH 2/4] actually added the first attention kernel and the isa file --- QKV-new.py | 136 +++++++++++++++++++++++++++++++++++++++++++ asm/attention.py | 47 +++++++++++++++ asm/attention_new.py | 28 +++++++++ 3 files changed, 211 insertions(+) create mode 100644 QKV-new.py create mode 100644 asm/attention.py create mode 100644 asm/attention_new.py diff --git a/QKV-new.py b/QKV-new.py new file mode 100644 index 0000000..5103d6e --- /dev/null +++ b/QKV-new.py @@ -0,0 +1,136 @@ +"""QKV Accelerator ISA definition, new exercise""" + +from taidl import Accelerator + +qkv = Accelerator("QKV") + +#data models + +#d1 is I/O but only 64/64 because we need to be able to transpose it into d3 and we don't want d3 to be 128 by 128 +qkv.add_data_model("d1", [128], [64], "bf16") +#d2 is intermediate value buffer. we're leaving it as 64 by 64 because that's all we need +qkv.add_data_model("d2", [64], [64], "bf16") +#d3 has to be 128 rows so that we can gemm (matrix multiplication) from just d3 into d2 +qkv.add_data_model("d3", [128], [64], "bf16") + +#instruction semantics +#notes: @c means computational attributes, @a means addressing attributes. +#d0 is implicit off-chip HBM/DRAM memory. Its elements are stored in a flat byte-addressed array, whereas all the scratch pads actually have rows and columns + + +instr = qkv.add_instruction("load1_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes]) +instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total. +instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in +instr.add_semantics(""" +ENTRY load1_rm { + %In1 = u8[`@c.n * 128`] parameter(0); + %a = u8[`@c.n`,64,2] reshape(%In1); + ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a); +} +""") + +instr = qkv.add_instruction("load3_rm", ["n"], ["addr_in", "addr_out"]) #(instruction_name, [list_of_computational_attributes], [list_of_addressing_attributes]) +instr.set_inputs([[ "d0", ["@a.addr_in"], ["@c.n * 128"] ]]) #([[ input_buffer, [addressing_attribute], [size_of_input] ]]). here we do c.n * 128 because c.n is the number of rows and for each row there are 64 bf16s which are each 2 bytes a piece (1 byte = 8 bits), so 128 bytes in total. +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) #([[ output_buffer, [addressing_attribute], [size_of_input] ]]). here we don't multiply by 128 because d1 already has row-size built-in +instr.add_semantics(""" +ENTRY load3_rm { + %In1 = u8[`@c.n * 128`] parameter(0); + %a = u8[`@c.n`,64,2] reshape(%In1); + ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a); +} +""") + +instr = qkv.add_instruction("store1_rm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]]) +instr.add_semantics(""" +ENTRY store1_rm { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = u8[`@c.n`,64,2] bitcast_convert(%In1); + ROOT %Out0 = u8[`@c.n*128`] reshape(%a); +} +""") + +instr = qkv.add_instruction("store3_rm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d3", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d0", ["@a.addr_out"], ["@c.n * 128"] ]]) +instr.add_semantics(""" +ENTRY store3_rm { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = u8[`@c.n`,64,2] bitcast_convert(%In1); + ROOT %Out0 = u8[`@c.n*128`] reshape(%a); +} +""") + +instr = qkv.add_instruction("mov1", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d1", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov1 { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[`@c.n`,64] copy(%In1); +} +""") + +instr = qkv.add_instruction("mov3", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d2", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov3 { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[`@c.n`,64] copy(%In1); +} +""") + +instr = qkv.add_instruction("mov_cm", ["n"], ["addr_in", "addr_out"]) +instr.set_inputs([[ "d1", ["@a.addr_in"], ["@c.n"] ]]) +instr.set_outputs([[ "d3", ["@a.addr_out"], ["@c.n"] ]]) +instr.add_semantics(""" +ENTRY mov_cm { + %In1 = bf16[`@c.n`,64] parameter(0); + ROOT %Out0 = bf16[64, `@c.n`] transpose(%In1), dimensions={1,0}; +} +""") + +#the other 2 errors are in gem13 (forgot = between ..._dims={1} and ..._dims{0} in both this and gemm33 + +instr = qkv.add_instruction("gemm13", [], ["addr_1", "addr_2", "addr_out"]) +instr.set_inputs([ ["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ]) +instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ]) +instr.add_semantics(""" +ENTRY gemm13 { + %In1 = bf16[64, 64] parameter(0); + %In2 = bf16[64, 64] parameter(1); + ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0}; +} +""") + +#2 of the errors are in gemm33 + +instr = qkv.add_instruction("gemm33", [], ["addr_1", "addr_2", "addr_out"]) +instr.set_inputs([ ["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]] ]) +instr.set_outputs([ ["d2", ["@a.addr_out"], ["64"]] ]) +instr.add_semantics(""" +ENTRY gemm33 { + %In1 = bf16[64, 64] parameter(0); + %In2 = bf16[64, 64] parameter(1); + ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0}; +} +""") + +instr = qkv.add_instruction("softmax", ["n"], ["addr"]) +instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]]) +instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]]) +instr.add_semantics(""" +ENTRY softmax { + %In1 = bf16[`@c.n`,64] parameter(0); + %a = bf16[`@c.n`,64] exponential(%In1); + %reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1}; + %b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0}; + ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b); +} +""") + +qkv.generate_oracle() + +qkv.generate_backend() \ No newline at end of file diff --git a/asm/attention.py b/asm/attention.py new file mode 100644 index 0000000..b421e97 --- /dev/null +++ b/asm/attention.py @@ -0,0 +1,47 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, # 32 KB: enough for 3 inputs + 1 output + input=[ + {'addr': 0, 'shape': (64,64), 'dtype': jnp.bfloat16}, #Q + {'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, #K + {'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, #V + ], # allocate the input addresses here + constant=[], # if needed, we can add constants here (none in this case) + output=[ + {'addr': 24576, 'shape': (64,64), 'dtype': jnp.bfloat16}, #O + ] # allocate the output address here + ) + def qkv_(): + # Kernel implementation goes here + + # Load Q in row-major (standard) + api.load_rm(n=64, addr_in=0, addr_out=0) + + # Load K in column-major (gives us K^T automatically!) + api.load_cm(n=64, addr_in=8192, addr_out=64) + + # Compute S = Q × K^T + api.gemm(addr_1=0, addr_2=64, addr_out=0) + + # Apply softmax to S, converting it to P (in-place in d2) + api.softmax(n=64, addr=0) + + # Move P from d2[0:63] to d1[0:63] + # Recall: mov copies between scratchpads + api.mov(n=64, addr_in=0, addr_out=0) + + # Load V into d1[64:127] (reusing K^T's space) + api.load_rm(n=64, addr_in=16384, addr_out=64) + + # Compute O = P × V, result in d2[0:63] + api.gemm(addr_1=0, addr_2=64, addr_out=0) + + # Move O from d2[0:63] to d1[0:63] + api.mov(n=64, addr_in=0, addr_out=0) + + # Store O to HBM at address 24576 + api.store_rm(n=64, addr_in=0, addr_out=24576) + + return qkv_ diff --git a/asm/attention_new.py b/asm/attention_new.py new file mode 100644 index 0000000..fa87852 --- /dev/null +++ b/asm/attention_new.py @@ -0,0 +1,28 @@ +import jax.numpy as jnp + +def qkv(kernel, api): + @kernel( + hbm=32768, + input=[ + {'addr': 0, 'shape':(64,64), 'dtype': jnp.bfloat16}, #Q + {'addr':8192, 'shape':(64,64), 'dtype': jnp.bfloat16}, #K + {'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}, #V + ], + constant=[], + output=[ + {'addr':24576, 'shape':(64,64), 'dtype': jnp.bfloat16}, + ] + ) + + def qkv_(): + api.load3_rm(n=64, addr_in = 0, addr_out = 0) #Load all of Q (64 rows) into d1 at addr 0. + api.load1_rm(n=64, addr_in=8192, addr_out=64) #Load K into the 2nd set of 64 rows in d1 to transpose later + api.mov_cm(n=64, addr_in=64, addr_out=64) #transpose K from d1[64] to d3[0] + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #multiply Q in d1[0] with K^T in d3[0]. deposit result in d2[0] + api.softmax(n=64, addr=0) #softmax (Q x K^T) already located in d2[0] in place + api.mov3(n=64, addr_in=0, addr_out=0) #move softmax(QxK^T) from d2[0] to d3[0]. overwrites K^T + api.load3_rm(n=64, addr_in=16384, addr_out=64) #load in v to d3[64]. + api.gemm33(addr_1=0, addr_2=64, addr_out=0) #perform final matrix multiplication. 1st input from d3[0], 2nd input from d3[64]. output in d2[0] + api.mov1(n=64, addr_in=0, addr_out=0) #moving our final result to d1 so I can output it + api.store1_rm(n=64, addr_in=0, addr_out=24576) #yay we did it we computed attention + return qkv_ \ No newline at end of file From e2accc909f3d4295529880e8a194070867911833 Mon Sep 17 00:00:00 2001 From: giberish4040404 Date: Fri, 2 Jan 2026 20:55:11 -0600 Subject: [PATCH 3/4] added test_qkv file to commit --- test_qkv.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 test_qkv.py diff --git a/test_qkv.py b/test_qkv.py new file mode 100644 index 0000000..f8e7604 --- /dev/null +++ b/test_qkv.py @@ -0,0 +1,84 @@ +try: + import sys + sys.path.insert(0, '/workspace/targets/QKV') + from oracle.decorator import kernel + import oracle.api as api + + from oracle.decorator import set_simulation_backend + # Run the simulations on CPU + set_simulation_backend('CPU') +except Exception as e: + print(f"Error setting path for QKV Oracle: {e}") + print("Make sure you generated the Oracle using the generator in Hands-on Exercise 1.") + exit(1) + +try: + import sys + sys.path.insert(0, '/workspace/asm') + from attention_new4 import qkv #this is the line to modify if you want to test a different kernel +except Exception as e: + print(f"Error importing attention kernel: {e}") + print("Make sure you compiled the correct attention hlo module.") + exit(1) + + +import os +import numpy as np +import jax +import jax.numpy as jnp + + +DATA_DIR = "/workspace/data/" + + +def load_bf16_matrix(path, shape): + """ + Load a raw 8-bit file and reinterpret as jax bfloat16 matrix with given shape. + """ + np_uint8 = np.fromfile(path, dtype=np.uint8) + if np_uint8.size != (shape[0] * shape[1] * 2): + raise ValueError(f"Data in {path} has size {np_uint8.size}, expected {shape[0]*shape[1]}") + np_uint8 = np_uint8.reshape(shape[0], shape[1], 2) + j_uint8 = jnp.array(np_uint8, dtype=jnp.uint8) + mat = jax.lax.bitcast_convert_type(j_uint8, jnp.bfloat16) + return mat + + +if __name__ == "__main__": + # Compile the simulation for the attention kernel + qkv = qkv(kernel, api) + inputs, compile_time = qkv('fsim-compile')() + print(f"Simulation ready in {compile_time}ms") + + # Load input data + Q = load_bf16_matrix(os.path.join(DATA_DIR, "Q.dat"), (64, 64)) + K = load_bf16_matrix(os.path.join(DATA_DIR, "K.dat"), (64, 64)) + V = load_bf16_matrix(os.path.join(DATA_DIR, "V.dat"), (64, 64)) + print("Loaded data/Q.dat, data/K.dat, data/V.dat (raw bfloat16 bits)") + + # Run the simulation + outputs, elapsed = qkv('fsim')(Q, K, V) + qkv_output = outputs[0] + print(f"Simulation ran in {elapsed}ms") + + # Print input and output shapes and dtypes + print(f"Inputs:") + print(f" Q: {Q.shape}, {Q.dtype}") + print(f" K: {K.shape}, {K.dtype}") + print(f" V: {V.shape}, {V.dtype}") + print(f"Outputs:") + print(f" Output: {qkv_output.shape}, {qkv_output.dtype}") + + # Load golden output (from FPGA implementation of QKV accelerator) + golden = load_bf16_matrix(os.path.join(DATA_DIR, "attention.dat"), (64, 64)) + print("Loaded data/attention.dat (raw bfloat16 bits) as golden output") + + # Compare simulation output of attention kernel with golden + max_diff = jnp.max(jnp.abs(qkv_output - golden)) + print(f"Max absolute difference between simulation and golden: {max_diff}") + if max_diff == 0: + print("Output matches golden exactly!") + print("Great! The compiled attention kernel is correct.") + else: + print("Output does not match golden.") + print("Oh no! There might be a bug in the compiled attention kernel.") From 7f1cae77e9346f053183490f8e89d3d346700198 Mon Sep 17 00:00:00 2001 From: Pedro Couto <115478001+giberish4040404@users.noreply.github.com> Date: Fri, 2 Jan 2026 21:00:49 -0600 Subject: [PATCH 4/4] Delete asm/attention.py had accidentally included this file in the commit from local machine --- asm/attention.py | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 asm/attention.py diff --git a/asm/attention.py b/asm/attention.py deleted file mode 100644 index b421e97..0000000 --- a/asm/attention.py +++ /dev/null @@ -1,47 +0,0 @@ -import jax.numpy as jnp - -def qkv(kernel, api): - @kernel( - hbm=32768, # 32 KB: enough for 3 inputs + 1 output - input=[ - {'addr': 0, 'shape': (64,64), 'dtype': jnp.bfloat16}, #Q - {'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, #K - {'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, #V - ], # allocate the input addresses here - constant=[], # if needed, we can add constants here (none in this case) - output=[ - {'addr': 24576, 'shape': (64,64), 'dtype': jnp.bfloat16}, #O - ] # allocate the output address here - ) - def qkv_(): - # Kernel implementation goes here - - # Load Q in row-major (standard) - api.load_rm(n=64, addr_in=0, addr_out=0) - - # Load K in column-major (gives us K^T automatically!) - api.load_cm(n=64, addr_in=8192, addr_out=64) - - # Compute S = Q × K^T - api.gemm(addr_1=0, addr_2=64, addr_out=0) - - # Apply softmax to S, converting it to P (in-place in d2) - api.softmax(n=64, addr=0) - - # Move P from d2[0:63] to d1[0:63] - # Recall: mov copies between scratchpads - api.mov(n=64, addr_in=0, addr_out=0) - - # Load V into d1[64:127] (reusing K^T's space) - api.load_rm(n=64, addr_in=16384, addr_out=64) - - # Compute O = P × V, result in d2[0:63] - api.gemm(addr_1=0, addr_2=64, addr_out=0) - - # Move O from d2[0:63] to d1[0:63] - api.mov(n=64, addr_in=0, addr_out=0) - - # Store O to HBM at address 24576 - api.store_rm(n=64, addr_in=0, addr_out=24576) - - return qkv_