summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/kernel/arch/amd64/pagedir.c10
-rw-r--r--src/user/tests/main.c26
2 files changed, 32 insertions, 4 deletions
diff --git a/src/kernel/arch/amd64/pagedir.c b/src/kernel/arch/amd64/pagedir.c
index 6f5fd4e..f043289 100644
--- a/src/kernel/arch/amd64/pagedir.c
+++ b/src/kernel/arch/amd64/pagedir.c
@@ -15,6 +15,12 @@ static void *addr_validate(void *addr) {
return addr;
}
+static bool addr_canonical(const __user void *addr) {
+ const int addr_bits = 48;
+ uintptr_t n = (uintptr_t)addr >> addr_bits;
+ return (n == 0) || ((~n) << addr_bits == 0);
+}
+
struct pagedir *pagedir_new(void) {
struct pagedir *dir = page_alloc(1);
@@ -57,7 +63,7 @@ get_entry(struct pagedir *dir, const void __user *virt) {
pe_generic_t *pml4e, *pdpte, *pde, *pte;
const union virt_addr v = {.full = (void __user *)virt};
- // TODO check if sign extension is valid
+ if (!addr_canonical(virt)) return NULL;
pml4e = &dir->e[v.pml4];
if (!pml4e->present) return NULL;
@@ -88,7 +94,7 @@ void pagedir_map(struct pagedir *dir, void __user *virt, void *phys,
pe_generic_t *pml4e, *pdpte, *pde, *pte;
const union virt_addr v = {.full = virt};
- // TODO check if sign extension is valid
+ if (!addr_canonical(virt)) return;
pml4e = &dir->e[v.pml4];
if (!pml4e->present) {
diff --git a/src/user/tests/main.c b/src/user/tests/main.c
index b4587f9..ab6b4a8 100644
--- a/src/user/tests/main.c
+++ b/src/user/tests/main.c
@@ -189,13 +189,35 @@ static void test_malloc(void) {
}
static void test_efault(void) {
+ const char *str = "o, 16 characters";
+ char str2[16];
char *invalid = (void*)0x1000;
handle_t h;
+ memcpy(str2, str, 16);
+
assert((h = _syscall_open(tmpfilepath, strlen(tmpfilepath), OPEN_CREATE)));
- assert(_syscall_write(h, "dzwdz sucks ass!", 16, 0) == 16);
+ assert(_syscall_write(h, str, 16, 0) == 16);
+ assert(_syscall_write(h, str2, 16, 0) == 16);
+
assert(_syscall_write(h, invalid, 16, 0) == -EFAULT);
- assert(_syscall_write(h, "dzwdz is cool!!!", 16, 0) == 16);
+
+ /* x64 canonical pointers */
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x8000000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x1000000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x0100000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x0010000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x0001000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str ^ 0x0000800000000000), 16, 0) == -EFAULT);
+
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x8000000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x1000000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x0100000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x0010000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x0001000000000000), 16, 0) == -EFAULT);
+ assert(_syscall_write(h, (void*)((uintptr_t)str2 ^ 0x0000800000000000), 16, 0) == -EFAULT);
+
+ assert(_syscall_write(h, str, 16, 0) == 16);
close(h);
}