从零写一个 shadowsocks(1)

0. 引言

类似 shadowsocks,v2ray,trojan 之类的代理软件原理,用一张很老的图就能概括:
ss原理

是不是很简单呀?那么今天我们就来简单实现一下。不过鉴于这种基石软件用的人数之多,网上想必也有很多类似的文章,我再写这种陈词滥调一是无聊,二是无用。所以我打算加一点新鲜东西,讲述一个代理软件是如何一步步将简单的转发流量的功能拓展到:

  • 多个监听和转发端口
  • 支持多种协议
  • 路由功能

最主要的还是记录下我对这种软件的各个部件如何设计和组织的理解。

因此,写下本文的主要目的是:

  • 记录代理软件的原理和实现
  • 描绘整个软件的设计框架
  • 回顾下我是怎么写出这破代码的

整个项目使用 C++ 编写,其中我使用了 asio 库,真的很好用,谁用谁知道!

Let’s go!

1. 开始

从哪开始?

就像老牛吃南瓜——无从下口,我们应该从哪里开始?代理软件的核心是转发流量,用 v2ray 中的名词,不管是 local 端还是 server 端,都需要 inbound(入站)和 outbound(出站)。由于 local 端和 server 端运行的是同一套软件,我们只需暂时将注意力集中在 local 上的设计即可。

现在我们再次将问题简化,设计一个仅将流量原封不动地转发到互联网的东西,它看起来就像:

1
2
3
         socks  +---------+          +----------+
browser <======>| inbound |<========>| outbound |<======> internet
+---------+ +----------+
监听本地

既然要转发流量,首先就要监听到流量,我们使用 127.0.0.1:8888 作为监听地址监听本地流量。我们需要通过这个 socket 不断 accept 浏览器的请求,随后再通过读写通过 accept 得到的 socket 处理后续连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
static io_context ctx;
static ip::tcp::acceptor listen_sock(
ctx,
ip::tcp::endpoint(ip::address::from_string("127.0.0.1"), 8888)
);

using tcp_sock = ip::tcp::socket;
using tcp_sock_p = shared_ptr<tcp_sock>;
void to_listen() {
tcp_sock_p sp(new tcp_sock(ctx));
listen_sock.async_accept(*sp, [sp](const error_code err) {
if (!err) {
// TODO: 在这里处理通过 accept 得到的 socket
}
to_listen();
});
}

int main() {
to_listen();
ctx.run();
return 0;
}

上面就是整个程序的入口,to_listen() 中通过调用 listen_sock.async_accept(),我们会把 remote sock 的信息写入到第一个参数,而作为第二个参数的 lambda 函数是 async_accept() 的 callback,这个函数会在 accept 工作完成,即,第一个参数被写入 remote sock 的信息后被调用;于是这个 callback 函数的逻辑就很清晰了:处理 remote sock,并调用 to_listen() 继续监听下一个到来的连接。

需要注意两点:

  • 这里处理 remote sock 的代码必须也是异步的,否则会阻塞监听下一个 sock
  • 传入 async_accept() 的 sock 必须分配在堆上。原因是 async_accept() 作为异步函数会立即返回,to_listen() 可能在 async_accept() 完成工作前就早早返回,如果 sock 分配在栈上可能会被过早回收。

我们已经得到了需要进行读写操作的 sock,而我们应该如何继续处理这个烫手山芋呢?

つづく

简记一次 rCore 中 sys_sbrk() 与 lazy allocation 的实现

sbrk() 的作用

sbrk() 系统是做什么的,可以参考这里,简要来说,就是调整用户程序的堆(heap)顶位置,表现为对堆大小的增大或缩小。

rCore 的地址空间布局与调整

这里的 sbrk() 实现仅与用户虚拟地址空间有关。参考这里

对于 rCore 原始布局,在低地址部分,从低地址往高地址依次为:.text.rodata.data.bssguard pageuser stack,随后才是 heap

application address space

但我对这个布局不太满意,于是通过修改 mm::memory_set::MemorySet::from_elf(),将空间布局修改为如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/// High 256GB
/// +-------------------+
/// | Trampoline Code |
/// +-------------------+ <- TRAMPOLINE
/// | TrapContext |
/// +-------------------+ <- TRAP_CONTEXT, user_stack_top
/// | User Stack |
/// +-------------------+ <- user_stack_bottom
/// | ... |
///
/// Low 256GB
/// | ... |
/// +-------------------+
/// | Heap Memory |
/// +-------------------+ <- heap_bottom
/// | .bss |
/// +-------------------+
/// | .data |
/// +-------------------+
/// | .rodata |
/// +-------------------+
/// | .text |
/// +-------------------+ <- BASE_ADDRESS (0x10000 va)

heap 紧挨着 bss 段,而 user stack 调整到高地址处。这么做的原因是当物理地址理论足够的情况下,可以方便 heap 和 user stack 的拓展,同时由于二者位置相差大,不太可能重叠,而原实现相较于此,user stack 的大小限制为 8K(即两个 PAGE SIZE),由由于其位置的特殊性,上有 heap 空间,下有 guard page,kernel 段,难以拓展。这个修改很简单,我们的重点不在这里。

sbrk()

与其他的系统调用实现一样,先添加系统调用接口,再实现具体的逻辑。

1
2
3
4
5
6
7
pub fn sys_sbrk(size: i32) -> isize {
if let Some(old_brk) = change_program_brk(size) {
return old_brk as isize
} else {
-1
}
}

顺着调用链,我们来到 task::task::TaskControlBlock::change_brk(),这是实际的 sbrk 的实现。主要的逻辑为根据判断传入的参数的正负判断堆空间的缩减,并更新相应的堆顶指针:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// ...
let old_brk = self.program_brk;
let new_brk = self.program_brk as isize + size as isize;
let result = if size < 0 {
// shrink heap
self.memory_set.shrink_to(...)
} else {
// extend heap
self.memory_set.append_to(...)
};
if result {
self.program_brk = new_brk as usize;
Some(old_brk)
} else {
None
}

缩小和扩大 heap 空间分别对应 shrink_toappend_to 两个函数,而二者的逻辑较为相似,这里以 append_to 为例:

1
2
3
4
5
6
pub fn append_to(&mut self, page_table: &mut PageTable, new_end: VirtPageNum) {
for vpn in VPNRange::new(self.vpn_range.get_end(), new_end) {
self.map_one(page_table, vpn);
}
self.vpn_range = VPNRange::new(self.vpn_range.get_start(), new_end);
}

即朴实地为每一页在 PageTable 里创建一个映射,并没有什么花里胡哨的东西。

lazy alloc

lazy alloc 可以概括为:只标记,需要时分配。

具体来说:

  • 当调用 sbrk() 申请拓展 heap 空间,内核只增加堆顶指针(在这里是 TaskControlBlock::program_brk)的值,不做实际的内存分配,添加映射。用户程序访问这段内存会引起 PageFault,我们在处理该 Trap 的时候通过 stval 寄存器判断该地址是否位于 heap_bottomprogram_brk 之间,若是,分配实际内存,添加页表映射;
  • 当调用 sbrk() 申请缩小 heap 空间,需要减小 program_brk 的值,同时收回原堆顶和现堆顶之间的 page。
拓展 PageTable

由于原实现中 PageTable::map()PageTable::unmap() 在接口方面实现较为有限,我将其进行了拓展,将

1
2
pub fn map(&mut self, vpn: VirtPageNum, ppn: PhysPageNum, flags: PTEFlags, mut frame: Option<FrameTracker>);
pub fn unmap(&mut self, vpn: VirtPageNum, dealloc: bool);

修改为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
pub fn map(&mut self, args: MapArgs) {
let MapArgs { vpn, ppn, flags, mut frame } = args;
// ...
}
pub fn unmap(&mut self, args: UnmapArgs) {
let vpn = args.vpn;
let pte = match self.find_pte(vpn) {
Some(pte) if pte.is_valid() => pte,
_ => if args.panic {
panic!("vpn {:?} should mapped but not", vpn);
} else {
return;
}
};
// ...
}

目前,MapArgsUnmapArgs 中字段足够使用:

1
2
3
4
5
6
7
8
9
10
11
12
pub struct MapArgs {
vpn: VirtPageNum,
ppn: PhysPageNum,
flags: PTEFlags,
frame: Option<FrameTracker>,
}

pub struct UnmapArgs {
vpn: VirtPageNum,
dealloc: bool,
panic: bool,
}

这两个内以设计模式中的 Builder 设计,使用一系列的 with_... 成员函数调整参数。

为了便于控制后续课程中 lazy_alloc 的使用与否,为 lazy alloc 新增一个 feature 并使用条件编译:

1
2
3
4
5
# Cargo.toml
# ...
[features]
default = ["sbrk_lazy_alloc"]
sbrk_lazy_alloc = []
Trap

trap_handler 中为 lazy alloc 添加的入口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Trap::Exception(Exception::StoreFault)
| Trap::Exception(Exception::StorePageFault)
| Trap::Exception(Exception::LoadFault)
| Trap::Exception(Exception::LoadPageFault) => {
let tcb = get_current_tcb_ref();
let ok = if stval >= tcb.heap_bottom && stval < tcb.program_brk {
// lazy allocation for sbrk()
#[cfg(feature = "sbrk_lazy_alloc")] {
lazy_alloc_page(stval.into())
}
#[cfg(not(feature = "sbrk_lazy_alloc"))] {
false
}
// ...

当判断 stval 中的地址为 heap 中的有效地址时,使用 lazy_alloc_page() 分配一页内存,实现较为简单,这里不过分说明。

change_brk()

我们的重点是 change_brk() 的修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
pub fn change_brk(&mut self, size: i32) -> Option<usize> {
let old_brk = self.program_brk;
let new_brk = self.program_brk as isize + size as isize;
if new_brk < self.heap_bottom as isize {
return None;
}
let ok;
cfg_if! {
if #[cfg(feature = "sbrk_lazy_alloc")] {
ok = true;
if size < 0 {
self.memory_set
.remove_framed_area(
VirtAddr::from(new_brk as usize),
VirtAddr::from(self.program_brk),
);
}
} else {
// ...
}
}
if ok {
self.program_brk = new_brk as usize;
Some(old_brk)
} else {
None
}
}

我们注意到,当 size > 0 我们仅仅增加了 program_brk 的值就返回,其他什么事也不做,而当该值小于 0,只增加了一点微不足道的操作:释放减少的那段空间。

至于为什么我拓展了 PageTable 的参数,下面以 remove_framed_area 为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
pub fn remove_framed_area(
&mut self,
start_va: VirtAddr,
end_va: VirtAddr,
) {
for vpn in VPNRange::new(start_va.ceil(), end_va.floor()) {
self.page_table.unmap(
UnmapArgs::builder(vpn)
.with_dealloc(true)
.with_panic(false),
)
}
}

可以看到我们通过链式构造一步一步对 UnmapArgs 做出了要求,我们要求释放物理地址,并且保证不会 panic。这里我们为什么要保证不会 panic?由于原实现中的 unmap(),设计默认认为传给它的参数一定经过映射,当无法找到最后一级页表时,会认为是内核设计出现了 bug,于是直接通过 unwarp() 进行了 panic。

而这里我们释放 heap 空间,由于 lazy alloc 的存在,我们并不能事先知道哪些 page 被映射了而哪些没有,我们又不想付出额外的存储代价,于是选择对这段释放的空间的每一 page 进行遍历,其中有很大概率会遇见由于从未读写而未进行分配的 page,这样直接作为参数传给原来的 unmap() 必然会导致内核 panic,于是抽象出了 unmap() 的参数,为其添加 panic 字段,用于提示 unmap() 是否能够 panic。

最终的实现见 #876012b

从网卡驱动到系统调用

placeholder

随着鸽了半年的操作系统的 lab 的完成,我对操作系统的实现的概念有了一定的了解,接下来便是对一些概念的熟悉了。

不过在此次前,我想记录一下我个人觉得最有意思的一部分:写一个简陋的网卡驱动。简单的任务涵盖了许多方面,这也是标题的由来。

先鸽一下,过几天有时间继续写🐦

6.824 Lab2A Leader Election 可行方案

一些碎碎念

书接上回,这是 6.824 的 lab2A,实现 raft 协议中的 leader election。关于 raft 的更多详细内容,raft paper 和网络上大多数文章一定会比我在这里介绍得详细,这里只给出一个链接,以动图的方式助于理解 raft 的工作原理:Rafthttps://raft.github.io/ 也提供了一个可交互的动画,大家可以去玩一玩。

个人的这个实现相较于网上的各种版本,代码量较大,但个人感觉逻辑更加清晰。

不保证 bug free,但使用课程中的 test 测试了近 1000 轮无一失败。

思路

实现 raft 协议中的 leader election。由 paper 中可知集群中的所有节点会选出一个 leader 节点,选出 leader 后其余节点均为 follower。对集群的各种操作都需要经过 leader 之手,具体表现为 client 直接向 leader 进行请求,或向 follower 请求,随后该 follower 将请求重定向到 leader。

对于单个节点,有三种可能的状态:

  • follower
  • candidate
  • leader

对于每个节点,有以下几种事件会导致状态间的转移:

  • 超时事件
  • 接收到来自 RequestVote RPC 的请求
  • 接收到来自 AppendEntries RPC 的请求

同时由于处于 candidate 和 leader 状态下的节点分别会发出 RequestVote 和 AppendEntries 请求,还应该在上面三个事件中加入:

  • 来自 RequestVote RPC 请求的回应
  • 来自 AppendEntries RPC 请求的回应

于是,整个 leader election 变成了一个填表游戏:

事件\行为\状态 follower candidate leader
timeout a d h
RequestVote recv b e i
AppendEntries recv c f j
RequestVote callback / g /
AppendEntries callback / / k

当表填完,整个流程就基本完成了。

实现

杂项
rpc.go

依照 paper 对 rpc.go 进行了一些修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// ---------- for RequestVote ----------
type RequestVoteArgs struct {
// Your data here (2A, 2B).
Term int // candidate's curTerm
CandidateId int // candidate requesting vote
}

type RequestVoteReply struct {
// Your data here (2A).
Term int // currentTerm, for candidate to update itself
VoteGranted bool // true means candidate received vote
}

type voteParam struct {
args *RequestVoteArgs
reply *RequestVoteReply
notify chan struct{}
}

func (rf *Raft) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) {
// Your code here (2A, 2B).
notify := make(chan struct{})
rf.voteChan <- voteParam{args, reply, notify}
<-notify
}

func (rf *Raft) sendRequestVote(server int, args *RequestVoteArgs, reply *RequestVoteReply) bool {
ok := rf.peers[server].Call("Raft.RequestVote", args, reply)
return ok
}

// ---------- for AppendEntries ----------
type AppendEntriesArgs struct {
Term int // leader's curTerm
LeaderId int // so follower can redirect clients
}

type AppendEntriesReply struct {
Term int // currentTerm, for leader to update itself
Success bool // true if follower contained entry matching prevLogIndex and prevLogTerm
}

type appendEntryParam struct {
args *AppendEntriesArgs
reply *AppendEntriesReply
notify chan struct{}
}

func (rf *Raft) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesReply) {
notify := make(chan struct{})
rf.entryChan <- appendEntryParam{args, reply, notify}
<-notify
}

func (rf *Raft) sendAppendEntries(server int, args *AppendEntriesArgs, reply *AppendEntriesReply) bool {
ok := rf.peers[server].Call("Raft.AppendEntries", args, reply)
return ok
}

我并没有将处理的逻辑直接写到 RPC 处理函数中,而是将其封装到一个结构中,发送到一个专门用于接受这个参数的 channel 中,并传入一个空 channel 作同步作用。

raft.go

type Raft struct 中增加了一些 paper 中提到的本实现需要用到的字段,包括上面提到的接受 RPC 参数的 channel:

1
2
3
4
5
6
7
8
9
10

// Your data here (2A, 2B, 2C).
// Look at the paper's Figure 2 for a description of what
// state a Raft server must maintain.
curTerm int // current curTerm
state RState // current state
votedFor int // candidate id that received vote in current curTerm

voteChan chan voteParam // channel for vote request
entryChan chan appendEntryParam // channel for entry request

有关 RState interface 和具体的实现,定义如下:

1
2
3
4
5
6
7
type RState interface {
Run(tf *Raft)
}

type Follower struct{}
type Candidate struct{}
type Leader struct{}

这里的设计参考了设计模式:可复用面向对象软件的基础一书中的 State 模式,于是,便有了 GetState 函数的如下写法:

1
2
3
4
5
func (rf *Raft) GetState() (int, bool) {
rf.mu.Lock()
defer rf.mu.Unlock()
return rf.curTerm, reflect.TypeOf(rf.state) == reflect.TypeOf(&Leader{})
}

ticker 函数也变得格外简单:

1
2
3
4
5
func (rf *Raft) ticker() {
for !rf.killed() {
rf.state.Run(rf)
}
}

Make 函数只需要初始化我们添加的几个字段即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
func Make(peers []*labrpc.ClientEnd, me int,
persister *Persister, applyCh chan ApplyMsg) *Raft {
rf := &Raft{}
rf.peers = peers
rf.persister = persister
rf.me = me

rf.curTerm = 0
rf.state = &Follower{}
rf.votedFor = -1

rf.voteChan = make(chan voteParam)
rf.entryChan = make(chan appendEntryParam)
rf.readPersist(persister.ReadRaftState())

// start ticker goroutine to start elections
go rf.ticker()

return rf
}
common.go

定义一些常用函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
const (
ReElectLower = 150
ReElectUpper = 300

HeartBeatTimeout = 100
)

func randBetween(lower, upper int) int {
return rand.Intn(upper-lower) + lower
}

func electionTimeout() time.Duration {
return time.Duration(randBetween(ReElectLower, ReElectUpper)) * time.Millisecond
}

func heartbeatTimeout() time.Duration {
return time.Duration(HeartBeatTimeout) * time.Millisecond
}

func resetTimer(timer *time.Timer, timeout time.Duration) {
if !timer.Stop() {
<-timer.C
}
timer.Reset(timeout)
}

重头戏

下面正片开始,我将不同状态的逻辑分开写,虽然代码看上去挺长的,但是我感觉逻辑挺清晰的

follower.go

follower 是最简单的一个实现对照我上面画的一个表,只需要处理三个事件,大致框架如下:

1
2
3
4
5
6
7
8
9
10
11
12
func (f *Follower) Run(rf *Raft) {
timer := time.NewTimer(electionTimeout())
for {
select {
case <-timer.C:
// a 处理超时
case vote := <-rf.voteChan:
// b 处理 RequestVote 请求
case entry := <-rf.entryChan:
// c 处理 AppendEntries 请求
}
}

当 follower 超时,自动变为 candidate,并为自己投票,因此,a 处填写:

1
2
3
4
5
rf.mu.Lock()
rf.state = &Candidate{}
rf.votedFor = rf.me
rf.mu.Unlock()
return

当 candidate 收到 RequestVote RPC 的请求,首先检查 Term,若小于 curTerm 则拒绝为其投票;若满足 Term > curTerm 或者 Term == curTerm && vote.args.CandidateId,同意为其投票,并重置超时计时器。b 处填写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
vote.reply.Term = vote.args.Term
vote.reply.VoteGranted = false

vote.args.CandidateId)
rf.mu.Lock()
if vote.args.Term < rf.curTerm {
rf.mu.Unlock()
vote.notify <- struct{}{}
continue
}

// grant vote, update curTerm
if vote.args.Term > rf.curTerm ||
(vote.args.Term == rf.curTerm && rf.votedFor == vote.args.CandidateId) {
rf.curTerm = vote.args.Term
rf.votedFor = vote.args.CandidateId
vote.reply.Term = rf.curTerm
vote.reply.VoteGranted = true
// reset timer
resetTimer(timer, electionTimeout())
}
rf.mu.Unlock()
vote.notify <- struct{}{}

这里提一嘴,time.Timertime.Ticker 真的很容易用错,这里建议参考 golang 关于这两者的文档

当收到 AppendEntries RPC,同样依照 paper 中所说的检查任期等等逻辑如法炮制,c 处填写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
entry.args.LeaderId)
rf.mu.Lock()
if entry.args.Term < rf.curTerm {
// stale curTerm, do nothing
entry.reply.Term = rf.curTerm
entry.reply.Success = false
} else if entry.args.Term > rf.curTerm {
// larger curTerm, may this server is out of date
rf.curTerm = entry.args.Term
rf.votedFor = entry.args.LeaderId
entry.reply.Term = rf.curTerm
entry.reply.Success = true
resetTimer(timer, electionTimeout())
} else {
// same curTerm, reset timer
entry.reply.Term = rf.curTerm
entry.reply.Success = true
resetTimer(timer, electionTimeout())
}
rf.mu.Unlock()
entry.notify <- struct{}{}
candidate.go

先给出框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func (c *Candidate) Run(rf *Raft) {
rf.mu.Lock()
rf.curTerm++

// TODO: send RequestVote RPC
rf.mu.Unlock()

timer := time.NewTimer(electionTimeout())
for {
select {
case <-timer.C:
// d 超时
case reply := <-replyChan:
// g RequestVote 回答
case vote := <-rf.voteChan:
// e RequestVote 请求
case entry := <-rf.entryChan:
// f AppendEntries 请求
}
}
}

首先自增 curTerm,发送 RequestVote 请求其他节点的选票,随后开启循环等待事件驱动,其中 d,e,f 遵随 paper 所描述即可,只需要注意加锁释放锁的时机,这里简单贴一下代码:

d:

1
return

e:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
rf.mu.Lock()
if vote.args.Term > rf.curTerm {
rf.curTerm = vote.args.Term
rf.votedFor = vote.args.CandidateId
rf.state = &Follower{}
rf.mu.Unlock()
vote.reply.VoteGranted = true
vote.reply.Term = vote.args.Term
vote.notify <- struct{}{}
return
} else if vote.args.Term < rf.curTerm {
vote.reply.VoteGranted = false
vote.reply.Term = rf.curTerm
} else {
vote.reply.VoteGranted = false
vote.reply.Term = rf.curTerm
}
rf.mu.Unlock()
vote.notify <- struct{}{}

f:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
rf.mu.Lock()
// a leader send AppendEntries, if args.term >= curTerm
// acknowledge the leader and become follower
// or reject the request then stay as candidate
if entry.args.Term >= rf.curTerm {
rf.curTerm = entry.args.Term
rf.votedFor = entry.args.LeaderId
rf.state = &Follower{}
rf.mu.Unlock()
entry.reply.Term = entry.args.Term
entry.reply.Success = true
entry.notify <- struct{}{}
return
} else {
entry.reply.Term = rf.curTerm
entry.reply.Success = false
}
rf.mu.Unlock()
entry.notify <- struct{}{}

而 candidate 相较于 follower 多出来的逻辑部分就全部在下面了,我们需要在进入循环前以类似异步的方式发送 RequestVote 请求,并在循环中通过添加了一个 replyChan 进行处理 RequestVote 请求的回答。这样将请求和事件分开,逻辑没有那么混乱。于是,发送请求的代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// send RequestVote RPC
replyChan := make(chan RequestVoteReply, len(rf.peers)-1)
args := RequestVoteArgs{
Term: rf.curTerm,
CandidateId: rf.me,
}
for i := range rf.peers {
if i == rf.me {
continue
}
go func(pees *labrpc.ClientEnd) {
var reply RequestVoteReply
if pees.Call("Raft.RequestVote", &args, &reply) {
replyChan <- reply
}
}(rf.peers[i])
}

grantedCnt := 1
minVote := len(rf.peers)/2 + 1

而对应的事件处理,当收到的选票大于集群数量的一半时,转换为 leader,但同时仍要处理任期 Term,g:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
rf.mu.Lock()
if reply.Term > rf.curTerm {
// received larger term, become follower
rf.curTerm = reply.Term
rf.votedFor = -1
rf.state = &Follower{}
rf.mu.Unlock()
return
} else if reply.Term == rf.curTerm && reply.VoteGranted {
// received a grantVote
grantedCnt++
if grantedCnt >= minVote {
rf.state = &Leader{}
rf.mu.Unlock()
return
}
}
rf.mu.Unlock()
leader.go

框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func (l *Leader) Run(rf *Raft) {
// TODO: send heartbeat
// TODO: send AppendEntries RPCs to all other servers

// make heartbeat timer
timer := time.NewTimer(heartbeatTimeout())
for {
select {
case <-timer.C:
// h 超时
case reply := <-replyChan:
// k AppendEntries 回答
case vote := <-rf.voteChan:
// i RequestVote 请求
case entry := <-rf.entryChan:
// j AppendEntries 请求
}
}
}

h: 超时:

1
return

i:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
rf.mu.Lock()
if vote.args.Term > rf.curTerm {
rf.curTerm = vote.args.Term
rf.votedFor = vote.args.CandidateId
rf.state = &Follower{}
vote.reply.VoteGranted = true
vote.reply.Term = rf.curTerm
rf.mu.Unlock()
vote.notify <- struct{}{}
return
} else if vote.args.Term <= rf.curTerm {
vote.reply.VoteGranted = false
vote.reply.Term = rf.curTerm
}
rf.mu.Unlock()
vote.notify <- struct{}{}

j:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
rf.mu.Lock()
if entry.args.Term > rf.curTerm {
rf.curTerm = entry.args.Term
rf.votedFor = entry.args.LeaderId
rf.state = &Follower{}
entry.reply.Success = true
entry.reply.Term = rf.curTerm
rf.mu.Unlock()
entry.notify <- struct{}{}
return
} else if entry.args.Term <= rf.curTerm {
entry.reply.Success = false
entry.reply.Term = rf.curTerm
}
rf.mu.Unlock()
entry.notify <- struct{}{}

模仿 RequestVote 的发送方式,发送 AppendEntries 请求,这个实验中只需要发送 heartbeat,实现起来并不困难:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// send heartbeat
rf.mu.Lock()
replyChan := make(chan AppendEntriesReply, len(rf.peers))
args := AppendEntriesArgs{
Term: rf.curTerm,
LeaderId: rf.me,
}
// send AppendEntries RPCs to all other servers
for i := range rf.peers {
if i == rf.me {
continue
}
go func(peer *labrpc.ClientEnd) {
var reply AppendEntriesReply
if peer.Call("Raft.AppendEntries", &args, &reply) {
replyChan <- reply
}
}(rf.peers[i])
}
rf.mu.Unlock()

k:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
rf.mu.Lock()
if vote.args.Term > rf.curTerm {
rf.curTerm = vote.args.Term
rf.votedFor = vote.args.CandidateId
rf.state = &Follower{}
vote.reply.VoteGranted = true
vote.reply.Term = rf.curTerm
rf.mu.Unlock()
vote.notify <- struct{}{}
return
} else if vote.args.Term <= rf.curTerm {
vote.reply.VoteGranted = false
vote.reply.Term = rf.curTerm
}
rf.mu.Unlock()
vote.notify <- struct{}{}

最后

整个实验做下来的感觉:难,但是豁然开朗。明白了很多东西,很多细节需要处理,以及如何把握整个框架。写这么复杂的状态框架的原因是为了后续实验的更好拓展。

6.824 Lab1 MapReduce 快速实现

这个是 MIT 6.824 分布式系统(Distributed Systems)2022 年的 lab1,lab 链接。字有点多,需要花些时间耐心读一下。

个人认为这个 Lab 在对 Golang 有一定了解的情况下不是很难实现,一定要把题意理解清楚,画一个大致框架再动手。我用了两天时间,第一天主要是读题(再次吐槽一下字有点多)和画大致框架,第二天花了一下午几个小时写了代码。Debug 没有难度,可能是我运气比较好,一把过2333

下面开始讲个人的实现,代码量不大,算上注释大致 200 行。

rpc.go 的修改

coordinator.go 的修改

worker.go 的修改

复述

实现 Google 曾经使用的 MapReduce 的一个简单版本。具体的,修改mr 目录下的 coordinator.gorpc.goworker.go 三个文件实现 MapReduce

思路

主要是修改 coordinator.goworker.go,分别对于 pdf 文档中的 Master 和 Worker。前者负责调度任务,比如如何分配文件,如何规定超时任务的重新安排等等;后者负责实际执行 mapreduce 函数。

由于实际编写的是 coordinator 和 woker 的分布式版本,我们需要模拟这两个函数实际是在不同的机器上运行的,所以 coordinator 和 worker 两者的通信方式需要使用 RPC。于是,我们需要简单定义一下两者的通信方式:

  • 我们需要知道,需要先启动单个 coordinator,随后再启动多个 worker,这个测试脚本为我们做了,我们只需要了解。
  • worker 通过 RPC 向 coordinator 发送请求,表示“我现在空闲,请给我一个任务”。
  • 由于 coordinator 已经注册了 RPC 服务(我们无需关注),其可以收到来自 worker 的 RPC 请求,我们从 coordinator 的任务队列中取出一个任务发送给 worker,并同时启动一个用于接收 worker 完成任务的方法和一个超时计时器。前者用于当 worker 完成任务时通知 coordinator;后者用于当 worker 没有按时完成任务时(paper 中解释的,可能 worker 崩溃了,或者网络阻塞等等),用于将任务重新安排给一个 worker。
  • 当 worker 通过 RPC 获取到任务时,进行实际的工作。如果任务出错,直接终止等待下个任务即可,因为 coordinator 的计时器会帮我们重新安排这个任务;如果任务成功,再通过 RPC 告知 coordinator 任务已完成。若对 coordinator 的 RPC 调用失败,我们认为 coordinator 已经传输了所有的任务已退出,worker 可以终止。

简单的流程图如下:

cw_rpc

程序的流程是:

  • 我们通过对每个原始文件的内容作为 map 函数的参数输入,得到一组 KeyValue,对于不同的 Key 通过 ihash(key) % nReduce 计算出不同的 keyreduceId 将这个结果保存在 mr-{taskId}-{reduceId} 这个中间文件(intermediate)中。
  • 当所有原始文件全部生成完对应的 mr-{taskId}-{reduceId} 文件后,coordinator 开始分配 reduce 函数的任务,于是每个 worker 只需要查找全部对应的 reduceId 的文件(具体的:mr-*-{reduceId})进行合并处理和 reduce 调用即可。对于每个 reduceId ,生成对应的 mr-out-{reduceId} 文件。

简单来说,就是:pg*.txt –> mr-1-0, mr-1-1, …, mr-2-0, mr-2-1, … –> mr-out-0, mr-out-1,… 这个流程。

代码修改

由于没有删除的代码,下面只列出添加的代码:

rpc.go

首先添加一些 rpc 的定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
type Task struct {
Type int // 任务的类型,包括 Map 和 Reduce
Filename string // 当 Type 为 Map 时字段有效,表示需要作用 Map 函数的文件的文件名
TaskId int // 用于标记这个 Task 的 Id,在对 Coordinator 回馈结果的时候作为参数返回
ReduceId int // Coordinator 要求这个 Worker 统计的 Reduce 的 Id
}

const (
Map = iota
Reduce
)

// 创建一个对应 Map 的 Task 结构
func NewMapTask(filename string, taskId int) Task {
return Task{
Type: Map,
Filename: filename,
TaskId: taskId,
}
}

// 创建一个对应 Reduce 的 Task 的结构
func NewReduceTask(reduceId, taskId int) Task {
return Task{
Type: Reduce,
TaskId: taskId,
ReduceId: reduceId,
}
}

// NeedWork RPC: 用于请求 Coordinator 分配新的任务
// NeedWork 的参数
type NeedWorkArgs struct {
}

// NeedWork 的返回值
type NeedWorkReply struct {
T Task
ReduceCnt int
}

// FinishWork RPC: 用于告知 Coordinator 任务已完成
// FinishWork 的参数
type FinishWorkArgs struct {
TaskId int
}

// FinishWork 的返回值
type FinishWorkReply struct {
}
worker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 添加的 import
import (
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
)

// 引入的排序 interface,从 mrsequential.go copy 来
// for sorting by key.
type ByKey []KeyValue

// for sorting by key.
func (a ByKey) Len() int { return len(a) }
func (a ByKey) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByKey) Less(i, j int) bool { return a[i].Key < a[j].Key }

下面是 Worker 函数的定义,全是感情,没有任何 Go 的技巧 :)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
func Worker(mapf func(string, string) []KeyValue,
reducef func(string, []string) string) {
for {
needWordReply := NeedWorkReply{}
ok := call("Coordinator.NeedWork", &NeedWorkArgs{}, &needWordReply)
if !ok {
// 当 RPC 调用失败,我们认为 Coordinator 任务完成退出,
// 所以 Worker 理应退出终止
// Coordinator finish its work
break
}
if needWordReply.T.Type == Map {
// 处理 Map 任务的逻辑
filename := needWordReply.T.Filename
file, err := os.Open(filename)
if err != nil {
log.Fatalf("cannot open %v", filename)
}
content, err := io.ReadAll(file)
if err != nil {
log.Fatalf("cannot read %v", filename)
}
file.Close()
// 得到单个文件的 Key-Value
kva := mapf(filename, string(content))

// 生成 intermediate 文件的逻辑
// 首先计算每个 Key 的 reduceId,保存在 intermediate 这个二维切片中
intermediate := make([][]KeyValue, needWordReply.ReduceCnt)
for _, kv := range kva {
reduceTask := ihash(kv.Key) % needWordReply.ReduceCnt
intermediate[reduceTask] = append(intermediate[reduceTask], kv)
}
// intermediate[i] 对应 mr-{taskId}-{i} 这个中间文件
// 以 paper 中提示的使用 json 的方式写入
for i := 0; i < needWordReply.ReduceCnt; i++ {
ofilename := fmt.Sprintf("mr-%d-%d", needWordReply.T.TaskId, i)
// ofile, _ := os.Create(ofilename)
tf, _ := os.CreateTemp("./", ofilename)
enc := json.NewEncoder(tf)
for _, kv := range intermediate[i] {
enc.Encode(&kv)
}
tf.Close()
os.Rename(tf.Name(), ofilename)
}
} else if needWordReply.T.Type == Reduce {
// 处理 Reduce 任务的逻辑
// 找出目录下所有对应该 Reduce 的文件名
// find all files corresponding to this reduce task
var filenames []string
files, err := os.ReadDir(".")
if err != nil {
log.Fatalf("cannot read current directory")
}
for _, file := range files {
if file.IsDir() {
continue
}
filename := file.Name()
prefix := "mr-"
suffix := fmt.Sprintf("-%d", needWordReply.T.ReduceId)
if strings.HasPrefix(filename, prefix) && strings.HasSuffix(filename, suffix) {
filenames = append(filenames, filename)
}
}

// 对所有已找到的文件进行读取
// do reduce job
var kva []KeyValue
for _, filename := range filenames {
file, err := os.Open(filename)
if err != nil {
log.Fatalf("cannot open %v", filename)
}
dec := json.NewDecoder(file)
for {
var kv KeyValue
if err := dec.Decode(&kv); err != nil {
break
}
kva = append(kva, kv)
}
}

// 对该任务进行 Reduce 调用,逻辑和 mrsequential.go 完全一致,代码直接 copy
// copy from mrsequential.go
sort.Sort(ByKey(kva))
oname := fmt.Sprintf("mr-out-%d", needWordReply.T.ReduceId)
ofile, _ := os.Create(oname)
i := 0
for i < len(kva) {
j := i + 1
for j < len(kva) && kva[j].Key == kva[i].Key {
j++
}
values := []string{}
for k := i; k < j; k++ {
values = append(values, kva[k].Value)
}
output := reducef(kva[i].Key, values)
fmt.Fprintf(ofile, "%v %v\n", kva[i].Key, output)
i = j
}
ofile.Close()
} else {
// unknown task type
log.Fatalf("unknown task type: %v", needWordReply.T.Type)
}

// 运行到此代表该任务成功完成,通过 RPC 告知 Coordinator 任务完成
// make FinishWork RPC call
call("Coordinator.FinishWork", &FinishWorkArgs{TaskId: needWordReply.T.TaskId}, &FinishWorkReply{})
}
// log.Printf("Worker: %v exit", os.Getpid())
}
coordinator.go

