Skip to content

Commit 6c10d9b

Browse files
committed
replace inplace copy with nn.Parameter
1 parent b176292 commit 6c10d9b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

taming/modules/vqvae/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
342342
self.embedding = nn.Embedding(self.n_embed, self.embedding_dim)
343343
self.embedding.weight.requires_grad = False
344344
self.cluster_size = nn.Parameter(torch.zeros(n_embed),requires_grad=False)
345-
self.embed_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embedding_dim),requires_grad=False)
346-
self.embed_avg.data.copy_(self.embedding.weight.data)
345+
self.embed_avg = nn.Parameter(torch.randn(self.n_embed, self.embedding_dim),requires_grad=False)
346+
347347
self.remap = remap
348348
if self.remap is not None:
349349
self.register_buffer("used", torch.tensor(np.load(self.remap)))

0 commit comments

Comments
 (0)