-
Notifications
You must be signed in to change notification settings - Fork 14
feat: add 1F1B schedule #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
496bbfd to
7108a12
Compare
3726518 to
9af4751
Compare
9af4751 to
a413a6e
Compare
| float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); | ||
| LOG(INFO) << "=== Schedule Table ==="; | ||
| LOG(INFO) << "n=" << n << ", stages=" << num_stages << ", vpp=" << vpp_size | ||
| << ", total_chunks=" << total_global_chunks; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了增加可读性,用 format 拼字符串吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
| struct StageInfo { | ||
| bool is_first_stage; | ||
| bool is_last_stage; | ||
| std::vector<std::pair<int, int>> layer_chunks; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个注释说明一下这个 vector 里存的是每个 chunk 包含的 layer 的起始位置吧,以及建议改个更直观的名字,比如 chunk_layer_ranges。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/llama3/net.cc
Outdated
|
|
||
| std::vector<std::shared_ptr<nn::Module>> chunk_blocks; | ||
| int current_index = 0; | ||
| for (auto it = h_layers->begin(); it != h_layers->end(); ++it, ++current_index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
给 ModuleList 类型重载一个索引操作吧,这里直接用索引获取对应 layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/gpt2/net.h
Outdated
| Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override; | ||
| }; | ||
|
|
||
| class GPT2Chunk { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPT2Chunk 继承 Module,
class GPT2Chunk : public Module {
public:
GPT2Chunk(
GPT2* parent,
int layer_begin,
int layer_end,
bool has_embedding,
bool has_lm_head
);
std::vector<std::shared_ptr<Tensor>>
Forward(const std::vector<std::shared_ptr<Tensor>>& x) override;
private:
GPT2* parent_ = nullptr;
int layer_begin_ = 0;
int layer_end_ = 0;
bool has_embedding_ = false;
bool has_lm_head_ = false;
};There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个定义其实对 Transformer 结构都类似,感觉可以提出来成 class TransformerChunk : public Module,然后把定义放到 pp 的文件夹,在各自的 net.cc 里面再定义一个 class GPT2 : public TransformerChunk,仅需要 override 一下 Forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
| } | ||
|
|
||
| std::tuple<bool, bool, int, int> PipelineParallel::GetStageInfo(int total_layers, int pp_size, int pp_rank) { | ||
| StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int chunks_per_stage) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pp_rank 还是从 net.cc 传进来,尽量控制 thread_local 变量使用的范围。
使用 thread_local 变量存储 pp_rank/tp_rank 的写法只是一个临时方案,因为线程如果再起子线程不会继承这些变量,所以这是一个不安全的方式,框架里其他地方都尽可能使用 device 里存储的 rank 数据结构获取这些信息,但是在模型初始化的地方 device 尚未创建,所以仅在此处这样做;为了规避 thread_local 大量存在带来的不安全性,后续我们需要开发线程池接管框架所有新起的线程,统一管理把这些 thread_local 变量继承给需要的线程。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/gpt2/net.h
Outdated
| std::vector<std::shared_ptr<infini_train::Tensor>> | ||
| Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override; | ||
|
|
||
| void BuildChunks(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BuildChunks 返回 stage 切分后得到的所有 chunk,在构造 PipelineStage 时调用 module 的 BuildChunks 方法,将所有 chunk 存在 PipelineStage 里。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
| return model_->Forward(inputs); | ||
| std::vector<std::shared_ptr<Tensor>> PipelineStage::ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs, | ||
| int local_chunk_idx) { | ||
| return model_->ForwardChunk(local_chunk_idx, inputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接通过 local_chunk_idx 索引获取 stage 存储的 chunk,调用 chunk 的 Forward 方法。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/gpt2/net.h
Outdated
| Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override; | ||
| }; | ||
|
|
||
| class GPT2Chunk { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个定义其实对 Transformer 结构都类似,感觉可以提出来成 class TransformerChunk : public Module,然后把定义放到 pp 的文件夹,在各自的 net.cc 里面再定义一个 class GPT2 : public TransformerChunk,仅需要 override 一下 Forward
example/gpt2/net.cc
Outdated
| std::vector<std::shared_ptr<infini_train::Tensor>> | ||
| GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) { | ||
| int pp_rank = nn::parallel::pp_rank; | ||
| void GPT2::BuildChunks() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
针对 Transformer 模型的话,BuildChunks 也可以合并,gpt2/llama 仅是一个 pos_emb 的区别,加个 if 判断就可以
0f5628b to
aeb8ee0
Compare
| if (tp_world_size > 1) { | ||
| auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( | ||
| nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); | ||
| tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多机时需要用 global rank 获取通信组
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/gpt2/net.cc
Outdated
| int tp_rank = 0; | ||
| if (tp_world_size > 1) { | ||
| auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( | ||
| nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GlobalRank, 这个文件里其他地方也是,除了 main.cc 里需要传递 device_id 时用 thread_rank,其他地方需要获取通信组时都要传 GlobalRank
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
example/gpt2/net.cc
Outdated
| auto [is_first_stage, is_last_stage, layer_chunks] | ||
| = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, vpp_size); | ||
| // ========== layer to chunk ========== | ||
| std::unordered_map<int, bool> owned_layers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉没必要用 map,用 vector 就行,查起来还更快
std::vector owned_layers(n_layer, false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
f8b086c to
c22da40
Compare
c22da40 to
213a164
Compare


No description provided.