Fix panics due to concurrent websocket writes

This commit is contained in:
Lubos Dolezel
2023-02-24 11:49:26 +01:00
parent 8bee96aeff
commit f27cbfb299
21 changed files with 131 additions and 141 deletions

12
app.go
View File

@@ -19,6 +19,10 @@ type App struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
conn *websocket.Conn conn *websocket.Conn
// Wraps the ws connection with added mutex locking
wsWriter *ws.WebsocketWriter
httpClient *http.HttpClient httpClient *http.HttpClient
service *Service service *Service
@@ -85,11 +89,13 @@ func NewApp(request NewAppRequest) *App {
httpClient := http.NewHttpClient(request.IpAddress, port, request.HAAuthToken) httpClient := http.NewHttpClient(request.IpAddress, port, request.HAAuthToken)
service := newService(conn, ctx, httpClient) wsWriter := &ws.WebsocketWriter{Conn: conn}
service := newService(wsWriter, ctx, httpClient)
state := newState(httpClient, request.HomeZoneEntityId) state := newState(httpClient, request.HomeZoneEntityId)
return &App{ return &App{
conn: conn, conn: conn,
wsWriter: wsWriter,
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
httpClient: httpClient, httpClient: httpClient,
@@ -169,7 +175,7 @@ func (a *App) RegisterEventListeners(evls ...EventListener) {
if elList, ok := a.eventListeners[eventType]; ok { if elList, ok := a.eventListeners[eventType]; ok {
a.eventListeners[eventType] = append(elList, &evl) a.eventListeners[eventType] = append(elList, &evl)
} else { } else {
ws.SubscribeToEventType(eventType, a.conn, a.ctx) ws.SubscribeToEventType(eventType, a.wsWriter, a.ctx)
a.eventListeners[eventType] = []*EventListener{&evl} a.eventListeners[eventType] = []*EventListener{&evl}
} }
} }
@@ -227,7 +233,7 @@ func (a *App) Start() {
// subscribe to state_changed events // subscribe to state_changed events
id := internal.GetId() id := internal.GetId()
ws.SubscribeToStateChangedEvents(id, a.conn, a.ctx) ws.SubscribeToStateChangedEvents(id, a.wsWriter, a.ctx)
a.entityListenersId = id a.entityListenersId = id
// entity listeners runOnStartup // entity listeners runOnStartup

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type AlarmControlPanel struct { type AlarmControlPanel struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -27,7 +26,7 @@ func (acp AlarmControlPanel) ArmAway(entityId string, serviceData ...map[string]
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for arm away. // Send the alarm the command for arm away.
@@ -41,7 +40,7 @@ func (acp AlarmControlPanel) ArmWithCustomBypass(entityId string, serviceData ..
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for arm home. // Send the alarm the command for arm home.
@@ -55,7 +54,7 @@ func (acp AlarmControlPanel) ArmHome(entityId string, serviceData ...map[string]
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for arm night. // Send the alarm the command for arm night.
@@ -69,7 +68,7 @@ func (acp AlarmControlPanel) ArmNight(entityId string, serviceData ...map[string
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for arm vacation. // Send the alarm the command for arm vacation.
@@ -83,7 +82,7 @@ func (acp AlarmControlPanel) ArmVacation(entityId string, serviceData ...map[str
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for disarm. // Send the alarm the command for disarm.
@@ -97,7 +96,7 @@ func (acp AlarmControlPanel) Disarm(entityId string, serviceData ...map[string]a
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }
// Send the alarm the command for trigger. // Send the alarm the command for trigger.
@@ -111,5 +110,5 @@ func (acp AlarmControlPanel) Trigger(entityId string, serviceData ...map[string]
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, acp.conn, acp.ctx) acp.conn.WriteMessage(req, acp.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Cover struct { type Cover struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -22,7 +21,7 @@ func (c Cover) Close(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "close_cover" req.Service = "close_cover"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Close all or specified cover tilt. Takes an entityId. // Close all or specified cover tilt. Takes an entityId.
@@ -31,7 +30,7 @@ func (c Cover) CloseTilt(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "close_cover_tilt" req.Service = "close_cover_tilt"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Open all or specified cover. Takes an entityId. // Open all or specified cover. Takes an entityId.
@@ -40,7 +39,7 @@ func (c Cover) Open(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "open_cover" req.Service = "open_cover"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Open all or specified cover tilt. Takes an entityId. // Open all or specified cover tilt. Takes an entityId.
@@ -49,7 +48,7 @@ func (c Cover) OpenTilt(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "open_cover_tilt" req.Service = "open_cover_tilt"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Move to specific position all or specified cover. Takes an entityId and an optional // Move to specific position all or specified cover. Takes an entityId and an optional
@@ -62,7 +61,7 @@ func (c Cover) SetPosition(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Move to specific position all or specified cover tilt. Takes an entityId and an optional // Move to specific position all or specified cover tilt. Takes an entityId and an optional
@@ -75,7 +74,7 @@ func (c Cover) SetTiltPosition(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Stop a cover entity. Takes an entityId. // Stop a cover entity. Takes an entityId.
@@ -84,7 +83,7 @@ func (c Cover) Stop(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "stop_cover" req.Service = "stop_cover"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Stop a cover entity tilt. Takes an entityId. // Stop a cover entity tilt. Takes an entityId.
@@ -93,7 +92,7 @@ func (c Cover) StopTilt(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "stop_cover_tilt" req.Service = "stop_cover_tilt"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Toggle a cover open/closed. Takes an entityId. // Toggle a cover open/closed. Takes an entityId.
@@ -102,7 +101,7 @@ func (c Cover) Toggle(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "toggle" req.Service = "toggle"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }
// Toggle a cover tilt open/closed. Takes an entityId. // Toggle a cover tilt open/closed. Takes an entityId.
@@ -111,5 +110,5 @@ func (c Cover) ToggleTilt(entityId string) {
req.Domain = "cover" req.Domain = "cover"
req.Service = "toggle_cover_tilt" req.Service = "toggle_cover_tilt"
ws.WriteMessage(req, c.conn, c.ctx) c.conn.WriteMessage(req, c.ctx)
} }

View File

@@ -3,12 +3,11 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
type HomeAssistant struct { type HomeAssistant struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -22,7 +21,7 @@ func (ha *HomeAssistant) TurnOn(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, ha.conn, ha.ctx) ha.conn.WriteMessage(req, ha.ctx)
} }
// Toggle a Home Assistant entity. Takes an entityId and an optional // Toggle a Home Assistant entity. Takes an entityId and an optional
@@ -35,7 +34,7 @@ func (ha *HomeAssistant) Toggle(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, ha.conn, ha.ctx) ha.conn.WriteMessage(req, ha.ctx)
} }
func (ha *HomeAssistant) TurnOff(entityId string) { func (ha *HomeAssistant) TurnOff(entityId string) {
@@ -43,5 +42,5 @@ func (ha *HomeAssistant) TurnOff(entityId string) {
req.Domain = "homeassistant" req.Domain = "homeassistant"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, ha.conn, ha.ctx) ha.conn.WriteMessage(req, ha.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type InputBoolean struct { type InputBoolean struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -21,7 +20,7 @@ func (ib InputBoolean) TurnOn(entityId string) {
req.Domain = "input_boolean" req.Domain = "input_boolean"
req.Service = "turn_on" req.Service = "turn_on"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputBoolean) Toggle(entityId string) { func (ib InputBoolean) Toggle(entityId string) {
@@ -29,19 +28,19 @@ func (ib InputBoolean) Toggle(entityId string) {
req.Domain = "input_boolean" req.Domain = "input_boolean"
req.Service = "toggle" req.Service = "toggle"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputBoolean) TurnOff(entityId string) { func (ib InputBoolean) TurnOff(entityId string) {
req := NewBaseServiceRequest(entityId) req := NewBaseServiceRequest(entityId)
req.Domain = "input_boolean" req.Domain = "input_boolean"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputBoolean) Reload() { func (ib InputBoolean) Reload() {
req := NewBaseServiceRequest("") req := NewBaseServiceRequest("")
req.Domain = "input_boolean" req.Domain = "input_boolean"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type InputButton struct { type InputButton struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -21,12 +20,12 @@ func (ib InputButton) Press(entityId string) {
req.Domain = "input_button" req.Domain = "input_button"
req.Service = "press" req.Service = "press"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputButton) Reload() { func (ib InputButton) Reload() {
req := NewBaseServiceRequest("") req := NewBaseServiceRequest("")
req.Domain = "input_button" req.Domain = "input_button"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -5,14 +5,13 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type InputDatetime struct { type InputDatetime struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -26,12 +25,12 @@ func (ib InputDatetime) Set(entityId string, value time.Time) {
"timestamp": fmt.Sprint(value.Unix()), "timestamp": fmt.Sprint(value.Unix()),
} }
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputDatetime) Reload() { func (ib InputDatetime) Reload() {
req := NewBaseServiceRequest("") req := NewBaseServiceRequest("")
req.Domain = "input_datetime" req.Domain = "input_datetime"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type InputNumber struct { type InputNumber struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -22,7 +21,7 @@ func (ib InputNumber) Set(entityId string, value float32) {
req.Service = "set_value" req.Service = "set_value"
req.ServiceData = map[string]any{"value": value} req.ServiceData = map[string]any{"value": value}
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputNumber) Increment(entityId string) { func (ib InputNumber) Increment(entityId string) {
@@ -30,7 +29,7 @@ func (ib InputNumber) Increment(entityId string) {
req.Domain = "input_number" req.Domain = "input_number"
req.Service = "increment" req.Service = "increment"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputNumber) Decrement(entityId string) { func (ib InputNumber) Decrement(entityId string) {
@@ -38,12 +37,12 @@ func (ib InputNumber) Decrement(entityId string) {
req.Domain = "input_number" req.Domain = "input_number"
req.Service = "decrement" req.Service = "decrement"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputNumber) Reload() { func (ib InputNumber) Reload() {
req := NewBaseServiceRequest("") req := NewBaseServiceRequest("")
req.Domain = "input_number" req.Domain = "input_number"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type InputText struct { type InputText struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -24,12 +23,12 @@ func (ib InputText) Set(entityId string, value string) {
"value": value, "value": value,
} }
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }
func (ib InputText) Reload() { func (ib InputText) Reload() {
req := NewBaseServiceRequest("") req := NewBaseServiceRequest("")
req.Domain = "input_text" req.Domain = "input_text"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Light struct { type Light struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -26,7 +25,7 @@ func (l Light) TurnOn(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, l.conn, l.ctx) l.conn.WriteMessage(req, l.ctx)
} }
// Toggle a light entity. Takes an entityId and an optional // Toggle a light entity. Takes an entityId and an optional
@@ -39,12 +38,12 @@ func (l Light) Toggle(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, l.conn, l.ctx) l.conn.WriteMessage(req, l.ctx)
} }
func (l Light) TurnOff(entityId string) { func (l Light) TurnOff(entityId string) {
req := NewBaseServiceRequest(entityId) req := NewBaseServiceRequest(entityId)
req.Domain = "light" req.Domain = "light"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, l.conn, l.ctx) l.conn.WriteMessage(req, l.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Lock struct { type Lock struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -26,7 +25,7 @@ func (l Lock) Lock(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, l.conn, l.ctx) l.conn.WriteMessage(req, l.ctx)
} }
// Unlock a lock entity. Takes an entityId and an optional // Unlock a lock entity. Takes an entityId and an optional
@@ -39,5 +38,5 @@ func (l Lock) Unlock(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, l.conn, l.ctx) l.conn.WriteMessage(req, l.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type MediaPlayer struct { type MediaPlayer struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -23,7 +22,7 @@ func (mp MediaPlayer) ClearPlaylist(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "clear_playlist" req.Service = "clear_playlist"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Group players together. Only works on platforms with support for player groups. // Group players together. Only works on platforms with support for player groups.
@@ -37,7 +36,7 @@ func (mp MediaPlayer) Join(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command for next track. // Send the media player the command for next track.
@@ -47,7 +46,7 @@ func (mp MediaPlayer) Next(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_next_track" req.Service = "media_next_track"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command for pause. // Send the media player the command for pause.
@@ -57,7 +56,7 @@ func (mp MediaPlayer) Pause(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_pause" req.Service = "media_pause"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command for play. // Send the media player the command for play.
@@ -67,7 +66,7 @@ func (mp MediaPlayer) Play(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_play" req.Service = "media_play"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Toggle media player play/pause state. // Toggle media player play/pause state.
@@ -77,7 +76,7 @@ func (mp MediaPlayer) PlayPause(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_play_pause" req.Service = "media_play_pause"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command for previous track. // Send the media player the command for previous track.
@@ -87,7 +86,7 @@ func (mp MediaPlayer) Previous(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_previous_track" req.Service = "media_previous_track"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command to seek in current playing media. // Send the media player the command to seek in current playing media.
@@ -101,7 +100,7 @@ func (mp MediaPlayer) Seek(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the stop command. // Send the media player the stop command.
@@ -111,7 +110,7 @@ func (mp MediaPlayer) Stop(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "media_stop" req.Service = "media_stop"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command for playing media. // Send the media player the command for playing media.
@@ -125,7 +124,7 @@ func (mp MediaPlayer) PlayMedia(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Set repeat mode. Takes an entityId and an optional // Set repeat mode. Takes an entityId and an optional
@@ -138,7 +137,7 @@ func (mp MediaPlayer) RepeatSet(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command to change sound mode. // Send the media player the command to change sound mode.
@@ -152,7 +151,7 @@ func (mp MediaPlayer) SelectSoundMode(entityId string, serviceData ...map[string
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Send the media player the command to change input source. // Send the media player the command to change input source.
@@ -166,7 +165,7 @@ func (mp MediaPlayer) SelectSource(entityId string, serviceData ...map[string]an
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Set shuffling state. // Set shuffling state.
@@ -180,7 +179,7 @@ func (mp MediaPlayer) Shuffle(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Toggles a media player power state. // Toggles a media player power state.
@@ -190,7 +189,7 @@ func (mp MediaPlayer) Toggle(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "toggle" req.Service = "toggle"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Turn a media player power off. // Turn a media player power off.
@@ -200,7 +199,7 @@ func (mp MediaPlayer) TurnOff(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Turn a media player power on. // Turn a media player power on.
@@ -210,7 +209,7 @@ func (mp MediaPlayer) TurnOn(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "turn_on" req.Service = "turn_on"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Unjoin the player from a group. Only works on // Unjoin the player from a group. Only works on
@@ -221,7 +220,7 @@ func (mp MediaPlayer) Unjoin(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "unjoin" req.Service = "unjoin"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Turn a media player volume down. // Turn a media player volume down.
@@ -231,7 +230,7 @@ func (mp MediaPlayer) VolumeDown(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "volume_down" req.Service = "volume_down"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Mute a media player's volume. // Mute a media player's volume.
@@ -245,7 +244,7 @@ func (mp MediaPlayer) VolumeMute(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Set a media player's volume level. // Set a media player's volume level.
@@ -259,7 +258,7 @@ func (mp MediaPlayer) VolumeSet(entityId string, serviceData ...map[string]any)
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }
// Turn a media player volume up. // Turn a media player volume up.
@@ -269,5 +268,5 @@ func (mp MediaPlayer) VolumeUp(entityId string) {
req.Domain = "media_player" req.Domain = "media_player"
req.Service = "volume_up" req.Service = "volume_up"
ws.WriteMessage(req, mp.conn, mp.ctx) mp.conn.WriteMessage(req, mp.ctx)
} }

View File

@@ -3,13 +3,12 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
"saml.dev/gome-assistant/types" "saml.dev/gome-assistant/types"
) )
type Notify struct { type Notify struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -27,5 +26,5 @@ func (ha *Notify) Notify(reqData types.NotifyRequest) {
} }
req.ServiceData = serviceData req.ServiceData = serviceData
ws.WriteMessage(req, ha.conn, ha.ctx) ha.conn.WriteMessage(req, ha.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Number struct { type Number struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -22,5 +21,5 @@ func (ib Number) SetValue(entityId string, value int) {
req.Service = "set_value" req.Service = "set_value"
req.ServiceData = map[string]any{"value": value} req.ServiceData = map[string]any{"value": value}
ws.WriteMessage(req, ib.conn, ib.ctx) ib.conn.WriteMessage(req, ib.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Scene struct { type Scene struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -25,7 +24,7 @@ func (s Scene) Apply(serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }
// Create a scene entity. Takes an entityId and an optional // Create a scene entity. Takes an entityId and an optional
@@ -38,7 +37,7 @@ func (s Scene) Create(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }
// Reload the scenes. // Reload the scenes.
@@ -47,7 +46,7 @@ func (s Scene) Reload() {
req.Domain = "scene" req.Domain = "scene"
req.Service = "reload" req.Service = "reload"
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }
// TurnOn a scene entity. Takes an entityId and an optional // TurnOn a scene entity. Takes an entityId and an optional
@@ -60,5 +59,5 @@ func (s Scene) TurnOn(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/gorilla/websocket"
"saml.dev/gome-assistant/internal" "saml.dev/gome-assistant/internal"
ws "saml.dev/gome-assistant/internal/websocket"
) )
func BuildService[ func BuildService[
@@ -26,7 +26,7 @@ func BuildService[
Scene | Scene |
TTS | TTS |
Vacuum, Vacuum,
](conn *websocket.Conn, ctx context.Context) *T { ](conn *ws.WebsocketWriter, ctx context.Context) *T {
return &T{conn: conn, ctx: ctx} return &T{conn: conn, ctx: ctx}
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Switch struct { type Switch struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -21,7 +20,7 @@ func (s Switch) TurnOn(entityId string) {
req.Domain = "switch" req.Domain = "switch"
req.Service = "turn_on" req.Service = "turn_on"
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }
func (s Switch) Toggle(entityId string) { func (s Switch) Toggle(entityId string) {
@@ -29,12 +28,12 @@ func (s Switch) Toggle(entityId string) {
req.Domain = "switch" req.Domain = "switch"
req.Service = "toggle" req.Service = "toggle"
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }
func (s Switch) TurnOff(entityId string) { func (s Switch) TurnOff(entityId string) {
req := NewBaseServiceRequest(entityId) req := NewBaseServiceRequest(entityId)
req.Domain = "switch" req.Domain = "switch"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, s.conn, s.ctx) s.conn.WriteMessage(req, s.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type TTS struct { type TTS struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -22,7 +21,7 @@ func (tts TTS) ClearCache() {
req.Domain = "tts" req.Domain = "tts"
req.Service = "clear_cache" req.Service = "clear_cache"
ws.WriteMessage(req, tts.conn, tts.ctx) tts.conn.WriteMessage(req, tts.ctx)
} }
// Say something using text-to-speech on a media player with cloud. // Say something using text-to-speech on a media player with cloud.
@@ -36,7 +35,7 @@ func (tts TTS) CloudSay(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, tts.conn, tts.ctx) tts.conn.WriteMessage(req, tts.ctx)
} }
// Say something using text-to-speech on a media player with google_translate. // Say something using text-to-speech on a media player with google_translate.
@@ -50,5 +49,5 @@ func (tts TTS) GoogleTranslateSay(entityId string, serviceData ...map[string]any
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, tts.conn, tts.ctx) tts.conn.WriteMessage(req, tts.ctx)
} }

View File

@@ -3,14 +3,13 @@ package services
import ( import (
"context" "context"
"github.com/gorilla/websocket"
ws "saml.dev/gome-assistant/internal/websocket" ws "saml.dev/gome-assistant/internal/websocket"
) )
/* Structs */ /* Structs */
type Vacuum struct { type Vacuum struct {
conn *websocket.Conn conn *ws.WebsocketWriter
ctx context.Context ctx context.Context
} }
@@ -23,7 +22,7 @@ func (v Vacuum) CleanSpot(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "clean_spot" req.Service = "clean_spot"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Locate the vacuum cleaner robot. // Locate the vacuum cleaner robot.
@@ -33,7 +32,7 @@ func (v Vacuum) Locate(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "locate" req.Service = "locate"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Pause the cleaning task. // Pause the cleaning task.
@@ -43,7 +42,7 @@ func (v Vacuum) Pause(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "pause" req.Service = "pause"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Tell the vacuum cleaner to return to its dock. // Tell the vacuum cleaner to return to its dock.
@@ -53,7 +52,7 @@ func (v Vacuum) ReturnToBase(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "return_to_base" req.Service = "return_to_base"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Send a raw command to the vacuum cleaner. Takes an entityId and an optional // Send a raw command to the vacuum cleaner. Takes an entityId and an optional
@@ -66,7 +65,7 @@ func (v Vacuum) SendCommand(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Set the fan speed of the vacuum cleaner. Takes an entityId and an optional // Set the fan speed of the vacuum cleaner. Takes an entityId and an optional
@@ -80,7 +79,7 @@ func (v Vacuum) SetFanSpeed(entityId string, serviceData ...map[string]any) {
req.ServiceData = serviceData[0] req.ServiceData = serviceData[0]
} }
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Start or resume the cleaning task. // Start or resume the cleaning task.
@@ -90,7 +89,7 @@ func (v Vacuum) Start(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "start" req.Service = "start"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Start, pause, or resume the cleaning task. // Start, pause, or resume the cleaning task.
@@ -100,7 +99,7 @@ func (v Vacuum) StartPause(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "start_pause" req.Service = "start_pause"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Stop the current cleaning task. // Stop the current cleaning task.
@@ -110,7 +109,7 @@ func (v Vacuum) Stop(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "stop" req.Service = "stop"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Stop the current cleaning task and return to home. // Stop the current cleaning task and return to home.
@@ -120,7 +119,7 @@ func (v Vacuum) TurnOff(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "turn_off" req.Service = "turn_off"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }
// Start a new cleaning task. // Start a new cleaning task.
@@ -130,5 +129,5 @@ func (v Vacuum) TurnOn(entityId string) {
req.Domain = "vacuum" req.Domain = "vacuum"
req.Service = "turn_on" req.Service = "turn_on"
ws.WriteMessage(req, v.conn, v.ctx) v.conn.WriteMessage(req, v.ctx)
} }

View File

@@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@@ -21,16 +22,16 @@ type AuthMessage struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
} }
// TODO: use a mutex to prevent concurrent writes panic here type WebsocketWriter struct {
// https://github.com/gorilla/websocket/issues/119 Conn *websocket.Conn
func WriteMessage[T any](msg T, conn *websocket.Conn, ctx context.Context) error { mutex sync.Mutex
msgJson, err := json.Marshal(msg)
// fmt.Println(string(msgJson))
if err != nil {
return err
} }
err = conn.WriteMessage(websocket.TextMessage, msgJson) func (w *WebsocketWriter) WriteMessage(msg interface{}, ctx context.Context) error {
w.mutex.Lock()
defer w.mutex.Unlock()
err := w.Conn.WriteJSON(msg)
if err != nil { if err != nil {
return err return err
} }
@@ -82,7 +83,7 @@ func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Conte
} }
func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error { func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error {
err := WriteMessage(AuthMessage{MsgType: "auth", AccessToken: token}, conn, ctx) err := conn.WriteJSON(AuthMessage{MsgType: "auth", AccessToken: token})
if err != nil { if err != nil {
return err return err
} }
@@ -116,11 +117,11 @@ type SubEvent struct {
EventType string `json:"event_type"` EventType string `json:"event_type"`
} }
func SubscribeToStateChangedEvents(id int64, conn *websocket.Conn, ctx context.Context) { func SubscribeToStateChangedEvents(id int64, conn *WebsocketWriter, ctx context.Context) {
SubscribeToEventType("state_changed", conn, ctx, id) SubscribeToEventType("state_changed", conn, ctx, id)
} }
func SubscribeToEventType(eventType string, conn *websocket.Conn, ctx context.Context, id ...int64) { func SubscribeToEventType(eventType string, conn *WebsocketWriter, ctx context.Context, id ...int64) {
var finalId int64 var finalId int64
if len(id) == 0 { if len(id) == 0 {
finalId = i.GetId() finalId = i.GetId()
@@ -132,7 +133,7 @@ func SubscribeToEventType(eventType string, conn *websocket.Conn, ctx context.Co
Type: "subscribe_events", Type: "subscribe_events",
EventType: eventType, EventType: eventType,
} }
err := WriteMessage(e, conn, ctx) err := conn.WriteMessage(e, ctx)
if err != nil { if err != nil {
log.Fatalf("Error writing to websocket: %s\n", err) log.Fatalf("Error writing to websocket: %s\n", err)
} }

View File

@@ -3,9 +3,9 @@ package gomeassistant
import ( import (
"context" "context"
"github.com/gorilla/websocket"
"saml.dev/gome-assistant/internal/http" "saml.dev/gome-assistant/internal/http"
"saml.dev/gome-assistant/internal/services" "saml.dev/gome-assistant/internal/services"
ws "saml.dev/gome-assistant/internal/websocket"
) )
type Service struct { type Service struct {
@@ -28,7 +28,7 @@ type Service struct {
Vacuum *services.Vacuum Vacuum *services.Vacuum
} }
func newService(conn *websocket.Conn, ctx context.Context, httpClient *http.HttpClient) *Service { func newService(conn *ws.WebsocketWriter, ctx context.Context, httpClient *http.HttpClient) *Service {
return &Service{ return &Service{
AlarmControlPanel: services.BuildService[services.AlarmControlPanel](conn, ctx), AlarmControlPanel: services.BuildService[services.AlarmControlPanel](conn, ctx),
Cover: services.BuildService[services.Cover](conn, ctx), Cover: services.BuildService[services.Cover](conn, ctx),