333 lines
9.6 KiB
Go
333 lines
9.6 KiB
Go
|
package domain
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"net"
|
|||
|
"strings"
|
|||
|
"time"
|
|||
|
|
|||
|
"github.com/libretro/netplay-lobby-server-go/model/entity"
|
|||
|
)
|
|||
|
|
|||
|
// SessionDeadline 是一个会话在没有收到任何更新的情况下的生命周期(以秒为单位)。
|
|||
|
const SessionDeadline = 60
|
|||
|
|
|||
|
// RateLimit 是客户端可以发送更新的最大速率(每五秒)。
|
|||
|
const RateLimit = 5
|
|||
|
|
|||
|
// requestType 枚举
|
|||
|
type requestType int
|
|||
|
|
|||
|
// SessionAddType 枚举值
|
|||
|
const (
|
|||
|
SessionCreate requestType = iota
|
|||
|
SessionUpdate
|
|||
|
SessionTouch
|
|||
|
)
|
|||
|
|
|||
|
// AddSessionRequest 定义了 SessionDomain.Add() 请求的请求结构。
|
|||
|
type AddSessionRequest struct {
|
|||
|
Username string `form:"username"`
|
|||
|
CoreName string `form:"core_name"`
|
|||
|
CoreVersion string `form:"core_version"`
|
|||
|
GameName string `form:"game_name"`
|
|||
|
GameCRC string `form:"game_crc"`
|
|||
|
Port uint16 `form:"port"`
|
|||
|
MITMServer string `form:"mitm_server"`
|
|||
|
HasPassword bool `form:"has_password"` // 1/0(可以绑定到 bool 吗?)
|
|||
|
HasSpectatePassword bool `form:"has_spectate_password"`
|
|||
|
ForceMITM bool `form:"force_mitm"`
|
|||
|
RetroArchVersion string `form:"retroarch_version"`
|
|||
|
Frontend string `form:"frontend"`
|
|||
|
SubsystemName string `form:"subsystem_name"`
|
|||
|
MITMSession string `form:"mitm_session"`
|
|||
|
MITMCustomServer string `form:"mitm_custom_addr"`
|
|||
|
MITMCustomPort uint16 `form:"mitm_custom_port"`
|
|||
|
}
|
|||
|
|
|||
|
// ErrSessionRejected 当会话被域逻辑拒绝时抛出。
|
|||
|
var ErrSessionRejected = errors.New("会话被拒绝")
|
|||
|
|
|||
|
// ErrRateLimited 当达到特定会话的速率限制时抛出。
|
|||
|
var ErrRateLimited = errors.New("速率限制已达到")
|
|||
|
|
|||
|
// SessionRepository 接口用于将域逻辑与存储库代码解耦。
|
|||
|
type SessionRepository interface {
|
|||
|
Create(s *entity.Session) error
|
|||
|
GetByID(id string) (*entity.Session, error)
|
|||
|
GetByRoomID(roomID int32) (*entity.Session, error)
|
|||
|
GetAll(deadline time.Time) ([]entity.Session, error)
|
|||
|
Update(s *entity.Session) error
|
|||
|
Touch(id string) error
|
|||
|
PurgeOld(deadline time.Time) error
|
|||
|
}
|
|||
|
|
|||
|
// SessionDomain 抽象了 netplay 会话处理的域逻辑。
|
|||
|
type SessionDomain struct {
|
|||
|
sessionRepo SessionRepository
|
|||
|
geopip2Domain *GeoIP2Domain
|
|||
|
validationDomain *ValidationDomain
|
|||
|
mitmDomain *MitmDomain
|
|||
|
}
|
|||
|
|
|||
|
// NewSessionDomain 返回一个初始化的 SessionDomain 结构。
|
|||
|
func NewSessionDomain(
|
|||
|
sessionRepo SessionRepository,
|
|||
|
geoIP2Domain *GeoIP2Domain,
|
|||
|
validationDomain *ValidationDomain,
|
|||
|
mitmDomain *MitmDomain) *SessionDomain {
|
|||
|
return &SessionDomain{sessionRepo, geoIP2Domain, validationDomain, mitmDomain}
|
|||
|
}
|
|||
|
|
|||
|
// Add 添加或更新会话,基于来自给定 IP 的传入请求。
|
|||
|
// 如果会话被拒绝,则返回 ErrSessionRejected。
|
|||
|
// 如果达到会话的速率限制,则返回 ErrRateLimited。
|
|||
|
func (d *SessionDomain) Add(request *AddSessionRequest, ip net.IP) (*entity.Session, error) {
|
|||
|
var err error
|
|||
|
var savedSession *entity.Session
|
|||
|
var requestType requestType = SessionCreate
|
|||
|
|
|||
|
session := d.parseSession(request, ip)
|
|||
|
|
|||
|
if session.IP == nil || session.Port == 0 {
|
|||
|
return nil, errors.New("IP 或端口未设置")
|
|||
|
}
|
|||
|
|
|||
|
// 决定这是 CREATE、UPDATE 还是 TOUCH 操作
|
|||
|
session.CalculateID()
|
|||
|
session.CalculateContentHash()
|
|||
|
if savedSession, err = d.sessionRepo.GetByID(session.ID); err != nil {
|
|||
|
return nil, fmt.Errorf("无法获取已保存的会话: %w", err)
|
|||
|
}
|
|||
|
if savedSession != nil {
|
|||
|
session.RoomID = savedSession.RoomID
|
|||
|
session.Country = savedSession.Country
|
|||
|
session.Connectable = savedSession.Connectable
|
|||
|
session.IsRetroArch = savedSession.IsRetroArch
|
|||
|
session.CreatedAt = savedSession.CreatedAt
|
|||
|
session.UpdatedAt = savedSession.UpdatedAt
|
|||
|
if savedSession.ContentHash != session.ContentHash {
|
|||
|
requestType = SessionUpdate
|
|||
|
} else {
|
|||
|
requestType = SessionTouch
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 在 UPDATE 或 TOUCH 上进行速率限制
|
|||
|
if requestType == SessionUpdate || requestType == SessionTouch {
|
|||
|
threshold := time.Now().Add(-5 * time.Second)
|
|||
|
if savedSession.UpdatedAt.After(threshold) {
|
|||
|
return nil, ErrRateLimited
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if requestType == SessionCreate || requestType == SessionUpdate {
|
|||
|
// 在 CREATE 和 UPDATE 上验证会话
|
|||
|
if !d.validateSession(session) {
|
|||
|
return nil, ErrSessionRejected
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 持久化会话更改
|
|||
|
switch requestType {
|
|||
|
case SessionCreate:
|
|||
|
if session.Country, err = d.geopip2Domain.GetCountryCodeForIP(session.IP); err != nil {
|
|||
|
return nil, fmt.Errorf("无法找到给定 IP %s 的国家: %w", session.IP, err)
|
|||
|
}
|
|||
|
|
|||
|
d.trySessionConnect(session)
|
|||
|
|
|||
|
if err = d.sessionRepo.Create(session); err != nil {
|
|||
|
return nil, fmt.Errorf("无法创建新会话: %w", err)
|
|||
|
}
|
|||
|
case SessionUpdate:
|
|||
|
d.trySessionConnect(session)
|
|||
|
|
|||
|
if err = d.sessionRepo.Update(session); err != nil {
|
|||
|
return nil, fmt.Errorf("无法更新旧会话: %w", err)
|
|||
|
}
|
|||
|
case SessionTouch:
|
|||
|
if !session.Connectable {
|
|||
|
d.trySessionConnect(session)
|
|||
|
if session.Connectable {
|
|||
|
if err = d.sessionRepo.Update(session); err != nil {
|
|||
|
return nil, fmt.Errorf("无法更新旧会话: %w", err)
|
|||
|
}
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if err = d.sessionRepo.Touch(session.ID); err != nil {
|
|||
|
return nil, fmt.Errorf("无法触碰旧会话: %w", err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return session, nil
|
|||
|
}
|
|||
|
|
|||
|
// Get 返回具有给定 RoomID 的会话
|
|||
|
func (d *SessionDomain) Get(roomID int32) (*entity.Session, error) {
|
|||
|
session, err := d.sessionRepo.GetByRoomID(roomID)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
return session, nil
|
|||
|
}
|
|||
|
|
|||
|
// List 返回当前正在托管的所有会话的列表
|
|||
|
func (d *SessionDomain) List() ([]entity.Session, error) {
|
|||
|
sessions, err := d.sessionRepo.GetAll(d.getDeadline())
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
return sessions, nil
|
|||
|
}
|
|||
|
|
|||
|
// PurgeOld 删除所有超过 45 秒未更新的会话。
|
|||
|
func (d *SessionDomain) PurgeOld() error {
|
|||
|
if err := d.sessionRepo.PurgeOld(d.getDeadline()); err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
// parseSession 将请求转换为可以与持久化会话进行比较的会话信息
|
|||
|
func (d *SessionDomain) parseSession(req *AddSessionRequest, ip net.IP) *entity.Session {
|
|||
|
var hostMethod entity.HostMethod = entity.HostMethodUnknown
|
|||
|
var mitmHandle string = ""
|
|||
|
var mitmAddress string = ""
|
|||
|
var mitmPort uint16 = 0
|
|||
|
var mitmSession string = ""
|
|||
|
|
|||
|
// 设置默认用户名
|
|||
|
if req.Username == "" {
|
|||
|
req.Username = "匿名"
|
|||
|
}
|
|||
|
|
|||
|
if req.ForceMITM && req.MITMServer != "" && req.MITMSession != "" {
|
|||
|
if req.MITMServer == "custom" {
|
|||
|
if req.MITMCustomServer != "" && req.MITMCustomPort != 0 {
|
|||
|
hostMethod = entity.HostMethodMITM
|
|||
|
mitmHandle = req.MITMServer
|
|||
|
mitmAddress = req.MITMCustomServer
|
|||
|
mitmPort = req.MITMCustomPort
|
|||
|
mitmSession = req.MITMSession
|
|||
|
}
|
|||
|
} else {
|
|||
|
if info := d.GetTunnel(req.MITMServer); info != nil {
|
|||
|
hostMethod = entity.HostMethodMITM
|
|||
|
mitmHandle = req.MITMServer
|
|||
|
mitmAddress = info.Address
|
|||
|
mitmPort = info.Port
|
|||
|
mitmSession = req.MITMSession
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return &entity.Session{
|
|||
|
Username: req.Username,
|
|||
|
GameName: req.GameName,
|
|||
|
GameCRC: strings.ToUpper(req.GameCRC),
|
|||
|
CoreName: req.CoreName,
|
|||
|
CoreVersion: req.CoreVersion,
|
|||
|
SubsystemName: req.SubsystemName,
|
|||
|
RetroArchVersion: req.RetroArchVersion,
|
|||
|
Frontend: req.Frontend,
|
|||
|
IP: ip,
|
|||
|
Port: req.Port,
|
|||
|
MitmHandle: mitmHandle,
|
|||
|
MitmAddress: mitmAddress,
|
|||
|
MitmPort: mitmPort,
|
|||
|
MitmSession: mitmSession,
|
|||
|
HostMethod: hostMethod,
|
|||
|
HasPassword: req.HasPassword,
|
|||
|
HasSpectatePassword: req.HasSpectatePassword,
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// validateSession 验证传入的会话
|
|||
|
func (d *SessionDomain) validateSession(s *entity.Session) bool {
|
|||
|
if len(s.Username) > 32 ||
|
|||
|
len(s.CoreName) > 255 ||
|
|||
|
len(s.GameName) > 255 ||
|
|||
|
len(s.GameCRC) != 8 ||
|
|||
|
len(s.RetroArchVersion) > 32 ||
|
|||
|
len(s.CoreVersion) > 255 ||
|
|||
|
len(s.SubsystemName) > 255 ||
|
|||
|
len(s.Frontend) > 255 ||
|
|||
|
len(s.MitmSession) > 32 {
|
|||
|
return false
|
|||
|
}
|
|||
|
|
|||
|
if !d.validationDomain.ValidateString(s.Username) ||
|
|||
|
!d.validationDomain.ValidateString(s.CoreName) ||
|
|||
|
!d.validationDomain.ValidateString(s.CoreVersion) ||
|
|||
|
!d.validationDomain.ValidateString(s.Frontend) ||
|
|||
|
!d.validationDomain.ValidateString(s.SubsystemName) ||
|
|||
|
!d.validationDomain.ValidateString(s.RetroArchVersion) {
|
|||
|
return false
|
|||
|
}
|
|||
|
|
|||
|
return true
|
|||
|
}
|
|||
|
|
|||
|
// trySessionConnect 测试会话是否可连接以及是否为 RetroArch
|
|||
|
func (d *SessionDomain) trySessionConnect(s *entity.Session) error {
|
|||
|
s.Connectable = true
|
|||
|
s.IsRetroArch = true
|
|||
|
|
|||
|
// 如果是 MITM,假设既可连接又是 RetroArch
|
|||
|
if s.HostMethod == entity.HostMethodMITM {
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
address := fmt.Sprintf("%s:%d", s.IP, s.Port)
|
|||
|
conn, err := net.DialTimeout("tcp", address, time.Second*3)
|
|||
|
if err != nil {
|
|||
|
s.Connectable = false
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
ranp := []byte{0x52, 0x41, 0x4E, 0x50} // RANP
|
|||
|
full := []byte{0x46, 0x55, 0x4C, 0x4C} // FULL
|
|||
|
poke := []byte{0x50, 0x4F, 0x4B, 0x45} // POKE
|
|||
|
magic := make([]byte, 4)
|
|||
|
|
|||
|
// 忽略写入错误
|
|||
|
conn.SetWriteDeadline(time.Now().Add(time.Second * 3))
|
|||
|
conn.Write(poke)
|
|||
|
|
|||
|
conn.SetReadDeadline(time.Now().Add(time.Second * 3))
|
|||
|
read, err := conn.Read(magic)
|
|||
|
|
|||
|
conn.Close()
|
|||
|
|
|||
|
// 在接收错误时假设它是 RetroArch
|
|||
|
if err != nil || read == 0 {
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
// 在不完整的魔术上假设它不是 RetroArch
|
|||
|
if read != len(magic) {
|
|||
|
s.IsRetroArch = false
|
|||
|
} else if !bytes.Equal(magic, ranp) && !bytes.Equal(magic, full) {
|
|||
|
s.IsRetroArch = false
|
|||
|
}
|
|||
|
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (d *SessionDomain) getDeadline() time.Time {
|
|||
|
return time.Now().Add(-SessionDeadline * time.Second)
|
|||
|
}
|
|||
|
|
|||
|
// GetTunnel 返回隧道的地址/端口对。
|
|||
|
func (d *SessionDomain) GetTunnel(tunnelName string) *MitmInfo {
|
|||
|
return d.mitmDomain.GetInfo(tunnelName)
|
|||
|
}
|