nostr_types/client/
connection.rs

1use super::AuthState;
2use crate::{ClientMessage, Error, Event, Filter, RelayMessage, SubscriptionId};
3use base64::Engine;
4use futures_util::stream::SplitSink;
5use futures_util::{SinkExt, StreamExt};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tracing::{event, span, Level};
11use tungstenite::protocol::Message;
12
13/// A WebSocket
14type Ws =
15    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
16
17/// A live connection to a relay, and all related state.
18///
19/// This connects when created, but may persist beyond disconnection, so we can't say that
20/// it is always connected. Reconnection is not done here; if it becomes disconnected and you
21/// want to reconnect then you should drop and recreate to reconnect (and probably take the
22/// Incoming data to not lose it)
23#[derive(Debug)]
24pub struct ClientConnection {
25    // Send messages with this
26    sink: Mutex<SplitSink<Ws, Message>>,
27
28    // Keeps subscription ids unique
29    next_sub_id: AtomicUsize,
30
31    // Authentication data
32    auth_state: Arc<RwLock<AuthState>>,
33
34    // The listener (stream) task handle
35    // listener_task: JoinHandle<()>,
36
37    // Incoming messages deposited by the listener task
38    incoming: Arc<RwLock<Vec<RelayMessage>>>,
39
40    // A signal that a new message has arrived, OR
41    // that the connection has been closed, OR
42    // that the authentication state has changed
43    wake: Arc<Notify>,
44
45    // Disconnection data
46    disconnected: Arc<AtomicBool>,
47}
48
49impl ClientConnection {
50    /// Create a new ClientConnection by connecting.
51    pub async fn new(relay_url: &str, timeout: Duration) -> Result<ClientConnection, Error> {
52        let incoming: Arc<RwLock<Vec<RelayMessage>>> = Arc::new(RwLock::new(Vec::new()));
53        Self::new_with_data(relay_url, timeout, incoming).await
54    }
55
56    /// Create a new ClientConnectdion by connecting, preserving data from a previous connection.
57    pub async fn new_with_data(
58        relay_url: &str,
59        timeout: Duration,
60        incoming: Arc<RwLock<Vec<RelayMessage>>>,
61    ) -> Result<ClientConnection, Error> {
62        let (host, uri) = super::url_to_host_and_uri(relay_url)?;
63        let key: [u8; 16] = rand::random();
64        let request = http::request::Request::builder()
65            .method("GET")
66            .header("Host", host)
67            .header("Connection", "Upgrade")
68            .header("Upgrade", "websocket")
69            .header("Sec-WebSocket-Version", "13")
70            .header(
71                "Sec-WebSocket-Key",
72                base64::engine::general_purpose::STANDARD.encode(key),
73            )
74            .uri(uri)
75            .body(())?;
76
77        let (websocket, response) =
78            tokio::time::timeout(timeout, tokio_tungstenite::connect_async(request)).await??;
79
80        let status = response.status();
81        if status.is_redirection() || status.is_client_error() || status.is_server_error() {
82            return Err(Error::WebsocketConnectionFailed(status));
83        }
84
85        // Split the websocket
86        let (sink, mut stream) = websocket.split();
87
88        let incoming2 = incoming.clone();
89
90        let disconnected: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
91        let disconnected2 = disconnected.clone();
92
93        let wake = Arc::new(Notify::new());
94        let wake2 = wake.clone();
95
96        let auth_state = Arc::new(RwLock::new(AuthState::NotYetRequested));
97        let auth_state2 = auth_state.clone();
98
99        // Start a task to handle the incoming stream
100        let _listener_task = tokio::task::spawn(Box::pin(async move {
101            let span = span!(Level::DEBUG, "connection listener thread");
102            let _enter = span.enter();
103            while let Some(message) = stream.next().await {
104                match message {
105                    Ok(Message::Text(s)) => {
106                        match serde_json::from_str(&s) {
107                            Ok(rm) => {
108                                // Maybe update authentication state
109                                match rm {
110                                    RelayMessage::Auth(challenge) => {
111                                        let mut lock = auth_state2.write().await;
112
113                                        event!(Level::DEBUG, "AUTH CHALLENGED");
114                                        *lock = AuthState::Challenged(challenge);
115
116                                        // No need to store into incoming
117                                        wake2.notify_waiters();
118                                        continue;
119                                    }
120                                    RelayMessage::Ok(id, is_ok, ref reason) => {
121                                        let mut lock = auth_state2.write().await;
122                                        if let AuthState::InProgress(sent_id) = *lock {
123                                            if id == sent_id {
124                                                *lock = if is_ok {
125                                                    AuthState::Success
126                                                } else {
127                                                    AuthState::Failure(reason.clone())
128                                                };
129                                                // No need to store into incoming
130                                                wake2.notify_waiters();
131                                                continue;
132                                            }
133                                        }
134                                    }
135                                    _ => {}
136                                }
137
138                                (*incoming2.write().await).push(rm);
139                                wake2.notify_waiters();
140                            }
141                            Err(e) => {
142                                event!(Level::INFO, "websocket wessage failed to deserialize: {e}")
143                            }
144                        }
145                    }
146                    Ok(Message::Close(_)) => {
147                        break;
148                    }
149                    Ok(_) => { }
150                    Err(e) => {
151                        event!(Level::ERROR, "{e}");
152                        break;
153                    }
154                }
155            }
156
157            disconnected2.store(true, Ordering::Relaxed);
158        }));
159
160        Ok(ClientConnection {
161            sink: Mutex::new(sink),
162            next_sub_id: AtomicUsize::new(0),
163            auth_state,
164            incoming,
165            wake,
166            disconnected,
167        })
168    }
169
170    /// Is disconnected
171    pub fn is_disconnected(&self) -> bool {
172        self.disconnected.load(Ordering::Relaxed)
173    }
174
175    /// Disconnect from the relay, consuming self
176    pub async fn disconnect(self) -> Result<(), Error> {
177        let msg = Message::Close(None);
178        let mut sink = self.sink.lock().await;
179        sink.send(msg).await?;
180        sink.close().await?;
181        Ok(())
182    }
183
184    /// Copy an Arc reference to the Incoming relay messages.
185    pub fn incoming(&self) -> Arc<RwLock<Vec<RelayMessage>>> {
186        self.incoming.clone()
187    }
188
189    fn fail_if_disconnected(&self) -> Result<(), Error> {
190        if self.disconnected.load(Ordering::Relaxed) {
191            Err(Error::Disconnected)
192        } else {
193            Ok(())
194        }
195    }
196
197    /// Subscribe to a filter. This does not wait for results.
198    pub async fn subscribe(&self, filter: Filter) -> Result<SubscriptionId, Error> {
199        self.fail_if_disconnected()?;
200        let sub_id_usize = self.next_sub_id.fetch_add(1, Ordering::Relaxed);
201        let sub_id = SubscriptionId(format!("sub{}", sub_id_usize));
202        let client_message = ClientMessage::Req(sub_id.clone(), filter.clone());
203        self.send_message(client_message).await?;
204        Ok(sub_id)
205    }
206
207    /// Close a subscription
208    pub async fn close_subscription(&self, sub_id: SubscriptionId) -> Result<(), Error> {
209        self.fail_if_disconnected()?;
210        let client_message = ClientMessage::Close(sub_id);
211        self.send_message(client_message).await?;
212        Ok(())
213    }
214
215    /// Send a `ClientMessage`
216    pub async fn send_message(&self, message: ClientMessage) -> Result<(), Error> {
217        let wire = serde_json::to_string(&message)?;
218        let msg = Message::Text(wire.into());
219        self.inner_send_message(msg).await?;
220        Ok(())
221    }
222
223    /// Send a websocket `Message`
224    pub async fn send_ws_message(&self, message: Message) -> Result<(), Error> {
225        self.inner_send_message(message).await?;
226        Ok(())
227    }
228
229    async fn inner_send_message(&self, msg: Message) -> Result<(), Error> {
230        self.fail_if_disconnected()?;
231        if let Err(e) = self.sink.lock().await.send(msg).await {
232            self.disconnected.store(true, Ordering::Relaxed);
233            Err(e)?
234        } else {
235            Ok(())
236        }
237    }
238
239    /// Wait for some matching RelayMessage.
240    ///
241    /// The timeout will be reset when any event happens, so it make take
242    /// longer than the timeout to give up.
243    pub async fn wait_for_relay_message<P>(
244        &self,
245        predicate: P,
246        timeout: Duration,
247    ) -> Result<RelayMessage, Error>
248    where
249        P: Fn(&RelayMessage) -> bool,
250    {
251        loop {
252            {
253                // Check incoming for a match
254                let mut incoming = self.incoming.write().await;
255                if let Some(found) = incoming.iter().position(&predicate) {
256                    let relay_message = incoming.remove(found);
257                    return Ok(relay_message);
258                }
259            }
260
261            // Wait for something to happen, or timeout
262            tokio::time::timeout(timeout, self.wake.notified()).await?;
263
264            self.fail_if_disconnected()?;
265        }
266    }
267
268    /// Get AuthState
269    pub async fn get_auth_state(&self) -> AuthState {
270        self.auth_state.read().await.clone()
271    }
272
273    /// Wait for the given AuthState to occur.
274    pub async fn wait_for_auth_state_change(&self, timeout: Duration) -> Result<AuthState, Error> {
275        let start = self.auth_state.read().await.clone();
276        loop {
277            let current = self.auth_state.read().await.clone();
278            if current != start {
279                return Ok(current);
280            }
281
282            // Wait for something to happen, or timeout
283            tokio::time::timeout(timeout, self.wake.notified()).await?;
284
285            self.fail_if_disconnected()?;
286        }
287    }
288
289    /// Authenticate
290    pub async fn send_authenticate(&self, event: Event) -> Result<(), Error> {
291        *self.auth_state.write().await = AuthState::InProgress(event.id);
292        self.send_message(ClientMessage::Auth(Box::new(event)))
293            .await?;
294        Ok(())
295    }
296}