Simplified proxy server code.

This commit is contained in:
Daan Vanden Bosch 2021-07-31 22:16:01 +02:00
parent 6daddcfd65
commit 089832c2fe
4 changed files with 39 additions and 65 deletions

View File

@ -62,13 +62,15 @@ class GuildCard(
val entries: List<GuildCardEntry> val entries: List<GuildCardEntry>
) )
sealed class BbMessage(override val buffer: Buffer) : Message(BB_HEADER_SIZE) { sealed class BbMessage(override val buffer: Buffer) : AbstractMessage(BB_HEADER_SIZE) {
override val code: Int get() = buffer.getUShort(BB_MSG_CODE_POS).toInt() override val code: Int get() = buffer.getUShort(BB_MSG_CODE_POS).toInt()
override val size: Int get() = buffer.getUShort(BB_MSG_SIZE_POS).toInt() override val size: Int get() = buffer.getUShort(BB_MSG_SIZE_POS).toInt()
class InitEncryption(buffer: Buffer) : BbMessage(buffer) { class InitEncryption(buffer: Buffer) : BbMessage(buffer), InitEncryptionMessage {
val serverKey: ByteArray get() = byteArray(INIT_MSG_SIZE, size = KEY_SIZE) override val serverKey: ByteArray
val clientKey: ByteArray get() = byteArray(INIT_MSG_SIZE + KEY_SIZE, size = KEY_SIZE) get() = byteArray(INIT_MSG_SIZE, size = KEY_SIZE)
override val clientKey: ByteArray
get() = byteArray(INIT_MSG_SIZE + KEY_SIZE, size = KEY_SIZE)
constructor(message: String, serverKey: ByteArray, clientKey: ByteArray) : this( constructor(message: String, serverKey: ByteArray, clientKey: ByteArray) : this(
buf(0x0003, INIT_MSG_SIZE + 2 * KEY_SIZE) { buf(0x0003, INIT_MSG_SIZE + 2 * KEY_SIZE) {
@ -87,14 +89,14 @@ sealed class BbMessage(override val buffer: Buffer) : Message(BB_HEADER_SIZE) {
constructor() : this(buf(0x0005)) constructor() : this(buf(0x0005))
} }
class Redirect(buffer: Buffer) : BbMessage(buffer) { class Redirect(buffer: Buffer) : BbMessage(buffer), RedirectMessage {
var ipAddress: ByteArray override var ipAddress: ByteArray
get() = byteArray(0, size = 4) get() = byteArray(0, size = 4)
set(value) { set(value) {
require(value.size == 4) require(value.size == 4)
setByteArray(0, value) setByteArray(0, value)
} }
var port: Int override var port: Int
get() = uShort(4).toInt() get() = uShort(4).toInt()
set(value) { set(value) {
require(value in 0..65535) require(value in 0..65535)

View File

@ -24,12 +24,25 @@ fun messageString(
data class Header(val code: Int, val size: Int) data class Header(val code: Int, val size: Int)
abstract class Message(val headerSize: Int) { interface Message {
abstract val buffer: Buffer val buffer: Buffer
abstract val code: Int val code: Int
abstract val size: Int val size: Int
val headerSize: Int
val bodySize: Int get() = size - headerSize val bodySize: Int get() = size - headerSize
}
interface InitEncryptionMessage : Message {
val serverKey: ByteArray
val clientKey: ByteArray
}
interface RedirectMessage : Message {
var ipAddress: ByteArray
var port: Int
}
abstract class AbstractMessage(override val headerSize: Int) : Message {
override fun toString(): String = messageString() override fun toString(): String = messageString()
protected fun uByte(offset: Int) = buffer.getUByte(headerSize + offset) protected fun uByte(offset: Int) = buffer.getUByte(headerSize + offset)
@ -58,5 +71,4 @@ abstract class Message(val headerSize: Int) {
protected fun messageString(vararg props: Pair<String, Any>): String = protected fun messageString(vararg props: Pair<String, Any>): String =
messageString(code, size, this::class.simpleName, *props) messageString(code, size, this::class.simpleName, *props)
} }

View File

@ -13,13 +13,15 @@ const val PC_HEADER_SIZE: Int = 4
const val PC_MSG_SIZE_POS: Int = 0 const val PC_MSG_SIZE_POS: Int = 0
const val PC_MSG_CODE_POS: Int = 2 const val PC_MSG_CODE_POS: Int = 2
sealed class PcMessage(override val buffer: Buffer) : Message(PC_HEADER_SIZE) { sealed class PcMessage(override val buffer: Buffer) : AbstractMessage(PC_HEADER_SIZE) {
override val code: Int get() = buffer.getUByte(PC_MSG_CODE_POS).toInt() override val code: Int get() = buffer.getUByte(PC_MSG_CODE_POS).toInt()
override val size: Int get() = buffer.getUShort(PC_MSG_SIZE_POS).toInt() override val size: Int get() = buffer.getUShort(PC_MSG_SIZE_POS).toInt()
class InitEncryption(buffer: Buffer) : PcMessage(buffer) { class InitEncryption(buffer: Buffer) : PcMessage(buffer), InitEncryptionMessage {
val serverKey: ByteArray get() = byteArray(INIT_MSG_SIZE, size = KEY_SIZE) override val serverKey: ByteArray
val clientKey: ByteArray get() = byteArray(INIT_MSG_SIZE + KEY_SIZE, size = KEY_SIZE) get() = byteArray(INIT_MSG_SIZE, size = KEY_SIZE)
override val clientKey: ByteArray
get() = byteArray(INIT_MSG_SIZE + KEY_SIZE, size = KEY_SIZE)
constructor(message: String, serverKey: ByteArray, clientKey: ByteArray) : this( constructor(message: String, serverKey: ByteArray, clientKey: ByteArray) : this(
buf(0x02, INIT_MSG_SIZE + 2 * KEY_SIZE) { buf(0x02, INIT_MSG_SIZE + 2 * KEY_SIZE) {
@ -60,14 +62,14 @@ sealed class PcMessage(override val buffer: Buffer) : Message(PC_HEADER_SIZE) {
) )
} }
class Redirect(buffer: Buffer) : PcMessage(buffer) { class Redirect(buffer: Buffer) : PcMessage(buffer), RedirectMessage {
var ipAddress: ByteArray override var ipAddress: ByteArray
get() = byteArray(0, size = 4) get() = byteArray(0, size = 4)
set(value) { set(value) {
require(value.size == 4) require(value.size == 4)
setByteArray(0, value) setByteArray(0, value)
} }
var port: Int override var port: Int
get() { get() {
buffer.endianness = Endianness.Big buffer.endianness = Endianness.Big
val p = uShort(4).toInt() val p = uShort(4).toInt()

View File

@ -4,10 +4,7 @@ import mu.KotlinLogging
import world.phantasmal.core.disposable.TrackedDisposable import world.phantasmal.core.disposable.TrackedDisposable
import world.phantasmal.psolib.buffer.Buffer import world.phantasmal.psolib.buffer.Buffer
import world.phantasmal.psoserv.encryption.Cipher import world.phantasmal.psoserv.encryption.Cipher
import world.phantasmal.psoserv.messages.BbMessage import world.phantasmal.psoserv.messages.*
import world.phantasmal.psoserv.messages.Header
import world.phantasmal.psoserv.messages.Message
import world.phantasmal.psoserv.messages.PcMessage
import java.net.* import java.net.*
class ProxyServer( class ProxyServer(
@ -119,7 +116,7 @@ class ProxyServer(
override fun processMessage(message: Message): ProcessResult { override fun processMessage(message: Message): ProcessResult {
when (message) { when (message) {
is PcMessage.InitEncryption -> if (decryptCipher == null) { is InitEncryptionMessage -> if (decryptCipher == null) {
decryptCipher = createCipher(message.serverKey) decryptCipher = createCipher(message.serverKey)
encryptCipher = createCipher(message.serverKey) encryptCipher = createCipher(message.serverKey)
@ -135,7 +132,7 @@ class ProxyServer(
clientSocket, clientSocket,
this, this,
clientDecryptCipher, clientDecryptCipher,
clientEncryptCipher clientEncryptCipher,
) )
this.clientHandler = clientListener this.clientHandler = clientListener
val thread = Thread(clientListener::listen) val thread = Thread(clientListener::listen)
@ -143,46 +140,7 @@ class ProxyServer(
thread.start() thread.start()
} }
is BbMessage.InitEncryption -> if (decryptCipher == null) { is RedirectMessage -> {
decryptCipher = createCipher(message.serverKey)
encryptCipher = createCipher(message.serverKey)
val clientDecryptCipher = createCipher(message.clientKey)
val clientEncryptCipher = createCipher(message.clientKey)
logger.info {
"Encryption initialized, start listening to client."
}
// Start listening to client on another thread.
val clientListener = ClientHandler(
clientSocket,
this,
clientDecryptCipher,
clientEncryptCipher
)
this.clientHandler = clientListener
val thread = Thread(clientListener::listen)
thread.name = "${ProxyServer::class.simpleName} client"
thread.start()
}
is PcMessage.Redirect -> {
val oldAddress = InetAddress.getByAddress(message.ipAddress)
redirectMap[Pair(oldAddress, message.port)]?.let { (newAddress, newPort) ->
logger.debug {
"Rewriting redirect from $oldAddress:${message.port} to $newAddress:$newPort."
}
message.ipAddress = newAddress.address
message.port = newPort
return ProcessResult.Changed
}
}
is BbMessage.Redirect -> {
val oldAddress = InetAddress.getByAddress(message.ipAddress) val oldAddress = InetAddress.getByAddress(message.ipAddress)
redirectMap[Pair(oldAddress, message.port)]?.let { (newAddress, newPort) -> redirectMap[Pair(oldAddress, message.port)]?.let { (newAddress, newPort) ->