/*
 * pbltool.c - Talks to the PBL of the Amstrad E3 (Delta)
 *
 * Copyright 2005 Jonathan McDowell <noodles@earth.li>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <fcntl.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/poll.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <termios.h>
#include <unistd.h>

int debug = 1;

int sendpacket(int fd, unsigned char *buf, size_t len, size_t max)
{
	struct pollfd p;
	unsigned char c;
	uint8_t checksum;
	int i;
	int ret;
	size_t replylen;

	/* Start byte */
	c = 2;
	write(fd, &c, 1);
	//printf("%x, ",c);
	/* Not compressed */
	c = 0;
	write(fd, &c, 1);
	//printf("%x, ",c);
	/* Length */
	checksum = 0;
	c = len & 0xFF;
	checksum += c;
	write(fd, &c, 1);
	//printf("%x, ",c);
	c = (len >> 8) & 0xFF;
	checksum += c;
	write(fd, &c, 1);
	//printf("%x, ",c);

	write(fd, buf, len);
	for (i=0;i<len;i++)
		//printf("%x, ",buf[i]);

	/* Write checksum */
	for (i = 0; i < len; i++) {
		checksum += buf[i];
	}
	c = checksum;
	write(fd, &c, 1);
	//printf("%x\n",c);

	ret = 0;
	p.fd = fd;
	p.events = POLLIN;
	if (poll(&p, 1, 10000) == 0) {
		/* Timed out */
		ret = 1;
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1 || c != 2) {
			printf("Didn't get expected 0x02 header -- 0x%02X.\n",
					c);
			ret = 2;
		}
	}
		
	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1 || c != 0) {
			printf("Compressed return block?\n");
			ret = 2;
		}
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1) {
			printf("Couldn't read low byte of reply len.\n");
			ret = 2;
		}
		replylen = c;
		checksum = c;
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1) {
			printf("Couldn't read high byte of reply len.\n");
			ret = 2;
		}
		replylen += (c << 8);
		checksum += c;
	}

	if (!ret) {
		if (replylen < max) {
			while (replylen > 0) {
				i = read(fd, buf, replylen);
				if (i > 0) {
					if (debug) {
						printf("Got %d of %d bytes.\n",
							i, replylen);
					}
					replylen -= i;
					buf += i;
				}
			}
		}
		while (read(fd, &c, 1) != 1) ;
		if (debug) {
			printf("Checksum = 0x%02X\n", c);
		}
	}

	/*
	 * Flush input if error.
	 */
	if (ret > 1) {
		while (poll(&p, 1, 1000) != 0) {
			while (read(fd, &c, 1) > 0) {
				printf("Flushing: 0x%02X\n", c);
			}
		}
	}

	return ret;
}


int writeblock(int fd, uint32_t start, char *file)
{
	unsigned char buf[16384 + 8];
	int imgfd;
	uint16_t blocklen;
	size_t count;
printf("Entering writeblock\n");
	count = 0;
	imgfd = open(file, O_RDONLY);

	if (imgfd != -1) {
		while ((blocklen = read(imgfd, buf + 8, 1024)) > 0) {
			buf[0] = 5;
			buf[1] = 0;
			*(uint32_t *) (buf + 2) = start + count;
			//buf[6] = 1;
			//buf[7] = 0;
			*(uint16_t *) (buf + 6) = blocklen;
			/*printf("Trying to write %d bytes starting %02X "
					"at 0x%08X.\n",
					blocklen,
					buf[8],
					start + count);*/
			printf("sendpacket: %d ",
				sendpacket(fd, buf, blocklen + 8,
					sizeof(buf)));
			//printf("%02X %02X %02X %02X\n",
			//	buf[0], buf[1], buf[2], buf[3]);
			count += blocklen;
			printf("%d bytes transfered so far (%d KB) \n",count, count /1024);
		}
	}
	close(imgfd);

	return 0;
}

int execute(int fd, uint32_t start)
{
	unsigned char buf[10];
	
	buf[0] = 4;
	buf[1] = 0;
	*(uint32_t *) (buf + 2) = start;
	printf("sendpacket: %d ",
		sendpacket(fd, buf, 6,
			sizeof(buf)));
}

