nostr_types/client/
mod.rs

1use crate::{
2    ClientMessage, Error, Event, EventKind, Filter, Id, PreEvent, RelayInformationDocument,
3    RelayMessage, Signer, SubscriptionId, Tag, Unixtime,
4};
5use http::Uri;
6use std::ops::DerefMut;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tungstenite::protocol::Message;
11
12mod auth;
13pub use auth::AuthState;
14
15mod connection;
16pub use connection::ClientConnection;
17
18/// A client connection to a relay.
19#[derive(Debug)]
20pub struct Client {
21    // read-only URL of the remote relay
22    relay_url: String,
23
24    // The connection information
25    // We only write-lock this to create or disconnect. Normal operations are
26    // read-locked with multiple readers allowed at once.
27    connection: RwLock<Option<ClientConnection>>,
28}
29
30impl Client {
31    /// Connect to a relay
32    pub fn new(relay_url: &str) -> Client {
33        Client {
34            relay_url: relay_url.to_string(),
35            connection: RwLock::new(None),
36        }
37    }
38
39    /// Is connected
40    pub async fn is_connected(&self) -> bool {
41        if let Some(ref cc) = *self.connection.read().await {
42            !cc.is_disconnected()
43        } else {
44            false
45        }
46    }
47
48    /// Reconnect to the relay if needed
49    async fn maybe_reconnect(&self, reconnect_timeout: Duration) -> Result<(), Error> {
50        let maybe_data = if let Some(ref cc) = *self.connection.read().await {
51            if cc.is_disconnected() {
52                Some(cc.incoming())
53            } else {
54                return Ok(());
55            }
56        } else {
57            None
58        };
59
60        match maybe_data {
61            Some(data) => {
62                let new_cc =
63                    ClientConnection::new_with_data(&self.relay_url, reconnect_timeout, data)
64                        .await?;
65                *self.connection.write().await = Some(new_cc);
66            }
67            None => {
68                let cc = ClientConnection::new(&self.relay_url, reconnect_timeout).await?;
69                *self.connection.write().await = Some(cc);
70            }
71        }
72
73        Ok(())
74    }
75
76    /// Disconnect from the relay
77    pub async fn disconnect(&self) -> Result<(), Error> {
78        let cc = std::mem::take(self.connection.write().await.deref_mut());
79        if let Some(cc) = cc {
80            cc.disconnect().await?
81        }
82        Ok(())
83    }
84
85    /// Get auth state
86    pub async fn get_auth_state(&self) -> Result<AuthState, Error> {
87        let lock = self.connection.read().await;
88        let Some(ref cc) = *lock else {
89            return Err(Error::Disconnected);
90        };
91        Ok(cc.get_auth_state().await)
92    }
93
94    /// Wait for auth state
95    pub async fn wait_for_auth_state_change(&self, timeout: Duration) -> Result<AuthState, Error> {
96        let lock = self.connection.read().await;
97        let Some(ref cc) = *lock else {
98            return Err(Error::Disconnected);
99        };
100        cc.wait_for_auth_state_change(timeout).await
101    }
102
103    /// Authenticate
104    /// This does not wait for any reply.
105    pub async fn send_authenticate(
106        &self,
107        challenge: String,
108        signer: Arc<dyn Signer>,
109        reconnect_timeout: Duration,
110    ) -> Result<Id, Error> {
111        let pre_event = PreEvent {
112            pubkey: signer.public_key(),
113            created_at: Unixtime::now(),
114            kind: EventKind::Auth,
115            tags: vec![
116                Tag::new(&["relay", &self.relay_url]),
117                Tag::new(&["challenge", &challenge]),
118            ],
119            content: "".to_string(),
120        };
121        let event = signer.sign_event(pre_event).await?;
122        let id = event.id;
123        self.maybe_reconnect(reconnect_timeout).await?;
124        let lock = self.connection.read().await;
125        let Some(ref cc) = *lock else {
126            return Err(Error::Disconnected);
127        };
128        cc.send_authenticate(event).await?;
129        Ok(id)
130    }
131
132    /// Full Authentication process.
133    /// Run this when you get "auth-required" in an OK or CLOSED message if you
134    /// wish to authenticate.
135    pub async fn full_authenticate(
136        &self,
137        signer: Arc<dyn Signer>,
138        timeout: Duration,
139    ) -> Result<(), Error> {
140        match self.get_auth_state().await? {
141            AuthState::NotYetRequested => Err(Error::RelayDidNotAuth),
142            AuthState::Challenged(ch) => {
143                let _ = self.send_authenticate(ch, signer, timeout).await?;
144                let auth_state = self.wait_for_auth_state_change(timeout).await?;
145                match auth_state {
146                    AuthState::Success => Ok(()),
147                    AuthState::Failure(_) => Err(Error::RelayRejectedAuth),
148                    _ => Err(Error::InvalidState(
149                        "AuthState in unexpected state".to_owned(),
150                    )),
151                }
152            }
153            AuthState::InProgress(_id) => {
154                let auth_state = self.wait_for_auth_state_change(timeout).await?;
155                match auth_state {
156                    AuthState::Success => Ok(()),
157                    AuthState::Failure(_) => Err(Error::RelayRejectedAuth),
158                    _ => Err(Error::InvalidState(
159                        "AuthState in unexpected state".to_owned(),
160                    )),
161                }
162            }
163            AuthState::Success => Err(Error::RelayForgotAuth),
164            AuthState::Failure(_) => Err(Error::RelayRejectedPost),
165        }
166    }
167
168    /// Post an event to the relay
169    pub async fn post_event(&self, event: Event, reconnect_timeout: Duration) -> Result<(), Error> {
170        let message = ClientMessage::Event(Box::new(event));
171        self.maybe_reconnect(reconnect_timeout).await?;
172        let lock = self.connection.read().await;
173        let Some(ref cc) = *lock else {
174            return Err(Error::Disconnected);
175        };
176        cc.send_message(message).await?;
177        Ok(())
178    }
179
180    /// Post a raw event to the relay
181    pub async fn post_raw_event(
182        &self,
183        json: String,
184        reconnect_timeout: Duration,
185    ) -> Result<(), Error> {
186        let wire = format!("[\"EVENT\",{}]", json);
187        let msg = Message::Text(wire.into());
188        self.maybe_reconnect(reconnect_timeout).await?;
189        let lock = self.connection.read().await;
190        let Some(ref cc) = *lock else {
191            return Err(Error::Disconnected);
192        };
193        cc.send_ws_message(msg).await?;
194        Ok(())
195    }
196
197    /// This posts the event, and waits for the OK result, authenticating
198    /// if requested if auth is Some.
199    pub async fn post_event_and_wait_for_result(
200        &self,
201        event: Event,
202        timeout: Duration,
203        auth: Option<Arc<dyn Signer>>,
204    ) -> Result<(bool, String), Error> {
205        self.post_event(event.clone(), timeout).await?;
206        let (ok, why) = self.wait_for_ok(event.id, timeout).await?;
207        if !ok && why.starts_with("auth-required:") {
208            match auth {
209                None => Err(Error::RelayRequiresAuth),
210                Some(signer) => {
211                    self.full_authenticate(signer, timeout).await?;
212                    self.post_event(event.clone(), timeout).await?;
213                    let (ok, why) = self.wait_for_ok(event.id, timeout).await?;
214                    Ok((ok, why))
215                }
216            }
217        } else {
218            Ok((ok, why))
219        }
220    }
221
222    /// Wait for an Ok
223    pub async fn wait_for_ok(&self, id: Id, timeout: Duration) -> Result<(bool, String), Error> {
224        let lock = self.connection.read().await;
225        let Some(ref cc) = *lock else {
226            return Err(Error::Disconnected);
227        };
228        let rm = cc
229            .wait_for_relay_message(
230                |rm| matches!(rm, RelayMessage::Ok(i, _, _) if *i==id),
231                timeout,
232            )
233            .await?;
234
235        match rm {
236            RelayMessage::Ok(_, ok, msg) => Ok((ok, msg)),
237            _ => unreachable!(),
238        }
239    }
240
241    /// Subscribe to a filter. This does not wait for results.
242    pub async fn subscribe(
243        &self,
244        filter: Filter,
245        reconnect_timeout: Duration,
246    ) -> Result<SubscriptionId, Error> {
247        self.maybe_reconnect(reconnect_timeout).await?;
248        let lock = self.connection.read().await;
249        let Some(ref cc) = *lock else {
250            return Err(Error::Disconnected);
251        };
252        cc.subscribe(filter).await
253    }
254
255    /// Close a subscription
256    pub async fn close_subscription(&self, sub_id: SubscriptionId) -> Result<(), Error> {
257        let lock = self.connection.read().await;
258        let Some(ref cc) = *lock else {
259            return Err(Error::Disconnected);
260        };
261        cc.close_subscription(sub_id).await
262    }
263
264    /// Wait for an event on the given subscription
265    pub async fn wait_for_subscribed_event(
266        &self,
267        sub_id: SubscriptionId,
268        timeout: Duration,
269    ) -> Result<Event, Error> {
270        let lock = self.connection.read().await;
271        let Some(ref cc) = *lock else {
272            return Err(Error::Disconnected);
273        };
274        let rm = cc
275            .wait_for_relay_message(
276                |rm| matches!(rm, RelayMessage::Event(s, _) if *s==sub_id),
277                timeout,
278            )
279            .await?;
280        match rm {
281            RelayMessage::Event(_, event) => Ok(*event),
282            _ => unreachable!(),
283        }
284    }
285
286    /// Wait for an event on the given subscription
287    pub async fn wait_for_subscribed_event_or_eose(
288        &self,
289        sub_id: SubscriptionId,
290        timeout: Duration,
291    ) -> Result<Option<Event>, Error> {
292        let lock = self.connection.read().await;
293        let Some(ref cc) = *lock else {
294            return Err(Error::Disconnected);
295        };
296        let rm = cc
297            .wait_for_relay_message(
298                |rm| {
299                    matches!(rm, RelayMessage::Event(s, _) if *s==sub_id)
300                        || matches!(rm, RelayMessage::Eose(s) if *s==sub_id)
301                },
302                timeout,
303            )
304            .await?;
305        match rm {
306            RelayMessage::Event(_, event) => Ok(Some(*event)),
307            RelayMessage::Eose(_) => Ok(None),
308            _ => unreachable!(),
309        }
310    }
311
312    /// Subscribe and collect all results up to the EOSE
313    pub async fn subscribe_and_wait_for_events(
314        &self,
315        filter: Filter,
316        timeout: Duration,
317        signer: Option<Arc<dyn Signer>>,
318    ) -> Result<Vec<Event>, Error> {
319        let mut output: Vec<Event> = Vec::new();
320
321        let mut sub_id = self.subscribe(filter.clone(), timeout).await?;
322
323        let lock = self.connection.read().await;
324        let Some(ref cc) = *lock else {
325            return Err(Error::Disconnected);
326        };
327
328        loop {
329            // Wait for any of EVENT or EOSE or CLOSED on this subscription_id
330            let rm = cc
331                .wait_for_relay_message(
332                    |rm| {
333                        matches!(rm, RelayMessage::Event(sid, _) if *sid==sub_id)
334                            || matches!(rm, RelayMessage::Eose(sid) if *sid==sub_id)
335                            || matches!(rm, RelayMessage::Closed(sid, _) if *sid==sub_id)
336                    },
337                    timeout,
338                )
339                .await?;
340
341            match rm {
342                RelayMessage::Event(_, event) => output.push(*event),
343                RelayMessage::Eose(_) => return Ok(output),
344                RelayMessage::Closed(_, message) => {
345                    if message.starts_with("auth-required:") {
346                        match signer {
347                            Some(ref signer) => {
348                                self.full_authenticate(signer.clone(), timeout).await?;
349                                sub_id = self.subscribe(filter.clone(), timeout).await?;
350                                continue;
351                            }
352                            None => {
353                                return Err(Error::RelayRequiresAuth);
354                            }
355                        }
356                    }
357                }
358                _ => unreachable!(),
359            }
360        }
361    }
362}
363
364fn url_to_host_and_uri(url: &str) -> Result<(String, Uri), Error> {
365    let uri: http::Uri = url.parse::<http::Uri>()?;
366    let authority = match uri.authority() {
367        Some(auth) => auth.as_str(),
368        None => return Err(Error::Url(url.to_string())),
369    };
370    let host = authority
371        .find('@')
372        .map(|idx| authority.split_at(idx + 1).1)
373        .unwrap_or_else(|| authority);
374    if host.is_empty() {
375        Err(Error::Url(url.to_string()))
376    } else {
377        Ok((host.to_owned(), uri))
378    }
379}
380
381/// Fetch a NIP-11 for a relay
382pub async fn fetch_nip11(relay_url: &str) -> Result<RelayInformationDocument, Error> {
383    use reqwest::redirect::Policy;
384    use reqwest::Client;
385    use std::time::Duration;
386
387    let (host, uri) = url_to_host_and_uri(relay_url)?;
388    let scheme = match uri.scheme() {
389        Some(refscheme) => match refscheme.as_str() {
390            "wss" => "https",
391            "ws" => "http",
392            u => panic!("Unknown scheme {}", u),
393        },
394        None => panic!("Relay URL has no scheme."),
395    };
396    let url = format!("{}://{}{}", scheme, host, uri.path());
397    let client = Client::builder()
398        .redirect(Policy::none())
399        .connect_timeout(Duration::from_secs(60))
400        .timeout(Duration::from_secs(60))
401        .connection_verbose(true)
402        .build()?;
403    let response = client
404        .get(url)
405        .header("Host", host)
406        .header("Accept", "application/nostr+json")
407        .send()
408        .await?;
409    let json = response.text().await?;
410    let rid: RelayInformationDocument = serde_json::from_str(&json)?;
411    Ok(rid)
412}