个人认为 coordinator.go 的实现颇具技巧,网络上很多实现有大量评论代码“很不 Go”,没有 Go 的风格。

我虽说不是 Go 的高手,但也有一段时间具体学习了一下 Go 这门语言,参加了一些开源项目,对自己的代码风格还有有些信心,在这里自卖自夸一下2333,话不多说,先是 import:

1
2
3
4
5
6
import (
"fmt"
"log"
"sync"
"time"
)

Coordinator 结构体的修改,添加了一些字段用于操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
type Coordinator struct {
// Your definitions here.
// 保存一下 nReduce
nReduce int

// 定义任务队列和任务的计数器,
// 用于任务的发送,判断 Map 和 Reduce 任务的交界以及何时关闭 Coordinator
// task sending definition
taskQueue chan Task
taskWg sync.WaitGroup

// 任务 Id 计数器,自增
// task id counter
taskIdCounter int

// 保存用于通知任务完成的 chan 的字段,
// 若 Worker 完成并调用 RPC,处理函数会向对应的 channel 发送一个信号
// task notification record
taskNotifyMap map[int]chan struct{}
taskNotifyLock sync.Mutex

// 用于标记所有任务是否完成
// done flag
done bool
doneLock sync.Mutex
}

// 获得一个新的 taskId
func (c *Coordinator) NewTaskId() int {
c.taskIdCounter++
return c.taskIdCounter
}

// Done 函数,返回所有任务是否已经完成
func (c *Coordinator) Done() bool {
c.doneLock.Lock()
ret := c.done
c.doneLock.Unlock()

// Your code here.
return ret
}

MakeCoordinator 函数的设计:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
func MakeCoordinator(files []string, nReduce int) *Coordinator {
// 初始化一个 Coordinator
c := Coordinator{
nReduce: nReduce,
taskQueue: make(chan Task, 100),
taskWg: sync.WaitGroup{},
taskIdCounter: 0,
taskNotifyMap: make(map[int]chan struct{}),
done: false,
}
// log.Printf("%d file found\n", len(files))

// 由于 Reduce 任务必须在所有 Map 任务完成后才能开始
// 添加 Map 任务计数到 WaitGroup,用于标识所有的 Map 任务是否完成
c.taskWg.Add(len(files))

// Your code here.

// 开启一个新的 goroutine 向 c.taskQueue channel 中发送任务
// make send task goroutine
go func(files []string) {
// 发送 Map 任务
// send map task
for _, filename := range files {
c.taskQueue <- NewMapTask(filename, c.NewTaskId())
}
// log.Println("all task sent, waiting for next reduce task")
// 等待所有 Map 任务完成,这时才能发送 Reduce 任务
// wait for all map task done
c.taskWg.Wait()
// log.Println("start sending reduce task")

// 添加 Reduce 任务的计数
// send reduce task
c.taskWg.Add(nReduce)

// 下面这个 goroutine 的作用是
// 等待 WaitGroup 值为 0 时,此时表示所有的 Reduce 任务也完成了
// 所以可以将 c.done 置为 true,表示 Coordinator 可以结束了
// make Done() check goroutine
// log.Println("waiting exit goroutine created")
go func() {
c.taskWg.Wait()
c.doneLock.Lock()
c.done = true
c.doneLock.Unlock()
}()
// log.Println("sending reduce task")

// 发送 Reduce 任务
for i := 0; i < nReduce; i++ {
c.taskQueue <- NewReduceTask(i, c.NewTaskId())
}
}(files)

c.server()
return &c
}

NeedWork RPC 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func (c *Coordinator) NeedWork(args *NeedWorkArgs, reply *NeedWorkReply) error {
// 从任务队列 c.taskQueue 中取出一个任务作为 RPC 返回值
task, ok := <-c.taskQueue
if !ok {
return fmt.Errorf("cannot get task from task queue")
}
reply.T = task
reply.ReduceCnt = c.nReduce

// 使用 taskId 作为唯一任务标识,创建一个 struct{} channel
// 用于 FinishWork RPC 函数通知任务已被 Worker 完成
// 注意锁的临界区,否则会导致 condition race
// make task notification channel
c.taskNotifyLock.Lock()
c.taskNotifyMap[task.TaskId] = make(chan struct{})
// 启动 goroutine 定时
// set timer for task
go func(taskChan chan struct{}) {
select {
case <-taskChan:
// 任务完成
// task done
return
case <-time.After(10 * time.Second):
// 任务超时,重新将任务放回任务队列中
c.taskQueue <- task
}
}(c.taskNotifyMap[task.TaskId])
c.taskNotifyLock.Unlock()

return nil
}

FinishWork RPC 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
func (c *Coordinator) FinishWork(args *FinishWorkArgs, reply *FinishWorkReply) error {
taskId := args.TaskId
c.taskNotifyLock.Lock()
// 思考这里为什么要检查 taskId 所对应的值是否有效
// 若考虑网络延迟之类的不确定因素,可能导致同一个 taskId 调用两次 FinishWork 函数
notifyChan, ok := c.taskNotifyMap[taskId]
if !ok {
c.taskNotifyLock.Unlock()
return nil
}
// 向对应 channel 发送信号表示任务完成,同时任务计数 -1
// notify task done
notifyChan <- struct{}{}
c.taskWg.Done()
// log.Printf("task %d done\n", taskId)

// 删除对应的 channel
// delete task notification channel
delete(c.taskNotifyMap, taskId)
c.taskNotifyLock.Unlock()
return nil
}

总结

通过这个实验我大致搞明白了一个简单的分布式系统由哪些部分组成。至此,分布式系统的 Lab1 完成,我认为难点不在代码的实现,而是题目的理解。更干净的代码实现见本人的 github commit history

6.s018 lab5 cow(Copy-on-write fork) 踩坑指南

如题,不讲废话,直接开始。2020 年的 lab,lab 链接

复述

实现调用 fork() 时的写时复制,即 copy-on-write。

思路

根据提示:

  • 修改 uvmcopy() 函数。该函数仅会被 fork() 调用,原本用于将父进程的 pagetable 中所含的所有 page 复制到子进程的 pagetable 中,类似 deepcopy;修改后的 uvmcopy() 函数使用浅拷贝,子进程的 pagetable 结构与父进程完全一致。但同时,我们需要把父进程和子进两者的 pagetable 中的所有 pte 中,PTE_W 有效的 pte 的这一位清空,具体表现为使用 *pte &= ~PTE_W。同时,我们需要使用 pte 中 RSW 中的其中一位,用于标记这个 pte 实现了 cow,这里我选择了使用第 8 位作为标记,定义为:#define PTE_COWPG (1L << 8) 对 pte 做 *pte |= PTE_COW 操作。

  • 修改 usertrap() 函数,为其添加 scause 处理入口。由 riscv 文档,我们仅需处理 15 号 code,该 code 代表 Store/AMO page fault,而 13 号代表的 Load page fault 不在我们实现 cow 的处理范围内。在这个 scause 处理分支里,我们需要通过调用 r_stval() 得知导致 page fault 的虚拟地址,并一定要判断这个地址是否为合法的 cow 地址,若否,将进程标记为 killed 并返回,否则,复制 stval 地址中一个 page 的内容至新跑分配的空间,计算出正确的 pte 目录并重写。

  • 这里是一步非常容易出错,即关于每一个 page 的引用计数 (Reference Count)。基本思路是维护一个关于每个 page 的 array,每个 page 对应其在 array 中的索引为该 page 的物理地址除以 PGSIZE(4096),即 pa / PGSIZE,我们还能粗略计算出这个表的长度为 PHYSTOP >> PGSHIFT。对于:

    • 每次 kalloc() 调用,分配新的 page,我们对该 page 所对应的 rc 加 1

    • 每次 kfree() 调用,将该 page 对应的 rc 减 1,若此时 rc 为 0,则真正释放这段内存,否则什么也不做。
      这里,我们有如下定义:#define RC_SZ PHYSTOP >> PGSHIFT

    • uvmcopy() 中,对每个添加了 PTE_COWPG 的 pte 条目对应的物理地址的 rc 加 1

    • usertrap() 中,在处理 cow 的分支中,对导致 page fault 的虚拟地址所对应的物理地址的 rc 减 1

而任何一步对 rc 的读写,需要进行加锁处理,否则会导致条件竞争,这种情况下,为 xv6 添加 CPUS=1 参数能勉强通过测试,但实际测试中无法得到正确结果。

  • 修改 copyout() 函数,这一步与在 scause 中所做的修改如出一辙,不再复述。

代码修改

  • defs.h:添加几个导出的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  // kalloc.c
+ void acquire_kmem(void);
+ void release_kmem(void);
+ void add_rc(uint64);
+ void sub_rc(uint64);
+ int get_rc(uint64);
+ void set_rc(uint64, int);
void* kalloc(void);
void kfree(void *);
void kinit(void);

// ...

// vm.c
+ pte_t * walk(pagetable_t, uint64, int);
  • riscv.h:添加 PTE_COWPG 定义
