958 lines
36 KiB
Erlang
958 lines
36 KiB
Erlang
%% Copyright (C) 2026 Fluxer Contributors
|
|
%%
|
|
%% This file is part of Fluxer.
|
|
%%
|
|
%% Fluxer is free software: you can redistribute it and/or modify
|
|
%% it under the terms of the GNU Affero General Public License as published by
|
|
%% the Free Software Foundation, either version 3 of the License, or
|
|
%% (at your option) any later version.
|
|
%%
|
|
%% Fluxer is distributed in the hope that it will be useful,
|
|
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
%% GNU Affero General Public License for more details.
|
|
%%
|
|
%% You should have received a copy of the GNU Affero General Public License
|
|
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
-module(gateway_handler).
|
|
-behaviour(cowboy_websocket).
|
|
|
|
-export([init/2, websocket_init/1, websocket_handle/2, websocket_info/2, terminate/3]).
|
|
|
|
-ifdef(TEST).
|
|
-include_lib("eunit/include/eunit.hrl").
|
|
-endif.
|
|
|
|
-define(VOICE_UPDATE_RATE_LIMIT, 10).
|
|
-define(VOICE_RATE_LIMIT_WINDOW, 1000).
|
|
-define(VOICE_QUEUE_TABLE, voice_update_queue).
|
|
-define(VOICE_RATE_LIMIT_TABLE, voice_update_rate_limit).
|
|
-define(VOICE_QUEUE_PROCESS_INTERVAL, 100).
|
|
-define(MAX_VOICE_QUEUE_LENGTH, 64).
|
|
|
|
-type state() :: #{
|
|
version := 1 | undefined,
|
|
encoding := gateway_codec:encoding(),
|
|
compress_ctx := gateway_compress:compress_ctx() | undefined,
|
|
session_pid := pid() | undefined,
|
|
heartbeat_state := map(),
|
|
socket_pid := pid() | undefined,
|
|
peer_ip := binary() | undefined,
|
|
rate_limit_state := map(),
|
|
otel_span_ctx := term(),
|
|
voice_queue_timer := reference() | undefined
|
|
}.
|
|
|
|
-spec new_state() -> state().
|
|
new_state() ->
|
|
#{
|
|
version => undefined,
|
|
encoding => json,
|
|
compress_ctx => undefined,
|
|
session_pid => undefined,
|
|
heartbeat_state => #{},
|
|
socket_pid => undefined,
|
|
peer_ip => undefined,
|
|
rate_limit_state => #{events => [], window_start => undefined},
|
|
otel_span_ctx => undefined,
|
|
voice_queue_timer => undefined
|
|
}.
|
|
|
|
-type ws_frame() :: {text, binary()} | {binary, binary()}.
|
|
-type ws_result() :: {ok, state()} | {[ws_frame() | {close, integer(), binary()}], state()}.
|
|
|
|
-spec init(cowboy_req:req(), term()) -> {cowboy_websocket, cowboy_req:req(), state()}.
|
|
init(Req, _Opts) ->
|
|
QS = cowboy_req:parse_qs(Req),
|
|
Version = parse_version(proplists:get_value(<<"v">>, QS)),
|
|
Encoding = gateway_codec:parse_encoding(proplists:get_value(<<"encoding">>, QS)),
|
|
Compression = gateway_compress:parse_compression(proplists:get_value(<<"compress">>, QS)),
|
|
CompressCtx = gateway_compress:new_context(Compression),
|
|
PeerIPBinary = extract_client_ip(Req),
|
|
State = new_state(),
|
|
{cowboy_websocket, Req, State#{
|
|
version => Version,
|
|
encoding => Encoding,
|
|
compress_ctx => CompressCtx,
|
|
socket_pid => self(),
|
|
peer_ip => PeerIPBinary
|
|
}}.
|
|
|
|
-spec parse_version(binary() | undefined) -> 1 | undefined.
|
|
parse_version(<<"1">>) -> 1;
|
|
parse_version(_) -> undefined.
|
|
|
|
-spec websocket_init(state()) -> ws_result().
|
|
websocket_init(State = #{version := 1}) ->
|
|
gateway_metrics_collector:inc_connections(),
|
|
ConnSpanCtx = start_websocket_connect_span(1, State),
|
|
NewState = State#{otel_span_ctx => ConnSpanCtx},
|
|
CompressionType = gateway_compress:get_type(maps:get(compress_ctx, NewState)),
|
|
FreshCompressCtx = gateway_compress:new_context(CompressionType),
|
|
FreshState0 = NewState#{compress_ctx => FreshCompressCtx},
|
|
HeartbeatInterval = constants:heartbeat_interval(),
|
|
HelloMessage = #{
|
|
<<"op">> => constants:opcode_to_num(hello),
|
|
<<"d">> => #{<<"heartbeat_interval">> => HeartbeatInterval}
|
|
},
|
|
schedule_heartbeat_check(),
|
|
NewState1 = FreshState0#{
|
|
heartbeat_state => #{
|
|
last_ack => erlang:system_time(millisecond),
|
|
waiting_for_ack => false
|
|
}
|
|
},
|
|
case encode_and_compress(HelloMessage, NewState1) of
|
|
{ok, Frame, NewState2} ->
|
|
{[Frame], NewState2};
|
|
{error, {compress_failed, CT, _Reason}} ->
|
|
close_with_reason(decode_error, compression_error_reason(CT), FreshState0);
|
|
{error, _Reason} ->
|
|
close_with_reason(decode_error, <<"Encode failed">>, FreshState0)
|
|
end;
|
|
websocket_init(State) ->
|
|
gateway_metrics_collector:inc_connections(),
|
|
ConnSpanCtx = start_websocket_connect_span(undefined, State),
|
|
NewState = State#{otel_span_ctx => ConnSpanCtx},
|
|
close_with_reason(invalid_api_version, <<"Invalid API version">>, NewState).
|
|
|
|
-spec websocket_handle({text, binary()} | {binary, binary()} | term(), state()) -> ws_result().
|
|
websocket_handle({text, Text}, State) ->
|
|
handle_incoming_data(Text, State);
|
|
websocket_handle({binary, Binary}, State) ->
|
|
handle_incoming_data(Binary, State);
|
|
websocket_handle(_, State) ->
|
|
{ok, State}.
|
|
|
|
-spec handle_incoming_data(binary(), state()) -> ws_result().
|
|
handle_incoming_data(Data, State = #{encoding := Encoding, compress_ctx := CompressCtx}) ->
|
|
MaxSize = constants:max_payload_size(),
|
|
case byte_size(Data) =< MaxSize of
|
|
true ->
|
|
handle_decode(gateway_codec:decode(Data, Encoding), State#{
|
|
compress_ctx => CompressCtx
|
|
});
|
|
false ->
|
|
close_with_reason(decode_error, <<"Payload too large">>, State)
|
|
end.
|
|
|
|
-spec handle_decode({ok, map()} | {error, term()}, state()) -> ws_result().
|
|
handle_decode({ok, #{<<"op">> := Op} = Payload}, State) ->
|
|
OpAtom = constants:gateway_opcode(Op),
|
|
case check_rate_limit(State) of
|
|
{ok, RateLimitedState} ->
|
|
handle_gateway_payload(OpAtom, Payload, RateLimitedState);
|
|
rate_limited ->
|
|
close_with_reason(rate_limited, <<"Rate limited">>, State)
|
|
end;
|
|
handle_decode({ok, _}, State) ->
|
|
close_with_reason(decode_error, <<"Invalid payload">>, State);
|
|
handle_decode({error, _Reason}, State) ->
|
|
close_with_reason(decode_error, <<"Decode failed">>, State).
|
|
|
|
-spec websocket_info(term(), state()) -> ws_result().
|
|
websocket_info({heartbeat_check}, State = #{heartbeat_state := HeartbeatState}) ->
|
|
handle_heartbeat_check(State, HeartbeatState);
|
|
websocket_info({dispatch, Event, Data, Seq}, State) ->
|
|
handle_dispatch(Event, Data, Seq, State);
|
|
websocket_info({session_backpressure_error, Details}, State) ->
|
|
handle_session_backpressure_error(Details, State);
|
|
websocket_info({'DOWN', _, process, Pid, _}, State = #{session_pid := Pid}) ->
|
|
handle_session_down(State);
|
|
websocket_info({process_voice_queue}, State) ->
|
|
NewState = process_queued_voice_updates(State#{voice_queue_timer => undefined}),
|
|
{ok, NewState};
|
|
websocket_info(_, State) ->
|
|
{ok, State}.
|
|
|
|
-spec handle_heartbeat_check(state(), map()) -> ws_result().
|
|
handle_heartbeat_check(State = #{heartbeat_state := HeartbeatState}, _) ->
|
|
Now = erlang:system_time(millisecond),
|
|
LastAck = maps:get(last_ack, HeartbeatState, Now),
|
|
WaitingForAck = maps:get(waiting_for_ack, HeartbeatState, false),
|
|
HeartbeatTimeout = constants:heartbeat_timeout(),
|
|
HeartbeatInterval = constants:heartbeat_interval(),
|
|
handle_heartbeat_state(State, Now, LastAck, WaitingForAck, HeartbeatTimeout, HeartbeatInterval).
|
|
|
|
-spec handle_heartbeat_state(state(), integer(), integer(), boolean(), integer(), integer()) ->
|
|
ws_result().
|
|
handle_heartbeat_state(State, Now, LastAck, true, HeartbeatTimeout, _) when
|
|
(Now - LastAck) > HeartbeatTimeout
|
|
->
|
|
gateway_metrics_collector:inc_heartbeat_failure(),
|
|
close_with_reason(session_timeout, <<"Heartbeat timeout">>, State);
|
|
handle_heartbeat_state(
|
|
State = #{heartbeat_state := HeartbeatState}, Now, LastAck, _, _, HeartbeatInterval
|
|
) when (Now - LastAck) >= (HeartbeatInterval * 0.9) ->
|
|
Message = #{<<"op">> => constants:opcode_to_num(heartbeat), <<"d">> => null},
|
|
schedule_heartbeat_check(),
|
|
NewState = State#{heartbeat_state => HeartbeatState#{waiting_for_ack => true}},
|
|
case encode_and_compress(Message, NewState) of
|
|
{ok, Frame, NewState2} -> {[Frame], NewState2};
|
|
{error, _} -> {ok, NewState}
|
|
end;
|
|
handle_heartbeat_state(State, _, _, _, _, _) ->
|
|
schedule_heartbeat_check(),
|
|
{ok, State}.
|
|
|
|
-spec handle_dispatch(atom() | binary(), map() | null, integer(), state()) -> ws_result().
|
|
handle_dispatch(Event, Data, Seq, State) ->
|
|
EventName = dispatch_event_name(Event),
|
|
Message = #{
|
|
<<"op">> => constants:opcode_to_num(dispatch),
|
|
<<"t">> => EventName,
|
|
<<"d">> => Data,
|
|
<<"s">> => Seq
|
|
},
|
|
case encode_and_compress(Message, State) of
|
|
{ok, Frame, NewState} -> {[Frame], NewState};
|
|
{error, _Reason} -> {ok, State}
|
|
end.
|
|
|
|
-spec dispatch_event_name(atom() | binary()) -> binary().
|
|
dispatch_event_name(Event) when is_binary(Event) -> Event;
|
|
dispatch_event_name(Event) when is_atom(Event) -> constants:dispatch_event_atom(Event);
|
|
dispatch_event_name(_) -> <<"UNKNOWN">>.
|
|
|
|
-spec handle_session_down(state()) -> ws_result().
|
|
handle_session_down(State) ->
|
|
Message = #{<<"op">> => constants:opcode_to_num(invalid_session), <<"d">> => false},
|
|
NewState = State#{session_pid => undefined},
|
|
case encode_and_compress(Message, NewState) of
|
|
{ok, Frame, NewState2} -> {[Frame], NewState2};
|
|
{error, _} -> {ok, NewState}
|
|
end.
|
|
|
|
-spec handle_session_backpressure_error(map(), state()) -> ws_result().
|
|
handle_session_backpressure_error(Details, State) ->
|
|
Message = format_backpressure_close_message(Details),
|
|
close_with_reason(ack_backpressure, Message, State).
|
|
|
|
-spec format_backpressure_close_message(map()) -> binary().
|
|
format_backpressure_close_message(Details) ->
|
|
Kind = detail_value_to_binary(map_utils:get_safe(Details, kind, <<"unknown">>)),
|
|
Unacked = map_utils:get_safe(Details, unacked_events, 0),
|
|
CurrentSize = map_utils:get_safe(Details, current_size, 0),
|
|
Limit = map_utils:get_safe(Details, limit, 0),
|
|
Seq = map_utils:get_safe(Details, seq, 0),
|
|
AckSeq = map_utils:get_safe(Details, ack_seq, 0),
|
|
<<
|
|
"Acknowledgement backlog exceeded: kind=",
|
|
Kind/binary,
|
|
" unacked=",
|
|
(integer_to_binary(Unacked))/binary,
|
|
" current=",
|
|
(integer_to_binary(CurrentSize))/binary,
|
|
" limit=",
|
|
(integer_to_binary(Limit))/binary,
|
|
" seq=",
|
|
(integer_to_binary(Seq))/binary,
|
|
" ack_seq=",
|
|
(integer_to_binary(AckSeq))/binary
|
|
>>.
|
|
|
|
-spec detail_value_to_binary(term()) -> binary().
|
|
detail_value_to_binary(Value) when is_binary(Value) ->
|
|
Value;
|
|
detail_value_to_binary(Value) when is_atom(Value) ->
|
|
atom_to_binary(Value, utf8);
|
|
detail_value_to_binary(_) ->
|
|
<<"unknown">>.
|
|
|
|
-spec terminate(term(), cowboy_req:req(), state() | term()) -> ok.
|
|
terminate(Reason, _Req, #{compress_ctx := CompressCtx, otel_span_ctx := SpanCtx}) ->
|
|
gateway_metrics_collector:inc_disconnections(),
|
|
end_websocket_disconnect_span(Reason, SpanCtx),
|
|
gateway_compress:close_context(CompressCtx),
|
|
ok;
|
|
terminate(_Reason, _Req, _State) ->
|
|
gateway_metrics_collector:inc_disconnections(),
|
|
ok.
|
|
|
|
-spec validate_identify_data(map()) ->
|
|
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
|
|
| {error, atom()}.
|
|
validate_identify_data(Data) ->
|
|
try
|
|
Token = maps:get(<<"token">>, Data),
|
|
Properties = maps:get(<<"properties">>, Data),
|
|
IgnoredEventsRaw = maps:get(<<"ignored_events">>, Data, []),
|
|
InitialGuildIdRaw = maps:get(<<"initial_guild_id">>, Data, undefined),
|
|
validate_properties(Token, Properties, IgnoredEventsRaw, InitialGuildIdRaw, Data)
|
|
catch
|
|
error:{badkey, _} -> {error, missing_required_field}
|
|
end.
|
|
|
|
-spec validate_properties(binary(), term(), term(), term(), map()) ->
|
|
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
|
|
| {error, atom()}.
|
|
validate_properties(Token, Properties, IgnoredEventsRaw, InitialGuildIdRaw, Data) when
|
|
is_map(Properties)
|
|
->
|
|
Os = maps:get(<<"os">>, Properties),
|
|
Browser = maps:get(<<"browser">>, Properties),
|
|
Device = maps:get(<<"device">>, Properties),
|
|
case {is_binary(Os), is_binary(Browser), is_binary(Device)} of
|
|
{true, true, true} ->
|
|
Presence = maps:get(<<"presence">>, Data, null),
|
|
validate_ignored_events(
|
|
Token, Properties, Presence, IgnoredEventsRaw, InitialGuildIdRaw, Data
|
|
);
|
|
_ ->
|
|
{error, invalid_properties}
|
|
end;
|
|
validate_properties(_, _, _, _, _) ->
|
|
{error, invalid_properties}.
|
|
|
|
-spec validate_ignored_events(binary(), map(), term(), term(), term(), map()) ->
|
|
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
|
|
| {error, atom()}.
|
|
validate_ignored_events(Token, Properties, Presence, IgnoredEventsRaw, InitialGuildIdRaw, Data) ->
|
|
case parse_ignored_events(IgnoredEventsRaw) of
|
|
{ok, IgnoredEvents} ->
|
|
FlagsRaw = maps:get(<<"flags">>, Data, 0),
|
|
validate_flags(Token, Properties, Presence, IgnoredEvents, InitialGuildIdRaw, FlagsRaw);
|
|
{error, Reason} ->
|
|
{error, Reason}
|
|
end.
|
|
|
|
-spec validate_flags(binary(), map(), term(), [binary()], term(), term()) ->
|
|
{ok, binary(), map(), term(), [binary()], non_neg_integer(), integer() | undefined}
|
|
| {error, atom()}.
|
|
validate_flags(Token, Properties, Presence, IgnoredEvents, InitialGuildIdRaw, Flags) when
|
|
is_integer(Flags), Flags >= 0
|
|
->
|
|
{ok, Token, Properties, Presence, IgnoredEvents, Flags,
|
|
parse_initial_guild_id(InitialGuildIdRaw)};
|
|
validate_flags(_, _, _, _, _, _) ->
|
|
{error, invalid_properties}.
|
|
|
|
-spec parse_ignored_events(term()) -> {ok, [binary()]} | {error, invalid_ignored_events}.
|
|
parse_ignored_events(undefined) ->
|
|
{ok, []};
|
|
parse_ignored_events(null) ->
|
|
{ok, []};
|
|
parse_ignored_events(Events) when is_list(Events) ->
|
|
case lists:all(fun erlang:is_binary/1, Events) of
|
|
true -> {ok, lists:usort([normalize_event_name(E) || E <- Events])};
|
|
false -> {error, invalid_ignored_events}
|
|
end;
|
|
parse_ignored_events(_) ->
|
|
{error, invalid_ignored_events}.
|
|
|
|
-spec parse_initial_guild_id(term()) -> integer() | undefined.
|
|
parse_initial_guild_id(undefined) ->
|
|
undefined;
|
|
parse_initial_guild_id(null) ->
|
|
undefined;
|
|
parse_initial_guild_id(Value) when is_binary(Value) ->
|
|
case validation:validate_snowflake(<<"initial_guild_id">>, Value) of
|
|
{ok, GuildId} -> GuildId;
|
|
{error, _, _} -> undefined
|
|
end;
|
|
parse_initial_guild_id(_) ->
|
|
undefined.
|
|
|
|
-spec normalize_event_name(binary()) -> binary().
|
|
normalize_event_name(Event) ->
|
|
list_to_binary(string:uppercase(binary_to_list(Event))).
|
|
|
|
-spec handle_gateway_payload(atom(), map(), state()) -> ws_result().
|
|
handle_gateway_payload(
|
|
heartbeat,
|
|
#{<<"d">> := Seq},
|
|
State = #{heartbeat_state := HeartbeatState, session_pid := SessionPid}
|
|
) ->
|
|
handle_heartbeat(Seq, SessionPid, State, HeartbeatState);
|
|
handle_gateway_payload(
|
|
identify, #{<<"d">> := Data}, State = #{session_pid := undefined, peer_ip := PeerIP}
|
|
) ->
|
|
handle_identify(Data, PeerIP, State);
|
|
handle_gateway_payload(identify, _, State) ->
|
|
close_with_reason(already_authenticated, <<"Already authenticated">>, State);
|
|
handle_gateway_payload(presence_update, #{<<"d">> := _}, State = #{session_pid := undefined}) ->
|
|
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
|
|
handle_gateway_payload(presence_update, #{<<"d">> := Data}, State = #{session_pid := Pid}) when
|
|
is_pid(Pid)
|
|
->
|
|
handle_presence_update(Data, Pid, State);
|
|
handle_gateway_payload(resume, #{<<"d">> := Data}, State) ->
|
|
handle_resume(Data, State);
|
|
handle_gateway_payload(
|
|
voice_state_update, #{<<"d">> := _}, State = #{session_pid := undefined}
|
|
) ->
|
|
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
|
|
handle_gateway_payload(
|
|
voice_state_update, #{<<"d">> := Data}, State = #{session_pid := Pid}
|
|
) when is_pid(Pid) ->
|
|
handle_voice_state_update(Pid, Data, State);
|
|
handle_gateway_payload(
|
|
request_guild_members, #{<<"d">> := _}, State = #{session_pid := undefined}
|
|
) ->
|
|
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
|
|
handle_gateway_payload(
|
|
request_guild_members, #{<<"d">> := Data}, State = #{session_pid := Pid}
|
|
) when is_pid(Pid) ->
|
|
handle_request_guild_members(Data, Pid, State);
|
|
handle_gateway_payload(lazy_request, #{<<"d">> := _}, State = #{session_pid := undefined}) ->
|
|
close_with_reason(not_authenticated, <<"Not authenticated">>, State);
|
|
handle_gateway_payload(lazy_request, #{<<"d">> := Data}, State = #{session_pid := Pid}) when
|
|
is_pid(Pid)
|
|
->
|
|
handle_lazy_request(Data, Pid, State);
|
|
handle_gateway_payload(_, _, State) ->
|
|
close_with_reason(unknown_opcode, <<"Unknown opcode">>, State).
|
|
|
|
-spec handle_heartbeat(term(), pid() | undefined, state(), map()) -> ws_result().
|
|
handle_heartbeat(Seq, SessionPid, State, HeartbeatState) ->
|
|
AckOk = verify_heartbeat_ack(Seq, SessionPid),
|
|
case AckOk of
|
|
true ->
|
|
NewHeartbeatState = HeartbeatState#{
|
|
last_ack => erlang:system_time(millisecond),
|
|
waiting_for_ack => false
|
|
},
|
|
gateway_metrics_collector:inc_heartbeat_success(),
|
|
AckMessage = #{<<"op">> => constants:opcode_to_num(heartbeat_ack)},
|
|
NewState = State#{heartbeat_state => NewHeartbeatState},
|
|
case encode_and_compress(AckMessage, NewState) of
|
|
{ok, Frame, NewState2} -> {[Frame], NewState2};
|
|
{error, _} -> {ok, NewState}
|
|
end;
|
|
false ->
|
|
gateway_metrics_collector:inc_heartbeat_failure(),
|
|
close_with_reason(invalid_seq, <<"Invalid sequence">>, State)
|
|
end.
|
|
|
|
-spec verify_heartbeat_ack(term(), pid() | undefined) -> boolean().
|
|
verify_heartbeat_ack(_, undefined) ->
|
|
true;
|
|
verify_heartbeat_ack(null, _) ->
|
|
true;
|
|
verify_heartbeat_ack(SeqNum, Pid) when is_integer(SeqNum), is_pid(Pid) ->
|
|
try
|
|
case gen_server:call(Pid, {heartbeat_ack, SeqNum}, 5000) of
|
|
true -> true;
|
|
ok -> true;
|
|
_ -> false
|
|
end
|
|
catch
|
|
exit:_ -> false
|
|
end;
|
|
verify_heartbeat_ack(_, _) ->
|
|
false.
|
|
|
|
-spec handle_identify(map(), binary(), state()) -> ws_result().
|
|
handle_identify(Data, PeerIP, State) ->
|
|
case validate_identify_data(Data) of
|
|
{ok, Token, Properties, Presence, IgnoredEvents, Flags, InitialGuildId} ->
|
|
SessionId = utils:generate_session_id(),
|
|
SocketPid = self(),
|
|
IdentifyData0 = #{
|
|
token => Token,
|
|
properties => Properties,
|
|
presence => Presence,
|
|
ignored_events => IgnoredEvents,
|
|
flags => Flags
|
|
},
|
|
IdentifyData = add_initial_guild_id(IdentifyData0, InitialGuildId),
|
|
Request = #{
|
|
session_id => SessionId,
|
|
peer_ip => PeerIP,
|
|
identify_data => IdentifyData,
|
|
version => maps:get(version, State)
|
|
},
|
|
start_session(Request, SocketPid, State);
|
|
{error, _Reason} ->
|
|
close_with_reason(decode_error, <<"Invalid identify payload">>, State)
|
|
end.
|
|
|
|
-spec add_initial_guild_id(map(), integer() | undefined) -> map().
|
|
add_initial_guild_id(Data, undefined) -> Data;
|
|
add_initial_guild_id(Data, Id) -> maps:put(initial_guild_id, Id, Data).
|
|
|
|
-spec start_session(map(), pid(), state()) -> ws_result().
|
|
start_session(Request, SocketPid, State) ->
|
|
case session_manager:start(Request, SocketPid) of
|
|
{success, Pid} when is_pid(Pid) ->
|
|
monitor(process, Pid),
|
|
{ok, State#{session_pid => Pid}};
|
|
{error, invalid_token} ->
|
|
close_with_reason(authentication_failed, <<"Invalid token">>, State);
|
|
{error, rate_limited} ->
|
|
close_with_reason(rate_limited, <<"Rate limited">>, State);
|
|
{error, identify_rate_limited} ->
|
|
gateway_metrics_collector:inc_identify_rate_limited(),
|
|
send_invalid_session(State);
|
|
_ ->
|
|
close_with_reason(unknown_error, <<"Failed to start session">>, State)
|
|
end.
|
|
|
|
-spec send_invalid_session(state()) -> ws_result().
|
|
send_invalid_session(State) ->
|
|
Message = #{<<"op">> => constants:opcode_to_num(invalid_session), <<"d">> => false},
|
|
case encode_and_compress(Message, State) of
|
|
{ok, Frame, NewState} -> {[Frame], NewState};
|
|
{error, _} -> {ok, State}
|
|
end.
|
|
|
|
-spec handle_presence_update(map(), pid(), state()) -> ws_result().
|
|
handle_presence_update(Data, Pid, State) ->
|
|
Status = utils:parse_status(maps:get(<<"status">>, Data)),
|
|
AdjustedStatus = adjust_status(Status),
|
|
Afk = maps:get(<<"afk">>, Data, false),
|
|
Mobile = maps:get(<<"mobile">>, Data, false),
|
|
gen_server:cast(
|
|
Pid, {presence_update, #{status => AdjustedStatus, afk => Afk, mobile => Mobile}}
|
|
),
|
|
{ok, State}.
|
|
|
|
-spec adjust_status(atom()) -> atom().
|
|
adjust_status(offline) -> invisible;
|
|
adjust_status(Other) -> Other.
|
|
|
|
-spec handle_resume(map(), state()) -> ws_result().
|
|
handle_resume(Data, State) ->
|
|
Token = maps:get(<<"token">>, Data),
|
|
SessionId = maps:get(<<"session_id">>, Data),
|
|
Seq = maps:get(<<"seq">>, Data),
|
|
case is_binary(SessionId) of
|
|
false ->
|
|
handle_resume_session_not_found(State);
|
|
true ->
|
|
case session_manager:lookup(SessionId) of
|
|
{ok, Pid} when is_pid(Pid) ->
|
|
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
|
|
{error, _} ->
|
|
handle_resume_session_not_found(State)
|
|
end
|
|
end.
|
|
|
|
-spec handle_voice_state_update(pid(), map(), state()) -> ws_result().
|
|
handle_voice_state_update(Pid, Data, State) ->
|
|
case should_queue_voice_update(Pid) of
|
|
false ->
|
|
process_voice_update(Pid, Data, State);
|
|
true ->
|
|
queue_voice_update(Pid, Data),
|
|
NewState = ensure_voice_queue_timer(State),
|
|
{ok, NewState}
|
|
end.
|
|
|
|
-spec handle_request_guild_members(map(), pid(), state()) -> ws_result().
|
|
handle_request_guild_members(Data, Pid, State) ->
|
|
SocketPid = self(),
|
|
spawn(fun() ->
|
|
try
|
|
case gen_server:call(Pid, {get_state}, 5000) of
|
|
SessionState when is_map(SessionState) ->
|
|
guild_request_members:handle_request(Data, SocketPid, SessionState);
|
|
_ ->
|
|
ok
|
|
end
|
|
catch
|
|
_:_ -> ok
|
|
end
|
|
end),
|
|
{ok, State}.
|
|
|
|
-spec handle_lazy_request(map(), pid(), state()) -> ws_result().
|
|
handle_lazy_request(Data, Pid, State) ->
|
|
SocketPid = self(),
|
|
spawn(fun() ->
|
|
try
|
|
case gen_server:call(Pid, {get_state}, 5000) of
|
|
SessionState when is_map(SessionState) ->
|
|
guild_unified_subscriptions:handle_subscriptions(Data, SocketPid, SessionState);
|
|
_ ->
|
|
ok
|
|
end
|
|
catch
|
|
_:_ -> ok
|
|
end
|
|
end),
|
|
{ok, State}.
|
|
|
|
-spec schedule_heartbeat_check() -> reference().
|
|
schedule_heartbeat_check() ->
|
|
erlang:send_after(constants:heartbeat_interval() div 3, self(), {heartbeat_check}).
|
|
|
|
-spec check_rate_limit(state()) -> {ok, state()} | rate_limited.
|
|
check_rate_limit(State = #{rate_limit_state := RateLimitState}) ->
|
|
Now = erlang:system_time(millisecond),
|
|
Events = maps:get(events, RateLimitState, []),
|
|
WindowStart = maps:get(window_start, RateLimitState, Now),
|
|
WindowDuration = 60000,
|
|
MaxEvents = 120,
|
|
EventsInWindow = [T || T <- Events, (Now - T) < WindowDuration],
|
|
case length(EventsInWindow) >= MaxEvents of
|
|
true ->
|
|
rate_limited;
|
|
false ->
|
|
NewEvents = [Now | EventsInWindow],
|
|
NewRateLimitState = #{events => NewEvents, window_start => WindowStart},
|
|
{ok, State#{rate_limit_state => NewRateLimitState}}
|
|
end.
|
|
|
|
-spec extract_client_ip(cowboy_req:req()) -> binary().
|
|
extract_client_ip(Req) ->
|
|
case cowboy_req:header(<<"x-forwarded-for">>, Req) of
|
|
undefined ->
|
|
peer_ip_to_binary(cowboy_req:peer(Req));
|
|
ForwardedFor ->
|
|
case parse_forwarded_for(ForwardedFor) of
|
|
<<>> -> peer_ip_to_binary(cowboy_req:peer(Req));
|
|
IP -> IP
|
|
end
|
|
end.
|
|
|
|
-spec peer_ip_to_binary({inet:ip_address(), inet:port_number()}) -> binary().
|
|
peer_ip_to_binary({PeerIP, _Port}) ->
|
|
list_to_binary(inet:ntoa(PeerIP)).
|
|
|
|
-spec parse_forwarded_for(binary()) -> binary().
|
|
parse_forwarded_for(HeaderValue) ->
|
|
case binary:split(HeaderValue, <<",">>) of
|
|
[First | _] ->
|
|
case normalize_forwarded_ip(First) of
|
|
{ok, IP} -> IP;
|
|
error -> <<>>
|
|
end;
|
|
[] ->
|
|
<<>>
|
|
end.
|
|
|
|
-spec normalize_forwarded_ip(binary()) -> {ok, binary()} | error.
|
|
normalize_forwarded_ip(Value) ->
|
|
Trimmed = string:trim(Value),
|
|
case Trimmed of
|
|
<<>> ->
|
|
error;
|
|
<<"[", _/binary>> ->
|
|
case strip_ipv6_brackets(Trimmed) of
|
|
{ok, IPv6} -> validate_ip(IPv6);
|
|
error -> error
|
|
end;
|
|
_ ->
|
|
Cleaned = strip_ipv4_port(Trimmed),
|
|
validate_ip(Cleaned)
|
|
end.
|
|
|
|
-spec strip_ipv6_brackets(binary()) -> {ok, binary()} | error.
|
|
strip_ipv6_brackets(<<"[", Rest/binary>>) ->
|
|
case binary:match(Rest, <<"]">>) of
|
|
{Pos, _Len} when Pos > 0 -> {ok, binary:part(Rest, 0, Pos)};
|
|
_ -> error
|
|
end;
|
|
strip_ipv6_brackets(_) ->
|
|
error.
|
|
|
|
-spec strip_ipv4_port(binary()) -> binary().
|
|
strip_ipv4_port(IP) ->
|
|
case binary:match(IP, <<".">>) of
|
|
nomatch ->
|
|
IP;
|
|
_ ->
|
|
case binary:split(IP, <<":">>, [global]) of
|
|
[Addr, _Port] -> Addr;
|
|
_ -> IP
|
|
end
|
|
end.
|
|
|
|
-spec validate_ip(binary()) -> {ok, binary()} | error.
|
|
validate_ip(IP) ->
|
|
case inet:parse_address(binary_to_list(IP)) of
|
|
{ok, Parsed} -> {ok, list_to_binary(inet:ntoa(Parsed))};
|
|
{error, _} -> error
|
|
end.
|
|
|
|
-spec handle_resume_with_session(pid(), binary(), binary(), integer(), state()) -> ws_result().
|
|
handle_resume_with_session(Pid, Token, SessionId, Seq, State) ->
|
|
case gen_server:call(Pid, {token_verify, Token}, 5000) of
|
|
true -> handle_resume_with_verified_token(Pid, SessionId, Seq, State);
|
|
false -> handle_resume_invalid_token(State)
|
|
end.
|
|
|
|
-spec handle_resume_with_verified_token(pid(), binary(), integer(), state()) -> ws_result().
|
|
handle_resume_with_verified_token(Pid, SessionId, Seq, State) ->
|
|
SocketPid = self(),
|
|
case gen_server:call(Pid, {resume, Seq, SocketPid}, 5000) of
|
|
{ok, MissedEvents} when is_list(MissedEvents) ->
|
|
handle_resume_success(Pid, SessionId, Seq, MissedEvents, State);
|
|
invalid_seq ->
|
|
handle_resume_invalid_seq(State)
|
|
end.
|
|
|
|
-spec handle_resume_success(pid(), binary(), integer(), [map()], state()) -> ws_result().
|
|
handle_resume_success(Pid, _SessionId, Seq, MissedEvents, State) ->
|
|
gateway_metrics_collector:inc_resume_success(),
|
|
SocketPid = self(),
|
|
monitor(process, Pid),
|
|
lists:foreach(
|
|
fun(Event) when is_map(Event) ->
|
|
SocketPid !
|
|
{dispatch, maps:get(event, Event), maps:get(data, Event), maps:get(seq, Event)}
|
|
end,
|
|
MissedEvents
|
|
),
|
|
SocketPid ! {dispatch, resumed, null, Seq},
|
|
{ok, State#{
|
|
session_pid => Pid,
|
|
heartbeat_state => #{
|
|
last_ack => erlang:system_time(millisecond),
|
|
waiting_for_ack => false
|
|
}
|
|
}}.
|
|
|
|
-spec handle_resume_invalid_seq(state()) -> ws_result().
|
|
handle_resume_invalid_seq(State) ->
|
|
gateway_metrics_collector:inc_resume_failure(),
|
|
close_with_reason(invalid_seq, <<"Invalid sequence">>, State).
|
|
|
|
-spec handle_resume_invalid_token(state()) -> ws_result().
|
|
handle_resume_invalid_token(State) ->
|
|
gateway_metrics_collector:inc_resume_failure(),
|
|
close_with_reason(authentication_failed, <<"Invalid token">>, State).
|
|
|
|
-spec handle_resume_session_not_found(state()) -> ws_result().
|
|
handle_resume_session_not_found(State) ->
|
|
gateway_metrics_collector:inc_resume_failure(),
|
|
send_invalid_session(State).
|
|
|
|
-spec encode_and_compress(map(), state()) -> {ok, ws_frame(), state()} | {error, term()}.
|
|
encode_and_compress(Message, State = #{encoding := Encoding, compress_ctx := CompressCtx}) ->
|
|
case gateway_codec:encode(Message, Encoding) of
|
|
{ok, Encoded, FrameType} ->
|
|
case gateway_compress:compress(Encoded, CompressCtx) of
|
|
{ok, Compressed, NewCompressCtx} ->
|
|
Frame = make_frame(Compressed, FrameType, NewCompressCtx),
|
|
{ok, Frame, State#{compress_ctx => NewCompressCtx}};
|
|
{error, Reason} ->
|
|
{error, {compress_failed, gateway_compress:get_type(CompressCtx), Reason}}
|
|
end;
|
|
{error, Reason} ->
|
|
{error, {encode_failed, Reason}}
|
|
end.
|
|
|
|
-spec compression_error_reason(atom()) -> binary().
|
|
compression_error_reason(zstd_stream) -> <<"Compression failed: zstd-stream">>;
|
|
compression_error_reason(_) -> <<"Encode failed">>.
|
|
|
|
-spec close_with_reason(atom(), binary(), state()) -> ws_result().
|
|
close_with_reason(Reason, Message, State) ->
|
|
gateway_metrics_collector:inc_websocket_close(Reason),
|
|
CloseCode = constants:close_code_to_num(Reason),
|
|
{[{close, CloseCode, Message}], State}.
|
|
|
|
-spec make_frame(binary(), text | binary, gateway_compress:compress_ctx()) -> ws_frame().
|
|
make_frame(Data, FrameType, CompressCtx) ->
|
|
case gateway_compress:get_type(CompressCtx) of
|
|
none -> {FrameType, Data};
|
|
_ -> {binary, Data}
|
|
end.
|
|
|
|
-spec start_websocket_connect_span(1 | undefined, state()) -> term().
|
|
start_websocket_connect_span(Version, #{peer_ip := PeerIP}) ->
|
|
gateway_tracing:start_connection_span(?MODULE, Version, PeerIP).
|
|
|
|
-spec end_websocket_disconnect_span(term(), term()) -> ok.
|
|
end_websocket_disconnect_span(_Reason, undefined) ->
|
|
ok;
|
|
end_websocket_disconnect_span(Reason, Ctx) ->
|
|
ReasonBin = reason_to_binary(Reason),
|
|
gateway_tracing:end_connection_span(Ctx, ReasonBin).
|
|
|
|
-spec reason_to_binary(term()) -> binary().
|
|
reason_to_binary(normal) -> <<"normal">>;
|
|
reason_to_binary(remote) -> <<"remote">>;
|
|
reason_to_binary(Atom) when is_atom(Atom) -> atom_to_binary(Atom, utf8);
|
|
reason_to_binary(_) -> <<"unknown">>.
|
|
|
|
-spec should_queue_voice_update(pid()) -> boolean().
|
|
should_queue_voice_update(SessionPid) ->
|
|
ensure_voice_rate_limit_table(),
|
|
Now = erlang:system_time(millisecond),
|
|
case ets:lookup(?VOICE_RATE_LIMIT_TABLE, SessionPid) of
|
|
[] ->
|
|
ets:insert(?VOICE_RATE_LIMIT_TABLE, {SessionPid, [Now]}),
|
|
false;
|
|
[{SessionPid, Timestamps}] ->
|
|
FilteredTimestamps = [T || T <- Timestamps, (Now - T) < ?VOICE_RATE_LIMIT_WINDOW],
|
|
Count = length(FilteredTimestamps),
|
|
case Count >= ?VOICE_UPDATE_RATE_LIMIT of
|
|
true ->
|
|
true;
|
|
false ->
|
|
NewTimestamps = [Now | FilteredTimestamps],
|
|
ets:insert(?VOICE_RATE_LIMIT_TABLE, {SessionPid, NewTimestamps}),
|
|
false
|
|
end
|
|
end.
|
|
|
|
-spec process_voice_update(pid(), map(), state()) -> ws_result().
|
|
process_voice_update(SessionPid, Data, State) ->
|
|
try
|
|
gen_server:call(
|
|
SessionPid,
|
|
{voice_state_update, Data},
|
|
5000
|
|
),
|
|
{ok, State}
|
|
catch
|
|
exit:{timeout, _} ->
|
|
{ok, State};
|
|
exit:{noproc, _} ->
|
|
{ok, State}
|
|
end.
|
|
|
|
-spec queue_voice_update(pid(), map()) -> ok.
|
|
queue_voice_update(SessionPid, Data) ->
|
|
ensure_voice_queue_table(),
|
|
case ets:lookup(?VOICE_QUEUE_TABLE, SessionPid) of
|
|
[] ->
|
|
Queue = queue:in(Data, queue:new()),
|
|
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, Queue});
|
|
[{SessionPid, Queue}] ->
|
|
TrimmedQueue = trim_voice_queue(Queue),
|
|
NewQueue = queue:in(Data, TrimmedQueue),
|
|
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, NewQueue})
|
|
end,
|
|
ok.
|
|
|
|
-spec trim_voice_queue(queue:queue()) -> queue:queue().
|
|
trim_voice_queue(Queue) ->
|
|
case queue:len(Queue) >= ?MAX_VOICE_QUEUE_LENGTH of
|
|
false ->
|
|
Queue;
|
|
true ->
|
|
case queue:out(Queue) of
|
|
{empty, EmptyQueue} ->
|
|
EmptyQueue;
|
|
{{value, _Dropped}, RemainingQueue} ->
|
|
RemainingQueue
|
|
end
|
|
end.
|
|
|
|
-spec ensure_voice_queue_timer(state()) -> state().
|
|
ensure_voice_queue_timer(State = #{voice_queue_timer := undefined}) ->
|
|
Timer = erlang:send_after(?VOICE_QUEUE_PROCESS_INTERVAL, self(), {process_voice_queue}),
|
|
State#{voice_queue_timer => Timer};
|
|
ensure_voice_queue_timer(State) ->
|
|
State.
|
|
|
|
-spec process_queued_voice_updates(state()) -> state().
|
|
process_queued_voice_updates(State = #{session_pid := SessionPid}) when is_pid(SessionPid) ->
|
|
ensure_voice_queue_table(),
|
|
case ets:lookup(?VOICE_QUEUE_TABLE, SessionPid) of
|
|
[] ->
|
|
State;
|
|
[{SessionPid, Queue}] ->
|
|
process_queue_item(Queue, SessionPid, State)
|
|
end;
|
|
process_queued_voice_updates(State) ->
|
|
State.
|
|
|
|
-spec process_queue_item(queue:queue(), pid(), state()) -> state().
|
|
process_queue_item(Queue, SessionPid, State) ->
|
|
case queue:out(Queue) of
|
|
{empty, _} ->
|
|
ets:delete(?VOICE_QUEUE_TABLE, SessionPid),
|
|
State;
|
|
{{value, Data}, NewQueue} ->
|
|
case should_queue_voice_update(SessionPid) of
|
|
false ->
|
|
process_voice_update(SessionPid, Data, State),
|
|
case queue:is_empty(NewQueue) of
|
|
true ->
|
|
ets:delete(?VOICE_QUEUE_TABLE, SessionPid),
|
|
State;
|
|
false ->
|
|
ets:insert(?VOICE_QUEUE_TABLE, {SessionPid, NewQueue}),
|
|
ensure_voice_queue_timer(State)
|
|
end;
|
|
true ->
|
|
ensure_voice_queue_timer(State)
|
|
end
|
|
end.
|
|
|
|
-spec ensure_voice_queue_table() -> ok.
|
|
ensure_voice_queue_table() ->
|
|
case ets:whereis(?VOICE_QUEUE_TABLE) of
|
|
undefined ->
|
|
try
|
|
ets:new(?VOICE_QUEUE_TABLE, [named_table, public, set]),
|
|
ok
|
|
catch
|
|
error:badarg -> ok
|
|
end;
|
|
_ ->
|
|
ok
|
|
end.
|
|
|
|
-spec ensure_voice_rate_limit_table() -> ok.
|
|
ensure_voice_rate_limit_table() ->
|
|
case ets:whereis(?VOICE_RATE_LIMIT_TABLE) of
|
|
undefined ->
|
|
try
|
|
ets:new(?VOICE_RATE_LIMIT_TABLE, [named_table, public, set]),
|
|
ok
|
|
catch
|
|
error:badarg -> ok
|
|
end;
|
|
_ ->
|
|
ok
|
|
end.
|
|
|
|
-ifdef(TEST).
|
|
|
|
parse_forwarded_for_ipv4_test() ->
|
|
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(<<"203.0.113.7">>)).
|
|
|
|
parse_forwarded_for_ipv4_with_port_test() ->
|
|
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(<<"203.0.113.7:8080">>)).
|
|
|
|
parse_forwarded_for_ipv4_with_port_and_extra_entries_test() ->
|
|
Header = <<" 203.0.113.7:8080 , 10.0.0.1">>,
|
|
?assertEqual(<<"203.0.113.7">>, parse_forwarded_for(Header)).
|
|
|
|
parse_forwarded_for_ipv6_test() ->
|
|
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"2001:db8::1">>)).
|
|
|
|
parse_forwarded_for_ipv6_with_brackets_test() ->
|
|
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"[2001:db8::1]">>)).
|
|
|
|
parse_forwarded_for_ipv6_with_brackets_and_port_test() ->
|
|
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<"[2001:db8::1]:443">>)).
|
|
|
|
parse_forwarded_for_ipv6_with_spaces_test() ->
|
|
?assertEqual(<<"2001:db8::1">>, parse_forwarded_for(<<" [2001:db8::1] ">>)).
|
|
|
|
parse_forwarded_for_invalid_ip_test() ->
|
|
?assertEqual(<<>>, parse_forwarded_for(<<"not_an_ip">>)).
|
|
|
|
parse_forwarded_for_invalid_ipv4_octet_test() ->
|
|
?assertEqual(<<>>, parse_forwarded_for(<<"203.0.113.300">>)).
|
|
|
|
parse_forwarded_for_unterminated_bracket_test() ->
|
|
?assertEqual(<<>>, parse_forwarded_for(<<"[2001:db8::1">>)).
|
|
|
|
parse_version_test() ->
|
|
?assertEqual(1, parse_version(<<"1">>)),
|
|
?assertEqual(undefined, parse_version(<<"2">>)),
|
|
?assertEqual(undefined, parse_version(undefined)).
|
|
|
|
parse_ignored_events_test() ->
|
|
?assertEqual({ok, []}, parse_ignored_events(undefined)),
|
|
?assertEqual({ok, []}, parse_ignored_events(null)),
|
|
?assertEqual({ok, [<<"TYPING_START">>]}, parse_ignored_events([<<"typing_start">>])),
|
|
?assertEqual({error, invalid_ignored_events}, parse_ignored_events([123])),
|
|
?assertEqual({error, invalid_ignored_events}, parse_ignored_events(<<"not_a_list">>)).
|
|
|
|
adjust_status_test() ->
|
|
?assertEqual(invisible, adjust_status(offline)),
|
|
?assertEqual(online, adjust_status(online)),
|
|
?assertEqual(idle, adjust_status(idle)).
|
|
|
|
-endif.
|