Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelining] add back support for multi-use parameters/buffers #126653

Closed
wants to merge 1 commit into from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented May 19, 2024

Stack from ghstack (oldest at bottom):

Motivation

Resolves #126626 to support TorchTitan.

With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet:

for layer in self.layers.values():
    h = layer(h, self.freqs_cis)

Solution

Step 1:
Remove the previous guards of if len(node.users) == 1.
Step 2:
Call move_param_to_callee multiple times, one for each stage ("callee").
Step 3:
Delay deletion of the get_attr node (for getting the param) from root till this param has been sunk into each stage that uses it.

The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only).

Test

Changed the ExampleCode model to use mm_param1 in multiple stages.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented May 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126653

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 89de9ec with merge base 5ea956a (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 19, 2024
kwen2501 added a commit that referenced this pull request May 19, 2024
Resolves #126626

ghstack-source-id: 35a3783f260d57972079289291f4ce827584d037
Pull Request resolved: #126653
@kwen2501 kwen2501 requested review from wconstab and H-Huang May 20, 2024 16:56
@wconstab
Copy link
Contributor

Which titan issue was this addressing? something with freqs_cis?

@kwen2501
Copy link
Contributor Author

See #126626. I filed it against pytorch rather than titan.
But yeah, it is wrt this code block in titan:

for layer in self.layers.values():
    h = layer(h, self.freqs_cis)

freqs_cis will be used in multiple stages once we cut the model by group of layers.

self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x, y):
x = torch.mm(x, self.mm_param0)
x = torch.mm(x, self.mm_param1) # mutli-use param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo. (again, typo below)

logger.info(
f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
)
for user in node.users:
assert user.op == "call_module"
# Move parameter into submodule
move_param_to_callee(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this affect the fqn of the shared parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This PR targets parameters (single FQN) used by multiple stages once the original model is split.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pianpwk 's PR targets the tied parameter case (aliasing):
#127094

skip_connection = x
x = x + y
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = torch.mm(x, self.mm_param1) # mutli-use param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have tests that verify fqn sanity (perhaps you added them along with unflattener)?

it'd be nice to confirm that when using multi-use param, the model's state_dict is clean and only has the original copy so checkpoint save/load will work as expected.

Copy link
Contributor Author

@kwen2501 kwen2501 May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, we don't have support for multi-use param in training yet. Because that would require an all-reduce between the multiple copies of that param, before the next batch forward happens. So, it would be kind of early to talk about how to save them before we can train them :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, multi-use buffer (as in titan case) and multi-user param in inference are different stories, they can be supported today.

@wconstab
Copy link
Contributor

I pulled this PR to see if it helps run torchtitan with tracer. It does get further, no longer error during tracing, so presumably the freqs_cis thing is worked out.

But there is still a tracer issue with applying TP/DP iterating the transformer layers.

image

@wconstab
Copy link
Contributor

i checked the fqns and they look correct to me. So i think this PR is good to land based on fixing the immediate issue with freqs_cis. however will need to do more work to verify e2e
image

callee = root.get_submodule(callee_name)
assert not hasattr(
callee, param_fqn
), f"Module {callee_name} already has a parameter named {param_fqn}"

# Assign the parameter to the submodule
if is_buffer:
_assign_attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im kinda confused though, how come we can assign the attr to a submodule and not cause fqn duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving the attr to the submodule.
The original attr will be removed IIRC.

@pianpwk pianpwk self-requested a review May 28, 2024 18:46
Copy link
Contributor

@pianpwk pianpwk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes make sense, and stacked tests seem to work well

@kwen2501
Copy link
Contributor Author

I pulled this PR to see if it helps run torchtitan with tracer. It does get further, no longer error during tracing, so presumably the freqs_cis thing is worked out.

But there is still a tracer issue with applying TP/DP iterating the transformer layers.

image

Thanks for checking.
The error you see is basically saying:
"I want a ModuleDict after split to be still a ModuleDict, and I want .items() to still work on it."
But that is currently not in pippy's contract -- what's broken is broken.
User code needs change to support all cases, e.g. .items() --> .children().

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 28, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

kwen2501 added a commit to pytorch/torchtitan that referenced this pull request May 29, 2024
…dules"


This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)):
"Module object has no attributed items."

The reason is, a split `ModuleDict` is no longer a `ModuleDict`. (Future support is not guaranteed.)

It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules.

[ghstack-poisoned]
kwen2501 added a commit to pytorch/torchtitan that referenced this pull request May 29, 2024
This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)):
"Module object has no attributed items."

The reason is, a split `ModuleDict` is no longer a `ModuleDict`. (Future support is not guaranteed.)

It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules.

[ghstack-poisoned]
@kwen2501 kwen2501 added the release notes: distributed (pipeline) release notes category label May 29, 2024
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kwen2501 added a commit to pytorch/torchtitan that referenced this pull request May 29, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #362
* __->__ #371

This PR fixes the issue mentioned
[here](pytorch/pytorch#126653 (comment)):
"Module object has no attributed items."

The reason is, a split `ModuleDict` is no longer a `ModuleDict`.

It would be more generally applicable if we use `named_children()` and
`register_module()` to access and update submodules.
Aidyn-A pushed a commit to tinglvv/pytorch that referenced this pull request May 30, 2024
…ch#126653)

## Motivation
Resolves pytorch#126626 to support TorchTitan.

With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet:
```
for layer in self.layers.values():
    h = layer(h, self.freqs_cis)
```

## Solution
Step 1:
Remove the previous guards of `if len(node.users) == 1`.
Step 2:
Call `move_param_to_callee` multiple times, one for each stage ("callee").
Step 3:
Delay deletion of the `get_attr` node (for getting the param) from root till this param has been sunk into each stage that uses it.

The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only).

## Test
Changed the `ExampleCode` model to use `mm_param1` in multiple stages.

Pull Request resolved: pytorch#126653
Approved by: https://github.com/pianpwk
kwen2501 added a commit to pytorch/torchtitan that referenced this pull request Jun 1, 2024
…dules"


This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)):
"Module object has no attributed items."

The reason is, a split `ModuleDict` is no longer a `ModuleDict`.

It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules.

[ghstack-poisoned]
kwen2501 added a commit to pytorch/torchtitan that referenced this pull request Jun 1, 2024
This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)):
"Module object has no attributed items."

The reason is, a split `ModuleDict` is no longer a `ModuleDict`.

It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants