Optional Challenge
The user-level thread package interacts badly with the operating system in several ways. For example, if one user-level thread blocks in a system call, another user-level thread won't run, because the user-level threads scheduler doesn't know that one of its threads has been descheduled by the xv6 scheduler. As another example, two user-level threads will not run concurrently on different cores, because the xv6 scheduler isn't aware that there are multiple threads that could run in parallel. Note that if two user-level threads were to run truly in parallel, this implementation won't work because of several races (e.g., two threads on different processors could call <tt>thread_schedule</tt> concurrently, select the same runnable thread, and both run it on different processors.)
There are several ways of addressing these problems. One is using scheduler activations and another is to use one kernel thread per user-level thread (as Linux kernels do). Implement one of these ways in xv6. This is not easy to get right; for example, you will need to implement TLB shootdown when updating a page table for a multithreaded user process.
參考: https://courses.cs.duke.edu/fall23/compsci310/thread.html
為了實現(xiàn)上述功能,首先我們需要定義2個新的系統(tǒng)調(diào)用函數(shù)。
類似于進(jìn)程里的fork 和 wait
一個是用來開一個線程,另一個是用來等待線程結(jié)束。
大概細(xì)節(jié)如下:
API Details
We describe the API here, including its input parameters, what it does, and its return value.


1 clone實現(xiàn)
主要是借鑒fork函數(shù)的框架 進(jìn)行改寫。
int
clone(void(*fcn)(void*, void*), void *arg1, void *arg2, void *stack)
{
int i, pid;
struct proc *np;
struct proc *p = myproc();
// Ensure stack is page align, which help setup guard page.
if(((uint64)stack % PGSIZE) != 0)
return -1;
// Allocate process.
if((np = allocproc()) == 0){
return -1;
}
// mark this process is thread
np->isthread = 1;
// use same page table as parent, to keep same memory space
np->pagetable = p->pagetable;
// share some variable between threads
np->tshared = p->tshared;
np->parent = p;
// copy saved user registers.
*(np->trapframe) = *(p->trapframe);
// setup thread's function address
np->trapframe->epc = (uint64)fcn;
// setup thread's function args
// refer to riscv calling covention: https://pdos.csail.mit.edu/6.828/2023/readings/riscv-calling.pdf
np->trapframe->a0 = (uint64)arg1;
np->trapframe->a1 = (uint64)arg2;
// ensure thread without exit return to a invalid address to trigger trap
np->trapframe->ra = 0xffffffffffffffff;
// Use the second page as the user stack.
np->trapframe->sp = (uint64)(stack + 2 * PGSIZE);
// Keep stack address for "join" to return
np->tstack = (uint64)stack;
// setup first stack page as guard page, remove PTE_U
uvmclear(np->pagetable, np->tstack);
// find a address to remap TRAPFRAME page
// it is important since TRAPFRAME page should not be shared across threads
uint64 trap_va = PHYSTOP;
for(; trap_va < TRAPFRAME ; trap_va += PGSIZE) {
if (kwalkaddr(np->pagetable, trap_va) == 0) {
np->trap_va = trap_va;
mappages(np->pagetable, np->trap_va, PGSIZE,
(uint64)(np->trapframe), PTE_R | PTE_W);
break;
}
}
// failed to find a space
if (trap_va >= TRAPFRAME) {
return -1;
}
// increment reference counts on open file descriptors.
for(i = 0; i < NOFILE; i++)
if(p->ofile[i])
np->ofile[i] = filedup(p->ofile[i]);
np->cwd = idup(p->cwd);
safestrcpy(np->name, p->name, sizeof(p->name));
pid = np->pid;
np->state = RUNNABLE;
release(&np->lock);
return pid;
}
2 改動proc.h
然后我們需要對proc 結(jié)構(gòu)進(jìn)行一些改寫。

3 支持動態(tài)TRAPFRAME
然后因為,不同的thread TRAPFRAME 不再是一個常數(shù)地址,所以我們需要記錄這個變量,并且需要更改相應(yīng)的trap 匯編,改動如下:
((void (*)(uint64,uint64))trampoline_userret)(p->trap_va, satp);


4 使得sz在thread 間共享
因為thread彼此之間是共享內(nèi)存空間的,所以當(dāng)有一個線程,增大了內(nèi)存,應(yīng)該對其他線程可見。
我們之前在proc 結(jié)構(gòu)體已經(jīng)通過指針的方式,把sz存進(jìn)到TRAMPFRAME頁面的最后。并且在clone函數(shù)里使得所有thread 這個指針指向了父親的TRAMPFRAME頁上這個變量的地址。
那么我們需要在更改這個值的時候,使用指針修改就可以使得其他線程看見最新的sz. 同時我們?yōu)榱朔乐箮讉€線程同時修改sz, 出現(xiàn)更是丟失的情況,我們需要用同一把鎖,對修改進(jìn)行加鎖操作。
更改相應(yīng)的growproc函數(shù):
int
growproc(int n)
{
uint64 sz;
struct proc *p = myproc();
acquire(&p->tshared->tlock);
sz = p->tshared->sz;
if(n > 0){
if((sz = uvmalloc(p->pagetable, sz, sz + n, PTE_W)) == 0) {
release(&p->tshared->tlock);
return -1;
}
} else if(n < 0){
sz = uvmdealloc(p->pagetable, sz, sz + n);
}
p->tshared->sz = sz;
release(&p->tshared->tlock);
return 0;
}
同時更改其他所有p->sz 為p->tshared->sz
5 支持 kwalkaddr
因為我們需要在kernel 的頁表里,找到一頁動態(tài)映射到線程的TRAPFRAME上。所以借鑒walkaddr, 實現(xiàn)一個kernel level的
uint64
kwalkaddr(pagetable_t pagetable, uint64 va)
{
pte_t *pte;
uint64 pa;
if(va >= MAXVA)
return 0;
pte = walk(pagetable, va, 0);
if(pte == 0)
return 0;
if((*pte & PTE_V) == 0)
return 0;
pa = PTE2PA(*pte);
return pa;
}
第一個thread測試
這里借鑒了杜克大學(xué)寫的一個測試,如果完成了clone實現(xiàn),應(yīng)該test1 可以跑過,但是會在exit的時候拋錯。因為我們那邊代碼還沒改完。
#include "kernel/types.h"
#include "user.h"
#undef NULL
#define NULL ((void*)0)
#define PGSIZE (4096)
int ppid;
int global = 0;
int res1 = 0;
int res2 = 0;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(1); \
}
void
exittest(void *arg1, void * arg2){
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
res1 = int1;
res2 = int2;
// while(1){;}
exit(0);
}
void
emptytest(void *arg1, void* arg2) {
// int i;
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
int1 = int2 + int1;
// assert(getpid() == ppid);
exit(0);
}
void sbrktest(void* arg1, void* arg2) {
char* b = sbrk(65536);
// printf("sbrk end\n");
for (int i = 0; i < 4096000; i++) {
b[i % 65536] = 0;
}
exit(0);
}
void threadinthread(void* arg1, void* arg2) {
int int1 = *(int*) arg1;
if (int1 == 1234) {
// create a new thread
int a1 = 0, a2 = 0;
int threadid = thread_create(threadinthread, &a1, &a2);
assert(threadid > ppid);
}
for (int i = 0; i < 4096000; i++) {
int1++;
}
while(1);
exit(0);
}
void
stacktest(void *arg1, void* arg2) {
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
assert(int1 == 1);
assert(int2 == 2);
int1 = int2 + int1;
assert(int1 == 3);
exit(0);
}
void
heaptest(void *arg1, void* arg2) {
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
assert(int1 == 1);
assert(int2 == 2);
assert(global == 0);
global++;
assert(global == 1);
exit(0);
}
//test1: thread create function
int test1(){
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(emptytest, &arg1, &arg2);
int thread_pid2 = thread_create(emptytest, &arg1, &arg2);
assert(thread_pid1 > ppid);
assert(thread_pid2 > ppid);
printf("TEST1 PASSED\n");
return 0;
}
//test2: thread join function
int test2(){
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
printf("TEST2 PASSED\n");
return 0;
}
//test3: shared address space
int test3(){
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(stacktest, &arg1, &arg2);
int thread_pid2 = thread_create(heaptest, &arg1, &arg2);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
assert(arg1 == 1);
assert(arg2 == 2);
assert(global == 1);
printf("TEST3 PASSED\n");
return 0;
}
//test4: wait/exit
int test4(){
int pid = fork();
if(pid == 0){
ppid = getpid();
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(exittest, &arg1, &arg2);
int thread_pid2 = thread_create(exittest, &arg1, &arg2);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
assert(res1 == 1);
assert(res2 == 2);
assert(global == 1);
exit(0);
}
else{
int status;
wait(&status);
assert(status == 0);
assert(res1 == 0);
assert(res2 == 0);
printf("TEST4 PASSED\n");
return 0;
}
}
//test5: shared size
int test5() {
int thread_pid1 = thread_create(sbrktest, 0, 0);
int thread_pid2 = thread_create(sbrktest, 0, 0);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
thread_join();
thread_join();
printf("TEST5 PASSED\n");
return 0;
}
//test6: thread in thread
int test6() {
int pid = fork();
if (pid == 0) {
int arg1 = 1234;
int thread_pid1 = thread_create(threadinthread, &arg1, 0);
sleep(20);
assert(thread_pid1 > ppid);
exit(0);
} else {
wait(0);
printf("TEST6 PASSED\n");
}
return 0;
}
int
main(int argc, char *argv[])
{
ppid = getpid();
test1();
test2();
test3();
test4();
test5();
test6();
exit(0);
}
7 完成膠水函數(shù)
為了能夠讓編譯通過,我們需要把一些框架性的代碼給完成??梢韵萺eturn 0;
sysproc.c
uint64
sys_clone(void)
{
uint64 fcn, arg1, arg2, stack;
argaddr(0, &fcn);
argaddr(1, &arg1);
argaddr(2, &arg2);
argaddr(3, &stack);
return clone((void *)fcn, (void *)arg1, (void *)arg2, (void *)stack);
}
uint64
sys_join(void)
{
uint64 stack;
argaddr(0, &stack);
return join((void **)stack);
}
ulib.c
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
void *stack = malloc(3 * 4096);
uint64 addr = PGROUNDUP((uint64) stack);
return clone(start_routine, arg1, arg2, (void *)addr);
}
int thread_join(){
return 0;
}
測試效果:

