diff --git a/fluxer_gateway/src/gateway/gateway_handler.erl b/fluxer_gateway/src/gateway/gateway_handler.erl index 7991294e..22655c0a 100644 --- a/fluxer_gateway/src/gateway/gateway_handler.erl +++ b/fluxer_gateway/src/gateway/gateway_handler.erl @@ -30,6 +30,10 @@ -define(VOICE_RATE_LIMIT_TABLE, voice_update_rate_limit). -define(VOICE_QUEUE_PROCESS_INTERVAL, 100). -define(MAX_VOICE_QUEUE_LENGTH, 64). +-define(RATE_LIMIT_WINDOW_MS, 60000). +-define(RATE_LIMIT_MAX_EVENTS, 120). +-define(REQUEST_GUILD_MEMBERS_RATE_LIMIT_WINDOW_MS, 10000). +-define(REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, 3). -type state() :: #{ version := 1 | undefined, @@ -40,6 +44,7 @@ socket_pid := pid() | undefined, peer_ip := binary() | undefined, rate_limit_state := map(), + request_guild_members_pid := pid() | undefined, otel_span_ctx := term(), voice_queue_timer := reference() | undefined }. @@ -54,7 +59,12 @@ new_state() -> heartbeat_state => #{}, socket_pid => undefined, peer_ip => undefined, - rate_limit_state => #{events => [], window_start => undefined}, + rate_limit_state => #{ + events => [], + request_guild_members_events => [], + window_start => undefined + }, + request_guild_members_pid => undefined, otel_span_ctx => undefined, voice_queue_timer => undefined }. @@ -140,7 +150,7 @@ handle_incoming_data(Data, State = #{encoding := Encoding, compress_ctx := Compr -spec handle_decode({ok, map()} | {error, term()}, state()) -> ws_result(). handle_decode({ok, #{<<"op">> := Op} = Payload}, State) -> OpAtom = constants:gateway_opcode(Op), - case check_rate_limit(State) of + case check_rate_limit(State, OpAtom) of {ok, RateLimitedState} -> handle_gateway_payload(OpAtom, Payload, RateLimitedState); rate_limited -> @@ -160,6 +170,10 @@ websocket_info({session_backpressure_error, Details}, State) -> handle_session_backpressure_error(Details, State); websocket_info({'DOWN', _, process, Pid, _}, State = #{session_pid := Pid}) -> handle_session_down(State); +websocket_info( + {'DOWN', _, process, Pid, _}, State = #{request_guild_members_pid := Pid} +) -> + {ok, State#{request_guild_members_pid => undefined}}; websocket_info({process_voice_queue}, State) -> NewState = process_queued_voice_updates(State#{voice_queue_timer => undefined}), {ok, NewState}; @@ -541,9 +555,13 @@ handle_voice_state_update(Pid, Data, State) -> end. -spec handle_request_guild_members(map(), pid(), state()) -> ws_result(). +handle_request_guild_members( + _Data, _Pid, State = #{request_guild_members_pid := RequestPid} +) when is_pid(RequestPid) -> + {ok, State}; handle_request_guild_members(Data, Pid, State) -> SocketPid = self(), - spawn(fun() -> + {WorkerPid, _Ref} = spawn_monitor(fun() -> try case gen_server:call(Pid, {get_state}, 5000) of SessionState when is_map(SessionState) -> @@ -555,7 +573,7 @@ handle_request_guild_members(Data, Pid, State) -> _:_ -> ok end end), - {ok, State}. + {ok, State#{request_guild_members_pid => WorkerPid}}. -spec handle_lazy_request(map(), pid(), state()) -> ws_result(). handle_lazy_request(Data, Pid, State) -> @@ -578,23 +596,41 @@ handle_lazy_request(Data, Pid, State) -> schedule_heartbeat_check() -> erlang:send_after(constants:heartbeat_interval() div 3, self(), {heartbeat_check}). --spec check_rate_limit(state()) -> {ok, state()} | rate_limited. -check_rate_limit(State = #{rate_limit_state := RateLimitState}) -> +-spec check_rate_limit(state(), atom()) -> {ok, state()} | rate_limited. +check_rate_limit(State = #{rate_limit_state := RateLimitState}, Op) -> Now = erlang:system_time(millisecond), Events = maps:get(events, RateLimitState, []), WindowStart = maps:get(window_start, RateLimitState, Now), - WindowDuration = 60000, - MaxEvents = 120, - EventsInWindow = [T || T <- Events, (Now - T) < WindowDuration], - case length(EventsInWindow) >= MaxEvents of + EventsInWindow = [T || T <- Events, (Now - T) < ?RATE_LIMIT_WINDOW_MS], + case length(EventsInWindow) >= ?RATE_LIMIT_MAX_EVENTS of true -> rate_limited; false -> - NewEvents = [Now | EventsInWindow], - NewRateLimitState = #{events => NewEvents, window_start => WindowStart}, - {ok, State#{rate_limit_state => NewRateLimitState}} + case check_opcode_rate_limit(Op, RateLimitState, Now) of + rate_limited -> + rate_limited; + {ok, OpRateLimitState} -> + NewEvents = [Now | EventsInWindow], + NewRateLimitState = + OpRateLimitState#{events => NewEvents, window_start => WindowStart}, + {ok, State#{rate_limit_state => NewRateLimitState}} + end end. +-spec check_opcode_rate_limit(atom(), map(), integer()) -> {ok, map()} | rate_limited. +check_opcode_rate_limit(request_guild_members, RateLimitState, Now) -> + RequestEvents = maps:get(request_guild_members_events, RateLimitState, []), + RequestEventsInWindow = + [T || T <- RequestEvents, (Now - T) < ?REQUEST_GUILD_MEMBERS_RATE_LIMIT_WINDOW_MS], + case length(RequestEventsInWindow) >= ?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS of + true -> + rate_limited; + false -> + {ok, RateLimitState#{request_guild_members_events => [Now | RequestEventsInWindow]}} + end; +check_opcode_rate_limit(_, RateLimitState, _Now) -> + {ok, RateLimitState}. + -spec extract_client_ip(cowboy_req:req()) -> binary(). extract_client_ip(Req) -> case cowboy_req:header(<<"x-forwarded-for">>, Req) of @@ -954,4 +990,42 @@ adjust_status_test() -> ?assertEqual(online, adjust_status(online)), ?assertEqual(idle, adjust_status(idle)). +check_rate_limit_blocks_general_flood_test() -> + Now = erlang:system_time(millisecond), + Events = lists:duplicate(?RATE_LIMIT_MAX_EVENTS, Now - 1000), + State = (new_state())#{ + rate_limit_state => #{ + events => Events, + request_guild_members_events => [], + window_start => Now + } + }, + ?assertEqual(rate_limited, check_rate_limit(State, heartbeat)). + +check_rate_limit_blocks_request_guild_members_burst_test() -> + Now = erlang:system_time(millisecond), + RequestEvents = + lists:duplicate(?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000), + State = (new_state())#{ + rate_limit_state => #{ + events => [], + request_guild_members_events => RequestEvents, + window_start => Now + } + }, + ?assertEqual(rate_limited, check_rate_limit(State, request_guild_members)). + +check_rate_limit_allows_other_ops_when_request_guild_members_is_hot_test() -> + Now = erlang:system_time(millisecond), + RequestEvents = + lists:duplicate(?REQUEST_GUILD_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000), + State = (new_state())#{ + rate_limit_state => #{ + events => [], + request_guild_members_events => RequestEvents, + window_start => Now + } + }, + ?assertMatch({ok, _}, check_rate_limit(State, heartbeat)). + -endif. diff --git a/fluxer_gateway/src/guild/guild_request_members.erl b/fluxer_gateway/src/guild/guild_request_members.erl index db1f35a9..f6917158 100644 --- a/fluxer_gateway/src/guild/guild_request_members.erl +++ b/fluxer_gateway/src/guild/guild_request_members.erl @@ -24,6 +24,15 @@ -define(CHUNK_SIZE, 1000). -define(MAX_USER_IDS, 100). -define(MAX_NONCE_LENGTH, 32). +-define(FULL_MEMBER_LIST_LIMIT, 100000). +-define(DEFAULT_QUERY_LIMIT, 25). +-define(MAX_MEMBER_QUERY_LIMIT, 100). +-define(REQUEST_MEMBERS_RATE_LIMIT_TABLE, guild_request_members_rate_limit). +-define(REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS, 10000). +-define(REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, 5). +-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, guild_request_members_guild_rate_limit). +-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS, 10000). +-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, 25). -type session_state() :: map(). -type request_data() :: map(). @@ -121,7 +130,8 @@ ensure_binary(Value) when is_binary(Value) -> Value; ensure_binary(_) -> <<>>. -spec ensure_limit(term()) -> non_neg_integer(). -ensure_limit(Limit) when is_integer(Limit), Limit >= 0 -> Limit; +ensure_limit(Limit) when is_integer(Limit), Limit >= 0 -> + min(Limit, ?MAX_MEMBER_QUERY_LIMIT); ensure_limit(_) -> 0. -spec normalize_nonce(term()) -> binary() | null. @@ -135,13 +145,99 @@ process_request(Request, SocketPid, SessionState) -> #{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request, UserIdBin = maps:get(user_id, SessionState), UserId = type_conv:to_integer(UserIdBin), - case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of + case check_request_rate_limit(UserId) of ok -> - fetch_and_send_members(Request, SocketPid, SessionState); + case check_guild_request_rate_limit(GuildId) of + ok -> + case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of + ok -> + fetch_and_send_members(Request, SocketPid, SessionState); + {error, Reason} -> + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end; {error, Reason} -> {error, Reason} end. +-spec check_request_rate_limit(integer() | undefined) -> ok | {error, atom()}. +check_request_rate_limit(UserId) when is_integer(UserId), UserId > 0 -> + ensure_request_rate_limit_table(), + Now = erlang:system_time(millisecond), + case ets:lookup(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId) of + [] -> + ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now]}), + ok; + [{UserId, Timestamps}] -> + RecentTimestamps = + [T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS], + case length(RecentTimestamps) >= ?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS of + true -> + {error, rate_limited}; + false -> + ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now | RecentTimestamps]}), + ok + end + end; +check_request_rate_limit(_) -> + {error, invalid_session}. + +-spec check_guild_request_rate_limit(integer()) -> ok | {error, atom()}. +check_guild_request_rate_limit(GuildId) when is_integer(GuildId), GuildId > 0 -> + ensure_guild_request_rate_limit_table(), + Now = erlang:system_time(millisecond), + case ets:lookup(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId) of + [] -> + ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now]}), + ok; + [{GuildId, Timestamps}] -> + RecentTimestamps = + [T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS], + case length(RecentTimestamps) >= ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS of + true -> + {error, rate_limited}; + false -> + ets:insert( + ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now | RecentTimestamps]} + ), + ok + end + end; +check_guild_request_rate_limit(_) -> + {error, invalid_guild_id}. + +-spec ensure_request_rate_limit_table() -> ok. +ensure_request_rate_limit_table() -> + case ets:whereis(?REQUEST_MEMBERS_RATE_LIMIT_TABLE) of + undefined -> + try + ets:new(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, [named_table, public, set]), + ok + catch + error:badarg -> + ok + end; + _ -> + ok + end. + +-spec ensure_guild_request_rate_limit_table() -> ok. +ensure_guild_request_rate_limit_table() -> + case ets:whereis(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE) of + undefined -> + try + ets:new(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, [named_table, public, set]), + ok + catch + error:badarg -> + ok + end; + _ -> + ok + end. + -spec check_permission( integer(), integer(), binary(), non_neg_integer(), [integer()], session_state() ) -> @@ -215,18 +311,9 @@ fetch_and_send_members(Request, _SocketPid, SessionState) -> -spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()]. fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] -> - case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of - #{members := AllMembers} -> - filter_members_by_ids(AllMembers, UserIds); - _ -> - [] - end; + fetch_members_by_user_ids(GuildPid, UserIds); fetch_members(GuildPid, Query, Limit, []) -> - ActualLimit = - case Limit of - 0 -> 100000; - L -> L - end, + ActualLimit = resolve_member_limit(Query, Limit), case gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000) of @@ -241,17 +328,33 @@ fetch_members(GuildPid, Query, Limit, []) -> [] end. --spec filter_members_by_ids([member()], [integer()]) -> [member()]. -filter_members_by_ids(Members, UserIds) -> - UserIdSet = sets:from_list(UserIds), - lists:filter( - fun(Member) -> - UserId = extract_user_id(Member), - UserId =/= undefined andalso sets:is_element(UserId, UserIdSet) +-spec fetch_members_by_user_ids(pid(), [integer()]) -> [member()]. +fetch_members_by_user_ids(GuildPid, UserIds) -> + lists:filtermap( + fun(UserId) -> + try + case gen_server:call(GuildPid, {get_guild_member, #{user_id => UserId}}, 5000) of + #{success := true, member_data := Member} when is_map(Member) -> + {true, Member}; + _ -> + false + end + catch + exit:_ -> + false + end end, - Members + lists:usort(UserIds) ). +-spec resolve_member_limit(binary(), non_neg_integer()) -> pos_integer(). +resolve_member_limit(<<>>, 0) -> + ?FULL_MEMBER_LIST_LIMIT; +resolve_member_limit(_Query, 0) -> + ?DEFAULT_QUERY_LIMIT; +resolve_member_limit(_Query, Limit) -> + Limit. + -spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()]. filter_members_by_query(Members, Query, Limit) -> NormalizedQuery = string:lowercase(binary_to_list(Query)), @@ -518,6 +621,64 @@ ensure_limit_negative_test() -> ensure_limit_non_integer_test() -> ?assertEqual(0, ensure_limit(<<"10">>)). +ensure_limit_clamped_test() -> + ?assertEqual(?MAX_MEMBER_QUERY_LIMIT, ensure_limit(?MAX_MEMBER_QUERY_LIMIT + 1)). + +resolve_member_limit_full_scan_test() -> + ?assertEqual(?FULL_MEMBER_LIST_LIMIT, resolve_member_limit(<<>>, 0)). + +resolve_member_limit_query_default_test() -> + ?assertEqual(?DEFAULT_QUERY_LIMIT, resolve_member_limit(<<"ab">>, 0)). + +resolve_member_limit_explicit_test() -> + ?assertEqual(25, resolve_member_limit(<<"ab">>, 25)). + +check_request_rate_limit_allows_initial_request_test() -> + UserId = 987654321, + clear_request_rate_limit(UserId), + ?assertEqual(ok, check_request_rate_limit(UserId)), + clear_request_rate_limit(UserId). + +check_request_rate_limit_blocks_burst_test() -> + UserId = 987654322, + clear_request_rate_limit(UserId), + ensure_request_rate_limit_table(), + Now = erlang:system_time(millisecond), + Timestamps = lists:duplicate(?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000), + ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, Timestamps}), + ?assertEqual({error, rate_limited}, check_request_rate_limit(UserId)), + clear_request_rate_limit(UserId). + +check_request_rate_limit_invalid_user_test() -> + ?assertEqual({error, invalid_session}, check_request_rate_limit(undefined)). + +check_guild_request_rate_limit_allows_initial_request_test() -> + GuildId = 87654321, + clear_guild_request_rate_limit(GuildId), + ?assertEqual(ok, check_guild_request_rate_limit(GuildId)), + clear_guild_request_rate_limit(GuildId). + +check_guild_request_rate_limit_blocks_burst_test() -> + GuildId = 87654322, + clear_guild_request_rate_limit(GuildId), + ensure_guild_request_rate_limit_table(), + Now = erlang:system_time(millisecond), + Timestamps = lists:duplicate(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, Now - 1000), + ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, Timestamps}), + ?assertEqual({error, rate_limited}, check_guild_request_rate_limit(GuildId)), + clear_guild_request_rate_limit(GuildId). + +check_guild_request_rate_limit_invalid_guild_test() -> + ?assertEqual({error, invalid_guild_id}, check_guild_request_rate_limit(undefined)). + +clear_request_rate_limit(UserId) -> + ensure_request_rate_limit_table(), + ets:delete(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId). + +clear_guild_request_rate_limit(GuildId) -> + ensure_guild_request_rate_limit_table(), + ets:delete(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId). + validate_guild_id_integer_test() -> ?assertEqual({ok, 123}, validate_guild_id(123)). @@ -589,30 +750,6 @@ chunk_presences_no_matching_presences_test() -> Result = chunk_presences(Presences, [Members]), ?assertEqual([[]], Result). -filter_members_by_ids_basic_test() -> - Members = [ - #{<<"user">> => #{<<"id">> => <<"1">>}}, - #{<<"user">> => #{<<"id">> => <<"2">>}}, - #{<<"user">> => #{<<"id">> => <<"3">>}} - ], - Result = filter_members_by_ids(Members, [1, 3]), - ?assertEqual(2, length(Result)). - -filter_members_by_ids_empty_ids_test() -> - Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}], - Result = filter_members_by_ids(Members, []), - ?assertEqual([], Result). - -filter_members_by_ids_no_match_test() -> - Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}], - Result = filter_members_by_ids(Members, [999]), - ?assertEqual([], Result). - -filter_members_by_ids_skips_invalid_members_test() -> - Members = [#{}, #{<<"user">> => #{}}, #{<<"user">> => #{<<"id">> => <<"1">>}}], - Result = filter_members_by_ids(Members, [1]), - ?assertEqual(1, length(Result)). - filter_members_by_query_case_insensitive_test() -> Members = [ #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alice">>}},