From 5a416e70402bbbfaf8b2790a12b50c0ac159ec3f Mon Sep 17 00:00:00 2001
From: dzwdz
Date: Tue, 27 Dec 2022 19:14:16 +0100
Subject: amd64/ata: poll properly

---
 src/kernel/arch/amd64/ata.c | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

(limited to 'src/kernel/arch/amd64/ata.c')

diff --git a/src/kernel/arch/amd64/ata.c b/src/kernel/arch/amd64/ata.c
index 4c1f2e2..453ee6d 100644
--- a/src/kernel/arch/amd64/ata.c
+++ b/src/kernel/arch/amd64/ata.c
@@ -47,13 +47,14 @@ static void ata_driveselect(int drive, int lba) {
 	port_out8(ata_iobase(drive) + DRV, v);
 }
 
-static int ata_poll(int drive, int timeout) {
+static int ata_poll(int drive, int timeout, bool pio) {
 	uint16_t iobase = ata_iobase(drive);
 	/* if timeout < 0, cycle forever */
 	while (timeout < 0 || timeout--) {
 		uint8_t v = port_in8(iobase + STATUS);
 		if (v & 0x80) continue; /* BSY */
 		if (v & 0x40) return 0; /* RDY */
+		if (pio && (v & 0x08)) return 0; /* DRQ */
 		// TODO check for ERR
 	}
 	return -1;
@@ -64,7 +65,7 @@ static void ata_softreset(int drive) {
 	port_out8(iobase + CTRL, 4);
 	port_out8(iobase + CTRL, 0);
 	ata_400ns();
-	ata_poll(drive, 10000);
+	ata_poll(drive, 10000, false);
 }
 
 static void ata_detecttype(int drive) {
@@ -146,7 +147,8 @@ int ata_read(int drive, void *buf, size_t len, size_t off) {
 
 	int iobase = ata_iobase(drive);
 	ata_driveselect(drive, lba);
-	port_out8(iobase + FEAT, 0); /* supposedly pointless */
+	ata_400ns();
+	ata_poll(drive, -1, false);
 	port_out8(iobase + SCNT, cnt);
 	port_out8(iobase + LBAlo, lba);
 	port_out8(iobase + LBAmid, lba >> 8);
@@ -158,7 +160,7 @@ int ata_read(int drive, void *buf, size_t len, size_t off) {
 			uint16_t s;
 			char b[2];
 		} d;
-		ata_poll(drive, -1);
+		ata_poll(drive, -1, true);
 		for (int j = 0; j < 256; j++) {
 			d.s = port_in16(iobase);
 			for (int k = 0; k < 2; k++) {
@@ -177,6 +179,8 @@ int ata_read(int drive, void *buf, size_t len, size_t off) {
 static void ata_rawwrite(int drive, const void *buf, uint32_t lba, uint32_t cnt) {
 	int iobase = ata_iobase(drive);
 	ata_driveselect(drive, lba);
+	ata_400ns();
+	ata_poll(drive, -1, false);
 	port_out8(iobase + FEAT, 0);
 	port_out8(iobase + SCNT, cnt);
 	port_out8(iobase + LBAlo, lba);
@@ -185,11 +189,14 @@ static void ata_rawwrite(int drive, const void *buf, uint32_t lba, uint32_t cnt)
 	port_out8(iobase + CMD, 0x30); /* WRITE SECTORS */
 
 	for (uint32_t i = 0; i < cnt; i++) {
-		ata_poll(drive, -1);
+		ata_poll(drive, -1, true);
 		for (int j = 0; j < 256; j++) {
 			port_out16(iobase, ((uint16_t*)buf)[i * 256 + j]);
 		}
 	}
+
+	ata_poll(drive, -1, false);
+	port_out8(iobase + CMD, 0xE7); /* CACHE FLUSH */
 }
 
 int ata_write(int drive, const void *buf, size_t len, size_t off) {
-- 
cgit v1.2.3