Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

saved_tensors_hooks auto delete custom attributes #126676

Closed
godweiyang opened this issue May 20, 2024 · 4 comments
Closed

saved_tensors_hooks auto delete custom attributes #126676

godweiyang opened this issue May 20, 2024 · 4 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@godweiyang
Copy link

godweiyang commented May 20, 2024

馃悰 Describe the bug

I have a custom attribute defined in the saved tensors. But after unpacking, the attribute is deleted by the hook. How can I obtain the attribute?

import torch

class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w):
        y = x @ w
        ctx.save_for_backward(x, w)
        return y

    def backward(ctx, grad_output):
        x, w = ctx.saved_tensors
        print(hasattr(x, "aaa"))  # the custom attribute is deleted
        return grad_output @ w.t(), x.t() @ grad_output

def pack_hook(tensor):
    tensor.aaa = 1
    return tensor
        
def unpack_hook(tensor):
    print(tensor.aaa)  # the custom attribute still exists before unpack returns
    return tensor

x = torch.randn(100, 200).bfloat16().cuda().requires_grad_()
w = torch.randn(200, 300).bfloat16().cuda().requires_grad_()
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = f.apply(x, w).sum()
    y.backward()

The result is:

1
1
False

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230815+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.4.143.bsk.7-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A800-SXM4-80GB
GPU 1: NVIDIA A800-SXM4-80GB
GPU 2: NVIDIA A800-SXM4-80GB
GPU 3: NVIDIA A800-SXM4-80GB
GPU 4: NVIDIA A800-SXM4-80GB
GPU 5: NVIDIA A800-SXM4-80GB
GPU 6: NVIDIA A800-SXM4-80GB
GPU 7: NVIDIA A800-SXM4-80GB

Nvidia driver version: 470.214
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 57 bits virtual
CPU(s): 120
On-line CPU(s) list: 0-119
Thread(s) per core: 2
Core(s) per socket: 30
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz
Stepping: 6
CPU MHz: 2294.608
BogoMIPS: 4589.21
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 2.8 MiB
L1i cache: 1.9 MiB
L2 cache: 75 MiB
L3 cache: 108 MiB
NUMA node0 CPU(s): 0-59
NUMA node1 CPU(s): 60-119
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid md_clear arch_capabilities

Versions of relevant libraries:
[pip3] byted-torch==2.1.0.dev20230815
[pip3] byted-torch-monitor==0.0.1
[pip3] numpy==1.24.4
[pip3] onnx==1.14.1
[pip3] qtorch==0.3.0
[pip3] torch==2.1.0.dev20230815+cu121
[pip3] torchaudio==2.1.0.dev20230815+cu121
[pip3] torchlibrosa==0.1.0
[pip3] torchvision==0.16.0.dev20230815+cu121
[conda] Could not collect

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

@janeyx99 janeyx99 added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 21, 2024
@albanD
Copy link
Collaborator

albanD commented May 22, 2024

@soulitzer I guess there is a detach() or something there that makes the two not the same Tensor object. Do you think we could remove it?

@soulitzer
Copy link
Contributor

soulitzer commented May 28, 2024

I'm not sure we can remove it. When we smash the autograd metadata (e.g. grad_fn) back onto whatever was returned by the unpack hook, we may not want to do that in-place - if the user observes that grad_fn of their tensor changes it may be unexpected.

@albanD
Copy link
Collaborator

albanD commented May 30, 2024

Ok!
Sounds like expected behavior then and a subclass should be used if you need a property that is preserved?
Alternatively, if you want a property being preserved accross different views of the same data, you can add your attribute to the tensor.untyped_storage().

@soulitzer
Copy link
Contributor

Closing as it is expected behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants