fix(gateway): harden REQUEST_GUILD_MEMBERS path against DoS floods

This commit is contained in:
Hampus Kraft 2026-02-22 13:41:25 +00:00
parent 4f5704fa1f
commit d843d6f3f8
No known key found for this signature in database
GPG Key ID: 6090864C465A454D
2 changed files with 270 additions and 59 deletions

View File

@ -30,6 +30,10 @@
-define(VOICE_RATE_LIMIT_TABLE, voice_update_rate_limit).
-define(VOICE_QUEUE_PROCESS_INTERVAL, 100).
-define(MAX_VOICE_QUEUE_LENGTH, 64).
-define(RATE_LIMIT_WINDOW_MS, 60000).
-define(RATE_LIMIT_MAX_EVENTS, 120).
-define(REQUEST_GUILD_MEMBERS_RATE_LIMIT_WINDOW_MS, 10000).
-define(REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, 3).
-type state() :: #{
version := 1 | undefined,
@ -40,6 +44,7 @@
socket_pid := pid() | undefined,
peer_ip := binary() | undefined,
rate_limit_state := map(),
request_guild_members_pid := pid() | undefined,
otel_span_ctx := term(),
voice_queue_timer := reference() | undefined
}.
@ -54,7 +59,12 @@ new_state() ->
heartbeat_state => #{},
socket_pid => undefined,
peer_ip => undefined,
rate_limit_state => #{events => [], window_start => undefined},
rate_limit_state => #{
events => [],
request_guild_members_events => [],
window_start => undefined
},
request_guild_members_pid => undefined,
otel_span_ctx => undefined,
voice_queue_timer => undefined
}.
@ -140,7 +150,7 @@ handle_incoming_data(Data, State = #{encoding := Encoding, compress_ctx := Compr
-spec handle_decode({ok, map()} | {error, term()}, state()) -> ws_result().
handle_decode({ok, #{<<"op">> := Op} = Payload}, State) ->
OpAtom = constants:gateway_opcode(Op),
case check_rate_limit(State) of
case check_rate_limit(State, OpAtom) of
{ok, RateLimitedState} ->
handle_gateway_payload(OpAtom, Payload, RateLimitedState);
rate_limited ->
@ -160,6 +170,10 @@ websocket_info({session_backpressure_error, Details}, State) ->
handle_session_backpressure_error(Details, State);
websocket_info({'DOWN', _, process, Pid, _}, State = #{session_pid := Pid}) ->
handle_session_down(State);
websocket_info(
{'DOWN', _, process, Pid, _}, State = #{request_guild_members_pid := Pid}
) ->
{ok, State#{request_guild_members_pid => undefined}};
websocket_info({process_voice_queue}, State) ->
NewState = process_queued_voice_updates(State#{voice_queue_timer => undefined}),
{ok, NewState};
@ -541,9 +555,13 @@ handle_voice_state_update(Pid, Data, State) ->
end.
-spec handle_request_guild_members(map(), pid(), state()) -> ws_result().
handle_request_guild_members(
_Data, _Pid, State = #{request_guild_members_pid := RequestPid}
) when is_pid(RequestPid) ->
{ok, State};
handle_request_guild_members(Data, Pid, State) ->
SocketPid = self(),
spawn(fun() ->
{WorkerPid, _Ref} = spawn_monitor(fun() ->
try
case gen_server:call(Pid, {get_state}, 5000) of
SessionState when is_map(SessionState) ->
@ -555,7 +573,7 @@ handle_request_guild_members(Data, Pid, State) ->
_:_ -> ok
end
end),
{ok, State}.
{ok, State#{request_guild_members_pid => WorkerPid}}.
-spec handle_lazy_request(map(), pid(), state()) -> ws_result().
handle_lazy_request(Data, Pid, State) ->
@ -578,23 +596,41 @@ handle_lazy_request(Data, Pid, State) ->
schedule_heartbeat_check() ->
erlang:send_after(constants:heartbeat_interval() div 3, self(), {heartbeat_check}).
-spec check_rate_limit(state()) -> {ok, state()} | rate_limited.
check_rate_limit(State = #{rate_limit_state := RateLimitState}) ->
-spec check_rate_limit(state(), atom()) -> {ok, state()} | rate_limited.
check_rate_limit(State = #{rate_limit_state := RateLimitState}, Op) ->
Now = erlang:system_time(millisecond),
Events = maps:get(events, RateLimitState, []),
WindowStart = maps:get(window_start, RateLimitState, Now),
WindowDuration = 60000,
MaxEvents = 120,
EventsInWindow = [T || T <- Events, (Now - T) < WindowDuration],
case length(EventsInWindow) >= MaxEvents of
EventsInWindow = [T || T <- Events, (Now - T) < ?RATE_LIMIT_WINDOW_MS],
case length(EventsInWindow) >= ?RATE_LIMIT_MAX_EVENTS of
true ->
rate_limited;
false ->
NewEvents = [Now | EventsInWindow],
NewRateLimitState = #{events => NewEvents, window_start => WindowStart},
{ok, State#{rate_limit_state => NewRateLimitState}}
case check_opcode_rate_limit(Op, RateLimitState, Now) of
rate_limited ->
rate_limited;
{ok, OpRateLimitState} ->
NewEvents = [Now | EventsInWindow],
NewRateLimitState =
OpRateLimitState#{events => NewEvents, window_start => WindowStart},
{ok, State#{rate_limit_state => NewRateLimitState}}
end
end.
-spec check_opcode_rate_limit(atom(), map(), integer()) -> {ok, map()} | rate_limited.
check_opcode_rate_limit(request_guild_members, RateLimitState, Now) ->
RequestEvents = maps:get(request_guild_members_events, RateLimitState, []),
RequestEventsInWindow =
[T || T <- RequestEvents, (Now - T) < ?REQUEST_GUILD_MEMBERS_RATE_LIMIT_WINDOW_MS],
case length(RequestEventsInWindow) >= ?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS of
true ->
rate_limited;
false ->
{ok, RateLimitState#{request_guild_members_events => [Now | RequestEventsInWindow]}}
end;
check_opcode_rate_limit(_, RateLimitState, _Now) ->
{ok, RateLimitState}.
-spec extract_client_ip(cowboy_req:req()) -> binary().
extract_client_ip(Req) ->
case cowboy_req:header(<<"x-forwarded-for">>, Req) of
@ -954,4 +990,42 @@ adjust_status_test() ->
?assertEqual(online, adjust_status(online)),
?assertEqual(idle, adjust_status(idle)).
check_rate_limit_blocks_general_flood_test() ->
Now = erlang:system_time(millisecond),
Events = lists:duplicate(?RATE_LIMIT_MAX_EVENTS, Now - 1000),
State = (new_state())#{
rate_limit_state => #{
events => Events,
request_guild_members_events => [],
window_start => Now
}
},
?assertEqual(rate_limited, check_rate_limit(State, heartbeat)).
check_rate_limit_blocks_request_guild_members_burst_test() ->
Now = erlang:system_time(millisecond),
RequestEvents =
lists:duplicate(?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000),
State = (new_state())#{
rate_limit_state => #{
events => [],
request_guild_members_events => RequestEvents,
window_start => Now
}
},
?assertEqual(rate_limited, check_rate_limit(State, request_guild_members)).
check_rate_limit_allows_other_ops_when_request_guild_members_is_hot_test() ->
Now = erlang:system_time(millisecond),
RequestEvents =
lists:duplicate(?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000),
State = (new_state())#{
rate_limit_state => #{
events => [],
request_guild_members_events => RequestEvents,
window_start => Now
}
},
?assertMatch({ok, _}, check_rate_limit(State, heartbeat)).
-endif.

View File

@ -24,6 +24,15 @@
-define(CHUNK_SIZE, 1000).
-define(MAX_USER_IDS, 100).
-define(MAX_NONCE_LENGTH, 32).
-define(FULL_MEMBER_LIST_LIMIT, 100000).
-define(DEFAULT_QUERY_LIMIT, 25).
-define(MAX_MEMBER_QUERY_LIMIT, 100).
-define(REQUEST_MEMBERS_RATE_LIMIT_TABLE, guild_request_members_rate_limit).
-define(REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS, 10000).
-define(REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, 5).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, guild_request_members_guild_rate_limit).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS, 10000).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, 25).
-type session_state() :: map().
-type request_data() :: map().
@ -121,7 +130,8 @@ ensure_binary(Value) when is_binary(Value) -> Value;
ensure_binary(_) -> <<>>.
-spec ensure_limit(term()) -> non_neg_integer().
ensure_limit(Limit) when is_integer(Limit), Limit >= 0 -> Limit;
ensure_limit(Limit) when is_integer(Limit), Limit >= 0 ->
min(Limit, ?MAX_MEMBER_QUERY_LIMIT);
ensure_limit(_) -> 0.
-spec normalize_nonce(term()) -> binary() | null.
@ -135,13 +145,99 @@ process_request(Request, SocketPid, SessionState) ->
#{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request,
UserIdBin = maps:get(user_id, SessionState),
UserId = type_conv:to_integer(UserIdBin),
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
case check_request_rate_limit(UserId) of
ok ->
fetch_and_send_members(Request, SocketPid, SessionState);
case check_guild_request_rate_limit(GuildId) of
ok ->
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
ok ->
fetch_and_send_members(Request, SocketPid, SessionState);
{error, Reason} ->
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end.
-spec check_request_rate_limit(integer() | undefined) -> ok | {error, atom()}.
check_request_rate_limit(UserId) when is_integer(UserId), UserId > 0 ->
ensure_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
case ets:lookup(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId) of
[] ->
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now]}),
ok;
[{UserId, Timestamps}] ->
RecentTimestamps =
[T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS],
case length(RecentTimestamps) >= ?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS of
true ->
{error, rate_limited};
false ->
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now | RecentTimestamps]}),
ok
end
end;
check_request_rate_limit(_) ->
{error, invalid_session}.
-spec check_guild_request_rate_limit(integer()) -> ok | {error, atom()}.
check_guild_request_rate_limit(GuildId) when is_integer(GuildId), GuildId > 0 ->
ensure_guild_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
case ets:lookup(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId) of
[] ->
ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now]}),
ok;
[{GuildId, Timestamps}] ->
RecentTimestamps =
[T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS],
case length(RecentTimestamps) >= ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS of
true ->
{error, rate_limited};
false ->
ets:insert(
?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now | RecentTimestamps]}
),
ok
end
end;
check_guild_request_rate_limit(_) ->
{error, invalid_guild_id}.
-spec ensure_request_rate_limit_table() -> ok.
ensure_request_rate_limit_table() ->
case ets:whereis(?REQUEST_MEMBERS_RATE_LIMIT_TABLE) of
undefined ->
try
ets:new(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, [named_table, public, set]),
ok
catch
error:badarg ->
ok
end;
_ ->
ok
end.
-spec ensure_guild_request_rate_limit_table() -> ok.
ensure_guild_request_rate_limit_table() ->
case ets:whereis(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE) of
undefined ->
try
ets:new(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, [named_table, public, set]),
ok
catch
error:badarg ->
ok
end;
_ ->
ok
end.
-spec check_permission(
integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()
) ->
@ -215,18 +311,9 @@ fetch_and_send_members(Request, _SocketPid, SessionState) ->
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()].
fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] ->
case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of
#{members := AllMembers} ->
filter_members_by_ids(AllMembers, UserIds);
_ ->
[]
end;
fetch_members_by_user_ids(GuildPid, UserIds);
fetch_members(GuildPid, Query, Limit, []) ->
ActualLimit =
case Limit of
0 -> 100000;
L -> L
end,
ActualLimit = resolve_member_limit(Query, Limit),
case
gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000)
of
@ -241,17 +328,33 @@ fetch_members(GuildPid, Query, Limit, []) ->
[]
end.
-spec filter_members_by_ids([member()], [integer()]) -> [member()].
filter_members_by_ids(Members, UserIds) ->
UserIdSet = sets:from_list(UserIds),
lists:filter(
fun(Member) ->
UserId = extract_user_id(Member),
UserId =/= undefined andalso sets:is_element(UserId, UserIdSet)
-spec fetch_members_by_user_ids(pid(), [integer()]) -> [member()].
fetch_members_by_user_ids(GuildPid, UserIds) ->
lists:filtermap(
fun(UserId) ->
try
case gen_server:call(GuildPid, {get_guild_member, #{user_id => UserId}}, 5000) of
#{success := true, member_data := Member} when is_map(Member) ->
{true, Member};
_ ->
false
end
catch
exit:_ ->
false
end
end,
Members
lists:usort(UserIds)
).
-spec resolve_member_limit(binary(), non_neg_integer()) -> pos_integer().
resolve_member_limit(<<>>, 0) ->
?FULL_MEMBER_LIST_LIMIT;
resolve_member_limit(_Query, 0) ->
?DEFAULT_QUERY_LIMIT;
resolve_member_limit(_Query, Limit) ->
Limit.
-spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()].
filter_members_by_query(Members, Query, Limit) ->
NormalizedQuery = string:lowercase(binary_to_list(Query)),
@ -518,6 +621,64 @@ ensure_limit_negative_test() ->
ensure_limit_non_integer_test() ->
?assertEqual(0, ensure_limit(<<"10">>)).
ensure_limit_clamped_test() ->
?assertEqual(?MAX_MEMBER_QUERY_LIMIT, ensure_limit(?MAX_MEMBER_QUERY_LIMIT + 1)).
resolve_member_limit_full_scan_test() ->
?assertEqual(?FULL_MEMBER_LIST_LIMIT, resolve_member_limit(<<>>, 0)).
resolve_member_limit_query_default_test() ->
?assertEqual(?DEFAULT_QUERY_LIMIT, resolve_member_limit(<<"ab">>, 0)).
resolve_member_limit_explicit_test() ->
?assertEqual(25, resolve_member_limit(<<"ab">>, 25)).
check_request_rate_limit_allows_initial_request_test() ->
UserId = 987654321,
clear_request_rate_limit(UserId),
?assertEqual(ok, check_request_rate_limit(UserId)),
clear_request_rate_limit(UserId).
check_request_rate_limit_blocks_burst_test() ->
UserId = 987654322,
clear_request_rate_limit(UserId),
ensure_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
Timestamps = lists:duplicate(?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000),
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, Timestamps}),
?assertEqual({error, rate_limited}, check_request_rate_limit(UserId)),
clear_request_rate_limit(UserId).
check_request_rate_limit_invalid_user_test() ->
?assertEqual({error, invalid_session}, check_request_rate_limit(undefined)).
check_guild_request_rate_limit_allows_initial_request_test() ->
GuildId = 87654321,
clear_guild_request_rate_limit(GuildId),
?assertEqual(ok, check_guild_request_rate_limit(GuildId)),
clear_guild_request_rate_limit(GuildId).
check_guild_request_rate_limit_blocks_burst_test() ->
GuildId = 87654322,
clear_guild_request_rate_limit(GuildId),
ensure_guild_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
Timestamps = lists:duplicate(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, Now - 1000),
ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, Timestamps}),
?assertEqual({error, rate_limited}, check_guild_request_rate_limit(GuildId)),
clear_guild_request_rate_limit(GuildId).
check_guild_request_rate_limit_invalid_guild_test() ->
?assertEqual({error, invalid_guild_id}, check_guild_request_rate_limit(undefined)).
clear_request_rate_limit(UserId) ->
ensure_request_rate_limit_table(),
ets:delete(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId).
clear_guild_request_rate_limit(GuildId) ->
ensure_guild_request_rate_limit_table(),
ets:delete(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId).
validate_guild_id_integer_test() ->
?assertEqual({ok, 123}, validate_guild_id(123)).
@ -589,30 +750,6 @@ chunk_presences_no_matching_presences_test() ->
Result = chunk_presences(Presences, [Members]),
?assertEqual([[]], Result).
filter_members_by_ids_basic_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}},
#{<<"user">> => #{<<"id">> => <<"3">>}}
],
Result = filter_members_by_ids(Members, [1, 3]),
?assertEqual(2, length(Result)).
filter_members_by_ids_empty_ids_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, []),
?assertEqual([], Result).
filter_members_by_ids_no_match_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [999]),
?assertEqual([], Result).
filter_members_by_ids_skips_invalid_members_test() ->
Members = [#{}, #{<<"user">> => #{}}, #{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [1]),
?assertEqual(1, length(Result)).
filter_members_by_query_case_insensitive_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alice">>}},