diff --git a/app.go b/app.go index 58f8f0e..81aba9f 100644 --- a/app.go +++ b/app.go @@ -16,9 +16,13 @@ import ( ) type App struct { - ctx context.Context - ctxCancel context.CancelFunc - conn *websocket.Conn + ctx context.Context + ctxCancel context.CancelFunc + conn *websocket.Conn + + // Wraps the ws connection with added mutex locking + wsWriter *ws.WebsocketWriter + httpClient *http.HttpClient service *Service @@ -85,11 +89,13 @@ func NewApp(request NewAppRequest) *App { httpClient := http.NewHttpClient(request.IpAddress, port, request.HAAuthToken) - service := newService(conn, ctx, httpClient) + wsWriter := &ws.WebsocketWriter{Conn: conn} + service := newService(wsWriter, ctx, httpClient) state := newState(httpClient, request.HomeZoneEntityId) return &App{ conn: conn, + wsWriter: wsWriter, ctx: ctx, ctxCancel: ctxCancel, httpClient: httpClient, @@ -169,7 +175,7 @@ func (a *App) RegisterEventListeners(evls ...EventListener) { if elList, ok := a.eventListeners[eventType]; ok { a.eventListeners[eventType] = append(elList, &evl) } else { - ws.SubscribeToEventType(eventType, a.conn, a.ctx) + ws.SubscribeToEventType(eventType, a.wsWriter, a.ctx) a.eventListeners[eventType] = []*EventListener{&evl} } } @@ -227,7 +233,7 @@ func (a *App) Start() { // subscribe to state_changed events id := internal.GetId() - ws.SubscribeToStateChangedEvents(id, a.conn, a.ctx) + ws.SubscribeToStateChangedEvents(id, a.wsWriter, a.ctx) a.entityListenersId = id // entity listeners runOnStartup diff --git a/internal/services/alarm_control_panel.go b/internal/services/alarm_control_panel.go index 4dee6c6..5b0756e 100644 --- a/internal/services/alarm_control_panel.go +++ b/internal/services/alarm_control_panel.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type AlarmControlPanel struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -27,7 +26,7 @@ func (acp AlarmControlPanel) ArmAway(entityId string, serviceData ...map[string] req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for arm away. @@ -41,7 +40,7 @@ func (acp AlarmControlPanel) ArmWithCustomBypass(entityId string, serviceData .. req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for arm home. @@ -55,7 +54,7 @@ func (acp AlarmControlPanel) ArmHome(entityId string, serviceData ...map[string] req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for arm night. @@ -69,7 +68,7 @@ func (acp AlarmControlPanel) ArmNight(entityId string, serviceData ...map[string req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for arm vacation. @@ -83,7 +82,7 @@ func (acp AlarmControlPanel) ArmVacation(entityId string, serviceData ...map[str req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for disarm. @@ -97,7 +96,7 @@ func (acp AlarmControlPanel) Disarm(entityId string, serviceData ...map[string]a req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } // Send the alarm the command for trigger. @@ -111,5 +110,5 @@ func (acp AlarmControlPanel) Trigger(entityId string, serviceData ...map[string] req.ServiceData = serviceData[0] } - ws.WriteMessage(req, acp.conn, acp.ctx) + acp.conn.WriteMessage(req, acp.ctx) } diff --git a/internal/services/cover.go b/internal/services/cover.go index c38d599..8fb6e75 100644 --- a/internal/services/cover.go +++ b/internal/services/cover.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Cover struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -22,7 +21,7 @@ func (c Cover) Close(entityId string) { req.Domain = "cover" req.Service = "close_cover" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Close all or specified cover tilt. Takes an entityId. @@ -31,7 +30,7 @@ func (c Cover) CloseTilt(entityId string) { req.Domain = "cover" req.Service = "close_cover_tilt" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Open all or specified cover. Takes an entityId. @@ -40,7 +39,7 @@ func (c Cover) Open(entityId string) { req.Domain = "cover" req.Service = "open_cover" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Open all or specified cover tilt. Takes an entityId. @@ -49,7 +48,7 @@ func (c Cover) OpenTilt(entityId string) { req.Domain = "cover" req.Service = "open_cover_tilt" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Move to specific position all or specified cover. Takes an entityId and an optional @@ -62,7 +61,7 @@ func (c Cover) SetPosition(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Move to specific position all or specified cover tilt. Takes an entityId and an optional @@ -75,7 +74,7 @@ func (c Cover) SetTiltPosition(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Stop a cover entity. Takes an entityId. @@ -84,7 +83,7 @@ func (c Cover) Stop(entityId string) { req.Domain = "cover" req.Service = "stop_cover" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Stop a cover entity tilt. Takes an entityId. @@ -93,7 +92,7 @@ func (c Cover) StopTilt(entityId string) { req.Domain = "cover" req.Service = "stop_cover_tilt" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Toggle a cover open/closed. Takes an entityId. @@ -102,7 +101,7 @@ func (c Cover) Toggle(entityId string) { req.Domain = "cover" req.Service = "toggle" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } // Toggle a cover tilt open/closed. Takes an entityId. @@ -111,5 +110,5 @@ func (c Cover) ToggleTilt(entityId string) { req.Domain = "cover" req.Service = "toggle_cover_tilt" - ws.WriteMessage(req, c.conn, c.ctx) + c.conn.WriteMessage(req, c.ctx) } diff --git a/internal/services/homeassistant.go b/internal/services/homeassistant.go index 22cfacd..8400509 100644 --- a/internal/services/homeassistant.go +++ b/internal/services/homeassistant.go @@ -3,12 +3,11 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) type HomeAssistant struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -22,7 +21,7 @@ func (ha *HomeAssistant) TurnOn(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, ha.conn, ha.ctx) + ha.conn.WriteMessage(req, ha.ctx) } // Toggle a Home Assistant entity. Takes an entityId and an optional @@ -35,7 +34,7 @@ func (ha *HomeAssistant) Toggle(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, ha.conn, ha.ctx) + ha.conn.WriteMessage(req, ha.ctx) } func (ha *HomeAssistant) TurnOff(entityId string) { @@ -43,5 +42,5 @@ func (ha *HomeAssistant) TurnOff(entityId string) { req.Domain = "homeassistant" req.Service = "turn_off" - ws.WriteMessage(req, ha.conn, ha.ctx) + ha.conn.WriteMessage(req, ha.ctx) } diff --git a/internal/services/input_boolean.go b/internal/services/input_boolean.go index 0056b2c..ac589f7 100644 --- a/internal/services/input_boolean.go +++ b/internal/services/input_boolean.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type InputBoolean struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -21,7 +20,7 @@ func (ib InputBoolean) TurnOn(entityId string) { req.Domain = "input_boolean" req.Service = "turn_on" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputBoolean) Toggle(entityId string) { @@ -29,19 +28,19 @@ func (ib InputBoolean) Toggle(entityId string) { req.Domain = "input_boolean" req.Service = "toggle" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputBoolean) TurnOff(entityId string) { req := NewBaseServiceRequest(entityId) req.Domain = "input_boolean" req.Service = "turn_off" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputBoolean) Reload() { req := NewBaseServiceRequest("") req.Domain = "input_boolean" req.Service = "reload" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/input_button.go b/internal/services/input_button.go index 5155892..e0ec541 100644 --- a/internal/services/input_button.go +++ b/internal/services/input_button.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type InputButton struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -21,12 +20,12 @@ func (ib InputButton) Press(entityId string) { req.Domain = "input_button" req.Service = "press" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputButton) Reload() { req := NewBaseServiceRequest("") req.Domain = "input_button" req.Service = "reload" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/input_datetime.go b/internal/services/input_datetime.go index 2b3259c..92c12d5 100644 --- a/internal/services/input_datetime.go +++ b/internal/services/input_datetime.go @@ -5,14 +5,13 @@ import ( "fmt" "time" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type InputDatetime struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -26,12 +25,12 @@ func (ib InputDatetime) Set(entityId string, value time.Time) { "timestamp": fmt.Sprint(value.Unix()), } - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputDatetime) Reload() { req := NewBaseServiceRequest("") req.Domain = "input_datetime" req.Service = "reload" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/input_number.go b/internal/services/input_number.go index 097be93..59409f6 100644 --- a/internal/services/input_number.go +++ b/internal/services/input_number.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type InputNumber struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -22,7 +21,7 @@ func (ib InputNumber) Set(entityId string, value float32) { req.Service = "set_value" req.ServiceData = map[string]any{"value": value} - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputNumber) Increment(entityId string) { @@ -30,7 +29,7 @@ func (ib InputNumber) Increment(entityId string) { req.Domain = "input_number" req.Service = "increment" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputNumber) Decrement(entityId string) { @@ -38,12 +37,12 @@ func (ib InputNumber) Decrement(entityId string) { req.Domain = "input_number" req.Service = "decrement" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputNumber) Reload() { req := NewBaseServiceRequest("") req.Domain = "input_number" req.Service = "reload" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/input_text.go b/internal/services/input_text.go index 9e0848f..b7a0d1a 100644 --- a/internal/services/input_text.go +++ b/internal/services/input_text.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type InputText struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -24,12 +23,12 @@ func (ib InputText) Set(entityId string, value string) { "value": value, } - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } func (ib InputText) Reload() { req := NewBaseServiceRequest("") req.Domain = "input_text" req.Service = "reload" - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/light.go b/internal/services/light.go index efdee22..c1a2179 100644 --- a/internal/services/light.go +++ b/internal/services/light.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Light struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -26,7 +25,7 @@ func (l Light) TurnOn(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, l.conn, l.ctx) + l.conn.WriteMessage(req, l.ctx) } // Toggle a light entity. Takes an entityId and an optional @@ -39,12 +38,12 @@ func (l Light) Toggle(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, l.conn, l.ctx) + l.conn.WriteMessage(req, l.ctx) } func (l Light) TurnOff(entityId string) { req := NewBaseServiceRequest(entityId) req.Domain = "light" req.Service = "turn_off" - ws.WriteMessage(req, l.conn, l.ctx) + l.conn.WriteMessage(req, l.ctx) } diff --git a/internal/services/lock.go b/internal/services/lock.go index 0fe15b7..e122b25 100644 --- a/internal/services/lock.go +++ b/internal/services/lock.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Lock struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -26,7 +25,7 @@ func (l Lock) Lock(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, l.conn, l.ctx) + l.conn.WriteMessage(req, l.ctx) } // Unlock a lock entity. Takes an entityId and an optional @@ -39,5 +38,5 @@ func (l Lock) Unlock(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, l.conn, l.ctx) + l.conn.WriteMessage(req, l.ctx) } diff --git a/internal/services/media_player.go b/internal/services/media_player.go index 07fb8d7..727d7a9 100644 --- a/internal/services/media_player.go +++ b/internal/services/media_player.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type MediaPlayer struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -23,7 +22,7 @@ func (mp MediaPlayer) ClearPlaylist(entityId string) { req.Domain = "media_player" req.Service = "clear_playlist" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Group players together. Only works on platforms with support for player groups. @@ -37,7 +36,7 @@ func (mp MediaPlayer) Join(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command for next track. @@ -47,7 +46,7 @@ func (mp MediaPlayer) Next(entityId string) { req.Domain = "media_player" req.Service = "media_next_track" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command for pause. @@ -57,7 +56,7 @@ func (mp MediaPlayer) Pause(entityId string) { req.Domain = "media_player" req.Service = "media_pause" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command for play. @@ -67,7 +66,7 @@ func (mp MediaPlayer) Play(entityId string) { req.Domain = "media_player" req.Service = "media_play" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Toggle media player play/pause state. @@ -77,7 +76,7 @@ func (mp MediaPlayer) PlayPause(entityId string) { req.Domain = "media_player" req.Service = "media_play_pause" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command for previous track. @@ -87,7 +86,7 @@ func (mp MediaPlayer) Previous(entityId string) { req.Domain = "media_player" req.Service = "media_previous_track" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command to seek in current playing media. @@ -101,7 +100,7 @@ func (mp MediaPlayer) Seek(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the stop command. @@ -111,7 +110,7 @@ func (mp MediaPlayer) Stop(entityId string) { req.Domain = "media_player" req.Service = "media_stop" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command for playing media. @@ -125,7 +124,7 @@ func (mp MediaPlayer) PlayMedia(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Set repeat mode. Takes an entityId and an optional @@ -138,7 +137,7 @@ func (mp MediaPlayer) RepeatSet(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command to change sound mode. @@ -152,7 +151,7 @@ func (mp MediaPlayer) SelectSoundMode(entityId string, serviceData ...map[string req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Send the media player the command to change input source. @@ -166,7 +165,7 @@ func (mp MediaPlayer) SelectSource(entityId string, serviceData ...map[string]an req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Set shuffling state. @@ -180,7 +179,7 @@ func (mp MediaPlayer) Shuffle(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Toggles a media player power state. @@ -190,7 +189,7 @@ func (mp MediaPlayer) Toggle(entityId string) { req.Domain = "media_player" req.Service = "toggle" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Turn a media player power off. @@ -200,7 +199,7 @@ func (mp MediaPlayer) TurnOff(entityId string) { req.Domain = "media_player" req.Service = "turn_off" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Turn a media player power on. @@ -210,7 +209,7 @@ func (mp MediaPlayer) TurnOn(entityId string) { req.Domain = "media_player" req.Service = "turn_on" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Unjoin the player from a group. Only works on @@ -221,7 +220,7 @@ func (mp MediaPlayer) Unjoin(entityId string) { req.Domain = "media_player" req.Service = "unjoin" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Turn a media player volume down. @@ -231,7 +230,7 @@ func (mp MediaPlayer) VolumeDown(entityId string) { req.Domain = "media_player" req.Service = "volume_down" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Mute a media player's volume. @@ -245,7 +244,7 @@ func (mp MediaPlayer) VolumeMute(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Set a media player's volume level. @@ -259,7 +258,7 @@ func (mp MediaPlayer) VolumeSet(entityId string, serviceData ...map[string]any) req.ServiceData = serviceData[0] } - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } // Turn a media player volume up. @@ -269,5 +268,5 @@ func (mp MediaPlayer) VolumeUp(entityId string) { req.Domain = "media_player" req.Service = "volume_up" - ws.WriteMessage(req, mp.conn, mp.ctx) + mp.conn.WriteMessage(req, mp.ctx) } diff --git a/internal/services/notify.go b/internal/services/notify.go index 50ffc7f..e76dd42 100644 --- a/internal/services/notify.go +++ b/internal/services/notify.go @@ -3,13 +3,12 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" "saml.dev/gome-assistant/types" ) type Notify struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -27,5 +26,5 @@ func (ha *Notify) Notify(reqData types.NotifyRequest) { } req.ServiceData = serviceData - ws.WriteMessage(req, ha.conn, ha.ctx) + ha.conn.WriteMessage(req, ha.ctx) } diff --git a/internal/services/number.go b/internal/services/number.go index 761784b..8ea98fc 100644 --- a/internal/services/number.go +++ b/internal/services/number.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Number struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -22,5 +21,5 @@ func (ib Number) SetValue(entityId string, value int) { req.Service = "set_value" req.ServiceData = map[string]any{"value": value} - ws.WriteMessage(req, ib.conn, ib.ctx) + ib.conn.WriteMessage(req, ib.ctx) } diff --git a/internal/services/scene.go b/internal/services/scene.go index 4b5d9b9..e17ada9 100644 --- a/internal/services/scene.go +++ b/internal/services/scene.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Scene struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -25,7 +24,7 @@ func (s Scene) Apply(serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } // Create a scene entity. Takes an entityId and an optional @@ -38,7 +37,7 @@ func (s Scene) Create(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } // Reload the scenes. @@ -47,7 +46,7 @@ func (s Scene) Reload() { req.Domain = "scene" req.Service = "reload" - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } // TurnOn a scene entity. Takes an entityId and an optional @@ -60,5 +59,5 @@ func (s Scene) TurnOn(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } diff --git a/internal/services/services.go b/internal/services/services.go index 6a2f082..a8324f9 100644 --- a/internal/services/services.go +++ b/internal/services/services.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/gorilla/websocket" "saml.dev/gome-assistant/internal" + ws "saml.dev/gome-assistant/internal/websocket" ) func BuildService[ @@ -26,7 +26,7 @@ func BuildService[ Scene | TTS | Vacuum, -](conn *websocket.Conn, ctx context.Context) *T { +](conn *ws.WebsocketWriter, ctx context.Context) *T { return &T{conn: conn, ctx: ctx} } diff --git a/internal/services/switch.go b/internal/services/switch.go index d5ba0dc..0e7be52 100644 --- a/internal/services/switch.go +++ b/internal/services/switch.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Switch struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -21,7 +20,7 @@ func (s Switch) TurnOn(entityId string) { req.Domain = "switch" req.Service = "turn_on" - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } func (s Switch) Toggle(entityId string) { @@ -29,12 +28,12 @@ func (s Switch) Toggle(entityId string) { req.Domain = "switch" req.Service = "toggle" - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } func (s Switch) TurnOff(entityId string) { req := NewBaseServiceRequest(entityId) req.Domain = "switch" req.Service = "turn_off" - ws.WriteMessage(req, s.conn, s.ctx) + s.conn.WriteMessage(req, s.ctx) } diff --git a/internal/services/tts.go b/internal/services/tts.go index 712b802..74b4963 100644 --- a/internal/services/tts.go +++ b/internal/services/tts.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type TTS struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -22,7 +21,7 @@ func (tts TTS) ClearCache() { req.Domain = "tts" req.Service = "clear_cache" - ws.WriteMessage(req, tts.conn, tts.ctx) + tts.conn.WriteMessage(req, tts.ctx) } // Say something using text-to-speech on a media player with cloud. @@ -36,7 +35,7 @@ func (tts TTS) CloudSay(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, tts.conn, tts.ctx) + tts.conn.WriteMessage(req, tts.ctx) } // Say something using text-to-speech on a media player with google_translate. @@ -50,5 +49,5 @@ func (tts TTS) GoogleTranslateSay(entityId string, serviceData ...map[string]any req.ServiceData = serviceData[0] } - ws.WriteMessage(req, tts.conn, tts.ctx) + tts.conn.WriteMessage(req, tts.ctx) } diff --git a/internal/services/vacuum.go b/internal/services/vacuum.go index 1186410..fbc71b0 100644 --- a/internal/services/vacuum.go +++ b/internal/services/vacuum.go @@ -3,14 +3,13 @@ package services import ( "context" - "github.com/gorilla/websocket" ws "saml.dev/gome-assistant/internal/websocket" ) /* Structs */ type Vacuum struct { - conn *websocket.Conn + conn *ws.WebsocketWriter ctx context.Context } @@ -23,7 +22,7 @@ func (v Vacuum) CleanSpot(entityId string) { req.Domain = "vacuum" req.Service = "clean_spot" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Locate the vacuum cleaner robot. @@ -33,7 +32,7 @@ func (v Vacuum) Locate(entityId string) { req.Domain = "vacuum" req.Service = "locate" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Pause the cleaning task. @@ -43,7 +42,7 @@ func (v Vacuum) Pause(entityId string) { req.Domain = "vacuum" req.Service = "pause" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Tell the vacuum cleaner to return to its dock. @@ -53,7 +52,7 @@ func (v Vacuum) ReturnToBase(entityId string) { req.Domain = "vacuum" req.Service = "return_to_base" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Send a raw command to the vacuum cleaner. Takes an entityId and an optional @@ -66,7 +65,7 @@ func (v Vacuum) SendCommand(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Set the fan speed of the vacuum cleaner. Takes an entityId and an optional @@ -80,7 +79,7 @@ func (v Vacuum) SetFanSpeed(entityId string, serviceData ...map[string]any) { req.ServiceData = serviceData[0] } - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Start or resume the cleaning task. @@ -90,7 +89,7 @@ func (v Vacuum) Start(entityId string) { req.Domain = "vacuum" req.Service = "start" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Start, pause, or resume the cleaning task. @@ -100,7 +99,7 @@ func (v Vacuum) StartPause(entityId string) { req.Domain = "vacuum" req.Service = "start_pause" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Stop the current cleaning task. @@ -110,7 +109,7 @@ func (v Vacuum) Stop(entityId string) { req.Domain = "vacuum" req.Service = "stop" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Stop the current cleaning task and return to home. @@ -120,7 +119,7 @@ func (v Vacuum) TurnOff(entityId string) { req.Domain = "vacuum" req.Service = "turn_off" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } // Start a new cleaning task. @@ -130,5 +129,5 @@ func (v Vacuum) TurnOn(entityId string) { req.Domain = "vacuum" req.Service = "turn_on" - ws.WriteMessage(req, v.conn, v.ctx) + v.conn.WriteMessage(req, v.ctx) } diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index ce1fcee..f3ea45e 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "log" + "sync" "time" "github.com/gorilla/websocket" @@ -21,16 +22,16 @@ type AuthMessage struct { AccessToken string `json:"access_token"` } -// TODO: use a mutex to prevent concurrent writes panic here -// https://github.com/gorilla/websocket/issues/119 -func WriteMessage[T any](msg T, conn *websocket.Conn, ctx context.Context) error { - msgJson, err := json.Marshal(msg) - // fmt.Println(string(msgJson)) - if err != nil { - return err - } +type WebsocketWriter struct { + Conn *websocket.Conn + mutex sync.Mutex +} - err = conn.WriteMessage(websocket.TextMessage, msgJson) +func (w *WebsocketWriter) WriteMessage(msg interface{}, ctx context.Context) error { + w.mutex.Lock() + defer w.mutex.Unlock() + + err := w.Conn.WriteJSON(msg) if err != nil { return err } @@ -82,7 +83,7 @@ func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Conte } func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error { - err := WriteMessage(AuthMessage{MsgType: "auth", AccessToken: token}, conn, ctx) + err := conn.WriteJSON(AuthMessage{MsgType: "auth", AccessToken: token}) if err != nil { return err } @@ -116,11 +117,11 @@ type SubEvent struct { EventType string `json:"event_type"` } -func SubscribeToStateChangedEvents(id int64, conn *websocket.Conn, ctx context.Context) { +func SubscribeToStateChangedEvents(id int64, conn *WebsocketWriter, ctx context.Context) { SubscribeToEventType("state_changed", conn, ctx, id) } -func SubscribeToEventType(eventType string, conn *websocket.Conn, ctx context.Context, id ...int64) { +func SubscribeToEventType(eventType string, conn *WebsocketWriter, ctx context.Context, id ...int64) { var finalId int64 if len(id) == 0 { finalId = i.GetId() @@ -132,7 +133,7 @@ func SubscribeToEventType(eventType string, conn *websocket.Conn, ctx context.Co Type: "subscribe_events", EventType: eventType, } - err := WriteMessage(e, conn, ctx) + err := conn.WriteMessage(e, ctx) if err != nil { log.Fatalf("Error writing to websocket: %s\n", err) } diff --git a/service.go b/service.go index 9692d84..e1c8eef 100644 --- a/service.go +++ b/service.go @@ -3,9 +3,9 @@ package gomeassistant import ( "context" - "github.com/gorilla/websocket" "saml.dev/gome-assistant/internal/http" "saml.dev/gome-assistant/internal/services" + ws "saml.dev/gome-assistant/internal/websocket" ) type Service struct { @@ -28,7 +28,7 @@ type Service struct { Vacuum *services.Vacuum } -func newService(conn *websocket.Conn, ctx context.Context, httpClient *http.HttpClient) *Service { +func newService(conn *ws.WebsocketWriter, ctx context.Context, httpClient *http.HttpClient) *Service { return &Service{ AlarmControlPanel: services.BuildService[services.AlarmControlPanel](conn, ctx), Cover: services.BuildService[services.Cover](conn, ctx),