int readblock(int fd, uint32_t start, size_t len, char *file)
{
	uint32_t checksum;
	unsigned char buf[10];
	unsigned char c;
	int i;
	int fdout;

	fdout = open(file, O_CREAT | O_WRONLY, 0644);

	for (i = 0; i < len; i++) {
		if ((i % 16) == 0) {
			printf("\n0x%08X ", start + i);
		}
		buf[0] = 3;
		buf[1] = 0;
		*(uint32_t *) (buf + 2) = start + i;
		*(uint32_t *) (buf + 6) = 1;
		if (sendpacket(fd, buf, 10, sizeof(buf)) == 0) {
			c = *(uint32_t *)(buf + 2);
			printf("%02X ", c);
			write(fdout, &c, 1);
		}
	}

	close(fdout);
}


int readflashblock(int fd, uint32_t start, size_t len, char *file)
{
	uint32_t checksum;
	unsigned char buf[10];
	unsigned char c;
	int i;
	int fdout;

	fdout = open(file, O_CREAT | O_WRONLY, 0644);

	if (start > 0x00400000) {
		buf[0] = 3;
		buf[1] = 0;
		buf[2] = 0x00;
		buf[3] = 0x00;
		buf[4] = 0x40;
		buf[5] = 0x00;
		*(uint32_t *) (buf + 6) = start - 0x00400000;
		sendpacket(fd, buf, 10, sizeof(buf));
		checksum = *(uint32_t *)(buf + 2);
	} else {
		checksum = 0;
	}

	for (i = 0; i < len; i++) {
		if ((i % 16) == 0) {
			printf("\n0x%08X ", start + i);
		}
		buf[0] = 3;
		buf[1] = 0;
		buf[2] = 0x00;
		buf[3] = 0x00;
		buf[4] = 0x40;
		buf[5] = 0x00;
		*(uint32_t *) (buf + 6) = start - 0x00400000 + i + 1;
		if (sendpacket(fd, buf, 10, sizeof(buf)) == 0) {
			c = *(uint32_t *)(buf + 2) - checksum;
			printf("%02X ", c);
			checksum = *(uint32_t *)(buf + 2);
			write(fdout, &c, 1);
		}
	}

	close(fdout);
}

int eraseflash(int fd, uint32_t start, size_t size)
{
	unsigned char buf[10];

	buf[0] = 6;
	buf[1] = 0;
	*(uint32_t *) (buf + 2) = start;
	*(uint32_t *) (buf + 6) = size;
	*(uint16_t *) (buf + 10) = 1;
	return sendpacket(fd, buf, 12, sizeof(buf));
}

int programflash(int fd, uint32_t start, char *file)
{
	unsigned char buf[16384 + 10];
	int imgfd;
	uint16_t blocklen;
	size_t count;

	count = 0;
	imgfd = open(file, O_RDONLY);

	if (imgfd != -1) {
		while ((blocklen = read(imgfd, buf + 10, 1024)) > 0) {
			buf[0] = 14;
			buf[1] = 0;
			*(uint32_t *) (buf + 2) = start + count;
			buf[6] = 1;
			buf[7] = 0;
			*(uint16_t *) (buf + 8) = blocklen;
			printf("Trying to program %d bytes starting %02X "
					"at 0x%08X.\n",
					blocklen,
					buf[10],
					start + count);
			printf("sendpacket: %d ",
				sendpacket(fd, buf, blocklen + 10,
					sizeof(buf)));
			printf("%02X %02X %02X %02X\n",
				buf[0], buf[1], buf[2], buf[3]);
			count += blocklen;
		}
	}
	close(imgfd);

	return 0;

}

void help(void)
{
	printf("pbltool v0.1\n");
	exit(EXIT_FAILURE);
}

