netplay-lobby-server-go/domain/sessiondomain.go
2024-12-24 22:09:17 +08:00

333 lines
9.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package domain
import (
"bytes"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/libretro/netplay-lobby-server-go/model/entity"
)
// SessionDeadline 是一个会话在没有收到任何更新的情况下的生命周期(以秒为单位)。
const SessionDeadline = 60
// RateLimit 是客户端可以发送更新的最大速率(每五秒)。
const RateLimit = 5
// requestType 枚举
type requestType int
// SessionAddType 枚举值
const (
SessionCreate requestType = iota
SessionUpdate
SessionTouch
)
// AddSessionRequest 定义了 SessionDomain.Add() 请求的请求结构。
type AddSessionRequest struct {
Username string `form:"username"`
CoreName string `form:"core_name"`
CoreVersion string `form:"core_version"`
GameName string `form:"game_name"`
GameCRC string `form:"game_crc"`
Port uint16 `form:"port"`
MITMServer string `form:"mitm_server"`
HasPassword bool `form:"has_password"` // 1/0可以绑定到 bool 吗?)
HasSpectatePassword bool `form:"has_spectate_password"`
ForceMITM bool `form:"force_mitm"`
RetroArchVersion string `form:"retroarch_version"`
Frontend string `form:"frontend"`
SubsystemName string `form:"subsystem_name"`
MITMSession string `form:"mitm_session"`
MITMCustomServer string `form:"mitm_custom_addr"`
MITMCustomPort uint16 `form:"mitm_custom_port"`
}
// ErrSessionRejected 当会话被域逻辑拒绝时抛出。
var ErrSessionRejected = errors.New("会话被拒绝")
// ErrRateLimited 当达到特定会话的速率限制时抛出。
var ErrRateLimited = errors.New("速率限制已达到")
// SessionRepository 接口用于将域逻辑与存储库代码解耦。
type SessionRepository interface {
Create(s *entity.Session) error
GetByID(id string) (*entity.Session, error)
GetByRoomID(roomID int32) (*entity.Session, error)
GetAll(deadline time.Time) ([]entity.Session, error)
Update(s *entity.Session) error
Touch(id string) error
PurgeOld(deadline time.Time) error
}
// SessionDomain 抽象了 netplay 会话处理的域逻辑。
type SessionDomain struct {
sessionRepo SessionRepository
geopip2Domain *GeoIP2Domain
validationDomain *ValidationDomain
mitmDomain *MitmDomain
}
// NewSessionDomain 返回一个初始化的 SessionDomain 结构。
func NewSessionDomain(
sessionRepo SessionRepository,
geoIP2Domain *GeoIP2Domain,
validationDomain *ValidationDomain,
mitmDomain *MitmDomain) *SessionDomain {
return &SessionDomain{sessionRepo, geoIP2Domain, validationDomain, mitmDomain}
}
// Add 添加或更新会话,基于来自给定 IP 的传入请求。
// 如果会话被拒绝,则返回 ErrSessionRejected。
// 如果达到会话的速率限制,则返回 ErrRateLimited。
func (d *SessionDomain) Add(request *AddSessionRequest, ip net.IP) (*entity.Session, error) {
var err error
var savedSession *entity.Session
var requestType requestType = SessionCreate
session := d.parseSession(request, ip)
if session.IP == nil || session.Port == 0 {
return nil, errors.New("IP 或端口未设置")
}
// 决定这是 CREATE、UPDATE 还是 TOUCH 操作
session.CalculateID()
session.CalculateContentHash()
if savedSession, err = d.sessionRepo.GetByID(session.ID); err != nil {
return nil, fmt.Errorf("无法获取已保存的会话: %w", err)
}
if savedSession != nil {
session.RoomID = savedSession.RoomID
session.Country = savedSession.Country
session.Connectable = savedSession.Connectable
session.IsRetroArch = savedSession.IsRetroArch
session.CreatedAt = savedSession.CreatedAt
session.UpdatedAt = savedSession.UpdatedAt
if savedSession.ContentHash != session.ContentHash {
requestType = SessionUpdate
} else {
requestType = SessionTouch
}
}
// 在 UPDATE 或 TOUCH 上进行速率限制
if requestType == SessionUpdate || requestType == SessionTouch {
threshold := time.Now().Add(-5 * time.Second)
if savedSession.UpdatedAt.After(threshold) {
return nil, ErrRateLimited
}
}
if requestType == SessionCreate || requestType == SessionUpdate {
// 在 CREATE 和 UPDATE 上验证会话
if !d.validateSession(session) {
return nil, ErrSessionRejected
}
}
// 持久化会话更改
switch requestType {
case SessionCreate:
if session.Country, err = d.geopip2Domain.GetCountryCodeForIP(session.IP); err != nil {
return nil, fmt.Errorf("无法找到给定 IP %s 的国家: %w", session.IP, err)
}
d.trySessionConnect(session)
if err = d.sessionRepo.Create(session); err != nil {
return nil, fmt.Errorf("无法创建新会话: %w", err)
}
case SessionUpdate:
d.trySessionConnect(session)
if err = d.sessionRepo.Update(session); err != nil {
return nil, fmt.Errorf("无法更新旧会话: %w", err)
}
case SessionTouch:
if !session.Connectable {
d.trySessionConnect(session)
if session.Connectable {
if err = d.sessionRepo.Update(session); err != nil {
return nil, fmt.Errorf("无法更新旧会话: %w", err)
}
break
}
}
if err = d.sessionRepo.Touch(session.ID); err != nil {
return nil, fmt.Errorf("无法触碰旧会话: %w", err)
}
}
return session, nil
}
// Get 返回具有给定 RoomID 的会话
func (d *SessionDomain) Get(roomID int32) (*entity.Session, error) {
session, err := d.sessionRepo.GetByRoomID(roomID)
if err != nil {
return nil, err
}
return session, nil
}
// List 返回当前正在托管的所有会话的列表
func (d *SessionDomain) List() ([]entity.Session, error) {
sessions, err := d.sessionRepo.GetAll(d.getDeadline())
if err != nil {
return nil, err
}
return sessions, nil
}
// PurgeOld 删除所有超过 45 秒未更新的会话。
func (d *SessionDomain) PurgeOld() error {
if err := d.sessionRepo.PurgeOld(d.getDeadline()); err != nil {
return err
}
return nil
}
// parseSession 将请求转换为可以与持久化会话进行比较的会话信息
func (d *SessionDomain) parseSession(req *AddSessionRequest, ip net.IP) *entity.Session {
var hostMethod entity.HostMethod = entity.HostMethodUnknown
var mitmHandle string = ""
var mitmAddress string = ""
var mitmPort uint16 = 0
var mitmSession string = ""
// 设置默认用户名
if req.Username == "" {
req.Username = "匿名"
}
if req.ForceMITM && req.MITMServer != "" && req.MITMSession != "" {
if req.MITMServer == "custom" {
if req.MITMCustomServer != "" && req.MITMCustomPort != 0 {
hostMethod = entity.HostMethodMITM
mitmHandle = req.MITMServer
mitmAddress = req.MITMCustomServer
mitmPort = req.MITMCustomPort
mitmSession = req.MITMSession
}
} else {
if info := d.GetTunnel(req.MITMServer); info != nil {
hostMethod = entity.HostMethodMITM
mitmHandle = req.MITMServer
mitmAddress = info.Address
mitmPort = info.Port
mitmSession = req.MITMSession
}
}
}
return &entity.Session{
Username: req.Username,
GameName: req.GameName,
GameCRC: strings.ToUpper(req.GameCRC),
CoreName: req.CoreName,
CoreVersion: req.CoreVersion,
SubsystemName: req.SubsystemName,
RetroArchVersion: req.RetroArchVersion,
Frontend: req.Frontend,
IP: ip,
Port: req.Port,
MitmHandle: mitmHandle,
MitmAddress: mitmAddress,
MitmPort: mitmPort,
MitmSession: mitmSession,
HostMethod: hostMethod,
HasPassword: req.HasPassword,
HasSpectatePassword: req.HasSpectatePassword,
}
}
// validateSession 验证传入的会话
func (d *SessionDomain) validateSession(s *entity.Session) bool {
if len(s.Username) > 32 ||
len(s.CoreName) > 255 ||
len(s.GameName) > 255 ||
len(s.GameCRC) != 8 ||
len(s.RetroArchVersion) > 32 ||
len(s.CoreVersion) > 255 ||
len(s.SubsystemName) > 255 ||
len(s.Frontend) > 255 ||
len(s.MitmSession) > 32 {
return false
}
if !d.validationDomain.ValidateString(s.Username) ||
!d.validationDomain.ValidateString(s.CoreName) ||
!d.validationDomain.ValidateString(s.CoreVersion) ||
!d.validationDomain.ValidateString(s.Frontend) ||
!d.validationDomain.ValidateString(s.SubsystemName) ||
!d.validationDomain.ValidateString(s.RetroArchVersion) {
return false
}
return true
}
// trySessionConnect 测试会话是否可连接以及是否为 RetroArch
func (d *SessionDomain) trySessionConnect(s *entity.Session) error {
s.Connectable = true
s.IsRetroArch = true
// 如果是 MITM假设既可连接又是 RetroArch
if s.HostMethod == entity.HostMethodMITM {
return nil
}
address := fmt.Sprintf("%s:%d", s.IP, s.Port)
conn, err := net.DialTimeout("tcp", address, time.Second*3)
if err != nil {
s.Connectable = false
return err
}
ranp := []byte{0x52, 0x41, 0x4E, 0x50} // RANP
full := []byte{0x46, 0x55, 0x4C, 0x4C} // FULL
poke := []byte{0x50, 0x4F, 0x4B, 0x45} // POKE
magic := make([]byte, 4)
// 忽略写入错误
conn.SetWriteDeadline(time.Now().Add(time.Second * 3))
conn.Write(poke)
conn.SetReadDeadline(time.Now().Add(time.Second * 3))
read, err := conn.Read(magic)
conn.Close()
// 在接收错误时假设它是 RetroArch
if err != nil || read == 0 {
return err
}
// 在不完整的魔术上假设它不是 RetroArch
if read != len(magic) {
s.IsRetroArch = false
} else if !bytes.Equal(magic, ranp) && !bytes.Equal(magic, full) {
s.IsRetroArch = false
}
return nil
}
func (d *SessionDomain) getDeadline() time.Time {
return time.Now().Add(-SessionDeadline * time.Second)
}
// GetTunnel 返回隧道的地址/端口对。
func (d *SessionDomain) GetTunnel(tunnelName string) *MitmInfo {
return d.mitmDomain.GetInfo(tunnelName)
}