Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,18 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
class ContinuousBatchingImpl;

class ContinuousBatchingForSpeculativeDecodingImpl;
class ContinuousBatchingForEagle3DecodingImpl;
class ContinuousBatchingForPromptLookupImpl;
class SpeculativeDecodingImpl;
class Eagle3DecodingImpl;
class PromptLookupImpl;

friend class ContinuousBatchingForSpeculativeDecodingImpl;

friend class ContinuousBatchingForPromptLookupImpl;
friend class ContinuousBatchingForEagle3DecodingImpl;
friend class SpeculativeDecodingImpl;
friend class Eagle3DecodingImpl;
friend class PromptLookupImpl;

std::shared_ptr<IContinuousBatchingPipeline> m_impl;
Expand Down
230 changes: 229 additions & 1 deletion src/cpp/src/continuous_batching/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec
return ss.str();
}

enum HiddenStateFlags : uint8_t {
HS_NONE = 0,
HS_EXPORT = 1 << 0,
HS_IMPORT = 1 << 1,
HS_INTERNAL = 1 << 2
Comment on lines +27 to +31
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HiddenStateFlags enum lacks documentation. Add comments explaining what each flag controls in the context of Eagle3 decoding.

Suggested change
enum HiddenStateFlags : uint8_t {
HS_NONE = 0,
HS_EXPORT = 1 << 0,
HS_IMPORT = 1 << 1,
HS_INTERNAL = 1 << 2
/**
* @brief Flags to control hidden state handling in Eagle3 decoding.
* HS_NONE: No special handling of hidden states.
* HS_EXPORT: Export hidden states after decoding for use outside the model (e.g., for caching or inspection).
* HS_IMPORT: Import hidden states before decoding, allowing the model to resume from a previous state.
* HS_INTERNAL: Use hidden states only internally within the model runner; not exposed for import/export.
*/
enum HiddenStateFlags : uint8_t {
HS_NONE = 0, ///< No special handling of hidden states.
HS_EXPORT = 1 << 0, ///< Export hidden states after decoding.
HS_IMPORT = 1 << 1, ///< Import hidden states before decoding.
HS_INTERNAL = 1 << 2 ///< Use hidden states only internally.

Copilot uses AI. Check for mistakes.
};

struct SequenceKey {
size_t request_id{};
size_t grouped_sequence_id{};
bool operator<(const SequenceKey& other) const {
return std::tie(request_id, grouped_sequence_id) <
std::tie(other.request_id, other.grouped_sequence_id);
}
};

struct HiddenStateRange {
size_t start_token_idx{};
size_t length{};
};

/**
* @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs,
* KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences.
Expand All @@ -48,6 +69,10 @@ class ModelRunner {
// Input shape: [N, conversation length].
// Output shape: [1, conversation length, hidden_size].
EmbeddingsModel::Ptr m_embedding;
uint8_t m_hidden_state_flags = HS_NONE;
std::map<SequenceKey, HiddenStateRange> m_sequence_hidden_state_mapping;
// a container which use sequence group id and request id as key to store hidden states
std::map<size_t, ov::Tensor> m_initial_hidden_states; // shape: [N, seq_len, hidden_size]

std::shared_ptr<InputsEmbedder> m_inputs_embedder;

Expand Down Expand Up @@ -107,6 +132,10 @@ class ModelRunner {
return m_request;
}

void enable_hidden_state_export(bool on) { on ? m_hidden_state_flags |= HS_EXPORT : m_hidden_state_flags &= ~HS_EXPORT; }
void enable_hidden_state_import(bool on) { on ? m_hidden_state_flags |= HS_IMPORT : m_hidden_state_flags &= ~HS_IMPORT; }
void enable_hidden_state_internal(bool on) { on ? m_hidden_state_flags |= HS_INTERNAL : m_hidden_state_flags &= ~HS_INTERNAL; }

void set_inputs_embedder(const std::shared_ptr<InputsEmbedder>& inputs_embedder) {
m_inputs_embedder = inputs_embedder;
m_embedding = inputs_embedder->get_embedding_model();
Expand Down Expand Up @@ -134,6 +163,9 @@ class ModelRunner {
m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer);
}

void set_initial_hidden_state(uint64_t request_id, const ov::Tensor& hidden_state) {
m_initial_hidden_states[request_id] = hidden_state;
}
/**
* Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences
* taking into account the supplied scheduler output struct.
Expand All @@ -142,6 +174,7 @@ class ModelRunner {
* @return An ov::Tensor with next-token logit scores for each sequence processed during this `forward` call.
*/
ov::Tensor forward(const std::vector<SequenceGroup::Ptr> & sequence_groups, const Scheduler::Output& scheduler_output) {
m_sequence_hidden_state_mapping.clear();
size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size();

size_t batch_size_in_sequences = 0;
Expand Down Expand Up @@ -185,6 +218,12 @@ class ModelRunner {
ov::Tensor score_aggregation_window = _get_or_resize_tensor(m_cached_score_aggregation_window, "score_aggregation_window",
{batch_size_in_sequences}, ov::element::i32);

ov::Tensor hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size);
float* hidden_state_data = nullptr;
if (hidden_state_input) {
hidden_state_data = hidden_state_input.data<float>();
}

ov::Tensor generated_ids_embeds;
float *generated_ids_embeds_data = nullptr;

Expand Down Expand Up @@ -236,6 +275,7 @@ class ModelRunner {

std::map<size_t, std::set<size_t>> seq_id_to_skipped_blocks_map;
size_t position_ids_idx = 0;
size_t current_token_idx = 0;
for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
Expand Down Expand Up @@ -265,6 +305,64 @@ class ModelRunner {

output_seq_len = 0;
Sequence::CPtr sequence = running_sequences[seq_idx];
if (_is_hs_export()) {
size_t start_token_idx = current_token_idx;
size_t sequence_length = num_scheduled_tokens;

SequenceKey key{sequence_group->get_request_id(), sequence->get_grouped_id()};
m_sequence_hidden_state_mapping[key] = HiddenStateRange{start_token_idx, sequence_length};
}
if (_is_hs_import()) {
auto it = m_initial_hidden_states.find(sequence_group->get_request_id());

if (it != m_initial_hidden_states.end()) {
const auto& stored_hidden_state = it->second;

if (stored_hidden_state.get_size() > 0) {
auto stored_shape = stored_hidden_state.get_shape();

if (stored_shape.size() >= 2) {
size_t stored_seq_len = stored_shape[0];
size_t stored_hidden_size = stored_shape[stored_shape.size() - 1];

if (stored_hidden_size == hidden_size) {
if (stored_seq_len == total_num_tokens) {
hidden_state_input = stored_hidden_state; // all tokens from eagle are accepted
} else {
size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens);

size_t source_start_idx =
stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0;
_copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx);
}
}
}
} else {
OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model");
}
}
} else if (_is_hs_internal()) {
// fill hidden_state_data with m_hidden_states
if (hidden_state_data) {
OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode");
std::memset(hidden_state_data + current_token_idx * hidden_size,
0,
num_scheduled_tokens * hidden_size * sizeof(float));
auto hidden_state = running_sequences[seq_idx]->get_hidden_state();
if (hidden_state.get_size() > 0) {
auto shape = hidden_state.get_shape();
if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) {
size_t seq_len = shape[0];
size_t copy_length = std::min(seq_len, num_scheduled_tokens);

size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0;
auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size};
ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size);
_copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0);
}
}
}
}
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
if (sequence_group_type == SequenceGroupType::TOKENS) {
Expand Down Expand Up @@ -343,6 +441,7 @@ class ModelRunner {
*score_aggregation_window_data = 1;
}
}
current_token_idx += num_scheduled_tokens;
past_lens_data += 1;
subsequence_begins_data += 1;
block_indices_begins_data += 1;
Expand All @@ -367,6 +466,13 @@ class ModelRunner {
m_request.set_tensor("token_type_ids", token_type_ids);
}
}
if (hidden_state_input && hidden_state_input.get_size() > 0) {
try {
m_request.set_tensor("hidden_states", hidden_state_input);
} catch (const ov::Exception& e) {
OPENVINO_THROW("Failed to set hidden states tensor: ", e.what());
}
}
if (position_ids.get_shape().size() == 3) {
// flatten positions ids for 3D position ids case
position_ids.set_shape({ov::shape_size(position_ids.get_shape())});
Expand Down Expand Up @@ -423,7 +529,23 @@ class ModelRunner {
}

_reset_cache_rotation_coefficients();

if (_is_hs_export()) {
try {
m_hidden_states = m_request.get_tensor("last_hidden_state");
for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) {
Sequence::Ptr sequence = running_sequences[seq_idx];
sequence->update_hidden_state(
_get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id()));
}
}
} catch (const ov::Exception&) {
m_hidden_states = ov::Tensor();
}
}
// return logits
return m_request.get_tensor("logits");
}
Expand Down Expand Up @@ -490,6 +612,112 @@ class ModelRunner {
}

