From 532a608e7a0ec1369df589fd9128ffb3bfeffe06 Mon Sep 17 00:00:00 2001 From: Daan Vanden Bosch Date: Thu, 22 Oct 2020 23:03:23 +0200 Subject: [PATCH] Added some more tests and getRegisterValue. --- .../world/phantasmal/lib/assembly/Assembly.kt | 21 +- .../lib/assembly/AssemblyTokenization.kt | 10 +- .../phantasmal/lib/assembly/Instructions.kt | 7 +- .../dataFlowAnalysis/ControlFlowGraph.kt | 29 ++- .../dataFlowAnalysis/GetMapDesignations.kt | 10 +- .../dataFlowAnalysis/GetRegisterValue.kt | 223 +++++++++++++++++- .../lib/assembly/dataFlowAnalysis/ValueSet.kt | 123 +++++++--- .../phantasmal/lib/assembly/AssemblyTests.kt | 35 +++ .../dataFlowAnalysis/ControlFlowGraphTests.kt | 158 +++++++++++++ .../dataFlowAnalysis/GetRegisterValueTests.kt | 203 ++++++++++++++++ .../dataFlowAnalysis/ValueSetTests.kt | 138 ++++++++--- .../world/phantasmal/lib/test/TestUtils.kt | 15 ++ 12 files changed, 882 insertions(+), 90 deletions(-) create mode 100644 lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/AssemblyTests.kt create mode 100644 lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraphTests.kt create mode 100644 lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValueTests.kt create mode 100644 lib/src/commonTest/kotlin/world/phantasmal/lib/test/TestUtils.kt diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt index 98601e11..05ceb7ec 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Assembly.kt @@ -217,11 +217,12 @@ private class Assembler(private val assembly: List, private val manualSt } } - private fun addError(col: Int, length: Int, message: String) { + private fun addError(col: Int, length: Int, uiMessage: String, message: String? = null) { result.addProblem( AssemblyProblem( Severity.Error, - message, + uiMessage, + message ?: "$uiMessage At $lineNo:$col.", lineNo = lineNo, col = col, length = length @@ -229,19 +230,19 @@ private class Assembler(private val assembly: List, private val manualSt ) } - private fun addError(token: Token, message: String) { - addError(token.col, token.len, message) + private fun addError(token: Token, uiMessage: String, message: String? = null) { + addError(token.col, token.len, uiMessage, message) } private fun addUnexpectedTokenError(token: Token) { - addError(token, "Unexpected token.") + addError(token, "Unexpected token.", "Unexpected ${token::class.simpleName} at ${token.srcLoc()}.") } - private fun addWarning(token: Token, message: String) { + private fun addWarning(token: Token, uiMessage: String) { result.addProblem( AssemblyProblem( Severity.Warning, - message, + uiMessage, lineNo = lineNo, col = token.col, length = token.len @@ -252,7 +253,7 @@ private class Assembler(private val assembly: List, private val manualSt private fun parseLabel(token: LabelToken) { val label = token.value - if (labels.add(label)) { + if (!labels.add(label)) { addError(token, "Duplicate label.") } @@ -653,7 +654,7 @@ private class Assembler(private val assembly: List, private val manualSt // Minimum of the signed version of this integer type. val minValue = -(1 shl (bitSize - 1)) // Maximum of the unsigned version of this integer type. - val maxValue = (1 shl (bitSize)) - 1 + val maxValue = (1L shl (bitSize)) - 1L when { value < minValue -> { @@ -713,4 +714,6 @@ private class Assembler(private val assembly: List, private val manualSt addString(token.value.replace("\n", "")) } + + private fun Token.srcLoc(): String = "$lineNo:$col" } diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt index 0826ffc7..3c8bef2b 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/AssemblyTokenization.kt @@ -191,7 +191,6 @@ private class LineTokenizer(private var line: String) { return tokenizeHexNumber(col) } else if (char == ':') { isLabel = true - skip() break } else if (char == ',' || char.isWhitespace()) { break @@ -201,7 +200,14 @@ private class LineTokenizer(private var line: String) { } val value = slice().toIntOrNull() - ?: return InvalidNumberToken(col, markedLen()) + + if (isLabel) { + skip() + } + + if (value == null) { + return InvalidNumberToken(col, markedLen()) + } return if (isLabel) { LabelToken(col, markedLen(), value) diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt index 1685bed9..6d27576d 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/Instructions.kt @@ -23,14 +23,15 @@ class Instruction( for (i in 0 until len) { val type = opcode.params[i].type val arg = args[i] - paramToArgs[i] = mutableListOf() + val pArgs = mutableListOf() + paramToArgs.add(pArgs) if (type is ILabelVarType || type is RegRefVarType) { for (j in i until args.size) { - paramToArgs[i].add(args[j]) + pArgs.add(args[j]) } } else { - paramToArgs[i].add(arg) + pArgs.add(arg) } } diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt index c4320069..24112900 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraph.kt @@ -52,6 +52,8 @@ interface BasicBlock { * The blocks this block branches to. */ val to: List + + fun indexOfInstruction(instruction: Instruction): Int } /** @@ -59,10 +61,13 @@ interface BasicBlock { */ class ControlFlowGraph( val blocks: List, - private val instructionsToBlock: Map, + private val instructionToBlock: Map, ) { - fun getBlockForInstruction(instruction: Instruction): BasicBlock? = - instructionsToBlock[instruction] + fun getBlockForInstruction(instruction: Instruction): BasicBlock { + val block = instructionToBlock[instruction] + requireNotNull(block) { "Instruction is not part of the control-flow graph." } + return block + } companion object { fun create(segments: List): ControlFlowGraph { @@ -88,7 +93,7 @@ private class ControlFlowGraphBuilder { ControlFlowGraph(blocks, instructionsToBlock) } -class BasicBlockImpl( +private class BasicBlockImpl( override val segment: InstructionSegment, override val start: Int, override val end: Int, @@ -98,14 +103,7 @@ class BasicBlockImpl( override val from: MutableList = mutableListOf() override val to: MutableList = mutableListOf() - fun linkTo(other: BasicBlockImpl) { - if (other !in to) { - to.add(other) - other.from.add(this) - } - } - - fun indexOfInstruction(instruction: Instruction): Int { + override fun indexOfInstruction(instruction: Instruction): Int { var index = -1 for (i in start until end) { @@ -117,6 +115,13 @@ class BasicBlockImpl( return index } + + fun linkTo(other: BasicBlockImpl) { + if (other !in to) { + to.add(other) + other.from.add(this) + } + } } private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: InstructionSegment) { diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetMapDesignations.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetMapDesignations.kt index 4e11bd98..037c7529 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetMapDesignations.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetMapDesignations.kt @@ -26,8 +26,10 @@ fun getMapDesignations( val areaId = getRegisterValue(cfg, inst, inst.args[0].value as Int) - if (areaId.size != 1) { - logger.warn { "Couldn't determine area ID for mapDesignate instruction." } + if (areaId.size > 1) { + logger.warn { + "Couldn't determine area ID for ${inst.opcode.mnemonic} instruction." + } continue } @@ -35,9 +37,9 @@ fun getMapDesignations( inst.args[0].value as Int + (if (inst.opcode == OP_MAP_DESIGNATE) 2 else 3) val variantId = getRegisterValue(cfg, inst, variantIdRegister) - if (variantId.size != 1) { + if (variantId.size > 1) { logger.warn { - "Couldn't determine area variant ID for mapDesignate instruction." + "Couldn't determine area variant ID for ${inst.opcode.mnemonic} instruction." } continue } diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValue.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValue.kt index e0dc32aa..c416541b 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValue.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValue.kt @@ -1,9 +1,226 @@ package world.phantasmal.lib.assembly.dataFlowAnalysis -import world.phantasmal.lib.assembly.Instruction +import mu.KotlinLogging +import world.phantasmal.lib.assembly.* +import kotlin.math.max +import kotlin.math.min + +private val logger = KotlinLogging.logger {} + +const val MIN_REGISTER_VALUE = MIN_SIGNED_DWORD_VALUE +const val MAX_REGISTER_VALUE = MAX_SIGNED_DWORD_VALUE /** * Computes the possible values of a register right before a specific instruction. */ -fun getRegisterValue(cfg: ControlFlowGraph, instruction: Instruction, register: Int): ValueSet = - TODO() +fun getRegisterValue(cfg: ControlFlowGraph, instruction: Instruction, register: Int): ValueSet { + val block = cfg.getBlockForInstruction(instruction) + + return findValues( + Context(), + mutableSetOf(), + block, + block.indexOfInstruction(instruction), + register + ) +} + +private class Context { + var iterations = 0 +} + +private fun findValues( + ctx: Context, + path: MutableSet, + block: BasicBlock, + end: Int, + register: Int, +): ValueSet { + if (++ctx.iterations > 100) { + logger.warn { "Too many iterations." } + return ValueSet.ofInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE) + } + + for (i in end - 1 downTo block.start) { + val instruction = block.segment.instructions[i] + val args = instruction.args + + when (instruction.opcode) { + OP_LET -> { + if (args[0].value == register) { + return findValues(ctx, LinkedHashSet(path), block, i, args[1].value as Int) + } + } + + OP_LETI, + OP_LETB, + OP_LETW, + OP_SYNC_LETI, + -> { + if (args[0].value == register) { + return ValueSet.of(args[1].value as Int) + } + } + + OP_SET -> { + if (args[0].value == register) { + return ValueSet.of(1) + } + } + + OP_CLEAR -> { + if (args[0].value == register) { + return ValueSet.of(0) + } + } + + OP_REV -> { + if (args[0].value == register) { + val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register) + + return if (prevVals.size == 1L && prevVals[0] == 0) { + ValueSet.of(1) + } else if (0 in prevVals) { + ValueSet.ofInterval(0, 1) + } else { + ValueSet.of(0) + } + } + } + + OP_ADDI -> { + if (args[0].value == register) { + val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register) + prevVals += args[1].value as Int + return prevVals + } + } + + OP_SUBI -> { + if (args[0].value == register) { + val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register) + prevVals -= args[1].value as Int + return prevVals + } + } + + OP_MULI -> { + if (args[0].value == register) { + val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register) + prevVals *= args[1].value as Int + return prevVals + } + } + + OP_DIVI -> { + if (args[0].value == register) { + val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register) + prevVals /= args[1].value as Int + return prevVals + } + } + + OP_IF_ZONE_CLEAR -> { + if (args[0].value == register) { + return ValueSet.ofInterval(0, 1) + } + } + + OP_GET_DIFFLVL -> { + if (args[0].value == register) { + return ValueSet.ofInterval(0, 2) + } + } + + OP_GET_SLOTNUMBER -> { + if (args[0].value == register) { + return ValueSet.ofInterval(0, 3) + } + } + + OP_GET_RANDOM -> { + if (args[1].value == register) { + // TODO: undefined values. + val min = findValues( + ctx, + LinkedHashSet(path), + block, + i, + args[0].value as Int + ).minOrNull()!! + + val max = max( + findValues( + ctx, + LinkedHashSet(path), + block, + i, + args[0].value as Int + 1 + ).maxOrNull()!!, + min + 1, + ) + + return ValueSet.ofInterval(min, max - 1) + } + } + + OP_STACK_PUSHM, + OP_STACK_POPM, + -> { + val minReg = args[0].value as Int + val maxReg = args[0].value as Int + args[1].value as Int + + if (register in minReg until maxReg) { + return ValueSet.ofInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE) + } + } + + else -> { + // Assume any other opcodes that write to the register can produce any value. + val params = instruction.opcode.params + val argLen = min(args.size, params.size) + + for (j in 0 until argLen) { + val param = params[j] + + if (param.type is RegTupRefType) { + val regRef = args[j].value as Int + + for ((k, reg_param) in param.type.registerTuples.withIndex()) { + if ((reg_param.access == ParamAccess.Write || + reg_param.access == ParamAccess.ReadWrite) && + regRef + k == register + ) { + return ValueSet.ofInterval( + MIN_REGISTER_VALUE, + MAX_REGISTER_VALUE, + ) + } + } + } + } + } + } + } + + val values = ValueSet.empty() + path.add(block) + + for (from in block.from) { + // Bail out from loops. + if (from in path) { + values.setInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE) + break + } + + values.union(findValues(ctx, LinkedHashSet(path), from, from.end, register)) + } + + // If values is empty at this point, we know nothing ever sets the register's value and it still + // has its initial value of 0. + if (values.isEmpty()) { + values.setValue(0) + } + + return values +} diff --git a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSet.kt b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSet.kt index 8cad104a..3b02b564 100644 --- a/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSet.kt +++ b/lib/src/commonMain/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSet.kt @@ -6,12 +6,10 @@ import kotlin.math.min /** * Represents a sorted set of integers. */ -class ValueSet : Iterable { - private val intervals: MutableList = mutableListOf() - - val size: Int +class ValueSet private constructor(private val intervals: MutableList) : Iterable { + val size: Long get() = - intervals.fold(0) { acc, i -> acc + i.end - i.start + 1 } + intervals.fold(0L) { acc, i -> acc + i.end - i.start + 1L } operator fun get(i: Int): Int? { var idx = i @@ -74,46 +72,99 @@ class ValueSet : Iterable { } /** - * Doesn't take into account integer overflow. + * Scalar addition. */ - fun scalarAdd(s: Int): ValueSet { - for (int in intervals) { - int.start += s - int.end += s + operator fun plusAssign(scalar: Int) { + if (scalar >= 0) { + var i = 0 + var addI = 0 + + while (i < intervals.size) { + val int = intervals[i] + val oldStart = int.start + val oldEnd = int.end + int.start += scalar + int.end += scalar + + if (int.start < oldStart) { + // Integer overflow of both start and end. + intervals.removeAt(i) + intervals.add(addI++, int) + } else if (int.end < oldEnd) { + // Integer overflow of end. + val newEnd = int.end + int.end = Int.MAX_VALUE + + if (newEnd + 1 == intervals.first().start) { + intervals.first().start = Int.MIN_VALUE + } else { + intervals.add(0, Interval(Int.MIN_VALUE, newEnd)) + addI++ + // Increment i twice because we left this interval and inserted a new one. + i++ + } + } + + i++ + } + } else { + var i = intervals.lastIndex + var addI = 0 + + while (i >= 0) { + val int = intervals[i] + val oldStart = int.start + val oldEnd = int.end + int.start += scalar + int.end += scalar + + if (int.end > oldEnd) { + // Integer underflow of both start and end. + intervals.removeAt(i) + intervals.add(intervals.size - addI++, int) + } else if (int.start > oldStart) { + // Integer underflow of start. + val newStart = int.start + int.start = Int.MIN_VALUE + + if (newStart - 1 == intervals.last().end) { + intervals.last().end = Int.MAX_VALUE + } else { + intervals.add(Interval(newStart, Int.MAX_VALUE)) + addI++ + } + } + + i-- + } } + } - return this + /** + * Scalar subtraction. + */ + operator fun minusAssign(scalar: Int) { + plusAssign(-scalar) } /** * Doesn't take into account integer overflow. */ - fun scalarSub(s: Int): ValueSet { - return scalarAdd(-s) - } - - /** - * Doesn't take into account integer overflow. - */ - fun scalarMul(s: Int): ValueSet { + operator fun timesAssign(s: Int) { for (int in intervals) { int.start *= s int.end *= s } - - return this } /** * Integer division. */ - fun scalarDiv(s: Int): ValueSet { + operator fun divAssign(s: Int) { for (int in intervals) { int.start = int.start / s int.end = int.end / s } - - return this } fun union(other: ValueSet): ValueSet { @@ -123,12 +174,12 @@ class ValueSet : Iterable { while (i < intervals.size) { val a = intervals[i] - if (b.end < a.start - 1) { + if (b.end < a.start - 1L) { // b lies entirely before a, insert it right before a. intervals.add(i, b.copy()) i++ continue@outer - } else if (b.start <= a.end + 1) { + } else if (b.start <= a.end + 1L) { // a and b overlap or form a continuous interval (e.g. [1, 2] and [3, 4]). a.start = min(a.start, b.start) @@ -136,7 +187,7 @@ class ValueSet : Iterable { val j = i + 1 while (j < intervals.size) { - if (b.end >= intervals[j].start - 1) { + if (b.end >= intervals[j].start - 1L) { a.end = intervals[j].end intervals.removeAt(j) } else { @@ -187,6 +238,24 @@ class ValueSet : Iterable { return v } } + + companion object { + /** + * Returns an empty [ValueSet]. + */ + fun empty(): ValueSet = ValueSet(mutableListOf()) + + /** + * Returns a [ValueSet] with a single initial [value]. + */ + fun of(value: Int): ValueSet = ValueSet(mutableListOf(Interval(value, value))) + + /** + * Returns a [ValueSet] with all values between [start] and [end], inclusively. + */ + fun ofInterval(start: Int, end: Int): ValueSet = + ValueSet(mutableListOf(Interval(start, end))) + } } /** diff --git a/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/AssemblyTests.kt b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/AssemblyTests.kt new file mode 100644 index 00000000..576774d6 --- /dev/null +++ b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/AssemblyTests.kt @@ -0,0 +1,35 @@ +package world.phantasmal.lib.assembly + +import world.phantasmal.core.Success +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class AssemblyTests { + @Test + fun assemble_basic_script() { + val result = assemble(""" + 0: + set_episode 0 + set_floor_handler 0, 150 + set_floor_handler 1, 151 + set_qt_success 250 + bb_map_designate 0, 0, 0, 0 + bb_map_designate 1, 1, 0, 0 + ret + 1: + ret + 250: + gset 101 + window_msg "You've been awarded 500 Meseta." + bgm 1 + winend + pl_add_meseta 0, 500 + ret + """.trimIndent().split('\n')) + + assertTrue(result is Success) + assertTrue(result.problems.isEmpty()) + assertEquals(3, result.value.size) + } +} diff --git a/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraphTests.kt b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraphTests.kt new file mode 100644 index 00000000..0a47e656 --- /dev/null +++ b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ControlFlowGraphTests.kt @@ -0,0 +1,158 @@ +package world.phantasmal.lib.assembly.dataFlowAnalysis + +import world.phantasmal.lib.test.toInstructions +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ControlFlowGraphTests { + @Test + fun single_instruction() { + val im = toInstructions(""" + 0: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + + assertEquals(1, cfg.blocks.size) + assertEquals(0, cfg.blocks[0].start) + assertEquals(1, cfg.blocks[0].end) + assertEquals(BranchType.Return, cfg.blocks[0].branchType) + assertTrue(cfg.blocks[0].from.isEmpty()) + assertTrue(cfg.blocks[0].to.isEmpty()) + assertTrue(cfg.blocks[0].branchLabels.isEmpty()) + } + + @Test + fun single_unconditional_jump() { + val im = toInstructions(""" + 0: + jmp 1 + 1: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + + assertEquals(2, cfg.blocks.size) + + assertEquals(0, cfg.blocks[0].start) + assertEquals(1, cfg.blocks[0].end) + assertEquals(BranchType.Jump, cfg.blocks[0].branchType) + assertEquals(0, cfg.blocks[0].from.size) + assertEquals(1, cfg.blocks[0].to.size) + assertEquals(1, cfg.blocks[0].branchLabels.size) + + assertEquals(0, cfg.blocks[1].start) + assertEquals(1, cfg.blocks[1].end) + assertEquals(BranchType.Return, cfg.blocks[1].branchType) + assertEquals(1, cfg.blocks[1].from.size) + assertEquals(0, cfg.blocks[1].to.size) + assertEquals(0, cfg.blocks[1].branchLabels.size) + } + + @Test + fun single_conditional_jump() { + val im = toInstructions(""" + 0: + jmp_= r1, r2, 1 + ret + 1: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + + assertEquals(3, cfg.blocks.size) + + assertEquals(0, cfg.blocks[0].start) + assertEquals(1, cfg.blocks[0].end) + assertEquals(BranchType.ConditionalJump, cfg.blocks[0].branchType) + assertEquals(0, cfg.blocks[0].from.size) + assertEquals(2, cfg.blocks[0].to.size) + assertEquals(1, cfg.blocks[0].branchLabels.size) + + assertEquals(1, cfg.blocks[1].start) + assertEquals(2, cfg.blocks[1].end) + assertEquals(BranchType.Return, cfg.blocks[1].branchType) + assertEquals(1, cfg.blocks[1].from.size) + assertEquals(0, cfg.blocks[1].to.size) + assertEquals(0, cfg.blocks[1].branchLabels.size) + + assertEquals(0, cfg.blocks[2].start) + assertEquals(1, cfg.blocks[2].end) + assertEquals(BranchType.Return, cfg.blocks[2].branchType) + assertEquals(1, cfg.blocks[2].from.size) + assertEquals(0, cfg.blocks[2].to.size) + assertEquals(0, cfg.blocks[2].branchLabels.size) + } + + @Test + fun single_call() { + val im = toInstructions(""" + 0: + call 1 + ret + 1: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + + assertEquals(3, cfg.blocks.size) + + assertEquals(0, cfg.blocks[0].start) + assertEquals(1, cfg.blocks[0].end) + assertEquals(BranchType.Call, cfg.blocks[0].branchType) + assertEquals(0, cfg.blocks[0].from.size) + assertEquals(1, cfg.blocks[0].to.size) + assertEquals(1, cfg.blocks[0].branchLabels.size) + + assertEquals(1, cfg.blocks[1].start) + assertEquals(2, cfg.blocks[1].end) + assertEquals(BranchType.Return, cfg.blocks[1].branchType) + assertEquals(1, cfg.blocks[1].from.size) + assertEquals(0, cfg.blocks[1].to.size) + assertEquals(0, cfg.blocks[1].branchLabels.size) + + assertEquals(0, cfg.blocks[2].start) + assertEquals(1, cfg.blocks[2].end) + assertEquals(BranchType.Return, cfg.blocks[2].branchType) + assertEquals(1, cfg.blocks[2].from.size) + assertEquals(1, cfg.blocks[2].to.size) + assertEquals(0, cfg.blocks[2].branchLabels.size) + } + + @Test + fun conditional_jump_with_fall_through() { + val im = toInstructions(""" + 0: + jmp_> r1, r2, 1 + nop + 1: + nop + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + + assertEquals(3, cfg.blocks.size) + + assertEquals(0, cfg.blocks[0].start) + assertEquals(1, cfg.blocks[0].end) + assertEquals(BranchType.ConditionalJump, cfg.blocks[0].branchType) + assertEquals(0, cfg.blocks[0].from.size) + assertEquals(2, cfg.blocks[0].to.size) + assertEquals(1, cfg.blocks[0].branchLabels.size) + + assertEquals(1, cfg.blocks[1].start) + assertEquals(2, cfg.blocks[1].end) + assertEquals(BranchType.None, cfg.blocks[1].branchType) + assertEquals(1, cfg.blocks[1].from.size) + assertEquals(1, cfg.blocks[1].to.size) + assertEquals(0, cfg.blocks[1].branchLabels.size) + + assertEquals(0, cfg.blocks[2].start) + assertEquals(2, cfg.blocks[2].end) + assertEquals(BranchType.Return, cfg.blocks[2].branchType) + assertEquals(2, cfg.blocks[2].from.size) + assertEquals(0, cfg.blocks[2].to.size) + assertEquals(0, cfg.blocks[2].branchLabels.size) + } +} diff --git a/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValueTests.kt b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValueTests.kt new file mode 100644 index 00000000..f72dbc6c --- /dev/null +++ b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/GetRegisterValueTests.kt @@ -0,0 +1,203 @@ +package world.phantasmal.lib.assembly.dataFlowAnalysis + +import world.phantasmal.lib.assembly.* +import world.phantasmal.lib.test.toInstructions +import kotlin.test.Test +import kotlin.test.assertEquals + +private const val MAX_REGISTER_VALUES_SIZE: Long = 1L shl 32 + +class GetRegisterValueTests { + @Test + fun when_no_instruction_sets_the_register_zero_is_returned() { + val im = toInstructions(""" + 0: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val values = getRegisterValue(cfg, im[0].instructions[0], 6) + + assertEquals(1L, values.size) + assertEquals(0, values[0]) + } + + @Test + fun a_single_register_assignment_results_in_one_value() { + val im = toInstructions(""" + 0: + leti r6, 1337 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val values = getRegisterValue(cfg, im[0].instructions[1], 6) + + assertEquals(1L, values.size) + assertEquals(1337, values[0]) + } + + @Test + fun two_assignments_in_separate_code_paths_results_in_two_values() { + val im = toInstructions(""" + 0: + jmp_> r1, r2, 1 + leti r10, 111 + jmp 2 + 1: + leti r10, 222 + 2: + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val values = getRegisterValue(cfg, im[2].instructions[0], 10) + + assertEquals(2L, values.size) + assertEquals(111, values[0]) + assertEquals(222, values[1]) + } + + @Test + fun bail_out_from_loops() { + val im = toInstructions(""" + 0: + addi r10, 5 + jmpi_< r10, 500, 0 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val values = getRegisterValue(cfg, im[0].instructions[2], 10) + + assertEquals(MAX_REGISTER_VALUES_SIZE, values.size) + } + + @Test + fun leta_and_leto() { + val im = toInstructions(""" + 0: + leta r0, r100 + leto r1, 100 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val r0 = getRegisterValue(cfg, im[0].instructions[2], 0) + + assertEquals(MAX_REGISTER_VALUES_SIZE, r0.size) + assertEquals(MIN_REGISTER_VALUE, r0.minOrNull()) + assertEquals(MAX_REGISTER_VALUE, r0.maxOrNull()) + + val r1 = getRegisterValue(cfg, im[0].instructions[2], 1) + + assertEquals(MAX_REGISTER_VALUES_SIZE, r1.size) + assertEquals(MIN_REGISTER_VALUE, r1.minOrNull()) + assertEquals(MAX_REGISTER_VALUE, r1.maxOrNull()) + } + + @Test + fun rev() { + val im = toInstructions(""" + 0: + leti r0, 10 + leti r1, 50 + get_random r0, r10 + rev r10 + leti r0, -10 + leti r1, 50 + get_random r0, r10 + rev r10 + leti r10, 0 + rev r10 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val v0 = getRegisterValue(cfg, im[0].instructions[4], 10) + + assertEquals(1L, v0.size) + assertEquals(0, v0[0]) + + val v1 = getRegisterValue(cfg, im[0].instructions[8], 10) + + assertEquals(2L, v1.size) + assertEquals(0, v1[0]) + assertEquals(1, v1[1]) + + val v2 = getRegisterValue(cfg, im[0].instructions[10], 10) + + assertEquals(1L, v2.size) + assertEquals(1, v2[0]) + } + + @Test + fun addi() { + testBranched(OP_ADDI, 25, 35) + } + + @Test + fun subi() { + testBranched(OP_SUBI, -5, 5) + } + + @Test + fun muli() { + testBranched(OP_MULI, 150, 300) + } + + @Test + fun divi() { + testBranched(OP_DIVI, 0, 1) + } + + /** + * Test an instruction taking a register and an integer. + * The instruction will be called with arguments r99, 15. r99 will be set to 10 or 20. + */ + private fun testBranched(opcode: Opcode, expected1: Int, expected2: Int) { + val im = toInstructions(""" + 0: + leti r99, 10 + jmpi_= r0, 100, 1 + leti r99, 20 + 1: + ${opcode.mnemonic} r99, 15 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val values = getRegisterValue(cfg, im[1].instructions[1], 99) + + assertEquals(2L, values.size) + assertEquals(expected1, values[0]) + assertEquals(expected2, values[1]) + } + + @Test + fun get_random() { + val im = toInstructions(""" + 0: + leti r0, 20 + leti r1, 20 + get_random r0, r10 + leti r1, 19 + get_random r0, r10 + leti r1, 25 + get_random r0, r10 + ret + """.trimIndent()) + val cfg = ControlFlowGraph.create(im) + val v0 = getRegisterValue(cfg, im[0].instructions[3], 10) + + assertEquals(1L, v0.size) + assertEquals(20, v0[0]) + + val v1 = getRegisterValue(cfg, im[0].instructions[5], 10) + + assertEquals(1L, v1.size) + assertEquals(20, v1[0]) + + val v2 = getRegisterValue(cfg, im[0].instructions[7], 10) + + assertEquals(5L, v2.size) + assertEquals(20, v2[0]) + assertEquals(21, v2[1]) + assertEquals(22, v2[2]) + assertEquals(23, v2[3]) + assertEquals(24, v2[4]) + } +} diff --git a/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSetTests.kt b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSetTests.kt index 66c0a561..cd573710 100644 --- a/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSetTests.kt +++ b/lib/src/commonTest/kotlin/world/phantasmal/lib/assembly/dataFlowAnalysis/ValueSetTests.kt @@ -8,17 +8,17 @@ import kotlin.test.assertTrue class ValueSetTests { @Test fun empty_set_has_size_0() { - val vs = ValueSet() + val vs = ValueSet.empty() - assertEquals(0, vs.size) + assertEquals(0L, vs.size) } @Test fun get() { - val vs = ValueSet().setInterval(10, 13) - .union(ValueSet().setInterval(20, 22)) + val vs = ValueSet.ofInterval(10, 13) + .union(ValueSet.ofInterval(20, 22)) - assertEquals(7, vs.size) + assertEquals(7L, vs.size) assertEquals(10, vs[0]) assertEquals(11, vs[1]) assertEquals(12, vs[2]) @@ -30,10 +30,10 @@ class ValueSetTests { @Test fun contains() { - val vs = ValueSet().setInterval(-20, 13) - .union(ValueSet().setInterval(20, 22)) + val vs = ValueSet.ofInterval(-20, 13) + .union(ValueSet.ofInterval(20, 22)) - assertEquals(37, vs.size) + assertEquals(37L, vs.size) assertFalse(-9001 in vs) assertFalse(-21 in vs) assertTrue(-20 in vs) @@ -48,38 +48,116 @@ class ValueSetTests { @Test fun setValue() { - val vs = ValueSet() + val vs = ValueSet.empty() vs.setValue(100) vs.setValue(4) vs.setValue(24324) - assertEquals(1, vs.size) + assertEquals(1L, vs.size) assertEquals(24324, vs[0]) } @Test - fun union() { - val vs = ValueSet() - .union(ValueSet().setValue(21)) - .union(ValueSet().setValue(4968)) + fun plusAssign_integer_overflow() { + // The set of all integers should stay the same after adding any integer. + for (i in Int.MIN_VALUE..Int.MAX_VALUE step 10_000_000) { + val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MAX_VALUE) + vs += i - assertEquals(2, vs.size) + assertEquals(1L shl 32, vs.size) + assertEquals(Int.MIN_VALUE, vs.minOrNull()) + assertEquals(Int.MAX_VALUE, vs.maxOrNull()) + } + + // Cause two intervals to split into three intervals. + val vs = ValueSet.ofInterval(5, 7) + vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE)) + vs += 1 + + assertEquals(6L, vs.size) + assertEquals(Int.MIN_VALUE, vs[0]) + assertEquals(6, vs[1]) + assertEquals(7, vs[2]) + assertEquals(8, vs[3]) + assertEquals(Int.MAX_VALUE - 1, vs[4]) + assertEquals(Int.MAX_VALUE, vs[5]) + + // Cause part of one interval to be joined to another. + vs.setInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2) + vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE)) + vs += 1 + + assertEquals(6L, vs.size) + assertEquals(Int.MIN_VALUE, vs[0]) + assertEquals(Int.MIN_VALUE + 1, vs[1]) + assertEquals(Int.MIN_VALUE + 2, vs[2]) + assertEquals(Int.MIN_VALUE + 3, vs[3]) + assertEquals(Int.MAX_VALUE - 1, vs[4]) + assertEquals(Int.MAX_VALUE, vs[5]) + } + + @Test + fun minusAssign_integer_underflow() { + // The set of all integers should stay the same after subtracting any integer. + for (i in Int.MIN_VALUE..Int.MAX_VALUE step 10_000_000) { + val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MAX_VALUE) + vs -= i + + assertEquals(1L shl 32, vs.size) + assertEquals(Int.MIN_VALUE, vs.minOrNull()) + assertEquals(Int.MAX_VALUE, vs.maxOrNull()) + } + + // Cause two intervals to split into three intervals. + val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2) + vs.union(ValueSet.ofInterval(5, 7)) + vs -= 1 + + assertEquals(6L, vs.size) + assertEquals(Int.MIN_VALUE, vs[0]) + assertEquals(Int.MIN_VALUE + 1, vs[1]) + assertEquals(4, vs[2]) + assertEquals(5, vs[3]) + assertEquals(6, vs[4]) + assertEquals(Int.MAX_VALUE, vs[5]) + + // Cause part of one interval to be joined to another. + vs.setInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2) + vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE)) + vs -= 1 + + assertEquals(6L, vs.size) + assertEquals(Int.MIN_VALUE, vs[0]) + assertEquals(Int.MIN_VALUE + 1, vs[1]) + assertEquals(Int.MAX_VALUE - 3, vs[2]) + assertEquals(Int.MAX_VALUE - 2, vs[3]) + assertEquals(Int.MAX_VALUE - 1, vs[4]) + assertEquals(Int.MAX_VALUE, vs[5]) + } + + @Test + fun union() { + val vs = ValueSet.empty() + .union(ValueSet.of(21)) + .union(ValueSet.of(4968)) + + assertEquals(2L, vs.size) assertEquals(21, vs[0]) assertEquals(4968, vs[1]) } @Test fun union_of_intervals() { - val vs = ValueSet() - .union(ValueSet().setInterval(10, 12)) - .union(ValueSet().setInterval(14, 16)) + val vs = ValueSet.empty() + .union(ValueSet.ofInterval(10, 12)) + .union(ValueSet.ofInterval(14, 16)) - assertEquals(6, vs.size) + assertEquals(6L, vs.size) assertTrue(arrayOf(10, 11, 12, 14, 15, 16).all { it in vs }) - vs.union(ValueSet().setInterval(13, 13)) + vs.union(ValueSet.ofInterval(13, 13)) - assertEquals(7, vs.size) + assertEquals(7L, vs.size) assertEquals(10, vs[0]) assertEquals(11, vs[1]) assertEquals(12, vs[2]) @@ -88,27 +166,27 @@ class ValueSetTests { assertEquals(15, vs[5]) assertEquals(16, vs[6]) - vs.union(ValueSet().setInterval(1, 2)) + vs.union(ValueSet.ofInterval(1, 2)) - assertEquals(9, vs.size) + assertEquals(9L, vs.size) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16).all { it in vs }) - vs.union(ValueSet().setInterval(30, 32)) + vs.union(ValueSet.ofInterval(30, 32)) - assertEquals(12, vs.size) + assertEquals(12L, vs.size) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 30, 31, 32).all { it in vs }) - vs.union(ValueSet().setInterval(20, 21)) + vs.union(ValueSet.ofInterval(20, 21)) - assertEquals(14, vs.size) + assertEquals(14L, vs.size) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 20, 21, 30, 31, 32).all { it in vs }) } @Test fun iterator() { - val vs = ValueSet() - .union(ValueSet().setInterval(5, 7)) - .union(ValueSet().setInterval(14, 16)) + val vs = ValueSet.empty() + .union(ValueSet.ofInterval(5, 7)) + .union(ValueSet.ofInterval(14, 16)) val iter = vs.iterator() diff --git a/lib/src/commonTest/kotlin/world/phantasmal/lib/test/TestUtils.kt b/lib/src/commonTest/kotlin/world/phantasmal/lib/test/TestUtils.kt new file mode 100644 index 00000000..1efb4465 --- /dev/null +++ b/lib/src/commonTest/kotlin/world/phantasmal/lib/test/TestUtils.kt @@ -0,0 +1,15 @@ +package world.phantasmal.lib.test + +import world.phantasmal.core.Success +import world.phantasmal.lib.assembly.InstructionSegment +import world.phantasmal.lib.assembly.assemble +import kotlin.test.assertTrue + +fun toInstructions(assembly: String): List { + val result = assemble(assembly.split('\n')) + + assertTrue(result is Success) + assertTrue(result.problems.isEmpty()) + + return result.value.filterIsInstance() +}