diff --git a/app.go b/app.go index b06dc11..4866b56 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,9 @@ import ( ws "saml.dev/gome-assistant/internal/websocket" ) +// Returned by NewApp() if authentication fails +var ErrInvalidToken = ws.ErrInvalidToken + type App struct { ctx context.Context ctxCancel context.CancelFunc @@ -77,7 +80,7 @@ type NewAppRequest struct { NewApp establishes the websocket connection and returns an object you can use to register schedules and listeners. */ -func NewApp(request NewAppRequest) *App { +func NewApp(request NewAppRequest) (*App, error) { if request.IpAddress == "" || request.HAAuthToken == "" || request.HomeZoneEntityId == "" { log.Fatalln("IpAddress, HAAuthToken, and HomeZoneEntityId are all required arguments in NewAppRequest.") } @@ -85,7 +88,11 @@ func NewApp(request NewAppRequest) *App { if port == "" { port = "8123" } - conn, ctx, ctxCancel := ws.SetupConnection(request.IpAddress, port, request.HAAuthToken) + conn, ctx, ctxCancel, err := ws.SetupConnection(request.IpAddress, port, request.HAAuthToken) + + if conn == nil { + return nil, err + } httpClient := http.NewHttpClient(request.IpAddress, port, request.HAAuthToken) @@ -105,7 +112,7 @@ func NewApp(request NewAppRequest) *App { intervals: pq.New(), entityListeners: map[string][]*EntityListener{}, eventListeners: map[string][]*EventListener{}, - } + }, nil } func (a *App) Cleanup() { diff --git a/example/example.go b/example/example.go index 206a366..f2599d7 100644 --- a/example/example.go +++ b/example/example.go @@ -10,11 +10,16 @@ import ( ) func main() { - app := ga.NewApp(ga.NewAppRequest{ + app, err := ga.NewApp(ga.NewAppRequest{ IpAddress: "192.168.86.67", // Replace with your Home Assistant IP Address HAAuthToken: os.Getenv("HA_AUTH_TOKEN"), HomeZoneEntityId: "zone.home", }) + + if err != nil { + log.Fatalln("Error connecting to HASS:", err) + } + defer app.Cleanup() pantryDoor := ga. diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index f3ea45e..f8c81ec 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -17,6 +17,10 @@ import ( i "saml.dev/gome-assistant/internal" ) +var ( + ErrInvalidToken = errors.New("invalid authentication token") +) + type AuthMessage struct { MsgType string `json:"type"` AccessToken string `json:"access_token"` @@ -47,7 +51,7 @@ func ReadMessage(conn *websocket.Conn, ctx context.Context) ([]byte, error) { return msg, nil } -func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Context, context.CancelFunc) { +func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Context, context.CancelFunc, error) { ctx, ctxCancel := context.WithTimeout(context.Background(), time.Second*3) // Init websocket connection @@ -55,31 +59,35 @@ func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Conte conn, _, err := dialer.DialContext(ctx, fmt.Sprintf("ws://%s:%s/api/websocket", ip, port), nil) if err != nil { ctxCancel() - log.Fatalf("ERROR: Failed to connect to websocket at ws://%s:%s/api/websocket. Check IP address and port\n", ip, port) + log.Printf("ERROR: Failed to connect to websocket at ws://%s:%s/api/websocket. Check IP address and port\n", ip, port) + return nil, nil, nil, err } // Read auth_required message _, err = ReadMessage(conn, ctx) if err != nil { ctxCancel() - log.Fatalf("Unknown error creating websocket client\n") + log.Printf("Unknown error creating websocket client\n") + return nil, nil, nil, err } // Send auth message err = SendAuthMessage(conn, ctx, authToken) if err != nil { ctxCancel() - log.Fatalf("Unknown error creating websocket client\n") + log.Printf("Unknown error creating websocket client\n") + return nil, nil, nil, err } // Verify auth message was successful err = VerifyAuthResponse(conn, ctx) if err != nil { ctxCancel() - log.Fatalf("ERROR: Auth token is invalid. Please double check it or create a new token in your Home Assistant profile\n") + log.Printf("ERROR: Auth token is invalid. Please double check it or create a new token in your Home Assistant profile\n") + return nil, nil, nil, err } - return conn, ctx, ctxCancel + return conn, ctx, ctxCancel, nil } func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error { @@ -105,7 +113,7 @@ func VerifyAuthResponse(conn *websocket.Conn, ctx context.Context) error { json.Unmarshal(msg, &authResp) // log.Println(authResp.MsgType) if authResp.MsgType != "auth_ok" { - return errors.New("invalid auth token") + return ErrInvalidToken } return nil