This repository was archived by the owner on May 13, 2025. It is now read-only.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Types of changes
Motivation and Context / Related issue
At present, it is not possible to do things like
torch_tensor.add(cryptensor)ortorch_tensor + cryptensor. The problem is that functions like__radd__never get called becausetorch.Tensor.addfails with aTypeErrorrather than aNotImplementedError(which would trigger the reverse function to get called). This limitation leads to issues such as #403This PR fixes this issue for the
add,sub, andmulfunctions. The general approach is as follows:torch.Tensor.{add,sub,mul}in the__torch_function__handler via an@implementsdecorator.__init_subclass__function inCrypTensorthat ensures these decorators are inherited by subclasses ofCrypTensor.MPCTensordynamically adds functions likeadd,sub, andmulafter the subclass is created, the registration is also done manually for those functions inMPCTensor.MPCTensor.binary_wrapper_functionassumes specific structure ofMPCTensorthattorch.Tensordoes not have, we switch the order of the arguments if needed and alter the function name to be__radd__,__rsub__, etc.Note that it is not immediately clear how to make the same work for other functions like
matmulthat do not have an__rmatmul__or for functions that do not exist in PyTorch likeconv1d. It can be done but things will get pretty messy. So the question with this PR is if this is a path we want to continue on.How Has This Been Tested
This PR is currently an RFC so I have not deeply tested all the changes yet. I would first like to get feedback on whether we want to make this change at all.
That said, these simple examples pass:
Similarly, the example from #403 passes:
If we want to proceed in this direction, I will add full unit tests.
Checklist