Added some more tests and getRegisterValue.

This commit is contained in:
Daan Vanden Bosch 2020-10-22 23:03:23 +02:00
parent 78690f9588
commit 532a608e7a
12 changed files with 882 additions and 90 deletions

View File

@ -217,11 +217,12 @@ private class Assembler(private val assembly: List<String>, private val manualSt
} }
} }
private fun addError(col: Int, length: Int, message: String) { private fun addError(col: Int, length: Int, uiMessage: String, message: String? = null) {
result.addProblem( result.addProblem(
AssemblyProblem( AssemblyProblem(
Severity.Error, Severity.Error,
message, uiMessage,
message ?: "$uiMessage At $lineNo:$col.",
lineNo = lineNo, lineNo = lineNo,
col = col, col = col,
length = length length = length
@ -229,19 +230,19 @@ private class Assembler(private val assembly: List<String>, private val manualSt
) )
} }
private fun addError(token: Token, message: String) { private fun addError(token: Token, uiMessage: String, message: String? = null) {
addError(token.col, token.len, message) addError(token.col, token.len, uiMessage, message)
} }
private fun addUnexpectedTokenError(token: Token) { private fun addUnexpectedTokenError(token: Token) {
addError(token, "Unexpected token.") addError(token, "Unexpected token.", "Unexpected ${token::class.simpleName} at ${token.srcLoc()}.")
} }
private fun addWarning(token: Token, message: String) { private fun addWarning(token: Token, uiMessage: String) {
result.addProblem( result.addProblem(
AssemblyProblem( AssemblyProblem(
Severity.Warning, Severity.Warning,
message, uiMessage,
lineNo = lineNo, lineNo = lineNo,
col = token.col, col = token.col,
length = token.len length = token.len
@ -252,7 +253,7 @@ private class Assembler(private val assembly: List<String>, private val manualSt
private fun parseLabel(token: LabelToken) { private fun parseLabel(token: LabelToken) {
val label = token.value val label = token.value
if (labels.add(label)) { if (!labels.add(label)) {
addError(token, "Duplicate label.") addError(token, "Duplicate label.")
} }
@ -653,7 +654,7 @@ private class Assembler(private val assembly: List<String>, private val manualSt
// Minimum of the signed version of this integer type. // Minimum of the signed version of this integer type.
val minValue = -(1 shl (bitSize - 1)) val minValue = -(1 shl (bitSize - 1))
// Maximum of the unsigned version of this integer type. // Maximum of the unsigned version of this integer type.
val maxValue = (1 shl (bitSize)) - 1 val maxValue = (1L shl (bitSize)) - 1L
when { when {
value < minValue -> { value < minValue -> {
@ -713,4 +714,6 @@ private class Assembler(private val assembly: List<String>, private val manualSt
addString(token.value.replace("\n", "<cr>")) addString(token.value.replace("\n", "<cr>"))
} }
private fun Token.srcLoc(): String = "$lineNo:$col"
} }

View File

@ -191,7 +191,6 @@ private class LineTokenizer(private var line: String) {
return tokenizeHexNumber(col) return tokenizeHexNumber(col)
} else if (char == ':') { } else if (char == ':') {
isLabel = true isLabel = true
skip()
break break
} else if (char == ',' || char.isWhitespace()) { } else if (char == ',' || char.isWhitespace()) {
break break
@ -201,7 +200,14 @@ private class LineTokenizer(private var line: String) {
} }
val value = slice().toIntOrNull() val value = slice().toIntOrNull()
?: return InvalidNumberToken(col, markedLen())
if (isLabel) {
skip()
}
if (value == null) {
return InvalidNumberToken(col, markedLen())
}
return if (isLabel) { return if (isLabel) {
LabelToken(col, markedLen(), value) LabelToken(col, markedLen(), value)

View File

@ -23,14 +23,15 @@ class Instruction(
for (i in 0 until len) { for (i in 0 until len) {
val type = opcode.params[i].type val type = opcode.params[i].type
val arg = args[i] val arg = args[i]
paramToArgs[i] = mutableListOf() val pArgs = mutableListOf<Arg>()
paramToArgs.add(pArgs)
if (type is ILabelVarType || type is RegRefVarType) { if (type is ILabelVarType || type is RegRefVarType) {
for (j in i until args.size) { for (j in i until args.size) {
paramToArgs[i].add(args[j]) pArgs.add(args[j])
} }
} else { } else {
paramToArgs[i].add(arg) pArgs.add(arg)
} }
} }

View File

@ -52,6 +52,8 @@ interface BasicBlock {
* The blocks this block branches to. * The blocks this block branches to.
*/ */
val to: List<BasicBlock> val to: List<BasicBlock>
fun indexOfInstruction(instruction: Instruction): Int
} }
/** /**
@ -59,10 +61,13 @@ interface BasicBlock {
*/ */
class ControlFlowGraph( class ControlFlowGraph(
val blocks: List<BasicBlock>, val blocks: List<BasicBlock>,
private val instructionsToBlock: Map<Instruction, BasicBlock>, private val instructionToBlock: Map<Instruction, BasicBlock>,
) { ) {
fun getBlockForInstruction(instruction: Instruction): BasicBlock? = fun getBlockForInstruction(instruction: Instruction): BasicBlock {
instructionsToBlock[instruction] val block = instructionToBlock[instruction]
requireNotNull(block) { "Instruction is not part of the control-flow graph." }
return block
}
companion object { companion object {
fun create(segments: List<InstructionSegment>): ControlFlowGraph { fun create(segments: List<InstructionSegment>): ControlFlowGraph {
@ -88,7 +93,7 @@ private class ControlFlowGraphBuilder {
ControlFlowGraph(blocks, instructionsToBlock) ControlFlowGraph(blocks, instructionsToBlock)
} }
class BasicBlockImpl( private class BasicBlockImpl(
override val segment: InstructionSegment, override val segment: InstructionSegment,
override val start: Int, override val start: Int,
override val end: Int, override val end: Int,
@ -98,14 +103,7 @@ class BasicBlockImpl(
override val from: MutableList<BasicBlockImpl> = mutableListOf() override val from: MutableList<BasicBlockImpl> = mutableListOf()
override val to: MutableList<BasicBlockImpl> = mutableListOf() override val to: MutableList<BasicBlockImpl> = mutableListOf()
fun linkTo(other: BasicBlockImpl) { override fun indexOfInstruction(instruction: Instruction): Int {
if (other !in to) {
to.add(other)
other.from.add(this)
}
}
fun indexOfInstruction(instruction: Instruction): Int {
var index = -1 var index = -1
for (i in start until end) { for (i in start until end) {
@ -117,6 +115,13 @@ class BasicBlockImpl(
return index return index
} }
fun linkTo(other: BasicBlockImpl) {
if (other !in to) {
to.add(other)
other.from.add(this)
}
}
} }
private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: InstructionSegment) { private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: InstructionSegment) {

View File

@ -26,8 +26,10 @@ fun getMapDesignations(
val areaId = getRegisterValue(cfg, inst, inst.args[0].value as Int) val areaId = getRegisterValue(cfg, inst, inst.args[0].value as Int)
if (areaId.size != 1) { if (areaId.size > 1) {
logger.warn { "Couldn't determine area ID for mapDesignate instruction." } logger.warn {
"Couldn't determine area ID for ${inst.opcode.mnemonic} instruction."
}
continue continue
} }
@ -35,9 +37,9 @@ fun getMapDesignations(
inst.args[0].value as Int + (if (inst.opcode == OP_MAP_DESIGNATE) 2 else 3) inst.args[0].value as Int + (if (inst.opcode == OP_MAP_DESIGNATE) 2 else 3)
val variantId = getRegisterValue(cfg, inst, variantIdRegister) val variantId = getRegisterValue(cfg, inst, variantIdRegister)
if (variantId.size != 1) { if (variantId.size > 1) {
logger.warn { logger.warn {
"Couldn't determine area variant ID for mapDesignate instruction." "Couldn't determine area variant ID for ${inst.opcode.mnemonic} instruction."
} }
continue continue
} }

View File

@ -1,9 +1,226 @@
package world.phantasmal.lib.assembly.dataFlowAnalysis package world.phantasmal.lib.assembly.dataFlowAnalysis
import world.phantasmal.lib.assembly.Instruction import mu.KotlinLogging
import world.phantasmal.lib.assembly.*
import kotlin.math.max
import kotlin.math.min
private val logger = KotlinLogging.logger {}
const val MIN_REGISTER_VALUE = MIN_SIGNED_DWORD_VALUE
const val MAX_REGISTER_VALUE = MAX_SIGNED_DWORD_VALUE
/** /**
* Computes the possible values of a register right before a specific instruction. * Computes the possible values of a register right before a specific instruction.
*/ */
fun getRegisterValue(cfg: ControlFlowGraph, instruction: Instruction, register: Int): ValueSet = fun getRegisterValue(cfg: ControlFlowGraph, instruction: Instruction, register: Int): ValueSet {
TODO() val block = cfg.getBlockForInstruction(instruction)
return findValues(
Context(),
mutableSetOf(),
block,
block.indexOfInstruction(instruction),
register
)
}
private class Context {
var iterations = 0
}
private fun findValues(
ctx: Context,
path: MutableSet<BasicBlock>,
block: BasicBlock,
end: Int,
register: Int,
): ValueSet {
if (++ctx.iterations > 100) {
logger.warn { "Too many iterations." }
return ValueSet.ofInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE)
}
for (i in end - 1 downTo block.start) {
val instruction = block.segment.instructions[i]
val args = instruction.args
when (instruction.opcode) {
OP_LET -> {
if (args[0].value == register) {
return findValues(ctx, LinkedHashSet(path), block, i, args[1].value as Int)
}
}
OP_LETI,
OP_LETB,
OP_LETW,
OP_SYNC_LETI,
-> {
if (args[0].value == register) {
return ValueSet.of(args[1].value as Int)
}
}
OP_SET -> {
if (args[0].value == register) {
return ValueSet.of(1)
}
}
OP_CLEAR -> {
if (args[0].value == register) {
return ValueSet.of(0)
}
}
OP_REV -> {
if (args[0].value == register) {
val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register)
return if (prevVals.size == 1L && prevVals[0] == 0) {
ValueSet.of(1)
} else if (0 in prevVals) {
ValueSet.ofInterval(0, 1)
} else {
ValueSet.of(0)
}
}
}
OP_ADDI -> {
if (args[0].value == register) {
val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register)
prevVals += args[1].value as Int
return prevVals
}
}
OP_SUBI -> {
if (args[0].value == register) {
val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register)
prevVals -= args[1].value as Int
return prevVals
}
}
OP_MULI -> {
if (args[0].value == register) {
val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register)
prevVals *= args[1].value as Int
return prevVals
}
}
OP_DIVI -> {
if (args[0].value == register) {
val prevVals = findValues(ctx, LinkedHashSet(path), block, i, register)
prevVals /= args[1].value as Int
return prevVals
}
}
OP_IF_ZONE_CLEAR -> {
if (args[0].value == register) {
return ValueSet.ofInterval(0, 1)
}
}
OP_GET_DIFFLVL -> {
if (args[0].value == register) {
return ValueSet.ofInterval(0, 2)
}
}
OP_GET_SLOTNUMBER -> {
if (args[0].value == register) {
return ValueSet.ofInterval(0, 3)
}
}
OP_GET_RANDOM -> {
if (args[1].value == register) {
// TODO: undefined values.
val min = findValues(
ctx,
LinkedHashSet(path),
block,
i,
args[0].value as Int
).minOrNull()!!
val max = max(
findValues(
ctx,
LinkedHashSet(path),
block,
i,
args[0].value as Int + 1
).maxOrNull()!!,
min + 1,
)
return ValueSet.ofInterval(min, max - 1)
}
}
OP_STACK_PUSHM,
OP_STACK_POPM,
-> {
val minReg = args[0].value as Int
val maxReg = args[0].value as Int + args[1].value as Int
if (register in minReg until maxReg) {
return ValueSet.ofInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE)
}
}
else -> {
// Assume any other opcodes that write to the register can produce any value.
val params = instruction.opcode.params
val argLen = min(args.size, params.size)
for (j in 0 until argLen) {
val param = params[j]
if (param.type is RegTupRefType) {
val regRef = args[j].value as Int
for ((k, reg_param) in param.type.registerTuples.withIndex()) {
if ((reg_param.access == ParamAccess.Write ||
reg_param.access == ParamAccess.ReadWrite) &&
regRef + k == register
) {
return ValueSet.ofInterval(
MIN_REGISTER_VALUE,
MAX_REGISTER_VALUE,
)
}
}
}
}
}
}
}
val values = ValueSet.empty()
path.add(block)
for (from in block.from) {
// Bail out from loops.
if (from in path) {
values.setInterval(MIN_REGISTER_VALUE, MAX_REGISTER_VALUE)
break
}
values.union(findValues(ctx, LinkedHashSet(path), from, from.end, register))
}
// If values is empty at this point, we know nothing ever sets the register's value and it still
// has its initial value of 0.
if (values.isEmpty()) {
values.setValue(0)
}
return values
}

View File

@ -6,12 +6,10 @@ import kotlin.math.min
/** /**
* Represents a sorted set of integers. * Represents a sorted set of integers.
*/ */
class ValueSet : Iterable<Int> { class ValueSet private constructor(private val intervals: MutableList<Interval>) : Iterable<Int> {
private val intervals: MutableList<Interval> = mutableListOf() val size: Long
val size: Int
get() = get() =
intervals.fold(0) { acc, i -> acc + i.end - i.start + 1 } intervals.fold(0L) { acc, i -> acc + i.end - i.start + 1L }
operator fun get(i: Int): Int? { operator fun get(i: Int): Int? {
var idx = i var idx = i
@ -74,46 +72,99 @@ class ValueSet : Iterable<Int> {
} }
/** /**
* Doesn't take into account integer overflow. * Scalar addition.
*/ */
fun scalarAdd(s: Int): ValueSet { operator fun plusAssign(scalar: Int) {
for (int in intervals) { if (scalar >= 0) {
int.start += s var i = 0
int.end += s var addI = 0
while (i < intervals.size) {
val int = intervals[i]
val oldStart = int.start
val oldEnd = int.end
int.start += scalar
int.end += scalar
if (int.start < oldStart) {
// Integer overflow of both start and end.
intervals.removeAt(i)
intervals.add(addI++, int)
} else if (int.end < oldEnd) {
// Integer overflow of end.
val newEnd = int.end
int.end = Int.MAX_VALUE
if (newEnd + 1 == intervals.first().start) {
intervals.first().start = Int.MIN_VALUE
} else {
intervals.add(0, Interval(Int.MIN_VALUE, newEnd))
addI++
// Increment i twice because we left this interval and inserted a new one.
i++
}
}
i++
}
} else {
var i = intervals.lastIndex
var addI = 0
while (i >= 0) {
val int = intervals[i]
val oldStart = int.start
val oldEnd = int.end
int.start += scalar
int.end += scalar
if (int.end > oldEnd) {
// Integer underflow of both start and end.
intervals.removeAt(i)
intervals.add(intervals.size - addI++, int)
} else if (int.start > oldStart) {
// Integer underflow of start.
val newStart = int.start
int.start = Int.MIN_VALUE
if (newStart - 1 == intervals.last().end) {
intervals.last().end = Int.MAX_VALUE
} else {
intervals.add(Interval(newStart, Int.MAX_VALUE))
addI++
}
}
i--
}
} }
}
return this /**
* Scalar subtraction.
*/
operator fun minusAssign(scalar: Int) {
plusAssign(-scalar)
} }
/** /**
* Doesn't take into account integer overflow. * Doesn't take into account integer overflow.
*/ */
fun scalarSub(s: Int): ValueSet { operator fun timesAssign(s: Int) {
return scalarAdd(-s)
}
/**
* Doesn't take into account integer overflow.
*/
fun scalarMul(s: Int): ValueSet {
for (int in intervals) { for (int in intervals) {
int.start *= s int.start *= s
int.end *= s int.end *= s
} }
return this
} }
/** /**
* Integer division. * Integer division.
*/ */
fun scalarDiv(s: Int): ValueSet { operator fun divAssign(s: Int) {
for (int in intervals) { for (int in intervals) {
int.start = int.start / s int.start = int.start / s
int.end = int.end / s int.end = int.end / s
} }
return this
} }
fun union(other: ValueSet): ValueSet { fun union(other: ValueSet): ValueSet {
@ -123,12 +174,12 @@ class ValueSet : Iterable<Int> {
while (i < intervals.size) { while (i < intervals.size) {
val a = intervals[i] val a = intervals[i]
if (b.end < a.start - 1) { if (b.end < a.start - 1L) {
// b lies entirely before a, insert it right before a. // b lies entirely before a, insert it right before a.
intervals.add(i, b.copy()) intervals.add(i, b.copy())
i++ i++
continue@outer continue@outer
} else if (b.start <= a.end + 1) { } else if (b.start <= a.end + 1L) {
// a and b overlap or form a continuous interval (e.g. [1, 2] and [3, 4]). // a and b overlap or form a continuous interval (e.g. [1, 2] and [3, 4]).
a.start = min(a.start, b.start) a.start = min(a.start, b.start)
@ -136,7 +187,7 @@ class ValueSet : Iterable<Int> {
val j = i + 1 val j = i + 1
while (j < intervals.size) { while (j < intervals.size) {
if (b.end >= intervals[j].start - 1) { if (b.end >= intervals[j].start - 1L) {
a.end = intervals[j].end a.end = intervals[j].end
intervals.removeAt(j) intervals.removeAt(j)
} else { } else {
@ -187,6 +238,24 @@ class ValueSet : Iterable<Int> {
return v return v
} }
} }
companion object {
/**
* Returns an empty [ValueSet].
*/
fun empty(): ValueSet = ValueSet(mutableListOf())
/**
* Returns a [ValueSet] with a single initial [value].
*/
fun of(value: Int): ValueSet = ValueSet(mutableListOf(Interval(value, value)))
/**
* Returns a [ValueSet] with all values between [start] and [end], inclusively.
*/
fun ofInterval(start: Int, end: Int): ValueSet =
ValueSet(mutableListOf(Interval(start, end)))
}
} }
/** /**

View File

@ -0,0 +1,35 @@
package world.phantasmal.lib.assembly
import world.phantasmal.core.Success
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class AssemblyTests {
@Test
fun assemble_basic_script() {
val result = assemble("""
0:
set_episode 0
set_floor_handler 0, 150
set_floor_handler 1, 151
set_qt_success 250
bb_map_designate 0, 0, 0, 0
bb_map_designate 1, 1, 0, 0
ret
1:
ret
250:
gset 101
window_msg "You've been awarded 500 Meseta."
bgm 1
winend
pl_add_meseta 0, 500
ret
""".trimIndent().split('\n'))
assertTrue(result is Success)
assertTrue(result.problems.isEmpty())
assertEquals(3, result.value.size)
}
}

View File

@ -0,0 +1,158 @@
package world.phantasmal.lib.assembly.dataFlowAnalysis
import world.phantasmal.lib.test.toInstructions
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class ControlFlowGraphTests {
@Test
fun single_instruction() {
val im = toInstructions("""
0:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
assertEquals(1, cfg.blocks.size)
assertEquals(0, cfg.blocks[0].start)
assertEquals(1, cfg.blocks[0].end)
assertEquals(BranchType.Return, cfg.blocks[0].branchType)
assertTrue(cfg.blocks[0].from.isEmpty())
assertTrue(cfg.blocks[0].to.isEmpty())
assertTrue(cfg.blocks[0].branchLabels.isEmpty())
}
@Test
fun single_unconditional_jump() {
val im = toInstructions("""
0:
jmp 1
1:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
assertEquals(2, cfg.blocks.size)
assertEquals(0, cfg.blocks[0].start)
assertEquals(1, cfg.blocks[0].end)
assertEquals(BranchType.Jump, cfg.blocks[0].branchType)
assertEquals(0, cfg.blocks[0].from.size)
assertEquals(1, cfg.blocks[0].to.size)
assertEquals(1, cfg.blocks[0].branchLabels.size)
assertEquals(0, cfg.blocks[1].start)
assertEquals(1, cfg.blocks[1].end)
assertEquals(BranchType.Return, cfg.blocks[1].branchType)
assertEquals(1, cfg.blocks[1].from.size)
assertEquals(0, cfg.blocks[1].to.size)
assertEquals(0, cfg.blocks[1].branchLabels.size)
}
@Test
fun single_conditional_jump() {
val im = toInstructions("""
0:
jmp_= r1, r2, 1
ret
1:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
assertEquals(3, cfg.blocks.size)
assertEquals(0, cfg.blocks[0].start)
assertEquals(1, cfg.blocks[0].end)
assertEquals(BranchType.ConditionalJump, cfg.blocks[0].branchType)
assertEquals(0, cfg.blocks[0].from.size)
assertEquals(2, cfg.blocks[0].to.size)
assertEquals(1, cfg.blocks[0].branchLabels.size)
assertEquals(1, cfg.blocks[1].start)
assertEquals(2, cfg.blocks[1].end)
assertEquals(BranchType.Return, cfg.blocks[1].branchType)
assertEquals(1, cfg.blocks[1].from.size)
assertEquals(0, cfg.blocks[1].to.size)
assertEquals(0, cfg.blocks[1].branchLabels.size)
assertEquals(0, cfg.blocks[2].start)
assertEquals(1, cfg.blocks[2].end)
assertEquals(BranchType.Return, cfg.blocks[2].branchType)
assertEquals(1, cfg.blocks[2].from.size)
assertEquals(0, cfg.blocks[2].to.size)
assertEquals(0, cfg.blocks[2].branchLabels.size)
}
@Test
fun single_call() {
val im = toInstructions("""
0:
call 1
ret
1:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
assertEquals(3, cfg.blocks.size)
assertEquals(0, cfg.blocks[0].start)
assertEquals(1, cfg.blocks[0].end)
assertEquals(BranchType.Call, cfg.blocks[0].branchType)
assertEquals(0, cfg.blocks[0].from.size)
assertEquals(1, cfg.blocks[0].to.size)
assertEquals(1, cfg.blocks[0].branchLabels.size)
assertEquals(1, cfg.blocks[1].start)
assertEquals(2, cfg.blocks[1].end)
assertEquals(BranchType.Return, cfg.blocks[1].branchType)
assertEquals(1, cfg.blocks[1].from.size)
assertEquals(0, cfg.blocks[1].to.size)
assertEquals(0, cfg.blocks[1].branchLabels.size)
assertEquals(0, cfg.blocks[2].start)
assertEquals(1, cfg.blocks[2].end)
assertEquals(BranchType.Return, cfg.blocks[2].branchType)
assertEquals(1, cfg.blocks[2].from.size)
assertEquals(1, cfg.blocks[2].to.size)
assertEquals(0, cfg.blocks[2].branchLabels.size)
}
@Test
fun conditional_jump_with_fall_through() {
val im = toInstructions("""
0:
jmp_> r1, r2, 1
nop
1:
nop
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
assertEquals(3, cfg.blocks.size)
assertEquals(0, cfg.blocks[0].start)
assertEquals(1, cfg.blocks[0].end)
assertEquals(BranchType.ConditionalJump, cfg.blocks[0].branchType)
assertEquals(0, cfg.blocks[0].from.size)
assertEquals(2, cfg.blocks[0].to.size)
assertEquals(1, cfg.blocks[0].branchLabels.size)
assertEquals(1, cfg.blocks[1].start)
assertEquals(2, cfg.blocks[1].end)
assertEquals(BranchType.None, cfg.blocks[1].branchType)
assertEquals(1, cfg.blocks[1].from.size)
assertEquals(1, cfg.blocks[1].to.size)
assertEquals(0, cfg.blocks[1].branchLabels.size)
assertEquals(0, cfg.blocks[2].start)
assertEquals(2, cfg.blocks[2].end)
assertEquals(BranchType.Return, cfg.blocks[2].branchType)
assertEquals(2, cfg.blocks[2].from.size)
assertEquals(0, cfg.blocks[2].to.size)
assertEquals(0, cfg.blocks[2].branchLabels.size)
}
}

View File

@ -0,0 +1,203 @@
package world.phantasmal.lib.assembly.dataFlowAnalysis
import world.phantasmal.lib.assembly.*
import world.phantasmal.lib.test.toInstructions
import kotlin.test.Test
import kotlin.test.assertEquals
private const val MAX_REGISTER_VALUES_SIZE: Long = 1L shl 32
class GetRegisterValueTests {
@Test
fun when_no_instruction_sets_the_register_zero_is_returned() {
val im = toInstructions("""
0:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val values = getRegisterValue(cfg, im[0].instructions[0], 6)
assertEquals(1L, values.size)
assertEquals(0, values[0])
}
@Test
fun a_single_register_assignment_results_in_one_value() {
val im = toInstructions("""
0:
leti r6, 1337
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val values = getRegisterValue(cfg, im[0].instructions[1], 6)
assertEquals(1L, values.size)
assertEquals(1337, values[0])
}
@Test
fun two_assignments_in_separate_code_paths_results_in_two_values() {
val im = toInstructions("""
0:
jmp_> r1, r2, 1
leti r10, 111
jmp 2
1:
leti r10, 222
2:
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val values = getRegisterValue(cfg, im[2].instructions[0], 10)
assertEquals(2L, values.size)
assertEquals(111, values[0])
assertEquals(222, values[1])
}
@Test
fun bail_out_from_loops() {
val im = toInstructions("""
0:
addi r10, 5
jmpi_< r10, 500, 0
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val values = getRegisterValue(cfg, im[0].instructions[2], 10)
assertEquals(MAX_REGISTER_VALUES_SIZE, values.size)
}
@Test
fun leta_and_leto() {
val im = toInstructions("""
0:
leta r0, r100
leto r1, 100
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val r0 = getRegisterValue(cfg, im[0].instructions[2], 0)
assertEquals(MAX_REGISTER_VALUES_SIZE, r0.size)
assertEquals(MIN_REGISTER_VALUE, r0.minOrNull())
assertEquals(MAX_REGISTER_VALUE, r0.maxOrNull())
val r1 = getRegisterValue(cfg, im[0].instructions[2], 1)
assertEquals(MAX_REGISTER_VALUES_SIZE, r1.size)
assertEquals(MIN_REGISTER_VALUE, r1.minOrNull())
assertEquals(MAX_REGISTER_VALUE, r1.maxOrNull())
}
@Test
fun rev() {
val im = toInstructions("""
0:
leti r0, 10
leti r1, 50
get_random r0, r10
rev r10
leti r0, -10
leti r1, 50
get_random r0, r10
rev r10
leti r10, 0
rev r10
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val v0 = getRegisterValue(cfg, im[0].instructions[4], 10)
assertEquals(1L, v0.size)
assertEquals(0, v0[0])
val v1 = getRegisterValue(cfg, im[0].instructions[8], 10)
assertEquals(2L, v1.size)
assertEquals(0, v1[0])
assertEquals(1, v1[1])
val v2 = getRegisterValue(cfg, im[0].instructions[10], 10)
assertEquals(1L, v2.size)
assertEquals(1, v2[0])
}
@Test
fun addi() {
testBranched(OP_ADDI, 25, 35)
}
@Test
fun subi() {
testBranched(OP_SUBI, -5, 5)
}
@Test
fun muli() {
testBranched(OP_MULI, 150, 300)
}
@Test
fun divi() {
testBranched(OP_DIVI, 0, 1)
}
/**
* Test an instruction taking a register and an integer.
* The instruction will be called with arguments r99, 15. r99 will be set to 10 or 20.
*/
private fun testBranched(opcode: Opcode, expected1: Int, expected2: Int) {
val im = toInstructions("""
0:
leti r99, 10
jmpi_= r0, 100, 1
leti r99, 20
1:
${opcode.mnemonic} r99, 15
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val values = getRegisterValue(cfg, im[1].instructions[1], 99)
assertEquals(2L, values.size)
assertEquals(expected1, values[0])
assertEquals(expected2, values[1])
}
@Test
fun get_random() {
val im = toInstructions("""
0:
leti r0, 20
leti r1, 20
get_random r0, r10
leti r1, 19
get_random r0, r10
leti r1, 25
get_random r0, r10
ret
""".trimIndent())
val cfg = ControlFlowGraph.create(im)
val v0 = getRegisterValue(cfg, im[0].instructions[3], 10)
assertEquals(1L, v0.size)
assertEquals(20, v0[0])
val v1 = getRegisterValue(cfg, im[0].instructions[5], 10)
assertEquals(1L, v1.size)
assertEquals(20, v1[0])
val v2 = getRegisterValue(cfg, im[0].instructions[7], 10)
assertEquals(5L, v2.size)
assertEquals(20, v2[0])
assertEquals(21, v2[1])
assertEquals(22, v2[2])
assertEquals(23, v2[3])
assertEquals(24, v2[4])
}
}

View File

@ -8,17 +8,17 @@ import kotlin.test.assertTrue
class ValueSetTests { class ValueSetTests {
@Test @Test
fun empty_set_has_size_0() { fun empty_set_has_size_0() {
val vs = ValueSet() val vs = ValueSet.empty()
assertEquals(0, vs.size) assertEquals(0L, vs.size)
} }
@Test @Test
fun get() { fun get() {
val vs = ValueSet().setInterval(10, 13) val vs = ValueSet.ofInterval(10, 13)
.union(ValueSet().setInterval(20, 22)) .union(ValueSet.ofInterval(20, 22))
assertEquals(7, vs.size) assertEquals(7L, vs.size)
assertEquals(10, vs[0]) assertEquals(10, vs[0])
assertEquals(11, vs[1]) assertEquals(11, vs[1])
assertEquals(12, vs[2]) assertEquals(12, vs[2])
@ -30,10 +30,10 @@ class ValueSetTests {
@Test @Test
fun contains() { fun contains() {
val vs = ValueSet().setInterval(-20, 13) val vs = ValueSet.ofInterval(-20, 13)
.union(ValueSet().setInterval(20, 22)) .union(ValueSet.ofInterval(20, 22))
assertEquals(37, vs.size) assertEquals(37L, vs.size)
assertFalse(-9001 in vs) assertFalse(-9001 in vs)
assertFalse(-21 in vs) assertFalse(-21 in vs)
assertTrue(-20 in vs) assertTrue(-20 in vs)
@ -48,38 +48,116 @@ class ValueSetTests {
@Test @Test
fun setValue() { fun setValue() {
val vs = ValueSet() val vs = ValueSet.empty()
vs.setValue(100) vs.setValue(100)
vs.setValue(4) vs.setValue(4)
vs.setValue(24324) vs.setValue(24324)
assertEquals(1, vs.size) assertEquals(1L, vs.size)
assertEquals(24324, vs[0]) assertEquals(24324, vs[0])
} }
@Test @Test
fun union() { fun plusAssign_integer_overflow() {
val vs = ValueSet() // The set of all integers should stay the same after adding any integer.
.union(ValueSet().setValue(21)) for (i in Int.MIN_VALUE..Int.MAX_VALUE step 10_000_000) {
.union(ValueSet().setValue(4968)) val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MAX_VALUE)
vs += i
assertEquals(2, vs.size) assertEquals(1L shl 32, vs.size)
assertEquals(Int.MIN_VALUE, vs.minOrNull())
assertEquals(Int.MAX_VALUE, vs.maxOrNull())
}
// Cause two intervals to split into three intervals.
val vs = ValueSet.ofInterval(5, 7)
vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE))
vs += 1
assertEquals(6L, vs.size)
assertEquals(Int.MIN_VALUE, vs[0])
assertEquals(6, vs[1])
assertEquals(7, vs[2])
assertEquals(8, vs[3])
assertEquals(Int.MAX_VALUE - 1, vs[4])
assertEquals(Int.MAX_VALUE, vs[5])
// Cause part of one interval to be joined to another.
vs.setInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2)
vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE))
vs += 1
assertEquals(6L, vs.size)
assertEquals(Int.MIN_VALUE, vs[0])
assertEquals(Int.MIN_VALUE + 1, vs[1])
assertEquals(Int.MIN_VALUE + 2, vs[2])
assertEquals(Int.MIN_VALUE + 3, vs[3])
assertEquals(Int.MAX_VALUE - 1, vs[4])
assertEquals(Int.MAX_VALUE, vs[5])
}
@Test
fun minusAssign_integer_underflow() {
// The set of all integers should stay the same after subtracting any integer.
for (i in Int.MIN_VALUE..Int.MAX_VALUE step 10_000_000) {
val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MAX_VALUE)
vs -= i
assertEquals(1L shl 32, vs.size)
assertEquals(Int.MIN_VALUE, vs.minOrNull())
assertEquals(Int.MAX_VALUE, vs.maxOrNull())
}
// Cause two intervals to split into three intervals.
val vs = ValueSet.ofInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2)
vs.union(ValueSet.ofInterval(5, 7))
vs -= 1
assertEquals(6L, vs.size)
assertEquals(Int.MIN_VALUE, vs[0])
assertEquals(Int.MIN_VALUE + 1, vs[1])
assertEquals(4, vs[2])
assertEquals(5, vs[3])
assertEquals(6, vs[4])
assertEquals(Int.MAX_VALUE, vs[5])
// Cause part of one interval to be joined to another.
vs.setInterval(Int.MIN_VALUE, Int.MIN_VALUE + 2)
vs.union(ValueSet.ofInterval(Int.MAX_VALUE - 2, Int.MAX_VALUE))
vs -= 1
assertEquals(6L, vs.size)
assertEquals(Int.MIN_VALUE, vs[0])
assertEquals(Int.MIN_VALUE + 1, vs[1])
assertEquals(Int.MAX_VALUE - 3, vs[2])
assertEquals(Int.MAX_VALUE - 2, vs[3])
assertEquals(Int.MAX_VALUE - 1, vs[4])
assertEquals(Int.MAX_VALUE, vs[5])
}
@Test
fun union() {
val vs = ValueSet.empty()
.union(ValueSet.of(21))
.union(ValueSet.of(4968))
assertEquals(2L, vs.size)
assertEquals(21, vs[0]) assertEquals(21, vs[0])
assertEquals(4968, vs[1]) assertEquals(4968, vs[1])
} }
@Test @Test
fun union_of_intervals() { fun union_of_intervals() {
val vs = ValueSet() val vs = ValueSet.empty()
.union(ValueSet().setInterval(10, 12)) .union(ValueSet.ofInterval(10, 12))
.union(ValueSet().setInterval(14, 16)) .union(ValueSet.ofInterval(14, 16))
assertEquals(6, vs.size) assertEquals(6L, vs.size)
assertTrue(arrayOf(10, 11, 12, 14, 15, 16).all { it in vs }) assertTrue(arrayOf(10, 11, 12, 14, 15, 16).all { it in vs })
vs.union(ValueSet().setInterval(13, 13)) vs.union(ValueSet.ofInterval(13, 13))
assertEquals(7, vs.size) assertEquals(7L, vs.size)
assertEquals(10, vs[0]) assertEquals(10, vs[0])
assertEquals(11, vs[1]) assertEquals(11, vs[1])
assertEquals(12, vs[2]) assertEquals(12, vs[2])
@ -88,27 +166,27 @@ class ValueSetTests {
assertEquals(15, vs[5]) assertEquals(15, vs[5])
assertEquals(16, vs[6]) assertEquals(16, vs[6])
vs.union(ValueSet().setInterval(1, 2)) vs.union(ValueSet.ofInterval(1, 2))
assertEquals(9, vs.size) assertEquals(9L, vs.size)
assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16).all { it in vs }) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16).all { it in vs })
vs.union(ValueSet().setInterval(30, 32)) vs.union(ValueSet.ofInterval(30, 32))
assertEquals(12, vs.size) assertEquals(12L, vs.size)
assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 30, 31, 32).all { it in vs }) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 30, 31, 32).all { it in vs })
vs.union(ValueSet().setInterval(20, 21)) vs.union(ValueSet.ofInterval(20, 21))
assertEquals(14, vs.size) assertEquals(14L, vs.size)
assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 20, 21, 30, 31, 32).all { it in vs }) assertTrue(arrayOf(1, 2, 10, 11, 12, 13, 14, 15, 16, 20, 21, 30, 31, 32).all { it in vs })
} }
@Test @Test
fun iterator() { fun iterator() {
val vs = ValueSet() val vs = ValueSet.empty()
.union(ValueSet().setInterval(5, 7)) .union(ValueSet.ofInterval(5, 7))
.union(ValueSet().setInterval(14, 16)) .union(ValueSet.ofInterval(14, 16))
val iter = vs.iterator() val iter = vs.iterator()

View File

@ -0,0 +1,15 @@
package world.phantasmal.lib.test
import world.phantasmal.core.Success
import world.phantasmal.lib.assembly.InstructionSegment
import world.phantasmal.lib.assembly.assemble
import kotlin.test.assertTrue
fun toInstructions(assembly: String): List<InstructionSegment> {
val result = assemble(assembly.split('\n'))
assertTrue(result is Success)
assertTrue(result.problems.isEmpty())
return result.value.filterIsInstance<InstructionSegment>()
}