Skip to content

ENH: add BMM operation #6

@Xylambda

Description

@Xylambda
class BatchMatrixMultiplication(BinaryOp):
    def forward(self):
        data_a, data_b = self.get_value()
        # np.stack([a[i] @ b[i] for i in range(a.shape[0])])
        return Tensor(
            np.eisum('ijk, ikz -> ijz', tensor_a, tensor_b),
            parents=self.parents,
            is_leaf=False,
            track_gradient=self.track_gradient,
            parents=self.parents,
            is_leaf=False,
            op_name=self.__repr__(),
        )

    def backward(self, gradient = None):
        data_a, data_b = self.get_value()
        grad_np = gradient.numpy()
        grad_a = np.einsum('ijk, ikz -> ijz', grad_np, np.transpose(data_b.detach().numpy(), (0, 2, 1)))
        grad_b = np.einsum('ijk, ikz -> ijz', np.transpose(data_a, (0, 2, 1)), grad_np)
        self._set_gradients(Tensor(grad_a), Tensor(grad_b))

    def __repr__(self):
        return "BatchMatrixMultiplication(BinaryOp)"

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions