@@ -49,16 +49,14 @@ export function findKNNGPUCosDistNorm<T>(
4949 // pair of points, which we sort using KMin data structure to obtain the
5050 // K nearest neighbors for each point.
5151 const nearest : NearestEntry [ ] [ ] = new Array ( N ) ;
52-
53- const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
54- const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
55- const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
56-
5752 function step ( resolve : ( result : NearestEntry [ ] [ ] ) => void ) {
5853 util
5954 . runAsyncTask (
6055 'Finding nearest neighbors...' ,
6156 async ( ) => {
57+ const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
58+ const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
59+ const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
6260 // 1 - A * A^T.
6361 const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
6462 const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
@@ -68,6 +66,9 @@ export function findKNNGPUCosDistNorm<T>(
6866 // [ 3 4 ],
6967 // `.data()` returns [1, 2, 3, 4].
7068 const partial = await cosDistMatrix . data ( ) ;
69+ // Discard all tensors and free up the memory.
70+ bigMatrix . dispose ( ) ;
71+ bigMatrixTransposed . dispose ( ) ;
7172 bigMatrixSquared . dispose ( ) ;
7273 cosDistMatrix . dispose ( ) ;
7374 for ( let i = 0 ; i < N ; i ++ ) {
@@ -93,15 +94,9 @@ export function findKNNGPUCosDistNorm<T>(
9394 . then (
9495 ( ) => {
9596 logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
96- // Discard all tensors and free up the memory.
97- bigMatrix . dispose ( ) ;
98- bigMatrixTransposed . dispose ( ) ;
9997 resolve ( nearest ) ;
10098 } ,
10199 ( error ) => {
102- // Discard all tensors and free up the memory.
103- bigMatrix . dispose ( ) ;
104- bigMatrixTransposed . dispose ( ) ;
105100 // GPU failed. Reverting back to CPU.
106101 logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
107102 let distFunc = ( a , b , limit ) => vector . cosDistNorm ( a , b ) ;
0 commit comments