Added semantic highlighting for labels to the ASM editor.

This commit is contained in:
Daan Vanden Bosch 2021-04-26 15:38:34 +02:00
parent 797c5a298e
commit c093cb813e
9 changed files with 337 additions and 116 deletions

View File

@ -551,7 +551,9 @@ private class Assembler(private val asm: List<String>, private val inlineStackAr
addError(col, len, "Expected ${typeStr}.")
}
} else if (stack) {
}
if (stack) {
// Inject stack push instructions if necessary.
checkNotNull(paramType)

View File

@ -167,7 +167,7 @@ private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: Instruction
// Unconditional jump.
OP_JMP.code -> {
branchType = BranchType.Jump
branchLabels = listOf((inst.args[0] as IntArg).value)
branchLabels = listOfNotNull((inst.args[0] as? IntArg)?.value)
}
// Conditional jumps.
@ -175,7 +175,7 @@ private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: Instruction
OP_JMP_OFF.code,
-> {
branchType = BranchType.ConditionalJump
branchLabels = listOf((inst.args[0] as IntArg).value)
branchLabels = listOfNotNull((inst.args[0] as? IntArg)?.value)
}
OP_JMP_E.code,
OP_JMPI_E.code,
@ -199,11 +199,11 @@ private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: Instruction
OP_JMPI_LE.code,
-> {
branchType = BranchType.ConditionalJump
branchLabels = listOf((inst.args[2] as IntArg).value)
branchLabels = listOfNotNull((inst.args[2] as? IntArg)?.value)
}
OP_SWITCH_JMP.code -> {
branchType = BranchType.ConditionalJump
branchLabels = inst.args.drop(1).map { (it as IntArg).value }
branchLabels = inst.args.drop(1).mapNotNull { (it as? IntArg)?.value }
}
// Calls.
@ -211,11 +211,11 @@ private fun createBasicBlocks(cfg: ControlFlowGraphBuilder, segment: Instruction
OP_VA_CALL.code,
-> {
branchType = BranchType.Call
branchLabels = listOf((inst.args[0] as IntArg).value)
branchLabels = listOfNotNull((inst.args[0] as? IntArg)?.value)
}
OP_SWITCH_CALL.code -> {
branchType = BranchType.Call
branchLabels = inst.args.drop(1).map { (it as IntArg).value }
branchLabels = inst.args.drop(1).mapNotNull { (it as? IntArg)?.value }
}
// All other opcodes.
@ -255,8 +255,7 @@ private fun linkBlocks(cfg: ControlFlowGraphBuilder) {
// Pairs of calling block and block to which callees should return to.
val callers = mutableListOf<Pair<BasicBlockImpl, BasicBlockImpl>>()
for (i in cfg.blocks.indices) {
val block = cfg.blocks[i]
for ((i, block) in cfg.blocks.withIndex()) {
val nextBlock = cfg.blocks.getOrNull(i + 1)
when (block.branchType) {
@ -264,7 +263,7 @@ private fun linkBlocks(cfg: ControlFlowGraphBuilder) {
continue
BranchType.Call ->
nextBlock?.let { callers.add(block to nextBlock); }
nextBlock?.let { callers.add(block to nextBlock) }
BranchType.None,
BranchType.ConditionalJump,

View File

@ -28,6 +28,7 @@ fun getRegisterValue(cfg: ControlFlowGraph, instruction: Instruction, register:
private class RegisterValueFinder {
private var iterations = 0
// TODO: Deal with incorrect argument types.
fun find(
path: MutableSet<BasicBlock>,
block: BasicBlock,
@ -267,7 +268,8 @@ private class RegisterValueFinder {
OP_ARG_PUSHW.code,
OP_ARG_PUSHA.code,
OP_ARG_PUSHO.code,
OP_ARG_PUSHS.code -> stack.add(instruction)
OP_ARG_PUSHS.code,
-> stack.add(instruction)
}
}
}
@ -282,7 +284,8 @@ private class RegisterValueFinder {
OP_ARG_PUSHL.code,
OP_ARG_PUSHB.code,
OP_ARG_PUSHW.code -> ValueSet.of((arg as IntArg).value)
OP_ARG_PUSHW.code,
-> ValueSet.of((arg as IntArg).value)
// TODO: Deal with strings.
else -> ValueSet.all() // String or pointer

View File

@ -7,9 +7,14 @@ private val logger = KotlinLogging.logger {}
/**
* Computes the possible values of a stack element at the nth position from the top, right before a
* specific instruction.
* specific instruction. If the stack element's value can be traced back to a single push
* instruction, that instruction is also returned.
*/
fun getStackValue(cfg: ControlFlowGraph, instruction: Instruction, position: Int): ValueSet {
fun getStackValue(
cfg: ControlFlowGraph,
instruction: Instruction,
position: Int,
): Pair<ValueSet, Instruction?> {
val block = cfg.getBlockForInstruction(instruction)
return StackValueFinder().find(
@ -30,10 +35,10 @@ private class StackValueFinder {
block: BasicBlock,
end: Int,
position: Int,
): ValueSet {
): Pair<ValueSet, Instruction?> {
if (++iterations > 100) {
logger.warn { "Too many iterations." }
return ValueSet.all()
return Pair(ValueSet.all(), null)
}
var pos = position
@ -51,7 +56,13 @@ private class StackValueFinder {
when (instruction.opcode.code) {
OP_ARG_PUSHR.code -> {
if (pos == 0) {
return getRegisterValue(cfg, instruction, (args[0] as IntArg).value)
val arg = args[0]
return if (arg is IntArg) {
Pair(getRegisterValue(cfg, instruction, arg.value), instruction)
} else {
Pair(ValueSet.all(), instruction)
}
} else {
pos--
}
@ -62,7 +73,13 @@ private class StackValueFinder {
OP_ARG_PUSHW.code,
-> {
if (pos == 0) {
return ValueSet.of((args[0] as IntArg).value)
val arg = args[0]
return if (arg is IntArg) {
Pair(ValueSet.of(arg.value), instruction)
} else {
Pair(ValueSet.all(), instruction)
}
} else {
pos--
}
@ -73,7 +90,7 @@ private class StackValueFinder {
OP_ARG_PUSHS.code,
-> {
if (pos == 0) {
return ValueSet.all()
return Pair(ValueSet.all(), instruction)
} else {
pos--
}
@ -82,17 +99,29 @@ private class StackValueFinder {
}
val values = ValueSet.empty()
var instruction: Instruction? = null
var multipleInstructions = false
path.add(block)
for (from in block.from) {
// Bail out from loops.
if (from in path) {
return ValueSet.all()
return Pair(ValueSet.all(), null)
}
values.union(find(LinkedHashSet(path), cfg, from, from.end, pos))
val (fromValues, fromInstruction) = find(LinkedHashSet(path), cfg, from, from.end, pos)
values.union(fromValues)
if (!multipleInstructions) {
if (instruction == null) {
instruction = fromInstruction
} else if (instruction != fromInstruction) {
instruction = null
multipleInstructions = true
}
}
}
return values
return Pair(values, instruction)
}
}

View File

@ -331,7 +331,7 @@ private fun getArgLabelValues(
cfg,
instruction,
instruction.opcode.params.size - paramIdx - 1,
)
).first
if (stackValues.size <= 20) {
for (value in stackValues) {

View File

@ -3,6 +3,7 @@ package world.phantasmal.web.assemblyWorker
import world.phantasmal.core.*
import world.phantasmal.lib.asm.*
import world.phantasmal.lib.asm.dataFlowAnalysis.ControlFlowGraph
import world.phantasmal.lib.asm.dataFlowAnalysis.ValueSet
import world.phantasmal.lib.asm.dataFlowAnalysis.getMapDesignations
import world.phantasmal.lib.asm.dataFlowAnalysis.getStackValue
import world.phantasmal.web.shared.messages.*
@ -198,9 +199,9 @@ class AsmAnalyser {
var signature: Signature? = null
var activeParam = -1
getInstructionForSrcLoc(lineNo, col)?.let { (inst, paramIdx) ->
signature = getSignature(inst.opcode)
activeParam = paramIdx
getInstructionForSrcLoc(lineNo, col)?.let { result ->
signature = getSignature(result.inst.opcode)
activeParam = result.paramIdx
}
return signature?.let { sig ->
@ -263,56 +264,28 @@ class AsmAnalyser {
var result = emptyList<AsmRange>()
getInstructionForSrcLoc(lineNo, col)?.inst?.let { inst ->
loop@
for ((paramIdx, param) in inst.opcode.params.withIndex()) {
if (param.type is LabelType) {
if (inst.opcode.stack != StackInteraction.Pop) {
// Immediate arguments.
val args = inst.getArgs(paramIdx)
val argSrcLocs = inst.getArgSrcLocs(paramIdx)
for (i in 0 until min(args.size, argSrcLocs.size)) {
val arg = args[i]
val srcLoc = argSrcLocs[i].coarse
if (positionInside(lineNo, col, srcLoc)) {
val label = (arg as IntArg).value
result = getLabelDefinitions(label)
break@loop
}
}
} else {
// Stack arguments.
val argSrcLocs = inst.getArgSrcLocs(paramIdx)
for ((i, argSrcLoc) in argSrcLocs.withIndex()) {
if (positionInside(lineNo, col, argSrcLoc.coarse)) {
val labelValues = getStackValue(cfg, inst, argSrcLocs.lastIndex - i)
if (labelValues.size <= 5) {
result = labelValues.flatMap(::getLabelDefinitions)
}
break@loop
}
getLabelArguments(
inst,
doCheck = { argSrcLoc -> positionInside(lineNo, col, argSrcLoc.coarse) },
processImmediateArg = { label, _ ->
result = getLabelDefinitionsAndReferences(label, references = false)
false
},
processStackArg = { labels, _, _ ->
if (labels.size <= 5) {
result = labels.flatMap {
getLabelDefinitionsAndReferences(it, references = false)
}
}
}
}
false
},
)
}
return Response.GetDefinition(requestId, result)
}
private fun getLabelDefinitions(label: Int): List<AsmRange> =
bytecodeIr.segments.asSequence()
.filter { label in it.labels }
.mapNotNull { segment ->
val labelIdx = segment.labels.indexOf(label)
segment.srcLoc.labels.getOrNull(labelIdx)?.toAsmRange()
}
.toList()
fun getLabels(requestId: Int): Response.GetLabels {
val result = bytecodeIr.segments.asSequence()
.flatMap { segment ->
@ -327,9 +300,13 @@ class AsmAnalyser {
}
fun getHighlights(requestId: Int, lineNo: Int, col: Int): Response.GetHighlights {
val result = mutableListOf<AsmRange>()
val results = mutableListOf<AsmRange>()
when (val ir = getIrForSrcLoc(lineNo, col)) {
is Ir.Label -> {
results.addAll(getLabelDefinitionsAndReferences(ir.label))
}
is Ir.Inst -> {
val srcLoc = ir.inst.srcLoc?.mnemonic
@ -338,20 +315,42 @@ class AsmAnalyser {
// first whitespace character preceding the first argument.
(srcLoc != null && col <= srcLoc.col + srcLoc.len)
) {
// Find all instructions with the same opcode.
for (segment in bytecodeIr.segments) {
if (segment is InstructionSegment) {
for (inst in segment.instructions) {
if (inst.opcode.code == ir.inst.opcode.code) {
inst.srcLoc?.mnemonic?.toAsmRange()?.let(result::add)
inst.srcLoc?.mnemonic?.toAsmRange()?.let(results::add)
}
}
}
}
} else {
getLabelArguments(
ir.inst,
doCheck = { argSrcLoc -> positionInside(lineNo, col, argSrcLoc.coarse) },
processImmediateArg = { label, _ ->
results.addAll(getLabelDefinitionsAndReferences(label))
false
},
processStackArg = { labels, pushInst, _ ->
// Filter out arg_pushr labels, because register values could be
// used for anything.
if (pushInst != null &&
pushInst.opcode.code != OP_ARG_PUSHR.code &&
labels.size == 1L
) {
results.addAll(getLabelDefinitionsAndReferences(labels[0]!!))
}
false
},
)
}
}
}
return Response.GetHighlights(requestId, result)
return Response.GetHighlights(requestId, results)
}
private fun getInstructionForSrcLoc(lineNo: Int, col: Int): Ir.Inst? =
@ -359,6 +358,15 @@ class AsmAnalyser {
private fun getIrForSrcLoc(lineNo: Int, col: Int): Ir? {
for (segment in bytecodeIr.segments) {
for ((index, srcLoc) in segment.srcLoc.labels.withIndex()) {
if (srcLoc.lineNo == lineNo &&
col >= srcLoc.col &&
col < srcLoc.col + srcLoc.len
) {
return Ir.Label(segment.labels[index])
}
}
if (segment is InstructionSegment) {
// Loop over instructions in reverse order so stack popping instructions will be
// handled before the related stack pushing instructions when inlineStackArgs is on.
@ -403,6 +411,120 @@ class AsmAnalyser {
return null
}
/**
* Returns all labels arguments of [instruction] with their value.
*/
private fun getLabelArguments(
instruction: Instruction,
doCheck: (ArgSrcLoc) -> Boolean,
processImmediateArg: (label: Int, ArgSrcLoc) -> Boolean,
processStackArg: (label: ValueSet, Instruction?, ArgSrcLoc) -> Boolean,
) {
loop@
for ((paramIdx, param) in instruction.opcode.params.withIndex()) {
if (param.type is LabelType) {
if (instruction.opcode.stack != StackInteraction.Pop) {
// Immediate arguments.
val args = instruction.getArgs(paramIdx)
val argSrcLocs = instruction.getArgSrcLocs(paramIdx)
for (i in 0 until min(args.size, argSrcLocs.size)) {
val arg = args[i]
val srcLoc = argSrcLocs[i]
if (doCheck(srcLoc)) {
val label = (arg as IntArg).value
if (!processImmediateArg(label, srcLoc)) {
break@loop
}
}
}
} else {
// Stack arguments.
val argSrcLocs = instruction.getArgSrcLocs(paramIdx)
for ((i, srcLoc) in argSrcLocs.withIndex()) {
if (doCheck(srcLoc)) {
val (labelValues, pushInstruction) =
getStackValue(cfg, instruction, argSrcLocs.lastIndex - i)
if (!processStackArg(labelValues, pushInstruction, srcLoc)) {
break@loop
}
}
}
}
}
}
}
/**
* Returns all definitions and all arguments that references the given [label].
*/
private fun getLabelDefinitionsAndReferences(
label: Int,
definitions: Boolean = true,
references: Boolean = true,
): List<AsmRange> {
val results = mutableListOf<AsmRange>()
for (segment in bytecodeIr.segments) {
// Add label definitions to the results.
if (definitions) {
val labelIdx = segment.labels.indexOf(label)
if (labelIdx != -1) {
segment.srcLoc.labels.getOrNull(labelIdx)?.let { srcLoc ->
results.add(
AsmRange(
startLineNo = srcLoc.lineNo,
startCol = srcLoc.col,
endLineNo = srcLoc.lineNo,
// Exclude the trailing ":" character.
endCol = srcLoc.col + srcLoc.len - 1,
)
)
}
}
}
// Find all instruction arguments that reference the label.
if (references) {
if (segment is InstructionSegment) {
for (inst in segment.instructions) {
getLabelArguments(
inst,
doCheck = { true },
processImmediateArg = { labelArg, argSrcLoc ->
if (labelArg == label) {
results.add(argSrcLoc.precise.toAsmRange())
}
true
},
processStackArg = { labelArg, pushInst, argSrcLoc ->
// Filter out arg_pushr labels, because register values could be
// used for anything.
if (pushInst != null &&
pushInst.opcode.code != OP_ARG_PUSHR.code &&
labelArg.size == 1L &&
label in labelArg
) {
results.add(argSrcLoc.precise.toAsmRange())
}
true
},
)
}
}
}
}
return results
}
private fun positionInside(lineNo: Int, col: Int, srcLoc: SrcLoc?): Boolean =
if (srcLoc == null) {
false
@ -422,7 +544,8 @@ class AsmAnalyser {
)
private sealed class Ir {
data class Inst(val inst: Instruction, val paramIdx: Int) : Ir()
class Label(val label: Int) : Ir()
class Inst(val inst: Instruction, val paramIdx: Int) : Ir()
}
companion object {

View File

@ -2,10 +2,7 @@ package world.phantasmal.web.assemblyWorker
import mu.KotlinLogging
import world.phantasmal.web.shared.Throttle
import world.phantasmal.web.shared.messages.ClientMessage
import world.phantasmal.web.shared.messages.ClientNotification
import world.phantasmal.web.shared.messages.Request
import world.phantasmal.web.shared.messages.ServerMessage
import world.phantasmal.web.shared.messages.*
import kotlin.time.measureTime
class AsmServer(
@ -21,53 +18,65 @@ class AsmServer(
}
private fun processMessages() {
// Split messages into ASM changes and other messages. Remove useless/duplicate
// notifications.
val asmChanges = mutableListOf<ClientNotification>()
val otherMessages = mutableListOf<ClientMessage>()
try {
// Split messages into ASM changes and other messages. Remove useless/duplicate
// notifications.
val asmChanges = mutableListOf<ClientNotification>()
val otherMessages = mutableListOf<ClientMessage>()
for (message in messageQueue) {
when (message) {
is ClientNotification.SetAsm -> {
// All previous ASM change messages can be discarded when the entire ASM has
// changed.
asmChanges.clear()
asmChanges.add(message)
for (message in messageQueue) {
when (message) {
is ClientNotification.SetAsm -> {
// All previous ASM change messages can be discarded when the entire ASM has
// changed.
asmChanges.clear()
asmChanges.add(message)
}
is ClientNotification.UpdateAsm ->
asmChanges.add(message)
else ->
otherMessages.add(message)
}
is ClientNotification.UpdateAsm ->
asmChanges.add(message)
else ->
otherMessages.add(message)
}
messageQueue.clear()
// Process ASM changes first.
processAsmChanges(asmChanges)
otherMessages.forEach(::processMessage)
} catch (e: Throwable) {
logger.error(e) { "Exception while processing messages." }
messageQueue.clear()
}
messageQueue.clear()
// Process ASM changes first.
processAsmChanges(asmChanges)
otherMessages.forEach(::processMessage)
}
private fun processAsmChanges(messages: List<ClientNotification>) {
if (messages.isNotEmpty()) {
val time = measureTime {
for (message in messages) {
when (message) {
is ClientNotification.SetAsm ->
asmAnalyser.setAsm(message.asm, message.inlineStackArgs)
val responses = try {
for (message in messages) {
when (message) {
is ClientNotification.SetAsm ->
asmAnalyser.setAsm(message.asm, message.inlineStackArgs)
is ClientNotification.UpdateAsm ->
asmAnalyser.updateAsm(message.changes)
is ClientNotification.UpdateAsm ->
asmAnalyser.updateAsm(message.changes)
else ->
// Should be processed by processMessage.
logger.error { "Unexpected ${message::class.simpleName}." }
else ->
// Should be processed by processMessage.
logger.error { "Unexpected ${message::class.simpleName}." }
}
}
asmAnalyser.processAsm()
} catch (e: Throwable) {
logger.error(e) { "Exception while processing ASM changes." }
emptyList<Response<*>>()
}
asmAnalyser.processAsm().forEach(sendMessage)
responses.forEach(sendMessage)
}
logger.trace {
@ -78,13 +87,18 @@ class AsmServer(
private fun processMessage(message: ClientMessage) {
val time = measureTime {
when (message) {
is ClientNotification.SetAsm,
is ClientNotification.UpdateAsm ->
// Should have been processed by processAsmChanges.
logger.error { "Unexpected ${message::class.simpleName}." }
try {
when (message) {
is ClientNotification.SetAsm,
is ClientNotification.UpdateAsm,
->
// Should have been processed by processAsmChanges.
logger.error { "Unexpected ${message::class.simpleName}." }
is Request -> processRequest(message)
is Request -> processRequest(message)
}
} catch (e: Throwable) {
logger.error(e) { "Exception while processing ${message::class.simpleName}." }
}
}

View File

@ -193,6 +193,45 @@ class AsmAnalyserTests : AssemblyWorkerTestSuite() {
assertEquals(AsmRange(5, 5, 5, 9), response.result[2])
}
@Test
fun getHighlights_for_label() = test {
// The following ASM contains, in the given order:
// - Label 100
// - Reference to label 100 by arg_pushl instruction
// - Reference to label 100 by arg_pushr instruction (should not be highlighted, the value
// comes from a register and might be used for many things)
// - Immediate argument referencing label 100
val analyser = createAsmAnalyser(
"""
.code
100:
set_floor_handler 0, 100
leti r0, 100
set_floor_handler 0, r0
jmp 100
""".trimIndent()
)
val requestId = 2999
for ((lineNo, col) in listOf(
// Cursor is at center of label 100.
Pair(2, 2),
// Cursor is at center of the first set_floor_handler label argument.
Pair(3, 27),
// Cursor is at center of the jmp label argument.
Pair(6, 10),
)) {
val response = analyser.getHighlights(requestId, lineNo, col)
assertEquals(requestId, response.id)
assertEquals(3, response.result.size)
assertEquals(AsmRange(2, 1, 2, 4), response.result[0])
assertEquals(AsmRange(3, 26, 3, 29), response.result[1])
assertEquals(AsmRange(6, 9, 6, 12), response.result[2])
}
}
private fun createAsmAnalyser(asm: String): AsmAnalyser {
val analyser = AsmAnalyser()
analyser.setAsm(asm.split("\n"), inlineStackArgs = true)

View File

@ -120,9 +120,21 @@ sealed class Response<T> : ServerMessage() {
@Serializable
data class AsmRange(
/**
* Starting line of the range, inclusive.
*/
val startLineNo: Int,
/**
* Starting column of the range, inclusive.
*/
val startCol: Int,
/**
* Ending line of the range, exclusive.
*/
val endLineNo: Int,
/**
* Ending column of the range, exclusive.
*/
val endCol: Int,
)