333 lines
9.6 KiB
Go
333 lines
9.6 KiB
Go
package domain
|
||
|
||
import (
|
||
"bytes"
|
||
"errors"
|
||
"fmt"
|
||
"net"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/libretro/netplay-lobby-server-go/model/entity"
|
||
)
|
||
|
||
// SessionDeadline 是一个会话在没有收到任何更新的情况下的生命周期(以秒为单位)。
|
||
const SessionDeadline = 60
|
||
|
||
// RateLimit 是客户端可以发送更新的最大速率(每五秒)。
|
||
const RateLimit = 5
|
||
|
||
// requestType 枚举
|
||
type requestType int
|
||
|
||
// SessionAddType 枚举值
|
||
const (
|
||
SessionCreate requestType = iota
|
||
SessionUpdate
|
||
SessionTouch
|
||
)
|
||
|
||
// AddSessionRequest 定义了 SessionDomain.Add() 请求的请求结构。
|
||
type AddSessionRequest struct {
|
||
Username string `form:"username"`
|
||
CoreName string `form:"core_name"`
|
||
CoreVersion string `form:"core_version"`
|
||
GameName string `form:"game_name"`
|
||
GameCRC string `form:"game_crc"`
|
||
Port uint16 `form:"port"`
|
||
MITMServer string `form:"mitm_server"`
|
||
HasPassword bool `form:"has_password"` // 1/0(可以绑定到 bool 吗?)
|
||
HasSpectatePassword bool `form:"has_spectate_password"`
|
||
ForceMITM bool `form:"force_mitm"`
|
||
RetroArchVersion string `form:"retroarch_version"`
|
||
Frontend string `form:"frontend"`
|
||
SubsystemName string `form:"subsystem_name"`
|
||
MITMSession string `form:"mitm_session"`
|
||
MITMCustomServer string `form:"mitm_custom_addr"`
|
||
MITMCustomPort uint16 `form:"mitm_custom_port"`
|
||
}
|
||
|
||
// ErrSessionRejected 当会话被域逻辑拒绝时抛出。
|
||
var ErrSessionRejected = errors.New("会话被拒绝")
|
||
|
||
// ErrRateLimited 当达到特定会话的速率限制时抛出。
|
||
var ErrRateLimited = errors.New("速率限制已达到")
|
||
|
||
// SessionRepository 接口用于将域逻辑与存储库代码解耦。
|
||
type SessionRepository interface {
|
||
Create(s *entity.Session) error
|
||
GetByID(id string) (*entity.Session, error)
|
||
GetByRoomID(roomID int32) (*entity.Session, error)
|
||
GetAll(deadline time.Time) ([]entity.Session, error)
|
||
Update(s *entity.Session) error
|
||
Touch(id string) error
|
||
PurgeOld(deadline time.Time) error
|
||
}
|
||
|
||
// SessionDomain 抽象了 netplay 会话处理的域逻辑。
|
||
type SessionDomain struct {
|
||
sessionRepo SessionRepository
|
||
geopip2Domain *GeoIP2Domain
|
||
validationDomain *ValidationDomain
|
||
mitmDomain *MitmDomain
|
||
}
|
||
|
||
// NewSessionDomain 返回一个初始化的 SessionDomain 结构。
|
||
func NewSessionDomain(
|
||
sessionRepo SessionRepository,
|
||
geoIP2Domain *GeoIP2Domain,
|
||
validationDomain *ValidationDomain,
|
||
mitmDomain *MitmDomain) *SessionDomain {
|
||
return &SessionDomain{sessionRepo, geoIP2Domain, validationDomain, mitmDomain}
|
||
}
|
||
|
||
// Add 添加或更新会话,基于来自给定 IP 的传入请求。
|
||
// 如果会话被拒绝,则返回 ErrSessionRejected。
|
||
// 如果达到会话的速率限制,则返回 ErrRateLimited。
|
||
func (d *SessionDomain) Add(request *AddSessionRequest, ip net.IP) (*entity.Session, error) {
|
||
var err error
|
||
var savedSession *entity.Session
|
||
var requestType requestType = SessionCreate
|
||
|
||
session := d.parseSession(request, ip)
|
||
|
||
if session.IP == nil || session.Port == 0 {
|
||
return nil, errors.New("IP 或端口未设置")
|
||
}
|
||
|
||
// 决定这是 CREATE、UPDATE 还是 TOUCH 操作
|
||
session.CalculateID()
|
||
session.CalculateContentHash()
|
||
if savedSession, err = d.sessionRepo.GetByID(session.ID); err != nil {
|
||
return nil, fmt.Errorf("无法获取已保存的会话: %w", err)
|
||
}
|
||
if savedSession != nil {
|
||
session.RoomID = savedSession.RoomID
|
||
session.Country = savedSession.Country
|
||
session.Connectable = savedSession.Connectable
|
||
session.IsRetroArch = savedSession.IsRetroArch
|
||
session.CreatedAt = savedSession.CreatedAt
|
||
session.UpdatedAt = savedSession.UpdatedAt
|
||
if savedSession.ContentHash != session.ContentHash {
|
||
requestType = SessionUpdate
|
||
} else {
|
||
requestType = SessionTouch
|
||
}
|
||
}
|
||
|
||
// 在 UPDATE 或 TOUCH 上进行速率限制
|
||
if requestType == SessionUpdate || requestType == SessionTouch {
|
||
threshold := time.Now().Add(-5 * time.Second)
|
||
if savedSession.UpdatedAt.After(threshold) {
|
||
return nil, ErrRateLimited
|
||
}
|
||
}
|
||
|
||
if requestType == SessionCreate || requestType == SessionUpdate {
|
||
// 在 CREATE 和 UPDATE 上验证会话
|
||
if !d.validateSession(session) {
|
||
return nil, ErrSessionRejected
|
||
}
|
||
}
|
||
|
||
// 持久化会话更改
|
||
switch requestType {
|
||
case SessionCreate:
|
||
if session.Country, err = d.geopip2Domain.GetCountryCodeForIP(session.IP); err != nil {
|
||
return nil, fmt.Errorf("无法找到给定 IP %s 的国家: %w", session.IP, err)
|
||
}
|
||
|
||
d.trySessionConnect(session)
|
||
|
||
if err = d.sessionRepo.Create(session); err != nil {
|
||
return nil, fmt.Errorf("无法创建新会话: %w", err)
|
||
}
|
||
case SessionUpdate:
|
||
d.trySessionConnect(session)
|
||
|
||
if err = d.sessionRepo.Update(session); err != nil {
|
||
return nil, fmt.Errorf("无法更新旧会话: %w", err)
|
||
}
|
||
case SessionTouch:
|
||
if !session.Connectable {
|
||
d.trySessionConnect(session)
|
||
if session.Connectable {
|
||
if err = d.sessionRepo.Update(session); err != nil {
|
||
return nil, fmt.Errorf("无法更新旧会话: %w", err)
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
if err = d.sessionRepo.Touch(session.ID); err != nil {
|
||
return nil, fmt.Errorf("无法触碰旧会话: %w", err)
|
||
}
|
||
}
|
||
|
||
return session, nil
|
||
}
|
||
|
||
// Get 返回具有给定 RoomID 的会话
|
||
func (d *SessionDomain) Get(roomID int32) (*entity.Session, error) {
|
||
session, err := d.sessionRepo.GetByRoomID(roomID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return session, nil
|
||
}
|
||
|
||
// List 返回当前正在托管的所有会话的列表
|
||
func (d *SessionDomain) List() ([]entity.Session, error) {
|
||
sessions, err := d.sessionRepo.GetAll(d.getDeadline())
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return sessions, nil
|
||
}
|
||
|
||
// PurgeOld 删除所有超过 45 秒未更新的会话。
|
||
func (d *SessionDomain) PurgeOld() error {
|
||
if err := d.sessionRepo.PurgeOld(d.getDeadline()); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// parseSession 将请求转换为可以与持久化会话进行比较的会话信息
|
||
func (d *SessionDomain) parseSession(req *AddSessionRequest, ip net.IP) *entity.Session {
|
||
var hostMethod entity.HostMethod = entity.HostMethodUnknown
|
||
var mitmHandle string = ""
|
||
var mitmAddress string = ""
|
||
var mitmPort uint16 = 0
|
||
var mitmSession string = ""
|
||
|
||
// 设置默认用户名
|
||
if req.Username == "" {
|
||
req.Username = "匿名"
|
||
}
|
||
|
||
if req.ForceMITM && req.MITMServer != "" && req.MITMSession != "" {
|
||
if req.MITMServer == "custom" {
|
||
if req.MITMCustomServer != "" && req.MITMCustomPort != 0 {
|
||
hostMethod = entity.HostMethodMITM
|
||
mitmHandle = req.MITMServer
|
||
mitmAddress = req.MITMCustomServer
|
||
mitmPort = req.MITMCustomPort
|
||
mitmSession = req.MITMSession
|
||
}
|
||
} else {
|
||
if info := d.GetTunnel(req.MITMServer); info != nil {
|
||
hostMethod = entity.HostMethodMITM
|
||
mitmHandle = req.MITMServer
|
||
mitmAddress = info.Address
|
||
mitmPort = info.Port
|
||
mitmSession = req.MITMSession
|
||
}
|
||
}
|
||
}
|
||
|
||
return &entity.Session{
|
||
Username: req.Username,
|
||
GameName: req.GameName,
|
||
GameCRC: strings.ToUpper(req.GameCRC),
|
||
CoreName: req.CoreName,
|
||
CoreVersion: req.CoreVersion,
|
||
SubsystemName: req.SubsystemName,
|
||
RetroArchVersion: req.RetroArchVersion,
|
||
Frontend: req.Frontend,
|
||
IP: ip,
|
||
Port: req.Port,
|
||
MitmHandle: mitmHandle,
|
||
MitmAddress: mitmAddress,
|
||
MitmPort: mitmPort,
|
||
MitmSession: mitmSession,
|
||
HostMethod: hostMethod,
|
||
HasPassword: req.HasPassword,
|
||
HasSpectatePassword: req.HasSpectatePassword,
|
||
}
|
||
}
|
||
|
||
// validateSession 验证传入的会话
|
||
func (d *SessionDomain) validateSession(s *entity.Session) bool {
|
||
if len(s.Username) > 32 ||
|
||
len(s.CoreName) > 255 ||
|
||
len(s.GameName) > 255 ||
|
||
len(s.GameCRC) != 8 ||
|
||
len(s.RetroArchVersion) > 32 ||
|
||
len(s.CoreVersion) > 255 ||
|
||
len(s.SubsystemName) > 255 ||
|
||
len(s.Frontend) > 255 ||
|
||
len(s.MitmSession) > 32 {
|
||
return false
|
||
}
|
||
|
||
if !d.validationDomain.ValidateString(s.Username) ||
|
||
!d.validationDomain.ValidateString(s.CoreName) ||
|
||
!d.validationDomain.ValidateString(s.CoreVersion) ||
|
||
!d.validationDomain.ValidateString(s.Frontend) ||
|
||
!d.validationDomain.ValidateString(s.SubsystemName) ||
|
||
!d.validationDomain.ValidateString(s.RetroArchVersion) {
|
||
return false
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// trySessionConnect 测试会话是否可连接以及是否为 RetroArch
|
||
func (d *SessionDomain) trySessionConnect(s *entity.Session) error {
|
||
s.Connectable = true
|
||
s.IsRetroArch = true
|
||
|
||
// 如果是 MITM,假设既可连接又是 RetroArch
|
||
if s.HostMethod == entity.HostMethodMITM {
|
||
return nil
|
||
}
|
||
|
||
address := fmt.Sprintf("%s:%d", s.IP, s.Port)
|
||
conn, err := net.DialTimeout("tcp", address, time.Second*3)
|
||
if err != nil {
|
||
s.Connectable = false
|
||
return err
|
||
}
|
||
|
||
ranp := []byte{0x52, 0x41, 0x4E, 0x50} // RANP
|
||
full := []byte{0x46, 0x55, 0x4C, 0x4C} // FULL
|
||
poke := []byte{0x50, 0x4F, 0x4B, 0x45} // POKE
|
||
magic := make([]byte, 4)
|
||
|
||
// 忽略写入错误
|
||
conn.SetWriteDeadline(time.Now().Add(time.Second * 3))
|
||
conn.Write(poke)
|
||
|
||
conn.SetReadDeadline(time.Now().Add(time.Second * 3))
|
||
read, err := conn.Read(magic)
|
||
|
||
conn.Close()
|
||
|
||
// 在接收错误时假设它是 RetroArch
|
||
if err != nil || read == 0 {
|
||
return err
|
||
}
|
||
|
||
// 在不完整的魔术上假设它不是 RetroArch
|
||
if read != len(magic) {
|
||
s.IsRetroArch = false
|
||
} else if !bytes.Equal(magic, ranp) && !bytes.Equal(magic, full) {
|
||
s.IsRetroArch = false
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (d *SessionDomain) getDeadline() time.Time {
|
||
return time.Now().Add(-SessionDeadline * time.Second)
|
||
}
|
||
|
||
// GetTunnel 返回隧道的地址/端口对。
|
||
func (d *SessionDomain) GetTunnel(tunnelName string) *MitmInfo {
|
||
return d.mitmDomain.GetInfo(tunnelName)
|
||
}
|