-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NHWC support for group normalization #126635
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126635
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7ad6380 with merge base 71f4915 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
auto cur_b = b[cur_sample * C + cur_channel]; | ||
Y[index] = (static_cast<T_ACC>(X[index]) + cur_b) * cur_a; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure how to achieve this behavior using tensor iterators and gpu_kernel
- so it may be the case that this is unneeded.
|
||
for (int64_t c = 0; c < group_channels; c++) { | ||
val = welford_op.reduce(val, static_cast<T_ACC>(X[index + c]), index + c); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kernel uses a different indexing strategy that will work for rNHWC tensors and uses welfords algorithm. Aside from that, the logic is very similar.
case MemoryFormat::ChannelsLast: { | ||
ApplyScaleBiasNHWCKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(X_data, Y_data, N, height, width, C, D*HxW, a_data, b_data); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know how to do this with a tensor iterator. IF someone could show me how to that would be great.
I intend on adding tests for this in |
Really looking forward to this! I've found that at least in some cases a naive implementation is actually faster than the existing native group norm kernel when using channels last format, and I'd love to see how much better a proper channels last kernel does. |
Interested to know the context? What GPU, architecture are you on? If possible could you give me a minimal reproducible example of a naive implementation outperforming native so I could take a look? |
This was on a 3090, using Stable Diffusion 1.5 in inference mode -- I'm not sure that it would be easy make a minimal reproducible example because I think it is at least partially dependent on having operations dispatched fairly far ahead of how fast they execute on the GPU. But to summarize, I first made sure that every |
Running the tests that failed before in |
break; | ||
} | ||
default: { | ||
break; // is this okay? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's unsupported we might want to explicitly raise an exception here
int num_blocks = | ||
(N * height * width * C + kCUDANumThreads - 1) / kCUDANumThreads; | ||
|
||
ApplyScaleBiasNHWCKernel<T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the previous TensorIterator
code just be adapted by adding a permute before the view
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean to use permute
to convert them to NCHW, and then convert them back to NHWC after it's done? Please correct me if i'm wrong, I apologize as i'm not that faimliar with getting these iterators to work.
Something like this?
auto X_permuted = X.permute({0, 3, 1, 2}); // N, C, H, W
auto Y_permuted = Y.permute({0, 3, 1, 2}); // N, C, H, W
TensorIterator iter =
TensorIteratorConfig()
.check_all_same_dtype(std::is_same<T, T_ACC>::value)
.resize_outputs(false)
.add_owned_output(Y_permuted.view({N * C, H * W}))
.add_owned_const_input(X_permuted.view({N * C, H * W}))
.add_owned_input(a.view({N * C, 1}))
.add_owned_input(b.view({N * C, 1}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
return a * static_cast<T_ACC>(x) + b;
});
Y = Y_permuted.permute({0, 2, 3, 1}); // N, H, W, C
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, something like that---I'm not sure if TensorIterator
is smart about getting the right (fast) memory access pattern for the permuted Tensor but it's worth an attempt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eqy So I get this with the above code
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Am I missing some steps here or is it just not straight forward to work with these iterators when the tensors are in NHWC?
(Somewhat offtopic) curious to know if these reused tensor iterator kernels have better compile times than hand-defined kernels? Or do they have clever dispatching and optimizations going on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reshape should be fine in this case as well (same semantics as what we intend).
I wouldn't say it's really about compile time in this case. TensorIterator
kernels save a lot of boilerplate code (imagine writing the same kernel every time for every pointwise OP variant but they also do a good job in achieving high B/W utilization for common-case workloads. There's a fair amount of optimization work that was done by e.g., @zasdfgbnm and if you look at the underlying kernels here they dispatch to there's the same optimizations that you would find in manually optimized kernels (vectorization, loop unrolling, etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I can't seem to get the above code to work with the reshape
either - 100% mismatched elements in the tests lol. I've tried quite a few different approaches but I can't seem to get it right...
I can spend some time optimizing this kernel if need be with some vectorizing, but I'll have to defer to you on whether or not that level of hand-tuned optimizations are less maintainable/desired at this layer of PyTorch.
Do you think you could help me get the iterator part right? Sorry for the trouble 😅
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed) | ||
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed) | ||
|
||
if device == 'cpu': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe leave a comment about channels_last_3d
not being supported on CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please back out the kineto
submodule change
Is there an easy way to achieve this? I'd like to avoid messing up and doing a bad rebase somehow leading to 100+ people getting notified Edit: I seem to have figured it out but i'd like to avoid running into this problem in the future. Could you outline me your workflow for git and what commands you run etc? Do you do an interactive add or do you run commands for hte submodules regularly? |
Fixes #111824
Currently it is the case that if the user specifies their group normalization to be of NHWC format, pytorch will default to NCHW tensors and convert. This conversion is not immediately obvious to the user unless they check the format themselves which is not intuitive. This PR adds suppor for NHWC for cuda by adding necessary kernels.
cc: @mikaylagawarecki