diff options
Diffstat (limited to 'src/kernel/mem/virt.c')
-rw-r--r-- | src/kernel/mem/virt.c | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/src/kernel/mem/virt.c b/src/kernel/mem/virt.c index f0dca06..bff4b5e 100644 --- a/src/kernel/mem/virt.c +++ b/src/kernel/mem/virt.c @@ -1,5 +1,6 @@ #include <kernel/arch/generic.h> #include <kernel/mem/virt.h> +#include <kernel/panic.h> #include <kernel/util.h> #include <shared/mem.h> @@ -52,12 +53,13 @@ bool virt_iter_next(struct virt_iter *iter) { return true; } -bool virt_cpy( +size_t virt_cpy( struct pagedir *dest_pages, void __user *dest, - struct pagedir *src_pages, const void __user *src, size_t length) + struct pagedir *src_pages, const void __user *src, + size_t length, struct virt_cpy_error *err) { struct virt_iter dest_iter, src_iter; - size_t cur_len; + size_t total = 0, partial; virt_iter_new(&dest_iter, dest, length, dest_pages, true, true); virt_iter_new( &src_iter, (userptr_t)src, length, src_pages, true, false); @@ -65,19 +67,26 @@ bool virt_cpy( src_iter.frag_len = 0; for (;;) { - if (dest_iter.frag_len <= 0) - if (!virt_iter_next(&dest_iter)) break; - if ( src_iter.frag_len <= 0) - if (!virt_iter_next( &src_iter)) break; + if (dest_iter.frag_len <= 0 && !virt_iter_next(&dest_iter)) break; + if ( src_iter.frag_len <= 0 && !virt_iter_next( &src_iter)) break; - cur_len = min(src_iter.frag_len, dest_iter.frag_len); - memcpy(dest_iter.frag, src_iter.frag, cur_len); + partial = min(src_iter.frag_len, dest_iter.frag_len); + total += partial; + memcpy(dest_iter.frag, src_iter.frag, partial); - dest_iter.frag_len -= cur_len; - dest_iter.frag += cur_len; - src_iter.frag_len -= cur_len; - src_iter.frag += cur_len; + dest_iter.frag_len -= partial; + dest_iter.frag += partial; + src_iter.frag_len -= partial; + src_iter.frag += partial; } - return !(dest_iter.error || src_iter.error); + if (err) { + err->read_fail = src_iter.error; + err->write_fail = dest_iter.error; + } + if (src_iter.error || dest_iter.error) + assert(total != length); + else + assert(total == length); + return total; } |