Skip to content

[BUG] Cannot factorize with repeated tensors in expression #246

@yhtang

Description

@yhtang

Describe the bug
If the same tensor shows up more than once in a tensor expression, then it cannot be correctly factorized. Related to #210, #228, #232.

To Reproduce

n = 2
V = ff.tensor('V', n, n)
i, j, k = ff.indices('i, j, k')
tsrex_eigh = V[i, j] * V[j, k]
fac = ff.factorize(
    tsrex_eigh,
    np.eye(n, dtype=np.float32)
)

gets

> fac()
tensor([[ 0.0072,  0.0116],
        [-0.0031,  0.0020]], grad_fn=<ViewBackward>)

> fac.tsrex.asciitree('data')
 ein: sum:multiply 
 ├── indexed_tensor: [i,j] 
 │   ├── tensor: V (data: tensor([[ 1.0021,  0.0717],        [-0.0569, -1.0022]])) 
 │   ╰── indices: i,j 
 │       ├── index: i 
 │       ╰── index: j 
 ╰── indexed_tensor: [j,k] 
     ├── tensor: V (data: tensor([[ 0.0070,  0.0118],        [ 0.0027, -0.0027]], requires_grad=True)) 
     ╰── indices: j,k 
         ├── index: j 
         ╰── index: k 

Expected behavior
A trivial solution should be found easily.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpriority

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions