nostr_types/client/
connection.rs1use super::AuthState;
2use crate::{ClientMessage, Error, Event, Filter, RelayMessage, SubscriptionId};
3use base64::Engine;
4use futures_util::stream::SplitSink;
5use futures_util::{SinkExt, StreamExt};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tracing::{event, span, Level};
11use tungstenite::protocol::Message;
12
13type Ws =
15 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
16
17#[derive(Debug)]
24pub struct ClientConnection {
25 sink: Mutex<SplitSink<Ws, Message>>,
27
28 next_sub_id: AtomicUsize,
30
31 auth_state: Arc<RwLock<AuthState>>,
33
34 incoming: Arc<RwLock<Vec<RelayMessage>>>,
39
40 wake: Arc<Notify>,
44
45 disconnected: Arc<AtomicBool>,
47}
48
49impl ClientConnection {
50 pub async fn new(relay_url: &str, timeout: Duration) -> Result<ClientConnection, Error> {
52 let incoming: Arc<RwLock<Vec<RelayMessage>>> = Arc::new(RwLock::new(Vec::new()));
53 Self::new_with_data(relay_url, timeout, incoming).await
54 }
55
56 pub async fn new_with_data(
58 relay_url: &str,
59 timeout: Duration,
60 incoming: Arc<RwLock<Vec<RelayMessage>>>,
61 ) -> Result<ClientConnection, Error> {
62 let (host, uri) = super::url_to_host_and_uri(relay_url)?;
63 let key: [u8; 16] = rand::random();
64 let request = http::request::Request::builder()
65 .method("GET")
66 .header("Host", host)
67 .header("Connection", "Upgrade")
68 .header("Upgrade", "websocket")
69 .header("Sec-WebSocket-Version", "13")
70 .header(
71 "Sec-WebSocket-Key",
72 base64::engine::general_purpose::STANDARD.encode(key),
73 )
74 .uri(uri)
75 .body(())?;
76
77 let (websocket, response) =
78 tokio::time::timeout(timeout, tokio_tungstenite::connect_async(request)).await??;
79
80 let status = response.status();
81 if status.is_redirection() || status.is_client_error() || status.is_server_error() {
82 return Err(Error::WebsocketConnectionFailed(status));
83 }
84
85 let (sink, mut stream) = websocket.split();
87
88 let incoming2 = incoming.clone();
89
90 let disconnected: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
91 let disconnected2 = disconnected.clone();
92
93 let wake = Arc::new(Notify::new());
94 let wake2 = wake.clone();
95
96 let auth_state = Arc::new(RwLock::new(AuthState::NotYetRequested));
97 let auth_state2 = auth_state.clone();
98
99 let _listener_task = tokio::task::spawn(Box::pin(async move {
101 let span = span!(Level::DEBUG, "connection listener thread");
102 let _enter = span.enter();
103 while let Some(message) = stream.next().await {
104 match message {
105 Ok(Message::Text(s)) => {
106 match serde_json::from_str(&s) {
107 Ok(rm) => {
108 match rm {
110 RelayMessage::Auth(challenge) => {
111 let mut lock = auth_state2.write().await;
112
113 event!(Level::DEBUG, "AUTH CHALLENGED");
114 *lock = AuthState::Challenged(challenge);
115
116 wake2.notify_waiters();
118 continue;
119 }
120 RelayMessage::Ok(id, is_ok, ref reason) => {
121 let mut lock = auth_state2.write().await;
122 if let AuthState::InProgress(sent_id) = *lock {
123 if id == sent_id {
124 *lock = if is_ok {
125 AuthState::Success
126 } else {
127 AuthState::Failure(reason.clone())
128 };
129 wake2.notify_waiters();
131 continue;
132 }
133 }
134 }
135 _ => {}
136 }
137
138 (*incoming2.write().await).push(rm);
139 wake2.notify_waiters();
140 }
141 Err(e) => {
142 event!(Level::INFO, "websocket wessage failed to deserialize: {e}")
143 }
144 }
145 }
146 Ok(Message::Close(_)) => {
147 break;
148 }
149 Ok(_) => { }
150 Err(e) => {
151 event!(Level::ERROR, "{e}");
152 break;
153 }
154 }
155 }
156
157 disconnected2.store(true, Ordering::Relaxed);
158 }));
159
160 Ok(ClientConnection {
161 sink: Mutex::new(sink),
162 next_sub_id: AtomicUsize::new(0),
163 auth_state,
164 incoming,
165 wake,
166 disconnected,
167 })
168 }
169
170 pub fn is_disconnected(&self) -> bool {
172 self.disconnected.load(Ordering::Relaxed)
173 }
174
175 pub async fn disconnect(self) -> Result<(), Error> {
177 let msg = Message::Close(None);
178 let mut sink = self.sink.lock().await;
179 sink.send(msg).await?;
180 sink.close().await?;
181 Ok(())
182 }
183
184 pub fn incoming(&self) -> Arc<RwLock<Vec<RelayMessage>>> {
186 self.incoming.clone()
187 }
188
189 fn fail_if_disconnected(&self) -> Result<(), Error> {
190 if self.disconnected.load(Ordering::Relaxed) {
191 Err(Error::Disconnected)
192 } else {
193 Ok(())
194 }
195 }
196
197 pub async fn subscribe(&self, filter: Filter) -> Result<SubscriptionId, Error> {
199 self.fail_if_disconnected()?;
200 let sub_id_usize = self.next_sub_id.fetch_add(1, Ordering::Relaxed);
201 let sub_id = SubscriptionId(format!("sub{}", sub_id_usize));
202 let client_message = ClientMessage::Req(sub_id.clone(), filter.clone());
203 self.send_message(client_message).await?;
204 Ok(sub_id)
205 }
206
207 pub async fn close_subscription(&self, sub_id: SubscriptionId) -> Result<(), Error> {
209 self.fail_if_disconnected()?;
210 let client_message = ClientMessage::Close(sub_id);
211 self.send_message(client_message).await?;
212 Ok(())
213 }
214
215 pub async fn send_message(&self, message: ClientMessage) -> Result<(), Error> {
217 let wire = serde_json::to_string(&message)?;
218 let msg = Message::Text(wire.into());
219 self.inner_send_message(msg).await?;
220 Ok(())
221 }
222
223 pub async fn send_ws_message(&self, message: Message) -> Result<(), Error> {
225 self.inner_send_message(message).await?;
226 Ok(())
227 }
228
229 async fn inner_send_message(&self, msg: Message) -> Result<(), Error> {
230 self.fail_if_disconnected()?;
231 if let Err(e) = self.sink.lock().await.send(msg).await {
232 self.disconnected.store(true, Ordering::Relaxed);
233 Err(e)?
234 } else {
235 Ok(())
236 }
237 }
238
239 pub async fn wait_for_relay_message<P>(
244 &self,
245 predicate: P,
246 timeout: Duration,
247 ) -> Result<RelayMessage, Error>
248 where
249 P: Fn(&RelayMessage) -> bool,
250 {
251 loop {
252 {
253 let mut incoming = self.incoming.write().await;
255 if let Some(found) = incoming.iter().position(&predicate) {
256 let relay_message = incoming.remove(found);
257 return Ok(relay_message);
258 }
259 }
260
261 tokio::time::timeout(timeout, self.wake.notified()).await?;
263
264 self.fail_if_disconnected()?;
265 }
266 }
267
268 pub async fn get_auth_state(&self) -> AuthState {
270 self.auth_state.read().await.clone()
271 }
272
273 pub async fn wait_for_auth_state_change(&self, timeout: Duration) -> Result<AuthState, Error> {
275 let start = self.auth_state.read().await.clone();
276 loop {
277 let current = self.auth_state.read().await.clone();
278 if current != start {
279 return Ok(current);
280 }
281
282 tokio::time::timeout(timeout, self.wake.notified()).await?;
284
285 self.fail_if_disconnected()?;
286 }
287 }
288
289 pub async fn send_authenticate(&self, event: Event) -> Result<(), Error> {
291 *self.auth_state.write().await = AuthState::InProgress(event.id);
292 self.send_message(ClientMessage::Auth(Box::new(event)))
293 .await?;
294 Ok(())
295 }
296}