@@ -236,6 +236,37 @@ def run_server():
236236 async_interpreter .server .run ()
237237
238238
239+ async def wait_for_websocket_complete (websocket , max_attempts = 5 ):
240+ """Wait for WebSocket 'complete' status message with retry limit."""
241+
242+ import asyncio
243+ import json
244+
245+ accumulated_content = ""
246+
247+ for attempt in range (1 , max_attempts + 1 ):
248+ try :
249+ message = await websocket .recv ()
250+ message_data = json .loads (message )
251+ if "error" in message_data :
252+ raise Exception (message_data ["content" ])
253+ print ("Received from WebSocket:" , message_data )
254+ if type (message_data .get ("content" )) == str :
255+ accumulated_content += message_data .get ("content" )
256+ if message_data == {
257+ "role" : "server" ,
258+ "type" : "status" ,
259+ "content" : "complete" ,
260+ }:
261+ print ("Received expected message from server" )
262+ return accumulated_content
263+ except Exception as e :
264+ print (f"WebSocket receive failed (attempt { attempt } /{ max_attempts } ): { e } " )
265+ await asyncio .sleep (1 )
266+ else :
267+ raise Exception (f"Never received 'complete' status after { max_attempts } attempts" )
268+
269+
239270# @pytest.mark.skip(reason="Requires uvicorn, which we don't require by default")
240271def test_server ():
241272 # Start the server in a new process
@@ -299,22 +330,7 @@ async def test_fastapi_server():
299330 print ("WebSocket chunks sent" )
300331
301332 # Wait for a specific response
302- accumulated_content = ""
303- while True :
304- message = await websocket .recv ()
305- message_data = json .loads (message )
306- if "error" in message_data :
307- raise Exception (message_data ["content" ])
308- print ("Received from WebSocket:" , message_data )
309- if type (message_data .get ("content" )) == str :
310- accumulated_content += message_data .get ("content" )
311- if message_data == {
312- "role" : "server" ,
313- "type" : "status" ,
314- "content" : "complete" ,
315- }:
316- print ("Received expected message from server" )
317- break
333+ accumulated_content = await wait_for_websocket_complete (websocket )
318334
319335 assert "crunk" in accumulated_content
320336
@@ -355,22 +371,7 @@ async def test_fastapi_server():
355371 print ("WebSocket chunks sent" )
356372
357373 # Wait for a specific response
358- accumulated_content = ""
359- while True :
360- message = await websocket .recv ()
361- message_data = json .loads (message )
362- if "error" in message_data :
363- raise Exception (message_data ["content" ])
364- print ("Received from WebSocket:" , message_data )
365- if message_data .get ("content" ):
366- accumulated_content += message_data .get ("content" )
367- if message_data == {
368- "role" : "server" ,
369- "type" : "status" ,
370- "content" : "complete" ,
371- }:
372- print ("Received expected message from server" )
373- break
374+ accumulated_content = await wait_for_websocket_complete (websocket )
374375
375376 assert "barloney" in accumulated_content
376377
@@ -404,22 +405,7 @@ async def test_fastapi_server():
404405 print ("WebSocket chunks sent" )
405406
406407 # Wait for response
407- accumulated_content = ""
408- while True :
409- message = await websocket .recv ()
410- message_data = json .loads (message )
411- if "error" in message_data :
412- raise Exception (message_data ["content" ])
413- print ("Received from WebSocket:" , message_data )
414- if message_data .get ("content" ):
415- accumulated_content += message_data .get ("content" )
416- if message_data == {
417- "role" : "server" ,
418- "type" : "status" ,
419- "content" : "complete" ,
420- }:
421- print ("Received expected message from server" )
422- break
408+ accumulated_content = await wait_for_websocket_complete (websocket )
423409
424410 time .sleep (5 )
425411
@@ -454,23 +440,7 @@ async def test_fastapi_server():
454440 )
455441
456442 # Wait for a specific response
457- accumulated_content = ""
458- while True :
459- message = await websocket .recv ()
460- message_data = json .loads (message )
461- if "error" in message_data :
462- raise Exception (message_data ["content" ])
463- print ("Received from WebSocket:" , message_data )
464- if message_data .get ("content" ):
465- if type (message_data .get ("content" )) == str :
466- accumulated_content += message_data .get ("content" )
467- if message_data == {
468- "role" : "server" ,
469- "type" : "status" ,
470- "content" : "complete" ,
471- }:
472- print ("Received expected message from server" )
473- break
443+ accumulated_content = await wait_for_websocket_complete (websocket )
474444
475445 assert "18893094989" in accumulated_content .replace ("," , "" )
476446
@@ -525,22 +495,7 @@ async def test_fastapi_server():
525495 print ("WebSocket chunks sent" )
526496
527497 # Wait for response
528- accumulated_content = ""
529- while True :
530- message = await websocket .recv ()
531- message_data = json .loads (message )
532- if "error" in message_data :
533- raise Exception (message_data ["content" ])
534- print ("Received from WebSocket:" , message_data )
535- if type (message_data .get ("content" )) == str :
536- accumulated_content += message_data .get ("content" )
537- if message_data == {
538- "role" : "server" ,
539- "type" : "status" ,
540- "content" : "complete" ,
541- }:
542- print ("Received expected message from server" )
543- break
498+ accumulated_content = await wait_for_websocket_complete (websocket )
544499
545500 # Get messages
546501 get_url = "http://localhost:8000/settings/messages"
@@ -602,22 +557,7 @@ async def test_fastapi_server():
602557 print ("WebSocket chunks sent" )
603558
604559 # Wait for response
605- accumulated_content = ""
606- while True :
607- message = await websocket .recv ()
608- message_data = json .loads (message )
609- if "error" in message_data :
610- raise Exception (message_data ["content" ])
611- print ("Received from WebSocket:" , message_data )
612- if type (message_data .get ("content" )) == str :
613- accumulated_content += message_data .get ("content" )
614- if message_data == {
615- "role" : "server" ,
616- "type" : "status" ,
617- "content" : "complete" ,
618- }:
619- print ("Received expected message from server" )
620- break
560+ accumulated_content = await wait_for_websocket_complete (websocket )
621561
622562 # Get messages
623563 get_url = "http://localhost:8000/settings/messages"
0 commit comments