netplay-lobby-server-go/domain/sessiondomain.go

333 lines
9.6 KiB
Go
Raw Permalink Normal View History

2024-12-24 22:09:17 +08:00
package domain
import (
"bytes"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/libretro/netplay-lobby-server-go/model/entity"
)
// SessionDeadline 是一个会话在没有收到任何更新的情况下的生命周期(以秒为单位)。
const SessionDeadline = 60
// RateLimit 是客户端可以发送更新的最大速率(每五秒)。
const RateLimit = 5
// requestType 枚举
type requestType int
// SessionAddType 枚举值
const (
SessionCreate requestType = iota
SessionUpdate
SessionTouch
)
// AddSessionRequest 定义了 SessionDomain.Add() 请求的请求结构。
type AddSessionRequest struct {
Username string `form:"username"`
CoreName string `form:"core_name"`
CoreVersion string `form:"core_version"`
GameName string `form:"game_name"`
GameCRC string `form:"game_crc"`
Port uint16 `form:"port"`
MITMServer string `form:"mitm_server"`
HasPassword bool `form:"has_password"` // 1/0可以绑定到 bool 吗?)
HasSpectatePassword bool `form:"has_spectate_password"`
ForceMITM bool `form:"force_mitm"`
RetroArchVersion string `form:"retroarch_version"`
Frontend string `form:"frontend"`
SubsystemName string `form:"subsystem_name"`
MITMSession string `form:"mitm_session"`
MITMCustomServer string `form:"mitm_custom_addr"`
MITMCustomPort uint16 `form:"mitm_custom_port"`
}
// ErrSessionRejected 当会话被域逻辑拒绝时抛出。
var ErrSessionRejected = errors.New("会话被拒绝")
// ErrRateLimited 当达到特定会话的速率限制时抛出。
var ErrRateLimited = errors.New("速率限制已达到")
// SessionRepository 接口用于将域逻辑与存储库代码解耦。
type SessionRepository interface {
Create(s *entity.Session) error
GetByID(id string) (*entity.Session, error)
GetByRoomID(roomID int32) (*entity.Session, error)
GetAll(deadline time.Time) ([]entity.Session, error)
Update(s *entity.Session) error
Touch(id string) error
PurgeOld(deadline time.Time) error
}
// SessionDomain 抽象了 netplay 会话处理的域逻辑。
type SessionDomain struct {
sessionRepo SessionRepository
geopip2Domain *GeoIP2Domain
validationDomain *ValidationDomain
mitmDomain *MitmDomain
}
// NewSessionDomain 返回一个初始化的 SessionDomain 结构。
func NewSessionDomain(
sessionRepo SessionRepository,
geoIP2Domain *GeoIP2Domain,
validationDomain *ValidationDomain,
mitmDomain *MitmDomain) *SessionDomain {
return &SessionDomain{sessionRepo, geoIP2Domain, validationDomain, mitmDomain}
}
// Add 添加或更新会话,基于来自给定 IP 的传入请求。
// 如果会话被拒绝,则返回 ErrSessionRejected。
// 如果达到会话的速率限制,则返回 ErrRateLimited。
func (d *SessionDomain) Add(request *AddSessionRequest, ip net.IP) (*entity.Session, error) {
var err error
var savedSession *entity.Session
var requestType requestType = SessionCreate
session := d.parseSession(request, ip)
if session.IP == nil || session.Port == 0 {
return nil, errors.New("IP 或端口未设置")
}
// 决定这是 CREATE、UPDATE 还是 TOUCH 操作
session.CalculateID()
session.CalculateContentHash()
if savedSession, err = d.sessionRepo.GetByID(session.ID); err != nil {
return nil, fmt.Errorf("无法获取已保存的会话: %w", err)
}
if savedSession != nil {
session.RoomID = savedSession.RoomID
session.Country = savedSession.Country
session.Connectable = savedSession.Connectable
session.IsRetroArch = savedSession.IsRetroArch
session.CreatedAt = savedSession.CreatedAt
session.UpdatedAt = savedSession.UpdatedAt
if savedSession.ContentHash != session.ContentHash {
requestType = SessionUpdate
} else {
requestType = SessionTouch
}
}
// 在 UPDATE 或 TOUCH 上进行速率限制
if requestType == SessionUpdate || requestType == SessionTouch {
threshold := time.Now().Add(-5 * time.Second)
if savedSession.UpdatedAt.After(threshold) {
return nil, ErrRateLimited
}
}
if requestType == SessionCreate || requestType == SessionUpdate {
// 在 CREATE 和 UPDATE 上验证会话
if !d.validateSession(session) {
return nil, ErrSessionRejected
}
}
// 持久化会话更改
switch requestType {
case SessionCreate:
if session.Country, err = d.geopip2Domain.GetCountryCodeForIP(session.IP); err != nil {
return nil, fmt.Errorf("无法找到给定 IP %s 的国家: %w", session.IP, err)
}
d.trySessionConnect(session)
if err = d.sessionRepo.Create(session); err != nil {
return nil, fmt.Errorf("无法创建新会话: %w", err)
}
case SessionUpdate:
d.trySessionConnect(session)
if err = d.sessionRepo.Update(session); err != nil {
return nil, fmt.Errorf("无法更新旧会话: %w", err)
}
case SessionTouch:
if !session.Connectable {
d.trySessionConnect(session)
if session.Connectable {
if err = d.sessionRepo.Update(session); err != nil {
return nil, fmt.Errorf("无法更新旧会话: %w", err)
}
break
}
}
if err = d.sessionRepo.Touch(session.ID); err != nil {
return nil, fmt.Errorf("无法触碰旧会话: %w", err)
}
}
return session, nil
}
// Get 返回具有给定 RoomID 的会话
func (d *SessionDomain) Get(roomID int32) (*entity.Session, error) {
session, err := d.sessionRepo.GetByRoomID(roomID)
if err != nil {
return nil, err
}
return session, nil
}
// List 返回当前正在托管的所有会话的列表
func (d *SessionDomain) List() ([]entity.Session, error) {
sessions, err := d.sessionRepo.GetAll(d.getDeadline())
if err != nil {
return nil, err
}
return sessions, nil
}
// PurgeOld 删除所有超过 45 秒未更新的会话。
func (d *SessionDomain) PurgeOld() error {
if err := d.sessionRepo.PurgeOld(d.getDeadline()); err != nil {
return err
}
return nil
}
// parseSession 将请求转换为可以与持久化会话进行比较的会话信息
func (d *SessionDomain) parseSession(req *AddSessionRequest, ip net.IP) *entity.Session {
var hostMethod entity.HostMethod = entity.HostMethodUnknown
var mitmHandle string = ""
var mitmAddress string = ""
var mitmPort uint16 = 0
var mitmSession string = ""
// 设置默认用户名
if req.Username == "" {
req.Username = "匿名"
}
if req.ForceMITM && req.MITMServer != "" && req.MITMSession != "" {
if req.MITMServer == "custom" {
if req.MITMCustomServer != "" && req.MITMCustomPort != 0 {
hostMethod = entity.HostMethodMITM
mitmHandle = req.MITMServer
mitmAddress = req.MITMCustomServer
mitmPort = req.MITMCustomPort
mitmSession = req.MITMSession
}
} else {
if info := d.GetTunnel(req.MITMServer); info != nil {
hostMethod = entity.HostMethodMITM
mitmHandle = req.MITMServer
mitmAddress = info.Address
mitmPort = info.Port
mitmSession = req.MITMSession
}
}
}
return &entity.Session{
Username: req.Username,
GameName: req.GameName,
GameCRC: strings.ToUpper(req.GameCRC),
CoreName: req.CoreName,
CoreVersion: req.CoreVersion,
SubsystemName: req.SubsystemName,
RetroArchVersion: req.RetroArchVersion,
Frontend: req.Frontend,
IP: ip,
Port: req.Port,
MitmHandle: mitmHandle,
MitmAddress: mitmAddress,
MitmPort: mitmPort,
MitmSession: mitmSession,
HostMethod: hostMethod,
HasPassword: req.HasPassword,
HasSpectatePassword: req.HasSpectatePassword,
}
}
// validateSession 验证传入的会话
func (d *SessionDomain) validateSession(s *entity.Session) bool {
if len(s.Username) > 32 ||
len(s.CoreName) > 255 ||
len(s.GameName) > 255 ||
len(s.GameCRC) != 8 ||
len(s.RetroArchVersion) > 32 ||
len(s.CoreVersion) > 255 ||
len(s.SubsystemName) > 255 ||
len(s.Frontend) > 255 ||
len(s.MitmSession) > 32 {
return false
}
if !d.validationDomain.ValidateString(s.Username) ||
!d.validationDomain.ValidateString(s.CoreName) ||
!d.validationDomain.ValidateString(s.CoreVersion) ||
!d.validationDomain.ValidateString(s.Frontend) ||
!d.validationDomain.ValidateString(s.SubsystemName) ||
!d.validationDomain.ValidateString(s.RetroArchVersion) {
return false
}
return true
}
// trySessionConnect 测试会话是否可连接以及是否为 RetroArch
func (d *SessionDomain) trySessionConnect(s *entity.Session) error {
s.Connectable = true
s.IsRetroArch = true
// 如果是 MITM假设既可连接又是 RetroArch
if s.HostMethod == entity.HostMethodMITM {
return nil
}
address := fmt.Sprintf("%s:%d", s.IP, s.Port)
conn, err := net.DialTimeout("tcp", address, time.Second*3)
if err != nil {
s.Connectable = false
return err
}
ranp := []byte{0x52, 0x41, 0x4E, 0x50} // RANP
full := []byte{0x46, 0x55, 0x4C, 0x4C} // FULL
poke := []byte{0x50, 0x4F, 0x4B, 0x45} // POKE
magic := make([]byte, 4)
// 忽略写入错误
conn.SetWriteDeadline(time.Now().Add(time.Second * 3))
conn.Write(poke)
conn.SetReadDeadline(time.Now().Add(time.Second * 3))
read, err := conn.Read(magic)
conn.Close()
// 在接收错误时假设它是 RetroArch
if err != nil || read == 0 {
return err
}
// 在不完整的魔术上假设它不是 RetroArch
if read != len(magic) {
s.IsRetroArch = false
} else if !bytes.Equal(magic, ranp) && !bytes.Equal(magic, full) {
s.IsRetroArch = false
}
return nil
}
func (d *SessionDomain) getDeadline() time.Time {
return time.Now().Add(-SessionDeadline * time.Second)
}
// GetTunnel 返回隧道的地址/端口对。
func (d *SessionDomain) GetTunnel(tunnelName string) *MitmInfo {
return d.mitmDomain.GetInfo(tunnelName)
}