1
#define PTE_COWPG (1L << 8)
  • kalloc.c:添加一些 rc 的函数,修改 kalloc()kfree()。需要注意操作 rc 时的加锁时机。这里的代码优化空间很大,我在这里的实现是 kmem.freelistkmem.rc 使用同一把锁,可以再研究一下锁的颗粒度,或者为 rc 单独细分出一个锁。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
+ #define RC_SZ PHYSTOP >> PGSHIFT
struct {
struct spinlock lock;
struct run *freelist;
+ uint8 rc[RC_SZ];
} kmem;

+ void acquire_kmem() {
+ acquire(&kmem.lock);
+ }

+ void release_kmem() {
+ release(&kmem.lock);
+ }

+ void add_rc(uint64 a) {
+ ++kmem.rc[a >> PGSHIFT];
+ }

+ void sub_rc(uint64 a) {
+ --kmem.rc[a >> PGSHIFT];
+ }

+ int get_rc(uint64 a) {
+ return kmem.rc[a >> PGSHIFT];
+ }

+ void set_rc(uint64 a, int rc) {
+ kmem.rc[a >> PGSHIFT] = rc;
+ }

注意下面对 freerange() 函数的修改:该函数在 xv6 启动时调用,并对每个 page 调用一次 kfree() ,由于 kfree() 会将所对应 page 的 rc 减 1,我们需要在 kfree() 调用之前将所对应的 page 的 rc 置为 1,否则会导致系统初始化后这些 page 的 rc 为 -1,引起未知的错误。这里不用为 rc 加锁的原因是 vx6 启动时只会使用一个核,不会产生竞争问题。

1
2
3
4
5
6
7
8
9
10
11
  void
freerange(void *pa_start, void *pa_end)
{
char *p;
p = (char*)PGROUNDUP((uint64)pa_start);
- for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE)
+ for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) {
+ kmem.rc[((uint64)p) >> PGSHIFT] = 1;
kfree(p);
+ }
}

下面是 kalloc()kfree() 的修改,注意锁的操控时机:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  void
kfree(void *pa)
{
struct run *r;

if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
panic("kfree");

+ acquire(&kmem.lock);
+ int rc = get_rc((uint64) pa);
+ if (rc == 0)
+ panic("kfree: bad rc");
+ sub_rc((uint64) pa);
+ if (get_rc((uint64) pa) != 0) {
+ release(&kmem.lock);
+ return;
+ }
// Fill with junk to catch dangling refs.
memset(pa, 1, PGSIZE);

r = (struct run*)pa;
- acquire(&kmem.lock);
r->next = kmem.freelist;
kmem.freelist = r;
release(&kmem.lock);
}

void *
kalloc(void)
{
struct run *r;

acquire(&kmem.lock);
r = kmem.freelist;
- if(r)
+ if(r) {
+ kmem.freelist = r->next;
+ if (get_rc((uint64) r) != 0)
+ panic("kalloc: bad rc");
+ add_rc((uint64) r);
+ }
release(&kmem.lock);

if(r)
memset((char*)r, 5, PGSIZE); // fill with junk
return (void*)r;
}
  • vm.c:uvmcopy:删除原来的分配内存的代码,取而代之的是仅对 pte 的修改。注意:对 rc 的操作需要加锁。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
