-
Notifications
You must be signed in to change notification settings - Fork 304
[do not review, test only]port to 25.4 RC3 #3076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
songbell
wants to merge
3
commits into
openvinotoolkit:releases/2025/4
Choose a base branch
from
songbell:bell/merge_eagle_cb_to_25_4
base: releases/2025/4
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,035
−147
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,27 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec | |
| return ss.str(); | ||
| } | ||
|
|
||
| enum HiddenStateFlags : uint8_t { | ||
| HS_NONE = 0, | ||
| HS_EXPORT = 1 << 0, | ||
| HS_IMPORT = 1 << 1, | ||
| HS_INTERNAL = 1 << 2 | ||
| }; | ||
|
|
||
| struct SequenceKey { | ||
| size_t request_id{}; | ||
| size_t grouped_sequence_id{}; | ||
| bool operator<(const SequenceKey& other) const { | ||
| return std::tie(request_id, grouped_sequence_id) < | ||
| std::tie(other.request_id, other.grouped_sequence_id); | ||
| } | ||
| }; | ||
|
|
||
| struct HiddenStateRange { | ||
| size_t start_token_idx{}; | ||
| size_t length{}; | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs, | ||
| * KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences. | ||
|
|
@@ -48,6 +69,10 @@ class ModelRunner { | |
| // Input shape: [N, conversation length]. | ||
| // Output shape: [1, conversation length, hidden_size]. | ||
| EmbeddingsModel::Ptr m_embedding; | ||
| uint8_t m_hidden_state_flags = HS_NONE; | ||
| std::map<SequenceKey, HiddenStateRange> m_sequence_hidden_state_mapping; | ||
| // a container which use sequence group id and request id as key to store hidden states | ||
| std::map<size_t, ov::Tensor> m_initial_hidden_states; // shape: [N, seq_len, hidden_size] | ||
|
|
||
| std::shared_ptr<InputsEmbedder> m_inputs_embedder; | ||
|
|
||
|
|
@@ -107,6 +132,10 @@ class ModelRunner { | |
| return m_request; | ||
| } | ||
|
|
||
| void enable_hidden_state_export(bool on) { on ? m_hidden_state_flags |= HS_EXPORT : m_hidden_state_flags &= ~HS_EXPORT; } | ||
| void enable_hidden_state_import(bool on) { on ? m_hidden_state_flags |= HS_IMPORT : m_hidden_state_flags &= ~HS_IMPORT; } | ||
| void enable_hidden_state_internal(bool on) { on ? m_hidden_state_flags |= HS_INTERNAL : m_hidden_state_flags &= ~HS_INTERNAL; } | ||
|
|
||
| void set_inputs_embedder(const std::shared_ptr<InputsEmbedder>& inputs_embedder) { | ||
| m_inputs_embedder = inputs_embedder; | ||
| m_embedding = inputs_embedder->get_embedding_model(); | ||
|
|
@@ -134,6 +163,9 @@ class ModelRunner { | |
| m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); | ||
| } | ||
|
|
||
| void set_initial_hidden_state(uint64_t request_id, const ov::Tensor& hidden_state) { | ||
| m_initial_hidden_states[request_id] = hidden_state; | ||
| } | ||
| /** | ||
| * Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences | ||
| * taking into account the supplied scheduler output struct. | ||
|
|
@@ -142,6 +174,7 @@ class ModelRunner { | |
| * @return An ov::Tensor with next-token logit scores for each sequence processed during this `forward` call. | ||
| */ | ||
| ov::Tensor forward(const std::vector<SequenceGroup::Ptr> & sequence_groups, const Scheduler::Output& scheduler_output) { | ||
| m_sequence_hidden_state_mapping.clear(); | ||
| size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); | ||
|
|
||
| size_t batch_size_in_sequences = 0; | ||
|
|
@@ -185,6 +218,12 @@ class ModelRunner { | |
| ov::Tensor score_aggregation_window = _get_or_resize_tensor(m_cached_score_aggregation_window, "score_aggregation_window", | ||
| {batch_size_in_sequences}, ov::element::i32); | ||
|
|
||
| ov::Tensor hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); | ||
| float* hidden_state_data = nullptr; | ||
| if (hidden_state_input) { | ||
| hidden_state_data = hidden_state_input.data<float>(); | ||
| } | ||
|
|
||
| ov::Tensor generated_ids_embeds; | ||
| float *generated_ids_embeds_data = nullptr; | ||
|
|
||
|
|
@@ -236,6 +275,7 @@ class ModelRunner { | |
|
|
||
| std::map<size_t, std::set<size_t>> seq_id_to_skipped_blocks_map; | ||
| size_t position_ids_idx = 0; | ||
| size_t current_token_idx = 0; | ||
| for (size_t i = 0; i < num_sequence_groups; ++i) { | ||
| size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; | ||
| SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; | ||
|
|
@@ -265,6 +305,64 @@ class ModelRunner { | |
|
|
||
| output_seq_len = 0; | ||
| Sequence::CPtr sequence = running_sequences[seq_idx]; | ||
| if (_is_hs_export()) { | ||
| size_t start_token_idx = current_token_idx; | ||
| size_t sequence_length = num_scheduled_tokens; | ||
|
|
||
| SequenceKey key{sequence_group->get_request_id(), sequence->get_grouped_id()}; | ||
| m_sequence_hidden_state_mapping[key] = HiddenStateRange{start_token_idx, sequence_length}; | ||
| } | ||
| if (_is_hs_import()) { | ||
| auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); | ||
|
|
||
| if (it != m_initial_hidden_states.end()) { | ||
| const auto& stored_hidden_state = it->second; | ||
|
|
||
| if (stored_hidden_state.get_size() > 0) { | ||
| auto stored_shape = stored_hidden_state.get_shape(); | ||
|
|
||
| if (stored_shape.size() >= 2) { | ||
| size_t stored_seq_len = stored_shape[0]; | ||
| size_t stored_hidden_size = stored_shape[stored_shape.size() - 1]; | ||
|
|
||
| if (stored_hidden_size == hidden_size) { | ||
| if (stored_seq_len == total_num_tokens) { | ||
| hidden_state_input = stored_hidden_state; // all tokens from eagle are accepted | ||
| } else { | ||
| size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens); | ||
|
|
||
| size_t source_start_idx = | ||
| stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; | ||
| _copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model"); | ||
| } | ||
| } | ||
| } else if (_is_hs_internal()) { | ||
| // fill hidden_state_data with m_hidden_states | ||
| if (hidden_state_data) { | ||
| OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode"); | ||
| std::memset(hidden_state_data + current_token_idx * hidden_size, | ||
| 0, | ||
| num_scheduled_tokens * hidden_size * sizeof(float)); | ||
| auto hidden_state = running_sequences[seq_idx]->get_hidden_state(); | ||
| if (hidden_state.get_size() > 0) { | ||
| auto shape = hidden_state.get_shape(); | ||
| if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { | ||
| size_t seq_len = shape[0]; | ||
| size_t copy_length = std::min(seq_len, num_scheduled_tokens); | ||
|
|
||
| size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0; | ||
| auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size}; | ||
| ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size); | ||
| _copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { | ||
| // compute token for current sequence | ||
| if (sequence_group_type == SequenceGroupType::TOKENS) { | ||
|
|
@@ -343,6 +441,7 @@ class ModelRunner { | |
| *score_aggregation_window_data = 1; | ||
| } | ||
| } | ||
| current_token_idx += num_scheduled_tokens; | ||
| past_lens_data += 1; | ||
| subsequence_begins_data += 1; | ||
| block_indices_begins_data += 1; | ||
|
|
@@ -367,6 +466,13 @@ class ModelRunner { | |
| m_request.set_tensor("token_type_ids", token_type_ids); | ||
| } | ||
| } | ||
| if (hidden_state_input && hidden_state_input.get_size() > 0) { | ||
| try { | ||
| m_request.set_tensor("hidden_states", hidden_state_input); | ||
| } catch (const ov::Exception& e) { | ||
| OPENVINO_THROW("Failed to set hidden states tensor: ", e.what()); | ||
| } | ||
| } | ||
| if (position_ids.get_shape().size() == 3) { | ||
| // flatten positions ids for 3D position ids case | ||
| position_ids.set_shape({ov::shape_size(position_ids.get_shape())}); | ||
|
|
@@ -423,7 +529,23 @@ class ModelRunner { | |
| } | ||
|
|
||
| _reset_cache_rotation_coefficients(); | ||
|
|
||
| if (_is_hs_export()) { | ||
| try { | ||
| m_hidden_states = m_request.get_tensor("last_hidden_state"); | ||
| for (size_t i = 0; i < num_sequence_groups; ++i) { | ||
| size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; | ||
| SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; | ||
| std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences(); | ||
| for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { | ||
| Sequence::Ptr sequence = running_sequences[seq_idx]; | ||
| sequence->update_hidden_state( | ||
| _get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); | ||
| } | ||
| } | ||
| } catch (const ov::Exception&) { | ||
| m_hidden_states = ov::Tensor(); | ||
| } | ||
| } | ||
| // return logits | ||
| return m_request.get_tensor("logits"); | ||
| } | ||
|
|
@@ -490,6 +612,112 @@ class ModelRunner { | |
| } | ||
|
|
||
| private: | ||
| ov::Tensor m_hidden_states; | ||
|
|
||
| // Hidden state flags and helpers | ||
| bool _is_hs_export() const { return m_hidden_state_flags & HS_EXPORT; } | ||
| bool _is_hs_import() const { return m_hidden_state_flags & HS_IMPORT; } | ||
| bool _is_hs_internal() const { return m_hidden_state_flags & HS_INTERNAL; } | ||
|
|
||
| ov::Tensor _get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { | ||
| if (m_hidden_states.get_size() == 0) { | ||
| return ov::Tensor(); | ||
| } | ||
|
|
||
| SequenceKey key{request_id, seq_grouped_id}; | ||
| const auto it = m_sequence_hidden_state_mapping.find(key); | ||
| if (it == m_sequence_hidden_state_mapping.end()) { | ||
| return ov::Tensor(); | ||
| } | ||
|
|
||
| size_t start_idx = it->second.start_token_idx; | ||
| size_t length = it->second.length; | ||
|
|
||
| auto shape = m_hidden_states.get_shape(); | ||
| if (shape.size() < 2) { | ||
| return ov::Tensor(); | ||
| } | ||
|
|
||
| ov::Coordinate start_coord(shape.size(), 0); | ||
| ov::Coordinate end_coord(shape.size(), 0); | ||
|
|
||
| start_coord[0] = start_idx; | ||
| end_coord[0] = start_idx + length; | ||
|
|
||
| for (size_t i = 1; i < shape.size(); ++i) { | ||
| start_coord[i] = 0; | ||
| end_coord[i] = shape[i]; | ||
| } | ||
|
|
||
| return ov::Tensor(m_hidden_states, start_coord, end_coord); | ||
| } | ||
|
|
||
| ov::Tensor _prepare_hidden_state_input(size_t total_num_tokens, | ||
| size_t& hidden_size /*in/out*/) { | ||
| if (!(m_hidden_state_flags & (HS_IMPORT | HS_INTERNAL))) { | ||
| return {}; | ||
| } | ||
|
|
||
| if (hidden_size == 0) { | ||
| for (const auto& kv : m_initial_hidden_states) { | ||
| const auto& initial_hidden_states = kv.second; | ||
| if (initial_hidden_states && initial_hidden_states.get_shape().size() >= 2) { | ||
| auto hidden_states_shape = initial_hidden_states.get_shape(); | ||
| hidden_size = hidden_states_shape.back(); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| if (hidden_size == 0) { | ||
| return {}; | ||
| } | ||
|
|
||
| ov::Tensor hs(ov::element::f32, {total_num_tokens, 1, hidden_size}); | ||
| std::memset(hs.data<float>(), 0, hs.get_byte_size()); | ||
| return hs; | ||
| } | ||
|
|
||
| // Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors. | ||
| // src_start_idx: start index along src first dimension | ||
| // copy_length: number of elements along first dim to copy | ||
| // dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer) | ||
| // dst_first_dim_start: start index in first dimension of dst_base where copy should be placed | ||
| static void _copy_roi_between_tensors(const ov::Tensor& src, | ||
| size_t src_start_idx, | ||
| size_t copy_length, | ||
| const ov::Tensor& dst_base, | ||
| size_t dst_first_dim_start) { | ||
|
Comment on lines
+685
to
+689
|
||
| if (copy_length == 0) { | ||
| return; | ||
| } | ||
|
|
||
| // lambda to create ROI coordinates | ||
| auto make_roi = [](const std::vector<size_t>& shape, size_t first_dim_start, size_t first_dim_end) { | ||
| ov::Coordinate start(shape.size(), 0), end(shape.size(), 0); | ||
| start[0] = first_dim_start; | ||
| end[0] = first_dim_end; | ||
| for (size_t d = 1; d < shape.size(); ++d) { | ||
| start[d] = 0; | ||
| end[d] = shape[d]; | ||
| } | ||
| return std::make_pair(start, end); | ||
| }; | ||
|
|
||
| // prepare source ROI coords | ||
| const auto src_shape = src.get_shape(); | ||
| OPENVINO_ASSERT(!src_shape.empty(), "source tensor rank is zero"); | ||
| auto [src_start, src_end] = make_roi(src_shape, src_start_idx, src_start_idx + copy_length); | ||
| ov::Tensor src_roi(src, src_start, src_end); | ||
|
|
||
| // prepare destination ROI coords | ||
| const auto dst_shape = dst_base.get_shape(); | ||
| OPENVINO_ASSERT(!dst_shape.empty(), "destination tensor rank is zero"); | ||
| auto [tgt_start, tgt_end] = make_roi(dst_shape, dst_first_dim_start, dst_first_dim_start + copy_length); | ||
| ov::Tensor tgt_roi(dst_base, tgt_start, tgt_end); | ||
|
|
||
| // bulk copy | ||
| src_roi.copy_to(tgt_roi); | ||
| } | ||
| ov::Tensor _get_or_resize_tensor(ov::Tensor& cached_tensor, | ||
| const std::string& tensor_name, | ||
| const ov::Shape& required_shape, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
HiddenStateFlagsenum lacks documentation. Add comments explaining what each flag controls in the context of Eagle3 decoding.