Skip to content

Commit 64a2619

Browse files
committed
refactor(uffd): update UffdPfHandler
Rename `UffdPfHandler` -> `UffdHandler` Replace direct access to `uffd` field with separate method. Rearrange couple functions/struct positions in the file. Signed-off-by: Egor Lazarchuk <yegorlz@amazon.co.uk>
1 parent 276e6e4 commit 64a2619

File tree

3 files changed

+53
-51
lines changed

3 files changed

+53
-51
lines changed

src/firecracker/examples/uffd/malicious_handler.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod uffd_utils;
99
use std::fs::File;
1010
use std::os::unix::net::UnixListener;
1111

12-
use uffd_utils::{Runtime, UffdPfHandler};
12+
use uffd_utils::{Runtime, UffdHandler};
1313

1414
fn main() {
1515
let mut args = std::env::args();
@@ -23,10 +23,9 @@ fn main() {
2323
let (stream, _) = listener.accept().expect("Cannot listen on UDS socket");
2424

2525
let mut runtime = Runtime::new(stream, file);
26-
runtime.run(|uffd_handler: &mut UffdPfHandler| {
26+
runtime.run(|uffd_handler: &mut UffdHandler| {
2727
// Read an event from the userfaultfd.
2828
let event = uffd_handler
29-
.uffd
3029
.read_event()
3130
.expect("Failed to read uffd_msg")
3231
.expect("uffd_msg not ready");

src/firecracker/examples/uffd/uffd_utils.rs

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::os::unix::net::UnixStream;
1111
use std::ptr;
1212

1313
use serde::Deserialize;
14-
use userfaultfd::Uffd;
14+
use userfaultfd::{Error, Event, Uffd};
1515
use utils::get_page_size;
1616
use utils::sock_ctrl_msg::ScmSocket;
1717

@@ -33,28 +33,28 @@ pub struct GuestRegionUffdMapping {
3333
pub offset: u64,
3434
}
3535

36+
#[derive(Debug, Clone)]
37+
pub enum MemPageState {
38+
Uninitialized,
39+
FromFile,
40+
Removed,
41+
Anonymous,
42+
}
43+
3644
#[derive(Debug)]
3745
struct MemRegion {
3846
mapping: GuestRegionUffdMapping,
3947
page_states: HashMap<u64, MemPageState>,
4048
}
4149

4250
#[derive(Debug)]
43-
pub struct UffdPfHandler {
51+
pub struct UffdHandler {
4452
mem_regions: Vec<MemRegion>,
4553
backing_buffer: *const u8,
46-
pub uffd: Uffd,
47-
}
48-
49-
#[derive(Debug, Clone)]
50-
pub enum MemPageState {
51-
Uninitialized,
52-
FromFile,
53-
Removed,
54-
Anonymous,
54+
uffd: Uffd,
5555
}
5656

57-
impl UffdPfHandler {
57+
impl UffdHandler {
5858
pub fn from_unix_stream(stream: &UnixStream, backing_buffer: *const u8, size: usize) -> Self {
5959
let mut message_buf = vec![0u8; 1024];
6060
let (bytes_read, file) = stream
@@ -83,6 +83,10 @@ impl UffdPfHandler {
8383
}
8484
}
8585

86+
pub fn read_event(&mut self) -> Result<Option<Event>, Error> {
87+
self.uffd.read_event()
88+
}
89+
8690
pub fn update_mem_state_mappings(&mut self, start: u64, end: u64, state: &MemPageState) {
8791
for region in self.mem_regions.iter_mut() {
8892
for (key, value) in region.page_states.iter_mut() {
@@ -93,36 +97,6 @@ impl UffdPfHandler {
9397
}
9498
}
9599

96-
fn populate_from_file(&self, region: &MemRegion, dst: u64, len: usize) -> (u64, u64) {
97-
let offset = dst - region.mapping.base_host_virt_addr;
98-
let src = self.backing_buffer as u64 + region.mapping.offset + offset;
99-
100-
let ret = unsafe {
101-
self.uffd
102-
.copy(src as *const _, dst as *mut _, len, true)
103-
.expect("Uffd copy failed")
104-
};
105-
106-
// Make sure the UFFD copied some bytes.
107-
assert!(ret > 0);
108-
109-
(dst, dst + len as u64)
110-
}
111-
112-
fn zero_out(&mut self, addr: u64) -> (u64, u64) {
113-
let page_size = get_page_size().unwrap();
114-
115-
let ret = unsafe {
116-
self.uffd
117-
.zeropage(addr as *mut _, page_size, true)
118-
.expect("Uffd zeropage failed")
119-
};
120-
// Make sure the UFFD zeroed out some bytes.
121-
assert!(ret > 0);
122-
123-
(addr, addr + page_size as u64)
124-
}
125-
126100
pub fn serve_pf(&mut self, addr: *mut u8, len: usize) {
127101
let page_size = get_page_size().unwrap();
128102

@@ -160,6 +134,36 @@ impl UffdPfHandler {
160134
addr
161135
);
162136
}
137+
138+
fn populate_from_file(&self, region: &MemRegion, dst: u64, len: usize) -> (u64, u64) {
139+
let offset = dst - region.mapping.base_host_virt_addr;
140+
let src = self.backing_buffer as u64 + region.mapping.offset + offset;
141+
142+
let ret = unsafe {
143+
self.uffd
144+
.copy(src as *const _, dst as *mut _, len, true)
145+
.expect("Uffd copy failed")
146+
};
147+
148+
// Make sure the UFFD copied some bytes.
149+
assert!(ret > 0);
150+
151+
(dst, dst + len as u64)
152+
}
153+
154+
fn zero_out(&mut self, addr: u64) -> (u64, u64) {
155+
let page_size = get_page_size().unwrap();
156+
157+
let ret = unsafe {
158+
self.uffd
159+
.zeropage(addr as *mut _, page_size, true)
160+
.expect("Uffd zeropage failed")
161+
};
162+
// Make sure the UFFD zeroed out some bytes.
163+
assert!(ret > 0);
164+
165+
(addr, addr + page_size as u64)
166+
}
163167
}
164168

165169
#[derive(Debug)]
@@ -168,7 +172,7 @@ pub struct Runtime {
168172
backing_file: File,
169173
backing_memory: *mut u8,
170174
backing_memory_size: usize,
171-
uffds: HashMap<i32, UffdPfHandler>,
175+
uffds: HashMap<i32, UffdHandler>,
172176
}
173177

174178
impl Runtime {
@@ -207,7 +211,7 @@ impl Runtime {
207211
/// When uffd is polled, page fault is handled by
208212
/// calling `pf_event_dispatch` with corresponding
209213
/// uffd object passed in.
210-
pub fn run(&mut self, pf_event_dispatch: impl Fn(&mut UffdPfHandler)) {
214+
pub fn run(&mut self, pf_event_dispatch: impl Fn(&mut UffdHandler)) {
211215
let mut pollfds = vec![];
212216

213217
// Poll the stream for incoming uffds
@@ -240,7 +244,7 @@ impl Runtime {
240244
nready -= 1;
241245
if pollfds[i].fd == self.stream.as_raw_fd() {
242246
// Handle new uffd from stream
243-
let handler = UffdPfHandler::from_unix_stream(
247+
let handler = UffdHandler::from_unix_stream(
244248
&self.stream,
245249
self.backing_memory,
246250
self.backing_memory_size,

src/firecracker/examples/uffd/valid_handler.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod uffd_utils;
1010
use std::fs::File;
1111
use std::os::unix::net::UnixListener;
1212

13-
use uffd_utils::{MemPageState, Runtime, UffdPfHandler};
13+
use uffd_utils::{MemPageState, Runtime, UffdHandler};
1414
use utils::get_page_size;
1515

1616
fn main() {
@@ -30,10 +30,9 @@ fn main() {
3030
let len = get_page_size().unwrap();
3131

3232
let mut runtime = Runtime::new(stream, file);
33-
runtime.run(|uffd_handler: &mut UffdPfHandler| {
33+
runtime.run(|uffd_handler: &mut UffdHandler| {
3434
// Read an event from the userfaultfd.
3535
let event = uffd_handler
36-
.uffd
3736
.read_event()
3837
.expect("Failed to read uffd_msg")
3938
.expect("uffd_msg not ready");

0 commit comments

Comments
 (0)