diff --git a/core/src/commonMain/kotlin/world/phantasmal/core/PrimitiveExtensions.kt b/core/src/commonMain/kotlin/world/phantasmal/core/PrimitiveExtensions.kt new file mode 100644 index 00000000..5b513d39 --- /dev/null +++ b/core/src/commonMain/kotlin/world/phantasmal/core/PrimitiveExtensions.kt @@ -0,0 +1,3 @@ +package world.phantasmal.core + +fun Char.isDigit(): Boolean = this in '0'..'9' diff --git a/core/src/commonMain/kotlin/world/phantasmal/core/PwResult.kt b/core/src/commonMain/kotlin/world/phantasmal/core/PwResult.kt index 3ace047e..b500b057 100644 --- a/core/src/commonMain/kotlin/world/phantasmal/core/PwResult.kt +++ b/core/src/commonMain/kotlin/world/phantasmal/core/PwResult.kt @@ -18,12 +18,17 @@ class Success(val value: T, problems: List = emptyList()) : PwResult class Failure(problems: List) : PwResult(problems) -class Problem( +open class Problem( val severity: Severity, /** * Readable message meant for users. */ val uiMessage: String, + /** + * Message meant for developers. + */ + val message: String? = null, + val cause: Throwable? = null, ) enum class Severity { @@ -38,6 +43,22 @@ enum class Severity { class PwResultBuilder(private val logger: KLogger) { private val problems: MutableList = mutableListOf() + /** + * Add a problem to the problems list and log it with [logger]. + */ + fun addProblem( + problem: Problem, + ): PwResultBuilder { + when (problem.severity) { + Severity.Info -> logger.info(problem.cause) { problem.message ?: problem.uiMessage } + Severity.Warning -> logger.warn(problem.cause) { problem.message ?: problem.uiMessage } + Severity.Error -> logger.error(problem.cause) { problem.message ?: problem.uiMessage } + } + + problems.add(problem) + return this + } + /** * Add a problem to the problems list and log it with [logger]. */ @@ -46,16 +67,8 @@ class PwResultBuilder(private val logger: KLogger) { uiMessage: String, message: String? = null, cause: Throwable? = null, - ): PwResultBuilder { - when (severity) { - Severity.Info -> logger.info(cause) { message ?: uiMessage } - Severity.Warning -> logger.warn(cause) { message ?: uiMessage } - Severity.Error -> logger.error(cause) { message ?: uiMessage } - } - - problems.add(Problem(severity, uiMessage)) - return this - } + ): PwResultBuilder = + addProblem(Problem(severity, uiMessage, message, cause)) /** * Add the given result's problems. diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt new file mode 100644 index 00000000..98601e11 --- /dev/null +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt @@ -0,0 +1,716 @@ +package world.phantasmal.lib.assembly + +import mu.KotlinLogging +import world.phantasmal.core.Problem +import world.phantasmal.core.PwResult +import world.phantasmal.core.PwResultBuilder +import world.phantasmal.core.Severity +import world.phantasmal.lib.buffer.Buffer + +private val logger = KotlinLogging.logger {} + +class AssemblyProblem( + severity: Severity, + uiMessage: String, + message: String? = null, + cause: Throwable? = null, + val lineNo: Int, + val col: Int, + val length: Int, +) : Problem(severity, uiMessage, message, cause) + +class AssemblySettings( + val manualStack: Boolean, +) + +fun assemble( + assembly: List, + manualStack: Boolean = false, +): PwResult> { + logger.trace { "Assembly start." } + + val result = Assembler(assembly, manualStack).assemble() + + logger.trace { + val warnings = result.problems.count { it.severity == Severity.Warning } + val errors = result.problems.count { it.severity == Severity.Error } + + "Assembly finished with $warnings warnings and $errors errors." + } + + return result +} + +private class Assembler(private val assembly: List, private val manualStack: Boolean) { + private var lineNo = 1 + private lateinit var tokens: MutableList + private var objectCode: MutableList = mutableListOf() + + /** + * The current segment. + */ + private var segment: Segment? = null + + /** + * Encountered labels. + */ + private val labels: MutableSet = mutableSetOf() + private var section: SegmentType = SegmentType.Instructions + private var firstSectionMarker = true + private var prevLineHadLabel = false + + private val result = PwResultBuilder>(logger) + + fun assemble(): PwResult> { + // Tokenize and assemble line by line. + for (line in assembly) { + tokens = tokenizeLine(line) + + if (tokens.isNotEmpty()) { + val token = tokens.removeFirst() + var hasLabel = false + + when (token) { + is LabelToken -> { + parseLabel(token) + hasLabel = true + } + is SectionToken, + -> { + parseSection(token) + } + is IntToken -> { + if (section == SegmentType.Data) { + parseBytes(token) + } else { + addUnexpectedTokenError(token) + } + } + is StringToken -> { + if (section == SegmentType.String) { + parseString(token) + } else { + addUnexpectedTokenError(token) + } + } + is IdentToken -> { + if (section === SegmentType.Instructions) { + parseInstruction(token) + } else { + addUnexpectedTokenError(token) + } + } + is InvalidSectionToken -> { + addError(token, "Invalid section type.") + } + is InvalidIdentToken -> { + addError(token, "Invalid identifier.") + } + else -> { + addUnexpectedTokenError(token) + } + } + + prevLineHadLabel = hasLabel + } + + lineNo++ + } + + return result.success(objectCode) + } + + private fun addInstruction( + opcode: Opcode, + args: List, + stackArgs: List, + token: Token?, + argTokens: List, + stackArgTokens: List, + ) { + when (val seg = segment) { + null -> { + // Unreachable code, technically valid. + segment = InstructionSegment( + labels = mutableListOf(), + instructions = mutableListOf(), + srcLoc = SegmentSrcLoc() + ) + + objectCode.add(segment!!) + } + + is InstructionSegment -> { + seg.instructions.add( + Instruction( + opcode, + args, + InstructionSrcLoc( + mnemonic = token?.let { + SrcLoc(lineNo, token.col, token.len) + }, + args = argTokens.map { + SrcLoc(lineNo, it.col, it.len) + }, + stackArgs = stackArgTokens.mapIndexed { i, sat -> + StackArgSrcLoc(lineNo, sat.col, sat.len, stackArgs[i].value) + }, + ) + ) + ) + } + + else -> { + logger.error { "Line $lineNo: Expected instructions segment." } + } + } + } + + private fun addBytes(bytes: ByteArray) { + when (val seg = segment) { + null -> { + // Unaddressable data, technically valid. + segment = DataSegment( + labels = mutableListOf(), + data = Buffer.fromByteArray(bytes), + srcLoc = SegmentSrcLoc() + ) + + objectCode.add(segment!!) + } + + is DataSegment -> { + val oldSize = seg.data.size + seg.data.size += bytes.size.toUInt() + + for (i in bytes.indices) { + seg.data.setI8(i.toUInt() + oldSize, bytes[i]) + } + } + + else -> { + logger.error { "Line $lineNo: Expected data segment." } + } + } + } + + private fun addString(str: String) { + when (val seg = segment) { + null -> { + // Unaddressable data, technically valid. + segment = StringSegment( + labels = mutableListOf(), + value = str, + srcLoc = SegmentSrcLoc() + ) + + objectCode.add(segment!!) + } + + is StringSegment -> { + seg.value += str + } + + else -> { + logger.error { "Line $lineNo: Expected string segment." } + } + } + } + + private fun addError(col: Int, length: Int, message: String) { + result.addProblem( + AssemblyProblem( + Severity.Error, + message, + lineNo = lineNo, + col = col, + length = length + ) + ) + } + + private fun addError(token: Token, message: String) { + addError(token.col, token.len, message) + } + + private fun addUnexpectedTokenError(token: Token) { + addError(token, "Unexpected token.") + } + + private fun addWarning(token: Token, message: String) { + result.addProblem( + AssemblyProblem( + Severity.Warning, + message, + lineNo = lineNo, + col = token.col, + length = token.len + ) + ) + } + + private fun parseLabel(token: LabelToken) { + val label = token.value + + if (labels.add(label)) { + addError(token, "Duplicate label.") + } + + val nextToken = tokens.removeFirstOrNull() + + val srcLoc = SrcLoc(lineNo, token.col, token.len) + + if (prevLineHadLabel) { + val segment = objectCode.last() + segment.labels.add(label) + segment.srcLoc.labels.add(srcLoc) + } + + when (section) { + SegmentType.Instructions -> { + if (!prevLineHadLabel) { + segment = InstructionSegment( + labels = mutableListOf(label), + instructions = mutableListOf(), + srcLoc = SegmentSrcLoc(labels = mutableListOf(srcLoc)), + ) + + objectCode.add(segment!!) + } + + if (nextToken != null) { + if (nextToken is IdentToken) { + parseInstruction(nextToken) + } else { + addError(nextToken, "Expected opcode mnemonic.") + } + } + } + + SegmentType.Data -> { + if (!prevLineHadLabel) { + segment = DataSegment( + labels = mutableListOf(label), + data = Buffer.withCapacity(0u), + srcLoc = SegmentSrcLoc(labels = mutableListOf(srcLoc)), + ) + objectCode.add(segment!!) + } + + if (nextToken != null) { + if (nextToken is IntToken) { + parseBytes(nextToken) + } else { + addError(nextToken, "Expected bytes.") + } + } + } + + SegmentType.String -> { + if (!prevLineHadLabel) { + segment = StringSegment( + labels = mutableListOf(label), + value = "", + srcLoc = SegmentSrcLoc(labels = mutableListOf(srcLoc)), + ) + objectCode.add(segment!!) + } + + if (nextToken != null) { + if (nextToken is StringToken) { + parseString(nextToken) + } else { + addError(nextToken, "Expected a string.") + } + } + } + } + } + + private fun parseSection(token: SectionToken) { + val section = when (token) { + is CodeSectionToken -> SegmentType.Instructions + is DataSectionToken -> SegmentType.Data + is StringSectionToken -> SegmentType.String + } + + if (this.section == section && !firstSectionMarker) { + addWarning(token, "Unnecessary section marker.") + } + + this.section = section + firstSectionMarker = false + + tokens.removeFirstOrNull()?.let { nextToken -> + addUnexpectedTokenError(nextToken) + } + } + + private fun parseInstruction(identToken: IdentToken) { + val opcode = mnemonicToOpcode(identToken.value) + + if (opcode == null) { + addError(identToken, "Unknown instruction.") + } else { + val varargs = opcode.params.any { + it.type is ILabelVarType || it.type is RegRefVarType + } + + val paramCount = + if (manualStack && opcode.stack == StackInteraction.Pop) 0 + else opcode.params.size + + val argCount = tokens.count { it !is ArgSeparatorToken } + + val lastToken = tokens.lastOrNull() + val errorLength = lastToken?.let { it.col + it.len - identToken.col } ?: 0 + // Inline arguments. + val insArgAndTokens = mutableListOf>() + // Stack arguments. + val stackArgAndTokens = mutableListOf>() + + if (!varargs && argCount != paramCount) { + addError( + identToken.col, + errorLength, + "Expected $paramCount argument ${if (paramCount == 1) "" else "s"}, got $argCount." + ) + + return + } else if (varargs && argCount < paramCount) { + addError( + identToken.col, + errorLength, + "Expected at least $paramCount argument ${if (paramCount == 1) "" else "s"}, got $argCount.", + ) + + return + } else if (opcode.stack !== StackInteraction.Pop) { + // Inline arguments. + if (!parseArgs(opcode.params, insArgAndTokens, stack = false)) { + return + } + } else { + if (!this.parseArgs(opcode.params, stackArgAndTokens, stack = true)) { + return + } + + for (i in opcode.params.indices) { + val param = opcode.params[i] + val argAndToken = stackArgAndTokens.getOrNull(i) ?: continue + val (arg, argToken) = argAndToken + + if (argToken is RegisterToken) { + if (param.type is RegTupRefType) { + addInstruction( + OP_ARG_PUSHB, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } else { + addInstruction( + OP_ARG_PUSHR, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + } else { + when (param.type) { + is ByteType, + is RegRefType, + is RegTupRefType, + -> { + addInstruction( + OP_ARG_PUSHB, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + + is WordType, + is LabelType, + is ILabelType, + is DLabelType, + is SLabelType, + -> { + addInstruction( + OP_ARG_PUSHW, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + + is DWordType -> { + addInstruction( + OP_ARG_PUSHL, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + + is FloatType -> { + addInstruction( + OP_ARG_PUSHL, + listOf(Arg((arg.value as Float).toRawBits())), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + + is StringType -> { + addInstruction( + OP_ARG_PUSHS, + listOf(arg), + emptyList(), + null, + listOf(argToken), + emptyList(), + ) + } + + else -> { + logger.error { + "Line $lineNo: Type ${param.type::class} not implemented." + } + } + } + } + } + } + + val (args, argTokens) = insArgAndTokens.unzip() + val (stackArgs, stackArgTokens) = stackArgAndTokens.unzip() + + addInstruction( + opcode, + args, + stackArgs, + identToken, + argTokens, + stackArgTokens, + ) + } + } + + /** + * @returns true if arguments can be translated to object code, possibly after truncation. False otherwise. + */ + private fun parseArgs( + params: List, + argAndTokens: MutableList>, + stack: Boolean, + ): Boolean { + var semiValid = true + var shouldBeArg = true + var paramI = 0 + + for (i in 0 until tokens.size) { + val token = tokens[i] + val param = params[paramI] + + if (token is ArgSeparatorToken) { + if (shouldBeArg) { + addError(token, "Expected an argument.") + } else if ( + param.type !is ILabelVarType && + param.type !is RegRefVarType + ) { + paramI++ + } + + shouldBeArg = true + } else { + if (!shouldBeArg) { + val prevToken = tokens[i - 1] + val col = prevToken.col + prevToken.len + + addError(col, token.col - col, "Expected a comma.") + } + + shouldBeArg = false + + var match: Boolean + + when (token) { + is IntToken -> { + when (param.type) { + is ByteType -> { + match = true + parseInt(1, token, argAndTokens) + } + is WordType, + is LabelType, + is ILabelType, + is DLabelType, + is SLabelType, + is ILabelVarType, + -> { + match = true + parseInt(2, token, argAndTokens) + } + is DWordType -> { + match = true + parseInt(4, token, argAndTokens) + } + is FloatType -> { + match = true + argAndTokens.add(Pair(Arg(token.value), token)) + } + else -> { + match = false + } + } + } + + is FloatToken -> { + match = param.type == FloatType + + if (match) { + argAndTokens.add(Pair(Arg(token.value), token)) + } + } + + is RegisterToken -> { + match = stack || + param.type is RegRefType || + param.type is RegRefVarType || + param.type is RegTupRefType + + parseRegister(token, argAndTokens) + } + + is StringToken -> { + match = param.type is StringType + + if (match) { + argAndTokens.add(Pair(Arg(token.value), token)) + } + } + + else -> { + match = false + } + } + + if (!match) { + semiValid = false + + val typeStr: String? = when (param.type) { + is ByteType -> "an 8-bit integer" + is WordType -> "a 16-bit integer" + is DWordType -> "a 32-bit integer" + is FloatType -> "a float" + is LabelType -> "a label" + + is ILabelType, + is ILabelVarType, + -> "an instruction label" + + is DLabelType -> "a data label" + is SLabelType -> "a string label" + is StringType -> "a string" + + is RegRefType, + is RegRefVarType, + is RegTupRefType, + -> "a register reference" + + else -> null + } + + addError( + token, + if (typeStr == null) "Unexpected token." else "Expected ${typeStr}." + ) + } + } + } + + tokens.clear() + return semiValid + } + + private fun parseInt(size: Int, token: IntToken, argAndTokens: MutableList>) { + val value = token.value + val bitSize = 8 * size + // Minimum of the signed version of this integer type. + val minValue = -(1 shl (bitSize - 1)) + // Maximum of the unsigned version of this integer type. + val maxValue = (1 shl (bitSize)) - 1 + + when { + value < minValue -> { + addError(token, "${bitSize}-Bit integer can't be less than ${minValue}.") + } + value > maxValue -> { + addError(token, "${bitSize}-Bit integer can't be greater than ${maxValue}.") + } + else -> { + argAndTokens.add(Pair(Arg(value), token)) + } + } + } + + private fun parseRegister(token: RegisterToken, argAndTokens: MutableList>) { + val value = token.value + + if (value > 255) { + addError(token, "Invalid register reference, expected r0-r255.") + } else { + argAndTokens.add(Pair(Arg(value), token)) + } + } + + private fun parseBytes(firstToken: IntToken) { + val bytes = mutableListOf() + var token: Token = firstToken + var i = 0 + + while (token is IntToken) { + if (token.value < 0) { + addError(token, "Unsigned 8-bit integer can't be less than 0.") + } else if (token.value > 255) { + addError(token, "Unsigned 8-bit integer can't be greater than 255.") + } + + bytes.add(token.value.toByte()) + + if (i < tokens.size) { + token = tokens[i++] + } else { + break + } + } + + if (i < tokens.size) { + addError(token, "Expected an unsigned 8-bit integer.") + } + + addBytes(bytes.toByteArray()) + } + + private fun parseString(token: StringToken) { + tokens.removeFirstOrNull()?.let { nextToken -> + addUnexpectedTokenError(nextToken) + } + + addString(token.value.replace("\n", "")) + } +} diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt new file mode 100644 index 00000000..0826ffc7 --- /dev/null +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt @@ -0,0 +1,353 @@ +package world.phantasmal.lib.assembly + +import world.phantasmal.core.isDigit + +private val HEX_INT_REGEX = Regex("""^0x[\da-fA-F]+$""") +private val FLOAT_REGEX = Regex("""^-?\d+(\.\d+)?(e-?\d+)?$""") +private val IDENT_REGEX = Regex("""^[a-z][a-z0-9_=<>!]*$""") + +sealed class Token( + val col: Int, + val len: Int, +) + +class IntToken( + col: Int, + len: Int, + val value: Int, +) : Token(col, len) + +class FloatToken( + col: Int, + len: Int, + val value: Float, +) : Token(col, len) + +class InvalidNumberToken( + col: Int, + len: Int, +) : Token(col, len) + +class RegisterToken( + col: Int, + len: Int, + val value: Int, +) : Token(col, len) + +class LabelToken( + col: Int, + len: Int, + val value: Int, +) : Token(col, len) + +sealed class SectionToken(col: Int, len: Int) : Token(col, len) + +class CodeSectionToken( + col: Int, + len: Int, +) : SectionToken(col, len) + +class DataSectionToken( + col: Int, + len: Int, +) : SectionToken(col, len) + +class StringSectionToken( + col: Int, + len: Int, +) : SectionToken(col, len) + +class InvalidSectionToken( + col: Int, + len: Int, +) : Token(col, len) + +class StringToken( + col: Int, + len: Int, + val value: String, +) : Token(col, len) + +class UnterminatedStringToken( + col: Int, + len: Int, + val value: String, +) : Token(col, len) + +class IdentToken( + col: Int, + len: Int, + val value: String, +) : Token(col, len) + +class InvalidIdentToken( + col: Int, + len: Int, +) : Token(col, len) + +class ArgSeparatorToken( + col: Int, + len: Int, +) : Token(col, len) + +fun tokenizeLine(line: String): MutableList = + LineTokenizer(line).tokenize() + +private class LineTokenizer(private var line: String) { + private var index = 0 + + private val col: Int + get() = index + 1 + + private var mark = 0 + + fun tokenize(): MutableList { + val tokens = mutableListOf() + + while (hasNext()) { + val char = peek() + var token: Token + + if (char == '/') { + skip() + + if (peek() == '/') { + // It's a comment. + break + } else { + back() + } + } + + if (char.isWhitespace()) { + skip() + continue + } else if (char == '-' || char.isDigit()) { + token = tokenizeNumberOrLabel() + } else if (char == ',') { + token = ArgSeparatorToken(col, 1) + skip() + } else if (char == '.') { + token = tokenizeSection() + } else if (char == '"') { + token = tokenizeString() + } else if (char == 'r') { + token = tokenizeRegisterOrIdent() + } else { + token = tokenizeIdent() + } + + tokens.add(token) + } + + return tokens + } + + private fun hasNext(): Boolean = index < line.length + + private fun next(): Char = line[index++] + + private fun peek(): Char = line[index] + + private fun skip() { + index++ + } + + private fun back() { + index-- + } + + private fun mark() { + mark = index + } + + private fun markedLen(): Int = index - mark + + private fun slice(): String = line.substring(mark, index) + + private fun eatRestOfToken() { + while (hasNext()) { + val char = next() + + if (char == ',' || char.isWhitespace()) { + back() + break + } + } + } + + private fun tokenizeNumberOrLabel(): Token { + mark() + val col = this.col + skip() + var isLabel = false + + while (hasNext()) { + val char = peek() + + if (char == '.' || char == 'e') { + return tokenizeFloat(col) + } else if (char == 'x') { + return tokenizeHexNumber(col) + } else if (char == ':') { + isLabel = true + skip() + break + } else if (char == ',' || char.isWhitespace()) { + break + } else { + skip() + } + } + + val value = slice().toIntOrNull() + ?: return InvalidNumberToken(col, markedLen()) + + return if (isLabel) { + LabelToken(col, markedLen(), value) + } else { + IntToken(col, markedLen(), value) + } + } + + private fun tokenizeHexNumber(col: Int): Token { + eatRestOfToken() + val hexStr = slice() + + if (HEX_INT_REGEX.matches(hexStr)) { + hexStr.toIntOrNull(16)?.let { value -> + return IntToken(col, markedLen(), value) + } + } + + return InvalidNumberToken(col, markedLen()) + } + + private fun tokenizeFloat(col: Int): Token { + eatRestOfToken() + val floatStr = slice() + + if (FLOAT_REGEX.matches(floatStr)) { + floatStr.toFloatOrNull()?.let { value -> + return FloatToken(col, markedLen(), value) + } + } + + return InvalidNumberToken(col, markedLen()) + } + + private fun tokenizeRegisterOrIdent(): Token { + val col = this.col + skip() + mark() + var isRegister = false + + while (hasNext()) { + val char = peek() + + if (char.isDigit()) { + isRegister = true + skip() + } else { + break + } + } + + return if (isRegister) { + val value = slice().toInt() + + RegisterToken(col, markedLen() + 1, value) + } else { + back() + tokenizeIdent() + } + } + + private fun tokenizeSection(): Token { + val col = this.col + mark() + + while (hasNext()) { + if (peek().isWhitespace()) { + break + } else { + skip() + } + } + + return when (slice()) { + ".code" -> CodeSectionToken(col, 5) + ".data" -> DataSectionToken(col, 5) + ".string" -> StringSectionToken(col, 7) + else -> InvalidSectionToken(col, markedLen()) + } + } + + private fun tokenizeString(): Token { + val col = this.col + skip() + mark() + var prevWasBackSpace = false + var terminated = false + + while (hasNext()) { + when (peek()) { + '\\' -> { + prevWasBackSpace = true + } + '"' -> { + if (!prevWasBackSpace) { + terminated = true + break + } + + prevWasBackSpace = false + } + else -> { + prevWasBackSpace = false + } + } + + next() + } + + val value = slice().replace("\\\"", "\"").replace("\\n", "\n") + + return if (terminated) { + next() + StringToken(col, markedLen() + 2, value) + } else { + UnterminatedStringToken(col, markedLen() + 1, value) + } + } + + private fun tokenizeIdent(): Token { + val col = this.col + mark() + + while (hasNext()) { + val char = peek() + + if (char == ',' || char.isWhitespace()) { + break + } else if (char == '/') { + skip() + + if (peek() == '/') { + back() + break + } + } else { + skip() + } + } + + val value = slice() + + return if (IDENT_REGEX.matches(value)) { + IdentToken(col, markedLen(), value) + } else { + InvalidIdentToken(col, markedLen()) + } + } +} diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt index d74be29b..1685bed9 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt @@ -103,25 +103,29 @@ enum class SegmentType { * referenced by one or more labels. The segment ends right before the next instruction, byte or * string character that is referenced by a label. */ -sealed class Segment(val type: SegmentType, val labels: MutableList) +sealed class Segment( + val type: SegmentType, + val labels: MutableList, + val srcLoc: SegmentSrcLoc, +) class InstructionSegment( labels: MutableList, - val instructions: List, - val srcLoc: SegmentSrcLoc, -) : Segment(SegmentType.Instructions, labels) + val instructions: MutableList, + srcLoc: SegmentSrcLoc, +) : Segment(SegmentType.Instructions, labels, srcLoc) class DataSegment( labels: MutableList, val data: Buffer, - val srcLoc: SegmentSrcLoc, -) : Segment(SegmentType.Data, labels) + srcLoc: SegmentSrcLoc, +) : Segment(SegmentType.Data, labels, srcLoc) class StringSegment( labels: MutableList, - val value: String, - val srcLoc: SegmentSrcLoc, -) : Segment(SegmentType.String, labels) + var value: String, + srcLoc: SegmentSrcLoc, +) : Segment(SegmentType.String, labels, srcLoc) /** * Position and length of related source assembly code. @@ -144,9 +148,9 @@ class InstructionSrcLoc( /** * Locations of an instruction's stack arguments in the source assembly code. */ -class StackArgSrcLoc(lineNo: Int, col: Int, len: Int, val value: Int) : SrcLoc(lineNo, col, len) +class StackArgSrcLoc(lineNo: Int, col: Int, len: Int, val value: Any) : SrcLoc(lineNo, col, len) /** * Locations of a segment's labels in the source assembly code. */ -class SegmentSrcLoc(val labels: List) +class SegmentSrcLoc(val labels: MutableList = mutableListOf()) diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Opcode.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Opcode.kt index 815b6f3a..72da7b9f 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Opcode.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Opcode.kt @@ -1,5 +1,16 @@ package world.phantasmal.lib.assembly +private val MNEMONIC_TO_OPCODES: MutableMap by lazy { + val map = mutableMapOf() + + OPCODES.forEach { if (it != null) map[it.mnemonic] = it } + OPCODES_F8.forEach { if (it != null) map[it.mnemonic] = it } + OPCODES_F9.forEach { if (it != null) map[it.mnemonic] = it } + + map +} +private val UNKNOWN_OPCODE_MNEMONIC_REGEX = Regex("""^unknown_((f8|f9)?[0-9a-f]{2})$""") + /** * Abstract super type of all types. */ @@ -164,6 +175,20 @@ fun codeToOpcode(code: Int): Opcode = else -> getOpcode(code, code and 0xFF, OPCODES_F9) } +fun mnemonicToOpcode(mnemonic: String): Opcode? { + var opcode = MNEMONIC_TO_OPCODES[mnemonic] + + if (opcode == null) { + UNKNOWN_OPCODE_MNEMONIC_REGEX.matchEntire(mnemonic)?.destructured?.let { (codeStr) -> + val code = codeStr.toInt(16) + opcode = codeToOpcode(code) + MNEMONIC_TO_OPCODES[mnemonic] = opcode!! + } + } + + return opcode +} + private fun getOpcode(code: Int, index: Int, opcodes: Array): Opcode { var opcode = opcodes[index] diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt index 4b496290..c4320069 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt @@ -1,10 +1,309 @@ package world.phantasmal.lib.assembly.dataFlowAnalysis -import world.phantasmal.lib.assembly.InstructionSegment +import world.phantasmal.lib.assembly.* + +// See https://en.wikipedia.org/wiki/Control-flow_graph. + +enum class BranchType { + None, + Return, + Jump, + ConditionalJump, + Call, +} + +/** + * Instruction sequence into which control flow only enters at the start and only leaves at the end. + * No code jumps/returns/calls into the middle of a basic block or branches out of a basic block + * from the middle. + */ +interface BasicBlock { + /** + * The instruction segment that this block is a part of. + */ + val segment: InstructionSegment + + /** + * Index of this block's first instruction. + */ + val start: Int + + /** + * Index of the instruction right after this block's last instruction. + */ + val end: Int + + /** + * The way control flow leaves this block. + */ + val branchType: BranchType + + /** + * Either jumps or calls when non-empty, depending on [branchType]. + */ + val branchLabels: List + + /** + * The blocks which branch to this block. + */ + val from: List + + /** + * The blocks this block branches to. + */ + val to: List +} + +/** + * Graph representing the flow of control through the [BasicBlock]s of a script. + */ +class ControlFlowGraph( + val blocks: List, + private val instructionsToBlock: Map, +) { + fun getBlockForInstruction(instruction: Instruction): BasicBlock? = + instructionsToBlock[instruction] -class ControlFlowGraph { companion object { - fun create(segments: List): ControlFlowGraph = - TODO() + fun create(segments: List): ControlFlowGraph { + val cfg = ControlFlowGraphBuilder() + + // Mapping of labels to basic blocks. + for (segment in segments) { + createBasicBlocks(cfg, segment) + } + + linkBlocks(cfg) + return cfg.build() + } + } +} + +private class ControlFlowGraphBuilder { + val blocks: MutableList = mutableListOf() + val instructionsToBlock: MutableMap = mutableMapOf() + val labelsToBlock: MutableMap = mutableMapOf() + + fun build(): ControlFlowGraph = + ControlFlowGraph(blocks, instructionsToBlock) +} + +class BasicBlockImpl( + override val segment: InstructionSegment, + override val start: Int, + override val end: Int, + override val branchType: BranchType, + override val branchLabels: List, +) : BasicBlock { + override val from: MutableList = mutableListOf() + override val to: MutableList = mutableListOf() + + fun linkTo(other: BasicBlockImpl) { + if (other !in to) { + to.add(other) + other.from.add(this) + } + } + + fun indexOfInstruction(instruction: Instruction): Int { + var index = -1 + + for (i in start until end) { + if (instruction == segment.instructions[i]) { + index = i + break + } + } + + return index + } +} + +private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: InstructionSegment) { + val len = segment.instructions.size + var start = 0 + var firstBlock = true + + for (i in 0 until len) { + val inst = segment.instructions[i] + + var branchType: BranchType + var branchLabels: List + + when (inst.opcode) { + // Return. + OP_RET -> { + branchType = BranchType.Return + branchLabels = emptyList() + } + + // Unconditional jump. + OP_JMP -> { + branchType = BranchType.Jump + branchLabels = listOf(inst.args[0].value as Int) + } + + // Conditional jumps. + OP_JMP_ON, + OP_JMP_OFF, + -> { + branchType = BranchType.ConditionalJump + branchLabels = listOf(inst.args[0].value as Int) + } + OP_JMP_E, + OP_JMPI_E, + OP_JMP_NE, + OP_JMPI_NE, + OP_UJMP_G, + OP_UJMPI_G, + OP_JMP_G, + OP_JMPI_G, + OP_UJMP_L, + OP_UJMPI_L, + OP_JMP_L, + OP_JMPI_L, + OP_UJMP_GE, + OP_UJMPI_GE, + OP_JMP_GE, + OP_JMPI_GE, + OP_UJMP_LE, + OP_UJMPI_LE, + OP_JMP_LE, + OP_JMPI_LE, + -> { + branchType = BranchType.ConditionalJump + branchLabels = listOf(inst.args[2].value as Int) + } + OP_SWITCH_JMP -> { + branchType = BranchType.ConditionalJump + branchLabels = inst.args.drop(1).map { it.value as Int } + } + + // Calls. + OP_CALL, + OP_VA_CALL, + -> { + branchType = BranchType.Call + branchLabels = listOf(inst.args[0].value as Int) + } + OP_SWITCH_CALL -> { + branchType = BranchType.Call + branchLabels = inst.args.drop(1).map { it.value as Int } + } + + // All other opcodes. + else -> { + if (i == len - 1) { + // This is the last block of the segment. + branchType = BranchType.None + branchLabels = emptyList() + } else { + // Non-branching instruction, part of the current block. + continue + } + } + } + + val block = BasicBlockImpl(segment, start, i + 1, branchType, branchLabels) + + for (j in block.start until block.end) { + cfg.instructionsToBlock[block.segment.instructions[j]] = block + } + + cfg.blocks.add(block) + + if (firstBlock) { + for (label in segment.labels) { + cfg.labelsToBlock[label] = block + } + + firstBlock = false + } + + start = i + 1 + } +} + +private fun linkBlocks(cfg: ControlFlowGraphBuilder) { + // Pairs of calling block and block to which callees should return to. + val callers = mutableListOf>() + + for (i in cfg.blocks.indices) { + val block = cfg.blocks[i] + val nextBlock = cfg.blocks.getOrNull(i + 1) + + when (block.branchType) { + BranchType.Return -> + continue + + BranchType.Call -> + nextBlock?.let { callers.add(block to nextBlock); } + + BranchType.None, + BranchType.ConditionalJump, + -> nextBlock?.let(block::linkTo) + + else -> { + // Ignore. + } + } + + for (label in block.branchLabels) { + cfg.labelsToBlock[label]?.let { toBlock -> + block.linkTo(toBlock) + } + } + } + + for ((caller, ret) in callers) { + linkReturningBlocks(cfg.labelsToBlock, ret, caller) + } +} + +/** + * Links returning blocks to their callers. + * + * @param labelBlocks Mapping of labels to basic blocks. + * @param ret Basic block the caller should return to. + * @param caller Calling basic block. + */ +private fun linkReturningBlocks( + labelBlocks: Map, + ret: BasicBlockImpl, + caller: BasicBlockImpl, +) { + for (label in caller.branchLabels) { + labelBlocks[label]?.let { callee -> + if (callee.branchType === BranchType.Return) { + callee.linkTo(ret) + } else { + linkReturningBlocksRecurse(mutableSetOf(), ret, callee) + } + } + } +} + +/** + * @param encountered For avoiding infinite loops. + * @param ret + * @param block + */ +private fun linkReturningBlocksRecurse( + encountered: MutableSet, + ret: BasicBlockImpl, + block: BasicBlockImpl, +) { + if (block in encountered) { + return + } else { + encountered.add(block) + } + + for (toBlock in block.to) { + if (toBlock.branchType === BranchType.Return) { + toBlock.linkTo(ret) + } else { + linkReturningBlocksRecurse(encountered, ret, toBlock) + } } } diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/fileFormats/quest/ObjectCode.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/fileFormats/quest/ObjectCode.kt index f0650f3e..219713b4 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/fileFormats/quest/ObjectCode.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/fileFormats/quest/ObjectCode.kt @@ -368,7 +368,7 @@ private fun parseInstructionsSegment( val segment = InstructionSegment( labels, instructions, - SegmentSrcLoc(emptyList()) + SegmentSrcLoc() ) offsetToSegment[cursor.position.toInt()] = segment @@ -437,7 +437,7 @@ private fun parseDataSegment( val segment = DataSegment( labels, cursor.buffer(endOffset.toUInt() - startOffset), - SegmentSrcLoc(listOf()), + SegmentSrcLoc(), ) offsetToSegment[startOffset.toInt()] = segment } @@ -465,7 +465,7 @@ private fun parseStringSegment( dropRemaining = true ) }, - SegmentSrcLoc(listOf()) + SegmentSrcLoc() ) offsetToSegment[startOffset.toInt()] = segment }