diff --git a/web/assembly-worker/src/main/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyser.kt b/web/assembly-worker/src/main/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyser.kt index 6136d9e7..f1205202 100644 --- a/web/assembly-worker/src/main/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyser.kt +++ b/web/assembly-worker/src/main/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyser.kt @@ -326,25 +326,56 @@ class AsmAnalyser { } } } else { - visitLabelArguments( + visitArgs( ir.inst, - accept = { argSrcLoc -> positionInside(lineNo, col, argSrcLoc.coarse) }, - processImmediateArg = { label, _ -> - results.addAll(getLabelDefinitionsAndReferences(label)) - VisitAction.Return + processParam = { VisitAction.Go }, + processImmediateArg = { param, arg, argSrcLoc -> + if (positionInside(lineNo, col, argSrcLoc.coarse)) { + (arg as? IntArg)?.let { + when (param.type) { + is LabelType -> { + results.addAll( + getLabelDefinitionsAndReferences(arg.value) + ) + } + is RegRefType -> { + results.addAll(getRegisterReferences(arg.value)) + } + else -> Unit + } + } + + VisitAction.Return + } else { + VisitAction.Continue + } }, - processStackArg = { labels, pushInst, _ -> - // Filter out arg_pushr labels, because register values could be - // used for anything. - if (pushInst != null && - pushInst.opcode.code != OP_ARG_PUSHR.code && - labels.size == 1L - ) { - results.addAll(getLabelDefinitionsAndReferences(labels[0]!!)) + processStackArgSrcLoc = { _, argSrcLoc -> + if (positionInside(lineNo, col, argSrcLoc.coarse)) { + VisitAction.Go + } else { + VisitAction.Continue + } + }, + processStackArg = { param, _, pushInst, _ -> + if (pushInst != null) { + val pushArg = pushInst.args.firstOrNull() + + if (pushArg is IntArg) { + if (pushInst.opcode.code == OP_ARG_PUSHR.code || + param.type is RegRefType + ) { + results.addAll(getRegisterReferences(pushArg.value)) + } else if (param.type is LabelType) { + results.addAll( + getLabelDefinitionsAndReferences(pushArg.value) + ) + } + } } VisitAction.Return - }, + } ) } } @@ -411,95 +442,48 @@ class AsmAnalyser { return null } - /** - * Visits all label arguments of [instruction] with their value. - */ - private fun visitLabelArguments( - instruction: Instruction, - accept: (ArgSrcLoc) -> Boolean, - processImmediateArg: (label: Int, ArgSrcLoc) -> VisitAction, - processStackArg: (label: ValueSet, Instruction?, ArgSrcLoc) -> VisitAction, - ) { - visitArgs( - instruction, - processParam = { if (it.type is LabelType) VisitAction.Go else VisitAction.Continue }, - processImmediateArg = { arg, srcLoc -> - if (accept(srcLoc)) { - processImmediateArg((arg as IntArg).value, srcLoc) - } else VisitAction.Continue - }, - processStackArgSrcLoc = { srcLoc -> - if (accept(srcLoc)) VisitAction.Go - else VisitAction.Continue - }, - processStackArg = { value, pushInst, srcLoc -> - processStackArg(value, pushInst, srcLoc) - } - ) - } + private fun getRegisterReferences(register: Int): List { + val results = mutableListOf() - private enum class VisitAction { - Go, Break, Continue, Return - } + for (segment in bytecodeIr.segments) { + if (segment is InstructionSegment) { + for (inst in segment.instructions) { + visitArgs( + inst, + processParam = { VisitAction.Go }, + processImmediateArg = { param, arg, argSrcLoc -> + if (param.type is RegRefType && + arg is IntArg && + arg.value == register + ) { + results.add(argSrcLoc.precise.toAsmRange()) + } - /** - * Visits all arguments of [instruction], including stack arguments. - */ - private fun visitArgs( - instruction: Instruction, - processParam: (Param) -> VisitAction, - processImmediateArg: (Arg, ArgSrcLoc) -> VisitAction, - processStackArgSrcLoc: (ArgSrcLoc) -> VisitAction, - processStackArg: (ValueSet, Instruction?, ArgSrcLoc) -> VisitAction, - ) { - for ((paramIdx, param) in instruction.opcode.params.withIndex()) { - when (processParam(param)) { - VisitAction.Go -> Unit // Keep going. - VisitAction.Break -> break // Same as Stop. - VisitAction.Continue -> continue - VisitAction.Return -> return - } + VisitAction.Go + }, + processStackArgSrcLoc = { param, _ -> + if (param.type is RegRefType) VisitAction.Go + else VisitAction.Continue + }, + processStackArg = { _, _, pushInst, argSrcLoc -> + if (pushInst != null && + pushInst.opcode.code != OP_ARG_PUSHR.code + ) { + val pushArg = pushInst.args.firstOrNull() - if (instruction.opcode.stack !== StackInteraction.Pop) { - // Immediate arguments. - val args = instruction.getArgs(paramIdx) - val argSrcLocs = instruction.getArgSrcLocs(paramIdx) + if (pushArg is IntArg && pushArg.value == register) { + results.add(argSrcLoc.precise.toAsmRange()) + } + } - for (i in 0 until min(args.size, argSrcLocs.size)) { - val arg = args[i] - val srcLoc = argSrcLocs[i] - - when (processImmediateArg(arg, srcLoc)) { - VisitAction.Go -> Unit // Keep going. - VisitAction.Break -> break - VisitAction.Continue -> continue // Same as Down. - VisitAction.Return -> return - } - } - } else { - // Stack arguments. - val argSrcLocs = instruction.getArgSrcLocs(paramIdx) - - for ((i, srcLoc) in argSrcLocs.withIndex()) { - when (processStackArgSrcLoc(srcLoc)) { - VisitAction.Go -> Unit // Keep going. - VisitAction.Break -> break - VisitAction.Continue -> continue - VisitAction.Return -> return - } - - val (labelValues, pushInstruction) = - getStackValue(cfg, instruction, argSrcLocs.lastIndex - i) - - when (processStackArg(labelValues, pushInstruction, srcLoc)) { - VisitAction.Go -> Unit // Keep going. - VisitAction.Break -> break - VisitAction.Continue -> continue // Same as Down. - VisitAction.Return -> return - } + VisitAction.Go + } + ) } } } + + return results } /** @@ -568,6 +552,101 @@ class AsmAnalyser { return results } + /** + * Visits all label arguments of [instruction] with their value. + */ + private fun visitLabelArguments( + instruction: Instruction, + accept: (ArgSrcLoc) -> Boolean, + processImmediateArg: (label: Int, ArgSrcLoc) -> VisitAction, + processStackArg: (label: ValueSet, Instruction?, ArgSrcLoc) -> VisitAction, + ) { + visitArgs( + instruction, + processParam = { if (it.type is LabelType) VisitAction.Go else VisitAction.Continue }, + processImmediateArg = { _, arg, srcLoc -> + if (accept(srcLoc) && arg is IntArg) { + processImmediateArg(arg.value, srcLoc) + } else VisitAction.Continue + }, + processStackArgSrcLoc = { _, srcLoc -> + if (accept(srcLoc)) VisitAction.Go + else VisitAction.Continue + }, + processStackArg = { _, value, pushInst, srcLoc -> + processStackArg(value, pushInst, srcLoc) + } + ) + } + + private enum class VisitAction { + Go, Break, Continue, Return + } + + /** + * Visits all arguments of [instruction], including stack arguments. + */ + private fun visitArgs( + instruction: Instruction, + processParam: (Param) -> VisitAction, + processImmediateArg: (Param, Arg, ArgSrcLoc) -> VisitAction, + processStackArgSrcLoc: (Param, ArgSrcLoc) -> VisitAction, + processStackArg: (Param, ValueSet, Instruction?, ArgSrcLoc) -> VisitAction, + ) { + for ((paramIdx, param) in instruction.opcode.params.withIndex()) { + when (processParam(param)) { + VisitAction.Go -> Unit // Keep going. + VisitAction.Break -> break // Same as Stop. + VisitAction.Continue -> continue + VisitAction.Return -> return + } + + if (instruction.opcode.stack !== StackInteraction.Pop) { + // Immediate arguments. + val args = instruction.getArgs(paramIdx) + val argSrcLocs = instruction.getArgSrcLocs(paramIdx) + + for (i in 0 until min(args.size, argSrcLocs.size)) { + val arg = args[i] + val srcLoc = argSrcLocs[i] + + when (processImmediateArg(param, arg, srcLoc)) { + VisitAction.Go -> Unit // Keep going. + VisitAction.Break -> break + VisitAction.Continue -> continue // Same as Down. + VisitAction.Return -> return + } + } + } else { + // Stack arguments. + val argSrcLocs = instruction.getArgSrcLocs(paramIdx) + + // Never varargs. + for (srcLoc in argSrcLocs) { + when (processStackArgSrcLoc(param, srcLoc)) { + VisitAction.Go -> Unit // Keep going. + VisitAction.Break -> break + VisitAction.Continue -> continue + VisitAction.Return -> return + } + + val (labelValues, pushInstruction) = getStackValue( + cfg, + instruction, + instruction.opcode.params.lastIndex - paramIdx, + ) + + when (processStackArg(param, labelValues, pushInstruction, srcLoc)) { + VisitAction.Go -> Unit // Keep going. + VisitAction.Break -> break + VisitAction.Continue -> continue // Same as Down. + VisitAction.Return -> return + } + } + } + } + } + private fun positionInside(lineNo: Int, col: Int, srcLoc: SrcLoc?): Boolean = if (srcLoc == null) { false diff --git a/web/assembly-worker/src/test/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyserTests.kt b/web/assembly-worker/src/test/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyserTests.kt index b291e7ed..969f306c 100644 --- a/web/assembly-worker/src/test/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyserTests.kt +++ b/web/assembly-worker/src/test/kotlin/world/phantasmal/web/assemblyWorker/AsmAnalyserTests.kt @@ -232,6 +232,36 @@ class AsmAnalyserTests : AssemblyWorkerTestSuite() { } } + @Test + fun getHighlights_for_register() = test { + val analyser = createAsmAnalyser( + """ + .code + 100: + leti r13, 4031 + set_floor_handler 0, r13 + leti r17, 5379 + npc_param_v3 r13, 4 + """.trimIndent() + ) + + val requestId = 2999 + + for ((lineNo, col) in listOf( + Pair(3, 11), + Pair(4, 27), + Pair(6, 19), + )) { + val response = analyser.getHighlights(requestId, lineNo, col) + + assertEquals(requestId, response.id) + assertEquals(3, response.result.size) + assertEquals(AsmRange(3, 10, 3, 13), response.result[0]) + assertEquals(AsmRange(4, 26, 4, 29), response.result[1]) + assertEquals(AsmRange(6, 18, 6, 21), response.result[2]) + } + } + private fun createAsmAnalyser(asm: String): AsmAnalyser { val analyser = AsmAnalyser() analyser.setAsm(asm.split("\n"), inlineStackArgs = true)