fluxer/fluxer_gateway/src/gateway/gateway_handler.erl

958 lines
36 KiB
Erlang

%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_handler).
-behaviour(cowboy_websocket).
-export([init/2, websocket_init/1, websocket_handle/2, websocket_info/2, terminate/3]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(VOICE_UPDATE_RATE_LIMIT, 10).
-define(VOICE_RATE_LIMIT_WINDOW, 1000).
-define(VOICE_QUEUE_TABLE, voice_update_queue).
-define(VOICE_RATE_LIMIT_TABLE, voice_update_rate_limit).
-define(VOICE_QUEUE_PROCESS_INTERVAL, 100).
-define(MAX_VOICE_QUEUE_LENGTH, 64).
-type state() :: #{
version := 1 | undefined,
encoding := gateway_codec:encoding(),
compress_ctx := gateway_compress:compress_ctx() | undefined,
session_pid := pid() | undefined,
heartbeat_state := map(),
socket_pid := pid() | undefined,
peer_ip := binary() | undefined,
rate_limit_state := map(),
otel_span_ctx := term(),
voice_queue_timer := reference() | undefined
}.
-spec new_state() -> state().
new_state() ->
#{
version => undefined,
encoding => json,
compress_ctx => undefined,
session_pid => undefined,
heartbeat_state => #{},
socket_pid => undefined,
peer_ip => undefined,
rate_limit_state => #{events => [], window_start => undefined},
otel_span_ctx => undefined,
voice_queue_timer => undefined
}.
-type ws_frame() :: {text, binary()} | {binary, binary()}.
-type ws_result() :: {ok, state()} | {[ws_frame() | {close, integer(), binary()}], state()}.
-spec init(cowboy_req:req(), term()) -> {cowboy_websocket, cowboy_req:req(), state()}.
init(Req, _Opts) ->
QS = cowboy_req:parse_qs(Req),
Version = parse_version(proplists:get_value(<<"v">>, QS)),
Encoding = gateway_codec:parse_encoding(proplists:get_value(<<"encoding">>, QS)),
Compression = gateway_compress:parse_compression(proplists:get_value(<<"compress">>, QS)),
CompressCtx = gateway_compress:new_context(Compression),
PeerIPBinary = extract_client_ip(Req),
State = new_state(),
{cowboy_websocket, Req, State#{
version => Version,
encoding => Encoding,
compress_ctx => CompressCtx,
socket_pid => self(),
peer_ip => PeerIPBinary
}}.
-spec parse_version(binary() | undefined) -> 1 | undefined.
parse_version(<<"1">>) -> 1;
parse_version(_) -> undefined.
-spec websocket_init(state()) -> ws_result().
websocket_init(State = #{version := 1}) ->
gateway_metrics_collector:inc_connections(),
ConnSpanCtx = start_websocket_connect_span(1, State),
NewState = State#{otel_span_ctx => ConnSpanCtx},
CompressionType = gateway_compress:get_type(maps:get(compress_ctx, NewState)),
FreshCompressCtx = gateway_compress:new_context(CompressionType),
FreshState0 = NewState#{compress_ctx => FreshCompressCtx},
HeartbeatInterval = constants:heartbeat_interval(),
HelloMessage = #{
<<"op">> => constants:opcode_to_num(hello),
<<"d">> => #{<<"heartbeat_interval">> => HeartbeatInterval}
},
schedule_heartbeat_check(),
NewState1 = FreshState0#{
heartbeat_state => #{
last_ack => erlang:system_time(millisecond),
waiting_for_ack => false
}
},
case encode_and_compress(HelloMessage, NewState1) of
{ok, Frame, NewState2} ->
{[Frame], NewState2};
{error, {compress_failed, CT, _Reason}} ->
close_with_reason(decode_error, compression_error_reason(CT), FreshState0);
{error, _Reason} ->
close_with_reason(decode_error, <<"Encode failed">>, FreshState0)
end;
websocket_init(State) ->
gateway_metrics_collector:inc_connections(),
ConnSpanCtx = start_websocket_connect_span(undefined, State),
NewState = State#{otel_span_ctx => ConnSpanCtx},
close_with_reason(invalid_api_version, <<"Invalid API version">>, NewState).
-spec websocket_handle({text, binary()} | {binary, binary()} | term(), state()) -> ws_result().
websocket_handle({text, Text}, State) ->
handle_incoming_data(Text, State);
websocket_handle({binary, Binary}, State) ->
handle_incoming_data(Binary, State);
websocket_handle(_, State) ->
{ok, State}.
-spec handle_incoming_data(binary(), state()) -> ws_result().
handle_incoming_data(Data, State = #{encoding := Encoding, compress_ctx := CompressCtx}) ->
MaxSize = constants:max_payload_size(),
case byte_size(Data) =< MaxSize of
true ->
handle_decode(gateway_codec:decode(Data, Encoding), State#{
compress_ctx => CompressCtx
});
false ->
close_with_reason(decode_error, <<"Payload too large">>, State)
end.
-spec handle_decode({ok, map()} | {error, term()}, state()) -> ws_result().
handle_decode({ok, #{<<"op">> := Op} = Payload}, State) ->
OpAtom = constants:gateway_opcode(Op),
case check_rate_limit(State) of
{ok, RateLimitedState} ->
handle_gateway_payload(OpAtom, Payload, RateLimitedState);
rate_limited ->
close_with_reason(rate_limited, <<"Rate limited">>, State)
end;
handle_decode({ok, _}, State) ->
close_with_reason(decode_error, <<"Invalid payload">>, State);
handle_decode({error, _Reason}, State) ->
close_with_reason(decode_error, <<"Decode failed">>, State).
-spec websocket_info(term(), state()) -> ws_result().
websocket_info({heartbeat_check}, State = #{heartbeat_state := HeartbeatState}) ->
handle_heartbeat_check(State, HeartbeatState);
websocket_info({dispatch, Event, Data, Seq}, State) ->
handle_dispatch(Event, Data, Seq, State);
websocket_info({session_backpressure_error, Details}, State) ->
handle_session_backpressure_error(Details, State);
websocket_info({'DOWN', _, process, Pid, _}, State = #{session_pid := Pid}) ->
handle_session_down(State);
websocket_info({process_voice_queue}, State) ->
NewState = process_queued_voice_updates(State#{voice_queue_timer => undefined}),
{ok, NewState};
websocket_info(_, State) ->
{ok, State}.
-spec handle_heartbeat_check(state(), map()) -> ws_result().
handle_heartbeat_check(State = #{heartbeat_state := HeartbeatState}, _) ->
Now = erlang:system_time(millisecond),
LastAck = maps:get(last_ack, HeartbeatState, Now),
WaitingForAck = maps:get(waiting_for_ack, HeartbeatState, false),
HeartbeatTimeout = constants:heartbeat_timeout(),
HeartbeatInterval = constants:heartbeat_interval(),
handle_heartbeat_state(State, Now, LastAck, WaitingForAck, HeartbeatTimeout, HeartbeatInterval).
-spec handle_heartbeat_state(state(), integer(), integer(), boolean(), integer(), integer()) ->
ws_result().
handle_heartbeat_state(State, Now, LastAck, true, HeartbeatTimeout, _) when
(Now - LastAck) > HeartbeatTimeout
->
gateway_metrics_collector:inc_heartbeat_failure(),
close_with_reason(session_timeout, <<"Heartbeat timeout">>, State);
handle_heartbeat_state(
State = #{heartbeat_state := HeartbeatState}, Now, LastAck, _, _, HeartbeatInterval
) when (Now - LastAck) >= (HeartbeatInterval * 0.9) ->
Message = #{<<"op">> => constants:opcode_to_num(heartbeat), <<"d">> => null},
schedule_heartbeat_check(),
NewState = State#{heartbeat_state => HeartbeatState#{waiting_for_ack => true}},
case encode_and_compress(Message, NewState) of
{ok, Frame, NewState2} -> {[Frame], NewState2};
{error, _} -> {ok, NewState}
end;
handle_heartbeat_state(State, _, _, _, _, _) ->
schedule_heartbeat_check(),
{ok, State}.
-spec handle_dispatch(atom() | binary(), map() | null, integer(), state()) -> ws_result().
handle_dispatch(Event, Data, Seq, State) ->
EventName = dispatch_event_name(Event),
Message = #{
<<"op">> => constants:opcode_to_num(dispatch),
<<"t">> => EventName,
<<"d">> => Data,
<<"s">> => Seq
},
case encode_and_compress(Message, State) of
{ok, Frame, NewState} -> {[Frame], NewState};
{error, _Reason} -> {ok, State}
end.
-spec dispatch_event_name(atom() | binary()) -> binary().
dispatch_event_name(Event) when is_binary(Event) -> Event;
dispatch_event_name(Event) when is_atom(Event) -> constants:dispatch_event_atom(Event);
dispatch_event_name(_) -> <<"UNKNOWN">>.
-spec handle_session_down(state()) -> ws_result().
handle_session_down(State) ->
Message = #{<<"op">> => constants:opcode_to_num(invalid_session), <<"d">> => false},
NewState = State#{session_pid => undefined},
case encode_and_compress(Message, NewState) of
{ok, Frame, NewState2} -> {[Frame], NewState2};
{error, _} -> {ok, NewState}
end.
-spec handle_session_backpressure_error(map(), state()) -> ws_result().
handle_session_backpressure_error(Details, State) ->
Message = format_backpressure_close_message(Details),
close_with_reason(ack_backpressure, Message, State).
-spec format_backpressure_close_message(map()) -> binary().
format_backpressure_close_message(Details) ->
Kind = detail_value_to_binary(map_utils:get_safe(Details, kind, <<"unknown">>)),
Unacked = map_utils:get_safe(Details, unacked_events, 0),
CurrentSize = map_utils:get_safe(Details, current_size, 0),
Limit = map_utils:get_safe(Details, limit, 0),
Seq = map_utils:get_safe(Details, seq, 0),
AckSeq = map_utils:get_safe(Details, ack_seq, 0),
<<
"Acknowledgement backlog exceeded: kind=",
Kind/binary,
" unacked=",
(integer_to_binary(Unacked))/binary,
" current=",
(integer_to_binary(CurrentSize))/binary,
" limit=",
(integer_to_binary(Limit))/binary,
" seq=",
(integer_to_binary(Seq))/binary,
" ack_seq=",
(integer_to_binary(AckSeq))/binary
>>.
-spec detail_value_to_binary(term()) -> binary().
detail_value_to_binary(Value) when is_binary(Value) ->
Value;
detail_value_to_binary(Value) when is_atom(Value) ->
atom_to_binary(Value, utf8);
detail_value_to_binary(_) ->
<<"unknown">>.
-spec terminate(term(), cowboy_req:req(), state() | term()) -> ok.
terminate(Reason, _Req, #{compress_ctx := CompressCtx, otel_span_ctx := SpanCtx}) ->
gateway_metrics_collector:inc_disconnections(),
end_websocket_disconnect_span(Reason, SpanCtx),
gateway_compress:close_context(CompressCtx),
ok;
terminate(_Reason, _Req, _State) ->
gateway_metrics_collector:inc_disconnections(),
ok.
-spec validate_identify_data(map()) ->
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
| {error, atom()}.
validate_identify_data(Data) ->
try
Token = maps:get(<<"token">>, Data),
Properties = maps:get(<<"properties">>, Data),
IgnoredEventsRaw = maps:get(<<"ignored_events">>, Data, []),
InitialGuildIdRaw = maps:get(<<"initial_guild_id">>, Data, undefined),
validate_properties(Token, Properties, IgnoredEventsRaw, InitialGuildIdRaw, Data)
catch
error:{badkey, _} -> {error, missing_required_field}
end.
-spec validate_properties(binary(), term(), term(), term(), map()) ->
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
| {error, atom()}.
validate_properties(Token, Properties, IgnoredEventsRaw, InitialGuildIdRaw, Data) when
is_map(Properties)
->
Os = maps:get(<<"os">>, Properties),
Browser = maps:get(<<"browser">>, Properties),
Device = maps:get(<<"device">>, Properties),
case {is_binary(Os), is_binary(Browser), is_binary(Device)} of
{true, true, true} ->
Presence = maps:get(<<"presence">>, Data, null),
validate_ignored_events(
Token, Properties, Presence, IgnoredEventsRaw, InitialGuildIdRaw, Data
);
_ ->
{error, invalid_properties}
end;
validate_properties(_, _, _, _, _) ->
{error, invalid_properties}.
-spec validate_ignored_events(binary(), map(), term(), term(), term(), map()) ->
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
| {error, atom()}.
validate_ignored_events(Token, Properties, Presence, IgnoredEventsRaw, InitialGuildIdRaw, Data) ->
case parse_ignored_events(IgnoredEventsRaw) of
{ok, IgnoredEvents} ->
FlagsRaw = maps:get(<<"flags">>, Data, 0),
validate_flags(Token, Properties, Presence, IgnoredEvents, InitialGuildIdRaw, FlagsRaw);
{error, Reason} ->
{error, Reason}
end.
-spec validate_flags(binary(), map(), term(), [binary()], term(), term()) ->
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
| {error, atom()}.
validate_flags(Token, Properties, Presence, IgnoredEvents, InitialGuildIdRaw, Flags) when
is_integer(Flags), Flags >= 0
->
{ok, Token, Properties, Presence, IgnoredEvents, Flags,
parse_initial_guild_id(InitialGuildIdRaw)};
validate_flags(_, _, _, _, _, _) ->
{error, invalid_properties}.
-spec parse_ignored_events(term()) -> {ok, [binary()]} | {error, invalid_ignored_events}.
parse_ignored_events(undefined) ->
{ok, []};
parse_ignored_events(null) ->
{ok, []};
parse_ignored_events(Events) when is_list(Events) ->
case lists:all(fun erlang:is_binary/1, Events) of
true -> {ok, lists:usort([normalize_event_name(E) || E <- Events])};
false -> {error, invalid_ignored_events}
end;
parse_ignored_events(_) ->
{error, invalid_ignored_events}.
-spec parse_initial_guild_id(term()) -> integer() | undefined.
parse_initial_guild_id(undefined) ->
undefined;
parse_initial_guild_id(null) ->
undefined;
parse_initial_guild_id(Value) when is_binary(Value) ->
case validation:validate_snowflake(<<"initial_guild_id">>, Value) of
{ok, GuildId} -> GuildId;
{error, _, _} -> undefined
end;
parse_initial_guild_id(_) ->
undefined.
-spec normalize_event_name(binary()) -> binary().
normalize_event_name(Event) ->
list_to_binary(string:uppercase(binary_to_list(Event))).
-spec handle_gateway_payload(atom(), map(), state()) -> ws_result().
handle_gateway_payload(
heartbeat,
#{<<"d">> := Seq},
State = #{heartbeat_state := HeartbeatState, session_pid := SessionPid}
) ->
handle_heartbeat(Seq, SessionPid, State, HeartbeatState);
handle_gateway_payload(
identify, #{<<"d">> := Data}, State = #{session_pid := undefined, peer_ip := PeerIP}
) ->
handle_identify(Data, PeerIP, State);
handle_gateway_payload(identify, _, State) ->
close_with_reason(already_authenticated, <<"Already authenticated">>, State);
handle_gateway_payload(presence_update, #{<<"d">> := _}, State = #{session_pid := undefined}) ->
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
handle_gateway_payload(presence_update, #{<<"d">> := Data}, State = #{session_pid := Pid}) when
is_pid(Pid)
->
handle_presence_update(Data, Pid, State);
handle_gateway_payload(resume, #{<<"d">> := Data}, State) ->
handle_resume(Data, State);
handle_gateway_payload(
voice_state_update, #{<<"d">> := _}, State = #{session_pid := undefined}
) ->
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
handle_gateway_payload(
voice_state_update, #{<<"d">> := Data}, State = #{session_pid := Pid}
) when is_pid(Pid) ->
handle_voice_state_update(Pid, Data, State);
handle_gateway_payload(
request_guild_members, #{<<"d">> := _}, State = #{session_pid := undefined}
) ->
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
handle_gateway_payload(
request_guild_members, #{<<"d">> := Data}, State = #{session_pid := Pid}
) when is_pid(Pid) ->
handle_request_guild_members(Data, Pid, State);
handle_gateway_payload(lazy_request, #{<<"d">> := _}, State = #{session_pid := undefined}) ->
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
handle_gateway_payload(lazy_request, #{<<"d">> := Data}, State = #{session_pid := Pid}) when
is_pid(Pid)
->
handle_lazy_request(Data, Pid, State);
handle_gateway_payload(_, _, State) ->
close_with_reason(unknown_opcode, <<"Unknown opcode">>, State).
-spec handle_heartbeat(term(), pid() | undefined, state(), map()) -> ws_result().
handle_heartbeat(Seq, SessionPid, State, HeartbeatState) ->
AckOk = verify_heartbeat_ack(Seq, SessionPid),
case AckOk of
true ->
NewHeartbeatState = HeartbeatState#{
last_ack => erlang:system_time(millisecond),
waiting_for_ack => false
},
gateway_metrics_collector:inc_heartbeat_success(),
AckMessage = #{<<"op">> => constants:opcode_to_num(heartbeat_ack)},
NewState = State#{heartbeat_state => NewHeartbeatState},
case encode_and_compress(AckMessage, NewState) of
{ok, Frame, NewState2} -> {[Frame], NewState2};
{error, _} -> {ok, NewState}
end;
false ->
gateway_metrics_collector:inc_heartbeat_failure(),
close_with_reason(invalid_seq, <<"Invalid sequence">>, State)
end.
-spec verify_heartbeat_ack(term(), pid() | undefined) -> boolean().
verify_heartbeat_ack(_, undefined) ->
true;
verify_heartbeat_ack(null, _) ->
true;
verify_heartbeat_ack(SeqNum, Pid) when is_integer(SeqNum), is_pid(Pid) ->
try
case gen_server:call(Pid, {heartbeat_ack, SeqNum}, 5000) of
true -> true;
ok -> true;
_ -> false
end
catch
exit:_ -> false
end;
verify_heartbeat_ack(_, _) ->
false.
-spec handle_identify(map(), binary(), state()) -> ws_result().
handle_identify(Data, PeerIP, State) ->
case validate_identify_data(Data) of
{ok, Token, Properties, Presence, IgnoredEvents, Flags, InitialGuildId} ->
SessionId = utils:generate_session_id(),
SocketPid = self(),
IdentifyData0 = #{
token => Token,
properties => Properties,
presence => Presence,
ignored_events => IgnoredEvents,
flags => Flags
},
IdentifyData = add_initial_guild_id(IdentifyData0, InitialGuildId),
Request = #{
session_id => SessionId,
peer_ip => PeerIP,
identify_data => IdentifyData,
version => maps:get(version, State)
},
start_session(Request, SocketPid, State);
{error, _Reason} ->
close_with_reason(decode_error, <<"Invalid identify payload">>, State)
end.
-spec add_initial_guild_id(map(), integer() | undefined) -> map().
add_initial_guild_id(Data, undefined) -> Data;
add_initial_guild_id(Data, Id) -> maps:put(initial_guild_id, Id, Data).
-spec start_session(map(), pid(), state()) -> ws_result().
start_session(Request, SocketPid, State) ->
case session_manager:start(Request, SocketPid) of
{success, Pid} when is_pid(Pid) ->
monitor(process, Pid),
{ok, State#{session_pid => Pid}};
{error, invalid_token} ->
close_with_reason(authentication_failed, <<"Invalid token">>, State);
{error, rate_limited} ->
close_with_reason(rate_limited, <<"Rate limited">>, State);
{error, identify_rate_limited} ->
gateway_metrics_collector:inc_identify_rate_limited(),
send_invalid_session(State);
_ ->
close_with_reason(unknown_error, <<"Failed to start session">>, State)
end.
-spec send_invalid_session(state()) -> ws_result().
send_invalid_session(State) ->
Message = #{<<"op">> => constants:opcode_to_num(invalid_session), <<"d">> => false},
case encode_and_compress(Message, State) of
{ok, Frame, NewState} -> {[Frame], NewState};
{error, _} -> {ok, State}
end.
-spec handle_presence_update(map(), pid(), state()) -> ws_result().
handle_presence_update(Data, Pid, State) ->
Status = utils:parse_status(maps:get(<<"status">>, Data)),
AdjustedStatus = adjust_status(Status),
Afk = maps:get(<<"afk">>, Data, false),
Mobile = maps:get(<<"mobile">>, Data, false),
gen_server:cast(
Pid, {presence_update, #{status => AdjustedStatus, afk => Afk, mobile => Mobile}}
),
{ok, State}.
-spec adjust_status(atom()) -> atom().
adjust_status(offline) -> invisible;
adjust_status(Other) -> Other.
-spec handle_resume(map(), state()) -> ws_result().
handle_resume(Data, State) ->
Token = maps:get(<<"token">>, Data),
SessionId = maps:get(<<"session_id">>, Data),
Seq = maps:get(<<"seq">>, Data),
case is_binary(SessionId) of
false ->
handle_resume_session_not_found(State);
true ->
case session_manager:lookup(SessionId) of
{ok, Pid} when is_pid(Pid) ->
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
{error, _} ->
handle_resume_session_not_found(State)
end
end.
-spec handle_voice_state_update(pid(), map(), state()) -> ws_result().
handle_voice_state_update(Pid, Data, State) ->
case should_queue_voice_update(Pid) of
false ->
process_voice_update(Pid, Data, State);
true ->
queue_voice_update(Pid, Data),
NewState = ensure_voice_queue_timer(State),
{ok, NewState}
end.
-spec handle_request_guild_members(map(), pid(), state()) -> ws_result().
handle_request_guild_members(Data, Pid, State) ->
SocketPid = self(),
spawn(fun() ->
try
case gen_server:call(Pid, {get_state}, 5000) of
SessionState when is_map(SessionState) ->
guild_request_members:handle_request(Data, SocketPid, SessionState);
_ ->
ok
end
catch
_:_ -> ok
end
end),
{ok, State}.
-spec handle_lazy_request(map(), pid(), state()) -> ws_result().
handle_lazy_request(Data, Pid, State) ->
SocketPid = self(),
spawn(fun() ->
try
case gen_server:call(Pid, {get_state}, 5000) of
SessionState when is_map(SessionState) ->
guild_unified_subscriptions:handle_subscriptions(Data, SocketPid, SessionState);
_ ->
ok
end
catch
_:_ -> ok
end
end),
{ok, State}.
-spec schedule_heartbeat_check() -> reference().
schedule_heartbeat_check() ->
erlang:send_after(constants:heartbeat_interval() div 3, self(), {heartbeat_check}).
-spec check_rate_limit(state()) -> {ok, state()} | rate_limited.
check_rate_limit(State = #{rate_limit_state := RateLimitState}) ->
Now = erlang:system_time(millisecond),
Events = maps:get(events, RateLimitState, []),
WindowStart = maps:get(window_start, RateLimitState, Now),
WindowDuration = 60000,
MaxEvents = 120,
EventsInWindow = [T || T <- Events, (Now - T) < WindowDuration],
case length(EventsInWindow) >= MaxEvents of
true ->
rate_limited;
false ->
NewEvents = [Now | EventsInWindow],
NewRateLimitState = #{events => NewEvents, window_start => WindowStart},
{ok, State#{rate_limit_state => NewRateLimitState}}
end.
-spec extract_client_ip(cowboy_req:req()) -> binary().
extract_client_ip(Req) ->
case cowboy_req:header(<<"x-forwarded-for">>, Req) of
undefined ->
peer_ip_to_binary(cowboy_req:peer(Req));
ForwardedFor ->
case parse_forwarded_for(ForwardedFor) of
<<>> -> peer_ip_to_binary(cowboy_req:peer(Req));
IP -> IP
end
end.
-spec peer_ip_to_binary({inet:ip_address(), inet:port_number()}) -> binary().
peer_ip_to_binary({PeerIP, _Port}) ->
list_to_binary(inet:ntoa(PeerIP)).
-spec parse_forwarded_for(binary()) -> binary().
parse_forwarded_for(HeaderValue) ->
case binary:split(HeaderValue, <<",">>) of
[First | _] ->
case normalize_forwarded_ip(First) of
{ok, IP} -> IP;
error -> <<>>
end;
[] ->
<<>>
end.
-spec normalize_forwarded_ip(binary()) -> {ok, binary()} | error.
normalize_forwarded_ip(Value) ->
Trimmed = string:trim(Value),
case Trimmed of
<<>> ->
error;
<<"[", _/binary>> ->
case strip_ipv6_brackets(Trimmed) of
{ok, IPv6} -> validate_ip(IPv6);
error -> error
end;
_ ->
Cleaned = strip_ipv4_port(Trimmed),
validate_ip(Cleaned)
end.
-spec strip_ipv6_brackets(binary()) -> {ok, binary()} | error.
strip_ipv6_brackets(<<"[", Rest/binary>>) ->
case binary:match(Rest, <<"]">>) of
{Pos, _Len} when Pos > 0 -> {ok, binary:part(Rest, 0, Pos)};
_ -> error
end;
strip_ipv6_brackets(_) ->
error.
-spec strip_ipv4_port(binary()) -> binary().
strip_ipv4_port(IP) ->
case binary:match(IP, <<".">>) of
nomatch ->
IP;
_ ->
case binary:split(IP, <<":">>, [global]) of
[Addr, _Port] -> Addr;
_ -> IP
end
end.
-spec validate_ip(binary()) -> {ok, binary()} | error.
validate_ip(IP) ->
case inet:parse_address(binary_to_list(IP)) of
{ok, Parsed} -> {ok, list_to_binary(inet:ntoa(Parsed))};
{error, _} -> error
end.
-spec handle_resume_with_session(pid(), binary(), binary(), integer(), state()) -> ws_result().
handle_resume_with_session(Pid, Token, SessionId, Seq, State) ->
case gen_server:call(Pid, {token_verify, Token}, 5000) of
true -> handle_resume_with_verified_token(Pid, SessionId, Seq, State);
false -> handle_resume_invalid_token(State)
end.
-spec handle_resume_with_verified_token(pid(), binary(), integer(), state()) -> ws_result().
handle_resume_with_verified_token(Pid, SessionId, Seq, State) ->
SocketPid = self(),
case gen_server:call(Pid, {resume, Seq, SocketPid}, 5000) of
{ok, MissedEvents} when is_list(MissedEvents) ->
handle_resume_success(Pid, SessionId, Seq, MissedEvents, State);
invalid_seq ->
handle_resume_invalid_seq(State)
end.
-spec handle_resume_success(pid(), binary(), integer(), [map()], state()) -> ws_result().
handle_resume_success(Pid, _SessionId, Seq, MissedEvents, State) ->
gateway_metrics_collector:inc_resume_success(),
SocketPid = self(),
monitor(process, Pid),
lists:foreach(
fun(Event) when is_map(Event) ->
SocketPid !
{dispatch, maps:get(event, Event), maps:get(data, Event), maps:get(seq, Event)}
end,
MissedEvents
),
SocketPid ! {dispatch, resumed, null, Seq},
{ok, State#{
session_pid => Pid,
heartbeat_state => #{
last_ack => erlang:system_time(millisecond),
waiting_for_ack => false
}
}}.
-spec handle_resume_invalid_seq(state()) -> ws_result().
handle_resume_invalid_seq(State) ->
gateway_metrics_collector:inc_resume_failure(),
close_with_reason(invalid_seq, <<"Invalid sequence">>, State).
-spec handle_resume_invalid_token(state()) -> ws_result().
handle_resume_invalid_token(State) ->
gateway_metrics_collector:inc_resume_failure(),
close_with_reason(authentication_failed, <<"Invalid token">>, State).
-spec handle_resume_session_not_found(state()) -> ws_result().
handle_resume_session_not_found(State) ->
gateway_metrics_collector:inc_resume_failure(),
send_invalid_session(State).
-spec encode_and_compress(map(), state()) -> {ok, ws_frame(), state()} | {error, term()}.
encode_and_compress(Message, State = #{encoding := Encoding, compress_ctx := CompressCtx}) ->
case gateway_codec:encode(Message, Encoding) of
{ok, Encoded, FrameType} ->
case gateway_compress:compress(Encoded, CompressCtx) of
{ok, Compressed, NewCompressCtx} ->
Frame = make_frame(Compressed, FrameType, NewCompressCtx),
{ok, Frame, State#{compress_ctx => NewCompressCtx}};
{error, Reason} ->
{error, {compress_failed, gateway_compress:get_type(CompressCtx), Reason}}
end;
{error, Reason} ->
{error, {encode_failed, Reason}}
end.
-spec compression_error_reason(atom()) -> binary().
compression_error_reason(zstd_stream) -> <<"Compression failed: zstd-stream">>;
compression_error_reason(_) -> <<"Encode failed">>.
-spec close_with_reason(atom(), binary(), state()) -> ws_result().
close_with_reason(Reason, Message, State) ->
gateway_metrics_collector:inc_websocket_close(Reason),
CloseCode = constants:close_code_to_num(Reason),
{[{close, CloseCode, Message}], State}.
-spec make_frame(binary(), text | binary, gateway_compress:compress_ctx()) -> ws_frame().
make_frame(Data, FrameType, CompressCtx) ->
case gateway_compress:get_type(CompressCtx) of
none -> {FrameType, Data};
_ -> {binary, Data}
end.
-spec start_websocket_connect_span(1 | undefined, state()) -> term().
start_websocket_connect_span(Version, #{peer_ip := PeerIP}) ->
gateway_tracing:start_connection_span(?MODULE, Version, PeerIP).
-spec end_websocket_disconnect_span(term(), term()) -> ok.
end_websocket_disconnect_span(_Reason, undefined) ->
ok;
end_websocket_disconnect_span(Reason, Ctx) ->
ReasonBin = reason_to_binary(Reason),
gateway_tracing:end_connection_span(Ctx, ReasonBin).
-spec reason_to_binary(term()) -> binary().
reason_to_binary(normal) -> <<"normal">>;
reason_to_binary(remote) -> <<"remote">>;
reason_to_binary(Atom) when is_atom(Atom) -> atom_to_binary(Atom, utf8);
reason_to_binary(_) -> <<"unknown">>.
-spec should_queue_voice_update(pid()) -> boolean().
should_queue_voice_update(SessionPid) ->
ensure_voice_rate_limit_table(),
Now = erlang:system_time(millisecond),
case ets:lookup(?VOICE_RATE_LIMIT_TABLE, SessionPid) of
[] ->
ets:insert(?VOICE_RATE_LIMIT_TABLE, {SessionPid, [Now]}),
false;
[{SessionPid, Timestamps}] ->
FilteredTimestamps = [T || T <- Timestamps, (Now - T) < ?VOICE_RATE_LIMIT_WINDOW],
Count = length(FilteredTimestamps),
case Count >= ?VOICE_UPDATE_RATE_LIMIT of
true ->
true;
false ->
NewTimestamps = [Now | FilteredTimestamps],
ets:insert(?VOICE_RATE_LIMIT_TABLE, {SessionPid, NewTimestamps}),
false
end
end.
-spec process_voice_update(pid(), map(), state()) -> ws_result().
process_voice_update(SessionPid, Data, State) ->
try
gen_server:call(
SessionPid,
{voice_state_update, Data},
5000
),
{ok, State}
catch
exit:{timeout, _} ->
{ok, State};
exit:{noproc, _} ->
{ok, State}
end.
-spec queue_voice_update(pid(), map()) -> ok.
queue_voice_update(SessionPid, Data) ->
ensure_voice_queue_table(),
case ets:lookup(?VOICE_QUEUE_TABLE, SessionPid) of
[] ->
Queue = queue:in(Data, queue:new()),
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, Queue});
[{SessionPid, Queue}] ->
TrimmedQueue = trim_voice_queue(Queue),
NewQueue = queue:in(Data, TrimmedQueue),
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, NewQueue})
end,
ok.
-spec trim_voice_queue(queue:queue()) -> queue:queue().
trim_voice_queue(Queue) ->
case queue:len(Queue) >= ?MAX_VOICE_QUEUE_LENGTH of
false ->
Queue;
true ->
case queue:out(Queue) of
{empty, EmptyQueue} ->
EmptyQueue;
{{value, _Dropped}, RemainingQueue} ->
RemainingQueue
end
end.
-spec ensure_voice_queue_timer(state()) -> state().
ensure_voice_queue_timer(State = #{voice_queue_timer := undefined}) ->
Timer = erlang:send_after(?VOICE_QUEUE_PROCESS_INTERVAL, self(), {process_voice_queue}),
State#{voice_queue_timer => Timer};
ensure_voice_queue_timer(State) ->
State.
-spec process_queued_voice_updates(state()) -> state().
process_queued_voice_updates(State = #{session_pid := SessionPid}) when is_pid(SessionPid) ->
ensure_voice_queue_table(),
case ets:lookup(?VOICE_QUEUE_TABLE, SessionPid) of
[] ->
State;
[{SessionPid, Queue}] ->
process_queue_item(Queue, SessionPid, State)
end;
process_queued_voice_updates(State) ->
State.
-spec process_queue_item(queue:queue(), pid(), state()) -> state().
process_queue_item(Queue, SessionPid, State) ->
case queue:out(Queue) of
{empty, _} ->
ets:delete(?VOICE_QUEUE_TABLE, SessionPid),
State;
{{value, Data}, NewQueue} ->
case should_queue_voice_update(SessionPid) of
false ->
process_voice_update(SessionPid, Data, State),
case queue:is_empty(NewQueue) of
true ->
ets:delete(?VOICE_QUEUE_TABLE, SessionPid),
State;
false ->
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, NewQueue}),
ensure_voice_queue_timer(State)
end;
true ->
ensure_voice_queue_timer(State)
end
end.
-spec ensure_voice_queue_table() -> ok.
ensure_voice_queue_table() ->
case ets:whereis(?VOICE_QUEUE_TABLE) of
undefined ->
try
ets:new(?VOICE_QUEUE_TABLE, [named_table, public, set]),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-spec ensure_voice_rate_limit_table() -> ok.
ensure_voice_rate_limit_table() ->
case ets:whereis(?VOICE_RATE_LIMIT_TABLE) of
undefined ->
try
ets:new(?VOICE_RATE_LIMIT_TABLE, [named_table, public, set]),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
parse_forwarded_for_ipv4_test() ->
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(<<"203.0.113.7">>)).
parse_forwarded_for_ipv4_with_port_test() ->
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(<<"203.0.113.7:8080">>)).
parse_forwarded_for_ipv4_with_port_and_extra_entries_test() ->
Header = <<" 203.0.113.7:8080 , 10.0.0.1">>,
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(Header)).
parse_forwarded_for_ipv6_test() ->
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"2001:db8::1">>)).
parse_forwarded_for_ipv6_with_brackets_test() ->
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"[2001:db8::1]">>)).
parse_forwarded_for_ipv6_with_brackets_and_port_test() ->
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"[2001:db8::1]:443">>)).
parse_forwarded_for_ipv6_with_spaces_test() ->
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<" [2001:db8::1] ">>)).
parse_forwarded_for_invalid_ip_test() ->
?assertEqual(<<>>, parse_forwarded_for(<<"not_an_ip">>)).
parse_forwarded_for_invalid_ipv4_octet_test() ->
?assertEqual(<<>>, parse_forwarded_for(<<"203.0.113.300">>)).
parse_forwarded_for_unterminated_bracket_test() ->
?assertEqual(<<>>, parse_forwarded_for(<<"[2001:db8::1">>)).
parse_version_test() ->
?assertEqual(1, parse_version(<<"1">>)),
?assertEqual(undefined, parse_version(<<"2">>)),
?assertEqual(undefined, parse_version(undefined)).
parse_ignored_events_test() ->
?assertEqual({ok, []}, parse_ignored_events(undefined)),
?assertEqual({ok, []}, parse_ignored_events(null)),
?assertEqual({ok, [<<"TYPING_START">>]}, parse_ignored_events([<<"typing_start">>])),
?assertEqual({error, invalid_ignored_events}, parse_ignored_events([123])),
?assertEqual({error, invalid_ignored_events}, parse_ignored_events(<<"not_a_list">>)).
adjust_status_test() ->
?assertEqual(invisible, adjust_status(offline)),
?assertEqual(online, adjust_status(online)),
?assertEqual(idle, adjust_status(idle)).
-endif.