summaryrefslogtreecommitdiff
path: root/src/kernel/mem/virt.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernel/mem/virt.c')
-rw-r--r--src/kernel/mem/virt.c37
1 files changed, 23 insertions, 14 deletions
diff --git a/src/kernel/mem/virt.c b/src/kernel/mem/virt.c
index f0dca06..bff4b5e 100644
--- a/src/kernel/mem/virt.c
+++ b/src/kernel/mem/virt.c
@@ -1,5 +1,6 @@
#include <kernel/arch/generic.h>
#include <kernel/mem/virt.h>
+#include <kernel/panic.h>
#include <kernel/util.h>
#include <shared/mem.h>
@@ -52,12 +53,13 @@ bool virt_iter_next(struct virt_iter *iter) {
return true;
}
-bool virt_cpy(
+size_t virt_cpy(
struct pagedir *dest_pages, void __user *dest,
- struct pagedir *src_pages, const void __user *src, size_t length)
+ struct pagedir *src_pages, const void __user *src,
+ size_t length, struct virt_cpy_error *err)
{
struct virt_iter dest_iter, src_iter;
- size_t cur_len;
+ size_t total = 0, partial;
virt_iter_new(&dest_iter, dest, length, dest_pages, true, true);
virt_iter_new( &src_iter, (userptr_t)src, length, src_pages, true, false);
@@ -65,19 +67,26 @@ bool virt_cpy(
src_iter.frag_len = 0;
for (;;) {
- if (dest_iter.frag_len <= 0)
- if (!virt_iter_next(&dest_iter)) break;
- if ( src_iter.frag_len <= 0)
- if (!virt_iter_next( &src_iter)) break;
+ if (dest_iter.frag_len <= 0 && !virt_iter_next(&dest_iter)) break;
+ if ( src_iter.frag_len <= 0 && !virt_iter_next( &src_iter)) break;
- cur_len = min(src_iter.frag_len, dest_iter.frag_len);
- memcpy(dest_iter.frag, src_iter.frag, cur_len);
+ partial = min(src_iter.frag_len, dest_iter.frag_len);
+ total += partial;
+ memcpy(dest_iter.frag, src_iter.frag, partial);
- dest_iter.frag_len -= cur_len;
- dest_iter.frag += cur_len;
- src_iter.frag_len -= cur_len;
- src_iter.frag += cur_len;
+ dest_iter.frag_len -= partial;
+ dest_iter.frag += partial;
+ src_iter.frag_len -= partial;
+ src_iter.frag += partial;
}
- return !(dest_iter.error || src_iter.error);
+ if (err) {
+ err->read_fail = src_iter.error;
+ err->write_fail = dest_iter.error;
+ }
+ if (src_iter.error || dest_iter.error)
+ assert(total != length);
+ else
+ assert(total == length);
+ return total;
}