int main(int argc, char *argv[])
{
	int fd;
	struct termios serialterm;
	struct pollfd p;
	unsigned char c;
	unsigned char buf[128];
	int i;
	int count;
	bool gotver;

	if (argc < 2) {
		help();
	}

	fd = open("/dev/ttyS0", O_RDWR | O_NOCTTY | O_NDELAY);
	tcgetattr(fd, &serialterm);
	serialterm.c_cflag = CS8 | CLOCAL | CREAD;
	serialterm.c_lflag = 0;
	serialterm.c_oflag = 0;
	serialterm.c_iflag = IGNPAR;
	cfsetspeed(&serialterm, B9600);
	tcsetattr(fd, TCSANOW, &serialterm);
	tcflush(fd, TCIOFLUSH);

	gotver = false;

	buf[0] = 2;
	buf[1] = 0;
	if (sendpacket(fd, buf, 2, sizeof(buf)) == 0) {
		gotver = true;
		printf("Talking to PBL v%d.%d Build %d\n",
				buf[4], buf[5],
				buf[6] + (buf[7] << 8));
	}

	if (!gotver) {
		printf("Prodding...\n");
	
		while (c != 0x06) {
			c = 0x1B;
			write(fd, &c, 1);

			p.fd = fd;
			p.events = POLLIN;
			if (poll(&p, 1, 100) == 1) {
				do {
					i = read(fd, &c, 1);
					if ((i == 1) && (c == 0x06)) {
						break;
					}
				} while (i == 1);
			}
		}
		printf("Handshaking...\n");

		for (;;) {
			p.fd = fd;
			p.events = POLLIN;
			if (poll(&p, 1, 100) == 0)
				break;
			do {
				i = read(fd, &c, 1);
				if ((i == 1) && (c != 0x06)) {
				printf("Error: Got 0x%02X instead of expected"
						" 0x06.\n", c);
				}
			} while (i == 1);
		}
	}

	while (!gotver) {
		buf[0] = 2;
		buf[1] = 0;
		if (sendpacket(fd, buf, 2, sizeof(buf)) == 0) {
			gotver = true;
			printf("Talking to PBL v%d.%d Build %d\n",
					buf[4], buf[5],
					buf[6] + (buf[7] << 8));
		}
	}

	sleep(1);

	if (strcmp(argv[1], "meminfo") == 0) {
		buf[0] = 0x0B;
		buf[1] = 0;
		if (sendpacket(fd, buf, 2, sizeof(buf)) == 0) {
			count = buf[2];
		}

		for (i = 0; i < count; i++) {
			buf[0] = 12;
			buf[1] = 0;
			buf[2] = i;
			buf[3] = 0;
			if (sendpacket(fd, buf, 4, sizeof(buf)) == 0) {
				printf("%d 0x%08X 0x%08X 0x%08X 0x%08X "
					"0x%04X 0x%04X\n",
					i,
					*(uint32_t *)(buf + 2),
					*(uint32_t *)(buf + 6),
					*(uint32_t *)(buf + 10),
					*(uint32_t *)(buf + 14),
					*(uint16_t *)(buf + 18),
					*(uint16_t *)(buf + 20));
			}
		}
	} else if (strcmp(argv[1], "read") == 0) {
		if (argc < 5) {
			printf("\nUsage: pbltool read <start addr> <len> "
					"<file>\n");
			exit(EXIT_FAILURE);
		}
		readblock(fd, strtol(argv[2], NULL, 0),
				strtol(argv[3], NULL, 0),
				argv[4]);
	} else if (strcmp(argv[1], "write") == 0) {
		if (argc < 4) {
			printf("\nUsage: pbltool write <start addr> "
					"<file>\n");
			exit(EXIT_FAILURE);
		}
		writeblock(fd, strtol(argv[2], NULL, 0),
				argv[3]);
	} else if (strcmp(argv[1], "exec") == 0) {
		if (argc < 2) {
			printf("\nUsage: pbltool write <start addr>\n");
			exit(EXIT_FAILURE);
		}
		execute(fd, strtol(argv[2], NULL, 0));
	} else if (strcmp(argv[1], "readflash") == 0) {
		if (argc < 5) {
			printf("\nUsage: pbltool readflash <start addr> <len> "
					"<file>\n");
			exit(EXIT_FAILURE);
		}
		readflashblock(fd, strtol(argv[2], NULL, 0) + 0x400000,
				strtol(argv[3], NULL, 0),
				argv[4]);
	} else if (strcmp(argv[1], "writeflash") == 0) {
		if (argc < 4) {
			printf("\nUsage: pbltool writeflash <start addr> "
					"<file>\n");
			exit(EXIT_FAILURE);
		}
		programflash(fd, strtol(argv[2], NULL, 0) + 0x400000,
				argv[3]);
	} else if (strcmp(argv[1], "eraseflash") == 0) {
		if (argc < 4) {
			printf("\nUsage: pbltool eraseflash <start addr> "
					"<number of 16k pages>\n");
			exit(EXIT_FAILURE);
		}
		eraseflash(fd, strtol(argv[2], NULL, 0),
				strtol(argv[3], NULL, 0) * 0x4000);

	} else {
		printf("Unknown command!\n");
		exit(EXIT_FAILURE);
	}

	close(fd);
	exit(EXIT_SUCCESS);
}
