netplay-lobby-server-go/domain/sessiondomain_test.go
2024-12-24 22:09:17 +08:00

352 lines
9.3 KiB
Go

package domain
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/libretro/netplay-lobby-server-go/model/entity"
)
var testIP = net.ParseIP("192.168.178.2")
// testRequest and testSession should have the same values for the test below.
var testRequest = AddSessionRequest{
Username: "zelda",
CoreName: "bsnes",
CoreVersion: "0.2.1",
GameName: "supergame",
GameCRC: "FFFFFFFF",
Port: 55355,
MITMServer: "",
HasPassword: false,
HasSpectatePassword: false,
ForceMITM: false,
RetroArchVersion: "1.1.1",
Frontend: "retro",
SubsystemName: "subsub",
MITMSession: "",
MITMCustomServer: "",
MITMCustomPort: 0,
}
var testSession = entity.Session{
ID: "",
RoomID: 100,
Username: "zelda",
Country: "en",
GameName: "supergame",
GameCRC: "FFFFFFFF",
CoreName: "bsnes",
CoreVersion: "0.2.1",
SubsystemName: "subsub",
RetroArchVersion: "1.1.1",
Frontend: "retro",
IP: net.ParseIP("192.168.178.2"),
Port: 55355,
MitmHandle: "",
MitmAddress: "",
MitmPort: 0,
MitmSession: "",
HostMethod: entity.HostMethodUnknown,
HasPassword: false,
HasSpectatePassword: false,
Connectable: true,
IsRetroArch: true,
CreatedAt: time.Now().Add(-5 * time.Minute),
UpdatedAt: time.Now().Add(-5 * time.Minute),
ContentHash: "",
}
type SessionRepositoryMock struct {
mock.Mock
}
func (m *SessionRepositoryMock) Create(s *entity.Session) error {
args := m.Called(s)
return args.Error(0)
}
func (m *SessionRepositoryMock) Update(s *entity.Session) error {
args := m.Called(s)
return args.Error(0)
}
func (m *SessionRepositoryMock) Touch(id string) error {
args := m.Called(id)
return args.Error(0)
}
func (m *SessionRepositoryMock) GetByID(id string) (*entity.Session, error) {
args := m.Called(id)
session, _ := args.Get(0).(*entity.Session)
return session, args.Error(1)
}
func (m *SessionRepositoryMock) GetByRoomID(roomID int32) (*entity.Session, error) {
args := m.Called(roomID)
session, _ := args.Get(0).(*entity.Session)
return session, args.Error(1)
}
func (m *SessionRepositoryMock) GetAll(deadline time.Time) ([]entity.Session, error) {
args := m.Called(deadline)
sessions, _ := args.Get(0).([]entity.Session)
return sessions, args.Error(1)
}
func (m *SessionRepositoryMock) PurgeOld(deadline time.Time) error {
args := m.Called(deadline)
return args.Error(0)
}
func setupSessionDomain(t *testing.T) (*SessionDomain, *SessionRepositoryMock) {
repoMock := SessionRepositoryMock{}
validationDomain, err := NewValidationDomain(testStringBlacklist, testIPBlacklist)
require.NoError(t, err)
geoip2Domain := setupGeoip2Domain(t)
sessionDomain := NewSessionDomain(&repoMock, geoip2Domain, validationDomain, &MitmDomain{})
require.NoError(t, err)
return sessionDomain, &repoMock
}
func TestSessionDomainPurgeOld(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
// Test the deadline duration
repoMock.On("PurgeOld", mock.MatchedBy(
func(d time.Time) bool {
before := time.Now().Add(-(SessionDeadline - 1) * time.Second)
after := time.Now().Add(-(SessionDeadline + 1) * time.Second)
return d.Before(before) && d.After(after)
})).Return(nil)
err := sessionDomain.PurgeOld()
require.NoError(t, err, "Can't purge old sessions")
}
func TestSessionDomainList(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
// Test the deadline duration
repoMock.On("GetAll", mock.MatchedBy(
func(d time.Time) bool {
before := time.Now().Add(-(SessionDeadline - 1) * time.Second)
after := time.Now().Add(-(SessionDeadline + 1) * time.Second)
return d.Before(before) && d.After(after)
})).Return(make([]entity.Session, 3), nil)
sessions, err := sessionDomain.List()
require.NoError(t, err, "Can't list sessions")
require.NotNil(t, sessions)
assert.Equal(t, 3, len(sessions))
}
func TestSessionDomainValidateSessionAtCreate(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.CalculateID()
comp.CalculateContentHash()
request.GameCRC = "123456789"
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(nil, nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.Error(t, err)
assert.Nil(t, newSession)
assert.True(t, errors.Is(err, ErrSessionRejected))
}
func TestSessionDomainValidateSessionAtUpdate(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.CalculateID()
comp.CalculateContentHash()
request.RetroArchVersion = "0123456789ABCDEF0123456789ABCDEF_INVALID"
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(&comp, nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.Error(t, err)
assert.Nil(t, newSession)
assert.True(t, errors.Is(err, ErrSessionRejected))
}
func TestSessionDomainAddSessionTypeCreate(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.CalculateID()
comp.CalculateContentHash()
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(nil, nil)
repoMock.On("Create", mock.MatchedBy(
func(s *entity.Session) bool {
return s.ID == comp.ID && s.ContentHash == comp.ContentHash
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.NoError(t, err)
require.NotNil(t, newSession)
assert.Equal(t, comp.ID, newSession.ID)
assert.Equal(t, comp.ContentHash, newSession.ContentHash)
}
func TestSessionDomainAddSessionTypeCreateShouldSetDefaultUsername(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
request.Username = ""
comp := testSession
comp.Username = "Anonymous"
comp.CalculateID()
comp.CalculateContentHash()
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(nil, nil)
repoMock.On("Create", mock.MatchedBy(
func(s *entity.Session) bool {
return s.ID == comp.ID && s.ContentHash == comp.ContentHash
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.NoError(t, err)
require.NotNil(t, newSession)
assert.Equal(t, comp.ID, newSession.ID)
assert.Equal(t, comp.ContentHash, newSession.ContentHash)
}
func TestSessionDomainAddSessionTypeUpdate(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.CalculateID()
comp.CalculateContentHash()
request.GameCRC = "88888888"
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(&comp, nil)
repoMock.On("Update", mock.MatchedBy(
func(s *entity.Session) bool {
return s.ID == comp.ID && s.ContentHash != comp.ContentHash
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.NoError(t, err)
require.NotNil(t, newSession)
assert.Equal(t, comp.ID, newSession.ID)
assert.NotEqual(t, comp.ContentHash, newSession.ContentHash)
}
func TestSessionDomainAddSessionTypeTouch(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.CalculateID()
comp.CalculateContentHash()
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(&comp, nil)
repoMock.On("Touch", mock.MatchedBy(
func(id string) bool {
return id == comp.ID
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.NoError(t, err)
require.NotNil(t, newSession)
assert.Equal(t, comp.ID, newSession.ID)
assert.Equal(t, comp.ContentHash, newSession.ContentHash)
}
func TestSessionDomainAddSessionTypeUpdateRateLimit(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.UpdatedAt = time.Now().Add(-4 * time.Second)
comp.CalculateID()
comp.CalculateContentHash()
request.GameCRC = "88888888"
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(&comp, nil)
repoMock.On("Update", mock.MatchedBy(
func(s *entity.Session) bool {
return s.ID == comp.ID && s.ContentHash != comp.ContentHash
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrRateLimited))
assert.Nil(t, newSession)
}
func TestSessionDomainAddSessionTypeTouchRateLimit(t *testing.T) {
sessionDomain, repoMock := setupSessionDomain(t)
request := testRequest
comp := testSession
comp.UpdatedAt = time.Now().Add(-4 * time.Second)
comp.CalculateID()
comp.CalculateContentHash()
repoMock.On("GetByID", mock.MatchedBy(
func(s string) bool {
return s == comp.ID
})).Return(&comp, nil)
repoMock.On("Touch", mock.MatchedBy(
func(id string) bool {
return id == comp.ID
})).Return(nil)
newSession, err := sessionDomain.Add(&request, testIP)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrRateLimited))
assert.Nil(t, newSession)
}