mirror of
https://github.com/Xevion/go-ha.git
synced 2025-12-06 01:15:10 -06:00
refactor: websockets into 'connect' module, rename & adjust generally
This commit is contained in:
181
internal/connect/connection.go
Normal file
181
internal/connect/connection.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Package websocket is used to interact with the Home Assistant
|
||||
// websocket API. All HA interaction is done via websocket
|
||||
// except for cases explicitly called out in http package
|
||||
// documentation.
|
||||
package connect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Xevion/go-ha/internal"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var ErrInvalidToken = errors.New("invalid authentication token")
|
||||
|
||||
// HAConnection is a wrapper around a websocket connection that provides a mutex for thread safety.
|
||||
type HAConnection struct {
|
||||
Conn *websocket.Conn // Note: this is not thread safe except for Close() and WriteControl()
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// WriteMessage writes a message to the websocket connection.
|
||||
func (w *HAConnection) WriteMessage(msg any) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
return w.Conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
// ReadMessageRaw reads a raw message from the websocket connection.
|
||||
func ReadMessageRaw(conn *websocket.Conn) ([]byte, error) {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads a message from the websocket connection and unmarshals it into the given type.
|
||||
func ReadMessage[T any](conn *websocket.Conn) (T, error) {
|
||||
var result T
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(msg, &result)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ConnectionFromUri creates a new websocket connection from the given base URL and authentication token.
|
||||
func ConnectionFromUri(baseUrl *url.URL, token string) (*HAConnection, context.Context, context.CancelFunc, error) {
|
||||
// Build the websocket URL
|
||||
urlWebsockets := *baseUrl
|
||||
urlWebsockets.Path = "/api/websocket"
|
||||
scheme, err := internal.GetEquivalentWebsocketScheme(baseUrl.Scheme)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to build websocket URL: %w", err)
|
||||
}
|
||||
urlWebsockets.Scheme = scheme
|
||||
|
||||
// Create a short timeout context for the connection only
|
||||
connCtx, connCtxCancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer connCtxCancel() // Always cancel the connection context when we're done
|
||||
|
||||
// Init websocket connection
|
||||
dialer := websocket.DefaultDialer
|
||||
conn, _, err := dialer.DialContext(connCtx, urlWebsockets.String(), nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to connect to websocket. Check URI\n", "url", urlWebsockets)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Read auth_required message
|
||||
msg, err := ReadMessage[struct {
|
||||
MsgType string `json:"type"`
|
||||
}](conn)
|
||||
if err != nil {
|
||||
slog.Error("Unknown error creating websocket client\n")
|
||||
return nil, nil, nil, err
|
||||
} else if msg.MsgType != "auth_required" {
|
||||
slog.Error("Expected auth_required message, got", "msgType", msg.MsgType)
|
||||
return nil, nil, nil, fmt.Errorf("expected auth_required message, got %s", msg.MsgType)
|
||||
}
|
||||
|
||||
// Send auth message
|
||||
err = SendAuthMessage(conn, connCtx, token)
|
||||
if err != nil {
|
||||
slog.Error("Unknown error creating websocket client\n")
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Verify auth message was successful
|
||||
err = VerifyAuthResponse(conn, connCtx)
|
||||
if err != nil {
|
||||
slog.Error("Auth token is invalid. Please double check it or create a new token in your Home Assistant profile\n")
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Create a new background context for the application lifecycle (no timeout)
|
||||
appCtx, appCtxCancel := context.WithCancel(context.Background())
|
||||
|
||||
return &HAConnection{Conn: conn}, appCtx, appCtxCancel, nil
|
||||
}
|
||||
|
||||
// SendAuthMessage sends an auth message to the websocket connection.
|
||||
func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error {
|
||||
type AuthMessage struct {
|
||||
MsgType string `json:"type"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
err := conn.WriteJSON(AuthMessage{MsgType: "auth", AccessToken: token})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyAuthResponse verifies that the auth response is valid.
|
||||
func VerifyAuthResponse(conn *websocket.Conn, ctx context.Context) error {
|
||||
msg, err := ReadMessage[struct {
|
||||
MsgType string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}](conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.MsgType != "auth_ok" {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SubscribeToStateChangedEvents(id int64, conn *HAConnection, ctx context.Context) {
|
||||
SubscribeToEventType("state_changed", conn, ctx, id)
|
||||
}
|
||||
|
||||
// TODO: Instead of using variadic arguments, just use a nillable pointer for the id
|
||||
func SubscribeToEventType(eventType string, conn *HAConnection, ctx context.Context, id ...int64) {
|
||||
type SubEvent struct {
|
||||
Id int64 `json:"id"`
|
||||
Type string `json:"type"`
|
||||
EventType string `json:"event_type"`
|
||||
}
|
||||
|
||||
// If no id is provided, generate a new one
|
||||
var finalId int64
|
||||
if len(id) == 0 {
|
||||
finalId = internal.NextId()
|
||||
} else {
|
||||
finalId = id[0]
|
||||
}
|
||||
e := SubEvent{
|
||||
Id: finalId,
|
||||
Type: "subscribe_events",
|
||||
EventType: eventType,
|
||||
}
|
||||
|
||||
err := conn.WriteMessage(e)
|
||||
// TODO: Handle errors better
|
||||
if err != nil {
|
||||
wrappedErr := fmt.Errorf("error writing to websocket: %w", err)
|
||||
slog.Error(wrappedErr.Error())
|
||||
panic(wrappedErr)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user