8 修正exit
根據(jù)之前課程我們可以知道exit 負(fù)責(zé)設(shè)置退出狀態(tài),之后會由父進(jìn)程再wait的時候,去釋放資源。
上面的錯誤是因為,我們?nèi)魏我粋€線程退出的時候,會釋放內(nèi)存空間。因為
那么其他線程正在運行的時候,pagetable就會錯亂。所以我們需要保證,在線程退出的時候不要釋放資源。
我們的實現(xiàn)方案是,只有當(dāng)?shù)谝粋€進(jìn)程退出時,進(jìn)行釋放資源。其他線程基于這個進(jìn)程創(chuàng)建出來不釋放資源。
如果進(jìn)程最先退出,他會KILL掉其他所有線程,然后等待他們完成,自己再退出。
這塊代碼可以借鑒freeproc or kill; 我們來看下exit的實現(xiàn)改動:
8.1 線程exit
首先reparent這一步,我們可以只要進(jìn)程去做就可以了。原因線程的孩子線程不需要交給init進(jìn)程去wait 去釋放資源。
// Pass p's abandoned children to init.
// Caller must hold wait_lock.
void
reparent(struct proc *p)
{
struct proc *pp;
for(pp = proc; pp < &proc[NPROC]; pp++){
if(pp->parent == p && !pp->isthread){
pp->parent = initproc;
wakeup(initproc);
}
}
}
然后在exit里,釋放掉clone創(chuàng)建出的trap_va
p->xstate = status;
p->state = ZOMBIE;
// unmap since we map trap_va in join
if (p->isthread) {
uvmunmap(p->pagetable, p->trap_va, 1, 0);
}
wait這邊,只考慮進(jìn)程,而非線程
// Wait for a child process to exit and return its pid.
// Return -1 if this process has no children.
int
wait(uint64 addr)
{
struct proc *pp;
int havekids, pid;
struct proc *p = myproc();
acquire(&wait_lock);
for(;;){
// Scan through table looking for exited children.
havekids = 0;
for(pp = proc; pp < &proc[NPROC]; pp++){
// wait only consider process
if(pp->parent == p && !pp->isthread){
freeproc 線程不需要free pagetable
// free a proc structure and the data hanging from it,
// including user pages.
// p->lock must be held.
static void
freeproc(struct proc *p)
{
// p->tshared->sz in p->trapframe, so move proc_freepagetable before
if(p->pagetable && !p->isthread)
proc_freepagetable(p->pagetable, p->tshared->sz);
p->pagetable = 0;
if(p->trapframe)
kfree((void*)p->trapframe);
p->trapframe = 0;
p->tshared = 0;
p->pid = 0;
p->parent = 0;
p->name[0] = 0;
p->chan = 0;
p->killed = 0;
p->xstate = 0;
p->tstack = 0;
p->isthread = 0;
p->trap_va = 0;
p->state = UNUSED;
}
8.2 進(jìn)程exit
// when a process (not a thread) calls exit, all threads of this process should be exit
void
tpkill(struct proc *curproc)
{
struct proc *p;
int havethreads;
acquire(&wait_lock);
// make all the threads in group to die (all process with same pid will be killed)
for(p = proc; p < &proc[NPROC]; p++){
if(p->parent == curproc && p->isthread){
acquire(&p->lock);
p->killed = 1;
if(p->state == SLEEPING) p->state = RUNNABLE;
release(&p->lock);
}
}
// now let all the threads finish and wait for them become zombie
for(;;){
havethreads = 0;
for(p = proc; p < &proc[NPROC]; p++){
if(p->parent != curproc || !p->isthread) continue;
// thread in group is not died yet so suspend untill it dies.
if(p->state != ZOMBIE){
havethreads = 1;
break;
} else {
acquire(&p->lock);
freeproc(p);
release(&p->lock);
}
}
// group leader doesn't have any threads
if(!havethreads){
break;
}
// sleep for an exisiting thread in group to be killed
sleep(curproc, &wait_lock);
}
release(&wait_lock);
}
然后如果是線程,就要看看它下面有沒有線程還活著。
// Exit the current process. Does not return.
// An exited process remains in the zombie state
// until its parent calls wait().
void
exit(int status)
{
struct proc *p = myproc();
if(p == initproc)
panic("init exiting");
// Close all open files.
for(int fd = 0; fd < NOFILE; fd++){
if(p->ofile[fd]){
struct file *f = p->ofile[fd];
fileclose(f);
p->ofile[fd] = 0;
}
}
begin_op();
iput(p->cwd);
end_op();
p->cwd = 0;
tpkill(p);
...
}
test1 通過

9 實現(xiàn) join
基本照抄wait函數(shù),針對isthread 做一些修改
int
join(void **stack)
{
struct proc *pp;
int havekids, pid;
struct proc *p = myproc();
acquire(&wait_lock);
for(;;){
// Scan through table looking for exited children.
havekids = 0;
for(pp = proc; pp < &proc[NPROC]; pp++){
if(pp->parent == p && pp->isthread){
acquire(&pp->lock);
havekids = 1;
if(pp->state == ZOMBIE){
pid = pp->pid;
if(stack != 0 && copyout(p->pagetable, (uint64)stack, (char *)&pp->tstack,
sizeof(pp->tstack)) < 0) {
release(&pp->lock);
release(&wait_lock);
return -1;
}
// reset guard page with PTE_U
uvmset(pp->pagetable, pp->tstack);
freeproc(pp);
release(&pp->lock);
release(&wait_lock);
return pid;
}
release(&pp->lock);
}
}
if(!havekids || p->killed){
release(&wait_lock);
return -1;
}
sleep(p, &wait_lock);
}
}
我們之前在 clone時設(shè)置的guard page,因為之后要還給用戶態(tài)的內(nèi)存使用,所以需要把PTE_U重新設(shè)置上
void
uvmset(pagetable_t pagetable, uint64 va)
{
pte_t *pte;
pte = walk(pagetable, va, 0);
if(pte == 0)
panic("uvmclear");
*pte |= PTE_U;
}
int thread_join(){
void *stack;
int pid = join(&stack);
free(stack);
return pid;
}
不過還存在一個問題就是,free需要拿到的是malloc 分配的起始地址,但是我們在malloc 時,做了一個PGROUNDUP. 那么free其實沒法真正的去free之前malloc的內(nèi)存。
10 新增malloc_align
為了解決上述問題,我們需要修改下umalloc.c, 增加一個malloc_align的函數(shù)。他會幫助找到一個4096對齊的地址空間,然后返回。
void*
malloc_align(uint oribytes)
{
// [header| ] <- [header| ] <- [header| ] <- [header| ]
// p ret right prevp
Header *p, *prevp, *ret, *right;
uint nunits, ounits;
// we need a larger block because of align requirement
uint nbytes = oribytes + 4096;
ounits = (oribytes + sizeof(Header) - 1)/sizeof(Header) + 1;
nunits = (nbytes + sizeof(Header) - 1)/sizeof(Header) + 1;
if((prevp = freep) == 0){
base.s.ptr = freep = prevp = &base;
base.s.size = 0;
}
for(p = prevp->s.ptr; ; prevp = p, p = p->s.ptr){
if(p->s.size >= nunits){
uint64 paddr = (uint64) p;
uint64 align_addr = PGROUNDUP(paddr + sizeof(Header)) - sizeof(Header);
uint sz = (align_addr - paddr)/sizeof(Header), psz = p->s.size;
ret = (Header *)align_addr;
ret->s.size = ounits;
right = ret + ounits;
right->s.size = psz - ounits - sz;
if (sz == 0) {
right->s.ptr = p->s.ptr;
} else {
right->s.ptr = p;
p->s.size = sz;
}
prevp->s.ptr = right;
freep = prevp;
return (void*)(ret + 1);
}
if(p == freep)
if((p = morecore(nunits)) == 0)
return 0;
}
}
然后在thread_create 使用malloc_align 去分配對齊的內(nèi)存。
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
return clone(start_routine, arg1, arg2, malloc_align(8192));
}
duke 的6個測試全部通過

11 用戶級別的鎖
我們之前用的spinlock 是kernel 層面的。但是我們現(xiàn)在在用戶態(tài)可以進(jìn)行多線程編程了,所以我們需要支持用戶級別的鎖。來保證線程安全。比如我們再調(diào)用malloc_align時,如果2個線程一起操作勢必會出問題,所以我們需要用鎖來保護(hù)。
首先我們實現(xiàn)一個原子的讀之后增加。
static inline int fetch_and_add(int* variable, int value) {
int result;
asm volatile (
"amoadd.w %0, %2, (%1)"
: "=r" (result)
: "r" (variable), "r" (value)
: "memory"
);
return result;
}
然后鎖里有2個變量,一個是獲取到的鎖的ticket, 然后當(dāng)前鎖著的turn
typedef struct _lock_t {
int ticket;
int turn;
} lock_t;
比如第一個線程上鎖,拿到ticket = 0, turn = 0;
第二個線程嘗試上同一把鎖,拿到ticket = 1, turn = 0; 開始spin等待,turn = 1;
第一個線程釋放鎖, turn = turn + 1;
第二線程spin等待退出。
void lock_init(lock_t *lock) {
lock->ticket = 0;
lock->turn = 0;
}
void lock_acquire(lock_t *lock) {
int myturn = fetch_and_add(&lock->ticket, 1);
while( fetch_and_add(&lock->turn, 0) != myturn ) {
;
}
}
void lock_release(lock_t *lock) {
lock->turn = lock->turn + 1;
}
然后對thread_create進(jìn)行上鎖保護(hù)
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
lock_acquire(&thread_create_lock);
void *stack = malloc_align(8192);
lock_release(&thread_create_lock);
return clone(start_routine, arg1, arg2, stack);
}
12 更多的測試
clone, join 測試9個
static inline uint64
xchg(volatile uint64 *addr, uint64 newval) {
uint64 result;
uint64 temp;
asm volatile (
"1: lr.d %0, %2 \n" // Load Reserved from addr
" mv %1, %3 \n" // Move newval to temp
" sc.d %1, %1, %2 \n" // Store Conditional temp to addr
" bnez %1, 1b \n" // If sc.d failed, retry
: "=&r" (result), "=&r" (temp), "+A" (*addr)
: "r" (newval)
: "memory");
return result;
}
#include "kernel/types.h"
#include "user/user.h"
#include "kernel/fcntl.h"
#include "kernel/riscv.h"
#undef NULL
#define NULL ((void*)0)
int ppid;
volatile int arg1 = 11;
volatile int arg2 = 22;
volatile int global = 1;
volatile uint64 newfd = 0;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(0); \
}
void worker(void *arg1, void *arg2);
void worker2(void *arg1, void *arg2);
void worker3(void *arg1, void *arg2);
void worker4(void *arg1, void *arg2);
void worker5(void *arg1, void *arg2);
void worker6(void *arg1, void *arg2);
/* clone and verify that address space is shared */
void test1(void *stack)
{
int clone_pid = clone(worker, 0, 0, stack);
assert(clone_pid > 0);
while(global != 5);
printf("TEST1 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
global = 1;
}
/* clone and play with the argument */
void test2(void *stack)
{
int clone_pid = clone(worker2, (void*)&arg1, (void*)&arg2, stack);
assert(clone_pid > 0);
while(global != 33);
assert(arg1 == 44);
assert(arg2 == 55);
printf("TEST2 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
/* clone copies file descriptors, but doesn't share */
void test3(void *stack)
{
int fd = open("tmp", O_WRONLY|O_CREATE);
assert(fd == 3);
int clone_pid = clone(worker3, 0, 0, stack);
assert(clone_pid > 0);
while(!newfd);
assert(write(newfd, "goodbye\n", 8) == -1);
printf("TEST3 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
/* clone with bad stack argument */
void test4(void *stack)
{
assert(clone(worker4, 0, 0, stack+4) == -1);
printf("TEST4 PASSED\n");
}
/* clone and join syscalls */
void test5(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
assert(stack == join_stack);
assert(global == 2);
printf("TEST5 PASSED\n");
}
/* join argument checking */
void test6(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
sbrk(PGSIZE);
void **join_stack = (void**) ((uint64)sbrk(0) - 8);
assert(join((void**)((uint64)join_stack + 4)) == -1);
assert(join(join_stack) == clone_pid);
assert(stack == *join_stack);
assert(global == 2);
printf("TEST6 PASSED\n");
}
/* join should not handle child processes (forked) */
void test7(void *stack)
{
global = 1;
int fork_pid = fork();
if(fork_pid == 0) {
exit(0);
}
assert(fork_pid > 0);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == -1);
assert(wait(0) > 0);
printf("TEST7 PASSED\n");
}
/* join, not wait, should handle threads */
void test8(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
sleep(10);
assert(wait(0) == -1);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
assert(stack == join_stack);
assert(global == 2);
printf("TEST8 PASSED\n");
}
/* set up stack correctly (and without extra items) */
void test9(void *stack)
{
global = 1;
int clone_pid = clone(worker6, stack, 0, stack);
assert(clone_pid > 0);
while(global != 5);
printf("TEST9 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
void (*functions[])() = {test1, test2, test3, test4, test5, test6, test7, test8, test9};
int
main(int argc, char *argv[])
{
int len = sizeof(functions) / sizeof(functions[0]);
for(int i = 0; i < len; i++) {
ppid = getpid();
void *stack, *p = malloc(PGSIZE * 2);
assert(p != NULL);
stack = ((uint64)p % PGSIZE) ? (p + (PGSIZE - (uint64)p % PGSIZE)) : p;
(*functions[i])(stack);
free(p);
}
exit(0);
}
void
worker(void *arg1, void *arg2) {
assert(global == 1);
global = 5;
exit(0);
}
void
worker2(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
*(int*)arg1 = 44;
*(int*)arg2 = 55;
assert(global == 1);
global = tmp1 + tmp2;
exit(0);
}
void
worker3(void *arg1, void *arg2) {
assert(write(3, "hello\n", 6) == 6);
xchg(&newfd, open("tmp2", O_WRONLY|O_CREATE));
exit(0);
}
void
worker4(void *arg1, void *arg2) {
exit(0);
}
void
worker5(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
exit(0);
}
void
worker6(void *arg1, void *arg2) {
// arg1 -> top stack
// arg1 -8 -> ra
// arg1 -16 -> fp
// arg1 - 24 -> a0
// arg1 - 32 -> a1
assert(*((uint64*) (arg1 + 2 * PGSIZE - 8)) == 0xffffffffffffffff);
assert((uint64)&arg2 == ((uint64)arg1 + 2 * PGSIZE - 32));
assert((uint64)&arg1 == ((uint64)arg1 + 2 * PGSIZE - 24));
global = 5;
exit(0);
}
void
worker7(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
exit(0);
}
thread_create, thread_join 測試13個
#include "kernel/types.h"
#include "user.h"
#include "kernel/fcntl.h"
#include "kernel/riscv.h"
#undef NULL
#define NULL ((void*)0)
int ppid;
int global = 1;
uint64 size = 0;
lock_t lock, lock2;
int num_threads = 30;
int loops = 10;
int* global_arr;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(0); \
}
void worker(void *arg1, void *arg2);
void worker2(void *arg1, void *arg2);
void worker3(void *arg1, void *arg2);
void worker4(void *arg1, void *arg2);
void worker5(void *arg1, void *arg2);
void worker6(void *arg1, void *arg2);
void worker7(void *arg1, void *arg2);
void merge_sort(void *array, void *size);
void worker9(void *array, void *size);
void worker10(void *array, void *size);
void worker11(void *array, void *size);
void worker12(void *array, void *size);
void worker13(void *array, void *size);
/* thread user library functions */
void test1()
{
int arg1 = 35;
int arg2 = 42;
int thread_pid = thread_create(worker, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 2);
printf("TEST1 PASSED\n");
}
/* memory leaks from thread library? */
void test2()
{
int i, thread_pid, join_pid;
for(i = 0; i < 2000; i++) {
global = 1;
thread_pid = thread_create(worker2, 0, 0);
assert(thread_pid > 0);
join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 5);
assert((uint64)sbrk(0) < (150 * 4096) && "shouldn't even come close");
}
printf("TEST2 PASSED\n");
}
/* check that address space size is updated in threads */
void test3()
{
global = 0;
int arg1 = 11, arg2 = 22;
lock_init(&lock);
lock_init(&lock2);
lock_acquire(&lock);
lock_acquire(&lock2);
for (int i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker3, &arg1, &arg2);
assert(thread_pid > 0);
}
size = (uint64)sbrk(0);
while (global < num_threads) {
lock_release(&lock);
sleep(2);
lock_acquire(&lock);
}
global = 0;
sbrk(10000);
size = (uint64)sbrk(0);
lock_release(&lock);
while (global < num_threads) {
lock_release(&lock2);
sleep(2);
lock_acquire(&lock2);
}
lock_release(&lock2);
for (int i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
printf("TEST3 PASSED\n");
}
/* multiple threads with some depth of function calls */
uint fib(uint n) {
if (n == 0) {
return 0;
} else if (n == 1) {
return 1;
} else {
return fib(n - 1) + fib(n - 2);
}
}
void test4()
{
assert(fib(28) == 317811);
int arg1 = 11, arg2 = 22;
for (int i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker4, &arg1, &arg2);
assert(thread_pid > 0);
}
for (int i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
printf("TEST4 PASSED\n");
}
/* no exit call in thread, should trap at bogus address */
void test5()
{
int arg1 = 42, arg2 = 24;
int thread_pid = thread_create(worker5, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 2);
printf("TEST5 PASSED\n");
}
/* test lock correctness */
void test6()
{
global = 0;
lock_init(&lock);
int i;
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker6, 0, 0);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert(global == num_threads * loops);
printf("TEST6 PASSED\n");
}
/* nested thread user library functions */
void test7()
{
int arg1 = 35;
int arg2 = 42;
int thread_pid = thread_create(worker7, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 3);
printf("TEST7 PASSED\n");
}
/* merge sort using nested threads */
void test8()
{
/*
1. Create global array and populate it
2. invoke merge sort (array ptr, size)
Merge sort:
0. base case - size = 1 --> return
1. thread create with merge sort (array left, size/2)
2. thread create with merge sort (array + size/2, size - size/2)
3. join both threads
4. Merge function
*/
int size = 11;
global_arr = (int*)malloc(size * sizeof(int));
for(int i = 0; i < size; i++){
global_arr[i] = size - i - 1;
}
int thread_pid = thread_create(merge_sort, global_arr, &size);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global_arr[0] == 0);
assert(global_arr[5] == 5);
assert(global_arr[10] == 10);
printf("TEST8 PASSED\n");
}
/* test lock correctness using nested threads */
void test9()
{
global = 0;
lock_init(&lock);
int i;
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker9, 0, 0);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert(global == num_threads * 2);
printf("TEST9 PASSED\n");
}
/* no exit call in nested thread, should trap at bogus address */
void test10()
{
int arg1 = 42, arg2 = 24;
int thread_pid = thread_create(worker10, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 3);
printf("TEST10 PASSED\n");
}
/* check that address space size is updated in threads */
void test11()
{
int arg1 = 11, arg2 = 22;
size = (uint64)sbrk(0);
int thread_pid = thread_create(worker11, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid > 0);
printf("TEST11 PASSED\n");
}
/* check that thread stack overflow, should trap */
void test12()
{
int arg1 = 11, arg2 = 22;
size = (uint64)sbrk(0);
int thread_pid = thread_create(worker12, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid > 0);
assert(global == 3);
printf("TEST12 PASSED\n");
}
/* check no malloc stack race condition */
void test13()
{
num_threads = 30;
int i;
int arg1 = 35;
int arg2 = 42;
uint64 origin = (uint64)sbrk(0);
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker13, &arg1, &arg2);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert((uint64)sbrk(0) < (origin + (16 + num_threads * 2 * 3) * 4096) && "shouldn't even come close");
printf("TEST13 PASSED\n");
}
void (*functions[])() = {test13};
int
main(int argc, char *argv[])
{
int len = sizeof(functions) / sizeof(functions[0]);
for(int i = 0; i < len; i++) {
global = 1;
ppid = getpid();
(*functions[i])();
}
exit(0);
}
void
worker(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
exit(0);
}
void
worker2(void *arg1, void *arg2) {
assert(global == 1);
global += 4;
exit(0);
}
void
worker3(void *arg1, void *arg2) {
lock_acquire(&lock);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock);
lock_acquire(&lock2);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock2);
exit(0);
}
void
worker4(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 11);
assert(tmp2 == 22);
assert(global == 1);
assert(fib(2) == 1);
assert(fib(3) == 2);
assert(fib(9) == 34);
assert(fib(15) == 610);
exit(0);
}
void
worker5(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
// no exit() in thread
}
void
worker6(void *arg1, void *arg2) {
int i, j, tmp;
for (i = 0; i < loops; i++) {
lock_acquire(&lock);
tmp = global;
for(j = 0; j < 50; j++); // take some time
global = tmp + 1;
lock_release(&lock);
}
exit(0);
}
void nested_worker(void *arg1, void *arg2){
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 2);
global++;
exit(0);
}
void
worker7(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
int nested_thread_pid = thread_create(nested_worker, &arg1_int, &arg2_int);
int nested_join_pid = thread_join();
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
void merge(int* array, int* array_right,int size_left, int size_right,int*temp_array){
int i = 0;
int j = 0;
int k = 0;
while(i < size_left && j < size_right){
if(array[i] < array_right[j]){
temp_array[k] = array[i];
i++;
}
else{
temp_array[k] = array_right[j];
j++;
}
k++;
}
while(i < size_left){
temp_array[k] = array[i];
i++;
k++;
}
while(j < size_right){
temp_array[k] = array_right[j];
j++;
k++;
}
for(int i = 0; i < size_left + size_right; i++){
array[i] = temp_array[i];
}
}
void merge_sort(void *arg1, void *arg2) {
int *array = (int*)arg1;
int size = *(int*)arg2;
if (size==1){
exit(0);
}
int size_left = size/2;
int size_right = size-size/2;
int* array_right = (int*)(array + size_left);
int nested_thread_pid_l = thread_create(merge_sort, array, &size_left);
int nested_thread_pid_r = thread_create(merge_sort, array_right, &size_right);
int nested_join_pid_1 = thread_join();
int nested_join_pid_2 = thread_join();
int* temp_array = malloc(size*sizeof(int));
merge(array,array_right,size_left,size_right,temp_array);
free(temp_array);
assert(nested_thread_pid_l == nested_join_pid_1 || nested_thread_pid_l == nested_join_pid_2);
assert(nested_thread_pid_r == nested_join_pid_1 || nested_thread_pid_r == nested_join_pid_2);
exit(0);
}
void nest_worker(void *arg1,void *arg2){
int j;
lock_acquire(&lock);
for(j=0;j<50;j++);
global++;
lock_release(&lock);
exit(0);
}
void
worker9(void *arg1, void *arg2) {
lock_acquire(&lock);
int j;
for(j = 0; j < 50; j++); // take some time
global++;
lock_release(&lock);
int nested_thread_pid = thread_create(nest_worker, 0, 0);
assert(nested_thread_pid > 0);
int nested_join_pid = thread_join();
assert(nested_join_pid > 0);
assert(nested_thread_pid==nested_join_pid);
exit(0);
}
void nested_worker2(void *arg1, void *arg2){
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 42);
assert(arg2_int == 24);
assert(global == 2);
global++;
// no exit() in thread
}
void
worker10(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
int nested_thread_pid = thread_create(nested_worker2, &tmp1, &tmp2);
assert(nested_thread_pid > 0);
for(int j=0;j<10000;j++);
int nested_join_pid = thread_join();
assert(nested_join_pid)
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
void nest_worker3(void *arg1, void *arg2)
{
lock_acquire(&lock);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock);
lock_acquire(&lock2);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock2);
exit(0);
}
void worker11(void *arg1, void *arg2) {
num_threads = 1;
lock_init(&lock);
lock_init(&lock2);
lock_acquire(&lock);
lock_acquire(&lock2);
int nested_thread_id = thread_create(nest_worker3, 0, 0);
assert(nested_thread_id > 0);
size = (uint64)sbrk(0);
while (global < num_threads) {
lock_release(&lock);
sleep(2);
lock_acquire(&lock);
}
global = 0;
sbrk(10000);
size = (uint64)sbrk(0);
lock_release(&lock);
while (global < num_threads) {
lock_release(&lock2);
sleep(2);
lock_acquire(&lock2);
}
lock_release(&lock2);
int nested_join_pid = thread_join();
assert(nested_join_pid > 0);
exit(0);
}
void call_forever()
{
int k = 3;
global = k;
call_forever();
}
void
worker12(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 11);
assert(tmp2 == 22);
assert(global == 1);
call_forever();
exit(0);
}
void empty(void *arg1, void *arg2)
{
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
exit(0);
}
void
worker13(void *arg1, void *arg2) {
sleep(3);
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
sleep(3);
int nested_thread_pid = thread_create(empty, &arg1_int, &arg2_int);
int nested_join_pid = thread_join();
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
測試結(jié)果

跑一下回歸測試 usertests