+ #include "spinlock.h"
int uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
{
pte_t *pte;
uint64 pa, i;
uint flags;
- char *mem;

for(i = 0; i < sz; i += PGSIZE){
if((pte = walk(old, i, 0)) == 0)
panic("uvmcopy: pte should exist");
if((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");
pa = PTE2PA(*pte);
flags = PTE_FLAGS(*pte);
- if((mem = kalloc()) == 0)
- goto err;
- memmove(mem, (char*)pa, PGSIZE);
+ if (flags & PTE_W) {
+ flags &= ~PTE_W;
+ flags |= PTE_COWPG;
+ *pte = PA2PTE(pa) | flags;
+ }
- if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
+ if(mappages(new, i, PGSIZE, pa, flags) != 0){
- kfree(mem);
goto err;
}
+ acquire_kmem(); // NOTE!!!
+ add_rc(pa);
+ release_kmem(); // NOTE!!!
}
return 0;

err:
uvmunmap(new, 0, i / PGSIZE, 1);
return -1;
}
  • trap.c:usertrap:在判断 scause 的分支中插入一段判断 15 号 code 的代码即可,这里我在处理 PTE_COWPG 的 page 时做了一个优化:当产生 page fault 的地址的 rc 为 1 时不再做重新分配内存,复制,释放原内存的操作,而是直接修改原来 pte 的 flag 使其拥有 PTE_W,并移除掉 PTE_COWPG 标记。老样子,还是得注意锁的使用。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// ...
else if (cause == 15) {
uint64 stval = r_stval();
if (stval > p->sz || stval >= MAXVA) {
p->killed = 1;
goto out;
}
stval = PGROUNDDOWN(stval);
pte_t *pte = walk(p->pagetable, stval, 0);
if(pte == 0 || (*pte & PTE_V) == 0) {
p->killed = 1;
goto out;
}
uint64 pa = PTE2PA(*pte);
uint flags = PTE_FLAGS(*pte);
if (pa == 0 || (flags & PTE_COWPG) == 0) {
p->killed = 1;
goto out;
}
flags &= ~PTE_COWPG;
flags |= PTE_W;
acquire_kmem();
int rc = get_rc(pa);
if (rc == 0) {
panic("usertrap: bad rc");
} else if (rc == 1) {
*pte |= PTE_W;
*pte &= ~PTE_COWPG;
release_kmem();
} else {
release_kmem();
uint64 mem = (uint64) kalloc();
if (mem == 0) {
p->killed = 1;
goto out;
}
memmove((char*)mem, (char*)pa, PGSIZE);
// NOTE: cannot use sub_rc(pa)
kfree((void *) pa);
*pte = PA2PTE(mem) | flags;
}
}
// ...

代码中有一段不能使用 sub_rc() 的注释,原因如下:在 CPUS=1 的参数下使用是没有问题的,但若是在多核模式中运行,在中间一段没有持有锁的时期,其他进程也许会修改 rc 的数量,此时仅仅调用 sub_rc() 可能导致 rc 为 0 却没有被释放的情况。

  • vm.c:copyout:和 usertrap() 的修改如出一辙,不多解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  int
copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
{
uint64 n, va0, pa0;

while(len > 0){
va0 = PGROUNDDOWN(dstva);
pa0 = walkaddr(pagetable, va0);
if(pa0 == 0)
return -1;

+ if (va0 >= MAXVA)
+ return -1;
+ va0 = PGROUNDDOWN(va0);
+ pte_t *pte = walk(pagetable, va0, 0);
+ if (*pte & PTE_COWPG) {
+ uint64 pa = PTE2PA(*pte);
+ uint16 flag = PTE_FLAGS(*pte);
+ flag &= ~PTE_COWPG;
+ flag |= PTE_W;
+ acquire_kmem();
+ int rc = get_rc(pa);
+ if (rc == 1) {
+ *pte = PA2PTE(pa) | flag;
+ release_kmem();
+ } else {
+ release_kmem();
+ uint64 mem = (uint64) kalloc();
+ if (mem == 0) {
+ return -1;
+ }
+ memmove((char *) mem, (char *) pa, PGSIZE);
+ acquire_kmem();
+ sub_rc(pa);
+ release_kmem();
+ *pte = PA2PTE(mem) | flag;
+ pa0 = mem;
+ }
+ }

n = PGSIZE - (dstva - va0);
if(n > len)
n = len;
memmove((void *)(pa0 + (dstva - va0)), src, n);

len -= n;
src += n;
dstva = va0 + PGSIZE;
}
return 0;
}

总结

这个 lab 总体思路和实现不是很难,但却让我究竟了几个小时,主要就是卡在了引用计数的同步问题上,这个一定要多加注意。可以通过先添加 CPUS=1 参数看看测试能否通过,用这种方式判断是否问题出在同步上。

测试通过结果没有什么意义,就懒得贴上来了(

将 linux 0.11 进程切换由基于 tss 改写为基于内核栈

来自哈工大李治军老师的操作系统网课 link 实验四。

具体是 linux 0.11 使用 CPU 提供的 TSS 进行进程的切换,虽然使用这个机制能很方便地进行进程切换,但是效率不太高,于是要求我们在课程的理解上将其改写为基于内核栈(kernel stack)的进程切换。

写在前面的废话

我的理解能力有限,外加周围没有人能指导,摸爬滚打查阅了数天的资料,理解了其中大致的原理,故才有了下面这篇文章。我认为只是看看网文,大致了解个思路然后心里想着“哦好像是这样”,然后草草了然进入下一课是不行的。你得了解核心代码每一行做了什么,为什么这样做,反复问自己。

前置知识

下面写一些前置知识,了解的可以当做复习,不了解的必须了解啊,有助于下文的理解。

中断做了什么

李老师在第五课中借用系统调用的实现顺便讲了讲中断是怎么做的,简单来说,就是用 int 指令发出中断信号,系统根据 int 后的操作数选择提前注册好的中断处理函数进行中断处理。举个例,int 0x80 指令发出后会陷入到系统调用,执行 system_call (定义于 kernel/system_call.s)中的指令。为什么这条指令执行后会直接执行 system_call?linux 在初始化时调用了 set_system_gate(0x80, &system_call)system_call 注册为 0x80 号中断的处理函数。后文提到的中断默认以系统调用作为例子。

执行 int 指令后由用户栈陷入内核栈,请注意,int 调用后并不是马上执行 system_call 中的代码。在 int 调用后,system_call 执行前,需要把用户态对应的一些重要寄存器保存,用以中断返回后的现场恢复。具体的,我们需要依次保存 ss(用户数据栈底),sp(用户数据栈顶),eflags(标志寄存器),csip 这五个寄存器的值到内核栈。随后进入到 syscall,此时内核栈是这样的:

kernel stack0

随后执行 system_call 中的代码。而要从内核栈返回到用户栈,需要使用 iret 指令,此时必须保证内核栈中保存有用于恢复用户态的信息,如上图。

TSS

TSS 的结构由 include/linux/sched.h 中的 struct tss_struct 定义。在 linux 0.11 代码中,TSS 用于切换进程时,将当前 CPU 中寄存器的所有值保存到当前进程的 TSS,随后使用即将切换的进程的 TSS 中的信息来填写 CPU 中的寄存器,这样就完成了进程的切换。其中较重要的寄存器是 ssspcsip。注意,这四个寄存器,不管是现在 CPU 中的值,还是即将切换的进程的 TSS 中所保存的这四个寄存器的值,记录的都是相应的内核态的状态。

switch_to 函数

定义于 include/linux/sched.h这篇博文解释地非常好,重点理解其中的 ljmp 指令,顺便复习一下 GDT 表的结构。

其他的

时刻记住 call 相当于 push ipret 相当于 pop ip

流程分析

首先应该明确,修改进程切换的方法同时需要修改 fork() 系统调用,因为 linux 启动时需要通过 fork() 启动 shell。

首先来看 system_call 中的部分代码,不重要的省去:

1
2
3
4
5
6
7
8
9
10
11
system_call:
cmpl $nr_system_calls-1,%eax
ja bad_sys_call
push %ds
push %es
push %fs
pushl %edx
pushl %ecx
pushl %ebx
...
call sys_call_table(,%eax,4)

入栈一些寄存器,使用 call 调用真正的系统调用函数是,内核栈如下图:

kernel stack1

系统调用执行完成后:

1
2
3
4
5
6
pushl %eax
movl current,%eax
cmpl $0,state(%eax) # state
jne reschedule
cmpl $0,counter(%eax) # counter
je reschedule

eax 入内核栈,其中两个 cmpl 分别判断进程的状态(运行?挂起?…)和时间片,根据判断结果决定是否跳转到 reschedule 处,reschedule 定义:

1
2
3
reschedule:
pushl $ret_from_sys_call
jmp schedule

ret_from_sys_call 的地址入栈,调用 schedule 函数进行进程切换。其中,对于 ret_from_sys_call,其核心为:

1
2
3
4
5
6
7
8
9
10
ret_from_sys_call:
...
popl %eax
popl %ebx
popl %ecx
popl %edx
pop %fs
pop %es
pop %ds
iret

将内核栈中的内容 pop 直到仅剩 ssspeflagscsip,这些正好是恢复到用户态需要的寄存器信息。随后调用 iret 回到用户态。

schedule 和 switch_to

schedule 是一个 C 函数,跳转到 schedule 函数时,内核栈中的内容:

kernel stack3

此时 schedule,中的调度算法计算出下一个切换的进程,并调用 switch_to 切换至这个进程。调用 switch_to 前,内核栈中的内容是不变的。

switch_to 的讲解见这篇博文,写得非常好。

fork() 函数

fork() 原始的函数的讲解见李老师的视频后半段,解释地非常清楚。

改写注意事项

改写时一定要清楚程序运行到改写处时,此时寄存器中的值和内核栈中的值。

改写 switch_to() 函数

将原先的 switch_to 宏定义删除,在 system_call.s 添加 switch_to汇编实现。

我们定义 long switch_to(long pnext, long ldt) 为函数签名,其中 pnext 为指向一个 task_struct 的指针,即为 PCB,ldt 为其对应的 LDT 描述符地址。因此在 schedule() 中调用 switch(pnext, _LDT(next)) 并进入 switch_to 函数体后,8(%ebp)pnext12(%ebp)_LDT(next),这是由 C 函数调用时将参数和返回值逆序压入栈中所决定的。下面贴上网络上给出的 switch_to 完整具体实现,并选择一些解释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
switch_to:
pushl %ebp
movl %esp,%ebp
pushl %ecx
pushl %ebx
pushl %eax
movl 8(%ebp),%ebx
cmpl %ebx,current
je 1f
movl %ebx,%eax
xchgl %eax,current
movl tss,%ecx
addl $4096,%ebx
movl %ebx,4(%ecx)
movl %esp,12(%eax)
movl 8(%ebp),%ebx
movl 12(%ebx),%esp
movl 12(%ebp),%ecx
lldt %cx
movl $0x17,%ecx
mov %cx,%fs
cmpl %eax,last_task_used_math
jne 1f
clts
1:
popl %eax
popl %ebx
popl %ecx
popl %ebp
ret

从第 1 行到第 7 行所作的任务是处理 C 调用的帧栈;保存一些寄存器的值,因为这些寄存器稍后将使用,第 7 行 movl 8(%ebp),%ebx 执行完毕后,内核栈的状态:

kernel stack4

1
movl 8(%ebp),%ebx

pnext,即新进程的 PCB 赋给 ebx

1
2
cmpl %ebx,current
je 1f

比较新进程的 PBC 与当前进程的 PCB,若相同,调到标签 1 处,即什么也不做返回函数。

1
2
movl %ebx,%eax
xchgl %eax,current

执行完这两句,ebxcurrent 值为新的进程,eax 为当前进程。

1
2
3
movl tss,%ecx
addl $4096,%ebx
movl %ebx,4(%ecx)

此处的 tss 为我们自己定义的全局变量,作为所有进程公共的 TSS 表,我们在 kernel/sched.h 中加上 struct tss_struct *global_tss = &(init_task.task.tss); 一句,定义全局 TSS。

对于 addl $4096,%ebx,由于 ebx 中此前的值为新的进程的 PCB,对于 linux 0.11 其一页内存大小为 4k。对于一个进程,其 PCB 存储在一页内存的低地址,其内核栈在该页内存的高地址,结构如图:

pcb

随后 movl %ebx,4(%ecx),将内核栈栈顶赋值给 TSS 的 sp0 字段,该字段在 tss_struct 结构的第 4 个偏移位置,至于为什么要设置 sp0,引用一段来自这里的解释:

虽然所有进程共用一个 tss,但不同进程的内核栈是不同的,所以在每次进程切换时,需要更新 tss 中 esp0 的值,让它指向新的进程的内核栈,并且要指向新的进程的内核栈的栈底,即要保证此时的内核栈是个空栈,帧指针和栈指针都指向内核栈的栈底。 这是因为新进程每次中断进入内核时,其内核栈应该是一个空栈。

接下来三句:

1
2
3
movl %esp,12(%eax)
movl 8(%ebp),%ebx
movl 12(%ebx),%esp

eax 为当前进程 PCB。

这里我们需要先修改 include/linux/sched.htask_struct ,即 PCB 的定义。新增一个 long kernel_stack 字段于结构体第四个字段,因此表现为相对于结构体首地址偏移 12 个单位。

这三句的作用是,将当前 esp ,即当前内核栈栈顶指针存储到 PCB 中;使 ebx 指向新进程 PCB,将新进程 PCB 的 esp 即内核栈栈顶指针赋值给 esp 寄存器。此时已经完成内核栈由当前进程到新进程的切换

接下来两句:

1
2
movl 12(%ebp),%ecx
lldt %cx

切换 LDT。

1
2
3
4
5
movl $0x17,%ecx
mov %cx,%fs
cmpl %eax,last_task_used_math
jne 1f
clts

为固定指令,后续课程会讲到,目前不做要求。具体解释引用蓝桥云课中一段解释,目前我还不了解:

这两句代码的含义是重新取一下段寄存器 fs 的值,这两句话必须要加、也必须要出现在切换完 LDT 之后,这是因为在实践项目 2 中曾经看到过 fs 的作用——通过 fs 访问进程的用户态内存,LDT 切换完成就意味着切换了分配给进程的用户态内存地址空间,所以前一个 fs 指向的是上一个进程的用户态内存,而现在需要执行下一个进程的用户态内存,所以就需要用这两条指令来重取 fs。

不过,细心的读者可能会发现:fs 是一个选择子,即 fs 是一个指向描述符表项的指针,这个描述符才是指向实际的用户态内存的指针,所以上一个进程和下一个进程的 fs 实际上都是 0x17,真正找到不同的用户态内存是因为两个进程查的 LDT 表不一样,所以这样重置一下 fs=0x17 有用吗,有什么用?要回答这个问题就需要对段寄存器有更深刻的认识,实际上段寄存器包含两个部分:显式部分和隐式部分,如下图给出实例所示,就是那个著名的 jmpi 0, 8,虽然我们的指令是让 cs=8,但在执行这条指令时,会在段表(GDT)中找到 8 对应的那个描述符表项,取出基地址和段限长,除了完成和 eip 的累加算出 PC 以外,还会将取出的基地址和段限长放在 cs 的隐藏部分,即图中的基地址 0 和段限长 7FF。为什么要这样做?下次执行 jmp 100 时,由于 cs 没有改过,仍然是 8,所以可以不再去查 GDT 表,而是直接用其隐藏部分中的基地址 0 和 100 累加直接得到 PC,增加了执行指令的效率。现在想必明白了为什么重新设置 fs=0x17 了吧?而且为什么要出现在切换完 LDT 之后?

最后,标签 1 后的代码为函数返回所做的工作,最后返回到系统调用 iret,退出中断。

改写 fork() 函数

调用 fork() 函数时,函数调用链大致为 fork(C)->syscall(asm)->sys_fork(asm)->copy_process(C)

进入 copy_process 前,内核栈内容如下图:

kernel stack5

对应函数签名:

1
2
3
4
int copy_process(int nr,long ebp,long edi,long esi,long gs,long none,
long ebx,long ecx,long edx,
long fs,long es,long ds,
long eip,long cs,long eflags,long esp,long ss)

copy_process 中使用 p = get_free_page() 获得一页新的内存,该内存即为新进程的内核栈和 PCB 所使用的的空间,p 指向 PCB 首地址,我们添加下面的代码设置子进程的内核栈:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
long *kernel_stack_top = (long *)(PAGE_SIZE + (long)p);
// prepare info for user stack
*(--kernel_stack_top) = ss & 0xffff;
*(--kernel_stack_top) = esp;
*(--kernel_stack_top) = eflags;
*(--kernel_stack_top) = cs & 0xffff;
*(--kernel_stack_top) = eip;
*(--kernel_stack_top) = ds & 0xffff;
*(--kernel_stack_top) = es & 0xffff;
*(--kernel_stack_top) = fs & 0xffff;
*(--kernel_stack_top) = gs & 0xffff;
*(--kernel_stack_top) = esi;
*(--kernel_stack_top) = edi;
*(--kernel_stack_top) = edx;
*(--kernel_stack_top) = (long)first_ret_from_fork;
*(--kernel_stack_top) = ebp;
*(--kernel_stack_top) = ecx;
*(--kernel_stack_top) = ebx;
*(--kernel_stack_top) = 0; // fork return 0 in child process
p->kernel_stack = (long)kernel_stack_top;

第一条指令将 kernel_stack_top 设置为内核栈栈顶,后面依次填写数据,最后一条指令记录子进程的内核栈栈顶,也就是 sp 指针位置。

对于子进程,fork() 返回 0,因此将内核栈栈顶,也就是对应的 eax 的位置设置为 0。同时可以删掉 copy_process 中一些不必要的代码,不删除不影响结果。

改写后的运行逻辑

现在假设我们将所有修改应用,我们现在分析 fork() 后子进程的行为。

switch_to 切换到子进程并运行到其中的标签 1 处,依次对 eaxebxecxebp 进行出栈并调用 ret,调用 ret 时内核栈栈顶为 first_ret_from_fork() 函数,考虑这个函数应该做什么。既然子进程已经切换完毕,我们则需要从内核态返回到用户态执行用户代码,因此该函数需要承担从内核态返回的任务,调用 first_ret_from_fork() 时,此时对应的子进程内核栈的内容如下:

kernel stack6

很自然写出 first_ret_from_fork() 的定义:

1
2
3
4
5
6
7
8
9
first_ret_from_fork:
popl %edx
popl %edi
popl %esi
pop %gs
pop %fs
pop %es
pop %ds
iret

写在后面的废话

上面记录了一些我认为较难理解的细节,我的完整的修改过程记录在了这个 commit 里。个人认为这个更像是对于线程的修改,因为父子进程共享了一个用户数据栈,实际写代码验证也正是如此,修改 fork() 后子进程中的变量会对父进程中的变量造成影响,反之亦然。

浅谈创建型设计模式

通常认为创建型构造模式(creational design pattern)共有这五种 Abstract FactoryBuilderFactory MethodPrototypeSingleton。书中所对应的中文一般分别为抽象工厂,生成器,工厂方法,原型,单件。

This is a draft.

To be continue…

实现 dup2() 函数

来自 APUE 上的第 3.3 题,原文如下:

编写一个与 3.12 节中 dup2 功能相同的函数,要求不调用 fcntl 函数,并且要有正确的出错处理。

函数原型是 int dup2(int fd, int fd2) ,作用是复制文件描述符 fd ,使 fd2 为新的描述符的值,即 fdfd2 共享同一个文件表项。若 fd2 已经打开,则先将其关闭。

很有意思的一个题。由于要求不调用 fcntl ,则基本思路为使用 dup 函数实现之。

实现基本逻辑 myDup2

我们令自己实现的函数原型为 int myDup2(int fd, int fd2) ,先考虑该函数的基本逻辑。

我们让所有的 errno 处理在该函数进行处理。

首先,我们需要要求传入函数的两个文件描述符必须有效,于是有:

1
2
3
4
5
6
7
int myDup2(int fd, int fd2) {
int max_fd = getdtablesize(); // 获取文件描述符表的大小
if (fd < 0 || fd >= max_fd || fd2 < 0 || fd2 >= max_fd || !isOpen(fd)) {
errno = EBADF;
return -1;
}
// ...

其中 getdtablesize() 用于获取文件描述符表的的大小,限制了系统最大的文件描述符。

isOpen 函数用于判断给定文件描述符是否为已打开的文件描述符,后续实现。

若给定的文件描述符 fdfd2 无效,则设置 errno 并返回 -1

接下来,由 dup2 的定义,若 fd 等于 fd2 ,则 dup2 返回 fd2 ,而不关闭它;以及如果 fd2 已经打开,则先将其关闭:

1
2
3
4
5
6
// ...
if (fd == fd2)
return fd;
if (isOpen(fd2))
close(fd2);
// ...

接下来为 myDup2 的核心逻辑。根据「 dup 返回的新文件描述符一定是当前可用文件描述符的最小数值」,基本思路为不断使用 dup(fd) 直至其返回值与 fd2 相等,随后将之前使用 dup 打开的文件描述符关闭即可。

考虑到需要记录由 dup(fd) 创建的文件描述符用于删除,加之 C 需要手动实现类似动态数组之类的数据结构,我们偷懒使用递归替代手动存储的工作。我们将这个操作使用 int recursiveDup(int fd, int fd2) 完成,我们定义返回值为成功复制后的 fd2 ,若失败,则返回 -1

1
2
3
4
5
// ...
if ((fd2 = recursiveDup(fd, fd2)) == -1)
errno = EBADF;
return fd2;
}

调用 recursiveDup 后期望得到 fd2 的值,若失败,则设置 errno

实现核心递归函数 recursiveDup

首先:

1
2
3
4
5
6
7
8
int recursiveDup(int fd, int fd2) {
int nf = dup(fd);
if (nf == -1)
return -1;
if (nf == fd2)
return nf;
int rf = recursiveDup(fd, fd2);
// ...

使用 dup 复制文件描述符,若失败,返回 -1 ,这里并不处理 errno 因为我们将所有的 errno 处理放在 myDup 中。若得到期望的文件描述符的值,返回之。若否,继续递归调用 recursiveDup 直至其产生。

之后:

1
2
3
4
// ...
close(nf);
return rf;
}

关闭不合要求的 nf 返回结果(不论是否为 -1)。

收尾,实现 isOpen

实现如下:

1
2
3
4
5
6
7
8
9
10
int isOpen(int fd) {
int old_errno = errno;
int test_f = dup(fd);
errno = old_errno;
if (test_f == -1) {
return 0;
}
close(test_f);
return 1;
}

我们利用 dup 函数只能作用于已打开文件描述符的机制进行判断。注意,我们需要对本来的 errno 进行处理,由于我们的 isOpen 仅进行判断,不应对系统本来的 errno 有影响,dup 后需要及时恢复。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
int isOpen(int fd) {
int old_errno = errno;
int test_f = dup(fd);
errno = old_errno;
if (test_f == -1) {
return 0;
}
close(test_f);
return 1;
}

int recursiveDup(int fd, int fd2) {
int nf = dup(fd);
if (nf == -1)
return -1;
if (nf == fd2)
return nf;
int rf = recursiveDup(fd, fd2);
close(nf);
return rf;
}

int myDup2(int fd, int fd2) {
int max_fd = getdtablesize();
if (fd < 0 || fd >= max_fd || fd2 < 0 || fd2 >= max_fd || !isOpen(fd)) {
errno = EBADF;
return -1;
}
if (fd == fd2)
return fd;
if (isOpen(fd2))
close(fd2);
if ((fd2 = recursiveDup(fd, fd2)) == -1)
errno = EBADF;
return fd2;
}

个人感觉很有意思的一个题,不断反复利用 dup 函数的性质得到结果。