summaryrefslogtreecommitdiff
path: root/lib/pleroma/web/mastodon_api/websocket_handler.ex
blob: 5652a37c19f08c5c1431e79c46e66d63d58e9ebc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Pleroma: A lightweight social networking server
# Copyright © 2017-2020 Pleroma Authors <https://pleroma.social/>
# SPDX-License-Identifier: AGPL-3.0-only

defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
  require Logger

  alias Pleroma.Repo
  alias Pleroma.User
  alias Pleroma.Web.OAuth.Token
  alias Pleroma.Web.Streamer

  @behaviour :cowboy_websocket

  @streams [
    "public",
    "public:local",
    "public:media",
    "public:local:media",
    "user",
    "user:notification",
    "direct",
    "list",
    "hashtag"
  ]
  @anonymous_streams ["public", "public:local", "hashtag"]

  # Handled by periodic keepalive in Pleroma.Web.Streamer.Ping.
  @timeout :infinity

  def init(%{qs: qs} = req, state) do
    with params <- :cow_qs.parse_qs(qs),
         sec_websocket <- :cowboy_req.header("sec-websocket-protocol", req, nil),
         access_token <- List.keyfind(params, "access_token", 0),
         {_, stream} <- List.keyfind(params, "stream", 0),
         {:ok, user} <- allow_request(stream, [access_token, sec_websocket]),
         topic when is_binary(topic) <- expand_topic(stream, params) do
      req =
        if sec_websocket do
          :cowboy_req.set_resp_header("sec-websocket-protocol", sec_websocket, req)
        else
          req
        end

      {:cowboy_websocket, req, %{user: user, topic: topic}, %{idle_timeout: @timeout}}
    else
      {:error, code} ->
        Logger.debug("#{__MODULE__} denied connection: #{inspect(code)} - #{inspect(req)}")
        {:ok, req} = :cowboy_req.reply(code, req)
        {:ok, req, state}

      error ->
        Logger.debug("#{__MODULE__} denied connection: #{inspect(error)} - #{inspect(req)}")
        {:ok, req} = :cowboy_req.reply(400, req)
        {:ok, req, state}
    end
  end

  def websocket_init(state) do
    send(self(), :subscribe)
    {:ok, state}
  end

  # We never receive messages.
  def websocket_handle(_frame, state) do
    {:ok, state}
  end

  def websocket_info(:subscribe, state) do
    Logger.debug(
      "#{__MODULE__} accepted websocket connection for user #{
        (state.user || %{id: "anonymous"}).id
      }, topic #{state.topic}"
    )

    Streamer.add_socket(state.topic, streamer_socket(state))
    {:ok, state}
  end

  def websocket_info({:text, message}, state) do
    {:reply, {:text, message}, state}
  end

  def terminate(reason, _req, state) do
    Logger.debug(
      "#{__MODULE__} terminating websocket connection for user #{
        (state.user || %{id: "anonymous"}).id
      }, topic #{state.topic || "?"}: #{inspect(reason)}"
    )

    Streamer.remove_socket(state.topic, streamer_socket(state))
    :ok
  end

  # Public streams without authentication.
  defp allow_request(stream, [nil, nil]) when stream in @anonymous_streams do
    {:ok, nil}
  end

  # Authenticated streams.
  defp allow_request(stream, [access_token, sec_websocket]) when stream in @streams do
    token =
      with {"access_token", token} <- access_token do
        token
      else
        _ -> sec_websocket
      end

    with true <- is_bitstring(token),
         %Token{user_id: user_id} <- Repo.get_by(Token, token: token),
         user = %User{} <- User.get_cached_by_id(user_id) do
      {:ok, user}
    else
      _ -> {:error, 403}
    end
  end

  # Not authenticated.
  defp allow_request(stream, _) when stream in @streams, do: {:error, 403}

  # No matching stream.
  defp allow_request(_, _), do: {:error, 404}

  defp expand_topic("hashtag", params) do
    case List.keyfind(params, "tag", 0) do
      {_, tag} -> "hashtag:#{tag}"
      _ -> nil
    end
  end

  defp expand_topic("list", params) do
    case List.keyfind(params, "list", 0) do
      {_, list} -> "list:#{list}"
      _ -> nil
    end
  end

  defp expand_topic(topic, _), do: topic

  defp streamer_socket(state) do
    %{transport_pid: self(), assigns: state}
  end
end