private:
ov::Tensor m_hidden_states;

// Hidden state flags and helpers
bool _is_hs_export() const { return m_hidden_state_flags & HS_EXPORT; }
bool _is_hs_import() const { return m_hidden_state_flags & HS_IMPORT; }
bool _is_hs_internal() const { return m_hidden_state_flags & HS_INTERNAL; }

ov::Tensor _get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const {
if (m_hidden_states.get_size() == 0) {
return ov::Tensor();
}

SequenceKey key{request_id, seq_grouped_id};
const auto it = m_sequence_hidden_state_mapping.find(key);
if (it == m_sequence_hidden_state_mapping.end()) {
return ov::Tensor();
}

size_t start_idx = it->second.start_token_idx;
size_t length = it->second.length;

auto shape = m_hidden_states.get_shape();
if (shape.size() < 2) {
return ov::Tensor();
}

ov::Coordinate start_coord(shape.size(), 0);
ov::Coordinate end_coord(shape.size(), 0);

start_coord[0] = start_idx;
end_coord[0] = start_idx + length;

for (size_t i = 1; i < shape.size(); ++i) {
start_coord[i] = 0;
end_coord[i] = shape[i];
}

return ov::Tensor(m_hidden_states, start_coord, end_coord);
}

ov::Tensor _prepare_hidden_state_input(size_t total_num_tokens,
size_t& hidden_size /*in/out*/) {
if (!(m_hidden_state_flags & (HS_IMPORT | HS_INTERNAL))) {
return {};
}

if (hidden_size == 0) {
for (const auto& kv : m_initial_hidden_states) {
const auto& initial_hidden_states = kv.second;
if (initial_hidden_states && initial_hidden_states.get_shape().size() >= 2) {
auto hidden_states_shape = initial_hidden_states.get_shape();
hidden_size = hidden_states_shape.back();
break;
}
}
}
if (hidden_size == 0) {
return {};
}

ov::Tensor hs(ov::element::f32, {total_num_tokens, 1, hidden_size});
std::memset(hs.data<float>(), 0, hs.get_byte_size());
return hs;
}

// Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors.
// src_start_idx: start index along src first dimension
// copy_length: number of elements along first dim to copy
// dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer)
// dst_first_dim_start: start index in first dimension of dst_base where copy should be placed
static void _copy_roi_between_tensors(const ov::Tensor& src,
size_t src_start_idx,
size_t copy_length,
const ov::Tensor& dst_base,
size_t dst_first_dim_start) {
Comment on lines +685 to +689
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _copy_roi_between_tensors helper method lacks documentation. Add a doc comment explaining its purpose and parameters.

Copilot uses AI. Check for mistakes.
if (copy_length == 0) {
return;
}

// lambda to create ROI coordinates
auto make_roi = [](const std::vector<size_t>& shape, size_t first_dim_start, size_t first_dim_end) {
ov::Coordinate start(shape.size(), 0), end(shape.size(), 0);
start[0] = first_dim_start;
end[0] = first_dim_end;
for (size_t d = 1; d < shape.size(); ++d) {
start[d] = 0;
end[d] = shape[d];
}
return std::make_pair(start, end);
};

// prepare source ROI coords
const auto src_shape = src.get_shape();
OPENVINO_ASSERT(!src_shape.empty(), "source tensor rank is zero");
auto [src_start, src_end] = make_roi(src_shape, src_start_idx, src_start_idx + copy_length);
ov::Tensor src_roi(src, src_start, src_end);

// prepare destination ROI coords
const auto dst_shape = dst_base.get_shape();
OPENVINO_ASSERT(!dst_shape.empty(), "destination tensor rank is zero");
auto [tgt_start, tgt_end] = make_roi(dst_shape, dst_first_dim_start, dst_first_dim_start + copy_length);
ov::Tensor tgt_roi(dst_base, tgt_start, tgt_end);

// bulk copy
src_roi.copy_to(tgt_roi);
}
ov::Tensor _get_or_resize_tensor(ov::Tensor& cached_tensor,
const std::string& tensor_name,
const ov::Shape& required_shape,
Expand Down
Loading
Loading