diff --git a/app.go b/app.go index 3497c53..e7db2fe 100644 --- a/app.go +++ b/app.go @@ -29,7 +29,7 @@ type App struct { httpClient *http.HttpClient service *Service - state *State + state *StateImpl schedules pq.PriorityQueue intervals pq.PriorityQueue @@ -192,7 +192,7 @@ func (a *App) RegisterEventListeners(evls ...EventListener) { } } -func getSunriseSunset(s *State, sunrise bool, dateToUse carbon.Carbon, offset ...DurationString) carbon.Carbon { +func getSunriseSunset(s *StateImpl, sunrise bool, dateToUse carbon.Carbon, offset ...DurationString) carbon.Carbon { date := dateToUse.Carbon2Time() rise, set := sunriseLib.SunriseSunset(s.latitude, s.longitude, date.Year(), date.Month(), date.Day()) rise, set = rise.Local(), set.Local() @@ -291,6 +291,6 @@ func (a *App) GetService() *Service { return a.service } -func (a *App) GetState() *State { +func (a *App) GetState() State { return a.state } diff --git a/checkers.go b/checkers.go index aa09c64..1936190 100644 --- a/checkers.go +++ b/checkers.go @@ -86,35 +86,60 @@ func checkExceptionRanges(eList []timeRange) conditionCheck { return cc } -func checkEnabledEntity(s *State, eid, expectedState string, runOnNetworkError bool) conditionCheck { +func checkEnabledEntity(s State, infos []internal.EnabledDisabledInfo) conditionCheck { cc := conditionCheck{fail: false} - if eid == "" || expectedState == "" { + if len(infos) == 0 { return cc } - matches, err := s.Equals(eid, expectedState) - if err != nil { - cc.fail = !runOnNetworkError - return cc - } + for _, edi := range infos { + matches, err := s.Equals(edi.Entity, edi.State) - cc.fail = !matches + if err != nil { + if edi.RunOnError { + // keep checking + continue + } else { + // don't run this automation + cc.fail = true + break + } + } + + if !matches { + cc.fail = true + break + } + } return cc } -func checkDisabledEntity(s *State, eid, expectedState string, runOnNetworkError bool) conditionCheck { +func checkDisabledEntity(s State, infos []internal.EnabledDisabledInfo) conditionCheck { cc := conditionCheck{fail: false} - if eid == "" || expectedState == "" { + if len(infos) == 0 { return cc } - matches, err := s.Equals(eid, expectedState) - if err != nil { - cc.fail = !runOnNetworkError - return cc + for _, edi := range infos { + matches, err := s.Equals(edi.Entity, edi.State) + + if err != nil { + if edi.RunOnError { + // keep checking + continue + } else { + // don't run this automation + cc.fail = true + break + } + } + + if matches { + cc.fail = true + break + } } - cc.fail = matches return cc } diff --git a/checkers_test.go b/checkers_test.go new file mode 100644 index 0000000..22dd451 --- /dev/null +++ b/checkers_test.go @@ -0,0 +1,133 @@ +package gomeassistant + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "saml.dev/gome-assistant/internal" +) + +type MockState struct { + EqualsReturn bool + EqualsError bool + GetReturn EntityState + GetError bool +} + +func (s MockState) AfterSunrise(_ ...DurationString) bool { + return true +} +func (s MockState) BeforeSunrise(_ ...DurationString) bool { + return true +} +func (s MockState) AfterSunset(_ ...DurationString) bool { + return true +} +func (s MockState) BeforeSunset(_ ...DurationString) bool { + return true +} +func (s MockState) Get(eid string) (EntityState, error) { + if s.GetError { + return EntityState{}, errors.New("some error") + } + return s.GetReturn, nil +} +func (s MockState) Equals(eid, state string) (bool, error) { + if s.EqualsError { + return false, errors.New("some error") + } + return s.EqualsReturn, nil +} + +var runOnError = internal.EnabledDisabledInfo{ + Entity: "eid", + State: "state", + RunOnError: true, +} + +var dontRunOnError = internal.EnabledDisabledInfo{ + Entity: "eid", + State: "state", + RunOnError: false, +} + +func list(infos ...internal.EnabledDisabledInfo) []internal.EnabledDisabledInfo { + ret := []internal.EnabledDisabledInfo{} + ret = append(ret, infos...) + return ret +} + +func TestEnabledEntity_StateEqual_Passes(t *testing.T) { + state := MockState{ + EqualsReturn: true, + } + c := checkEnabledEntity(state, list(runOnError)) + assert.False(t, c.fail, "should pass") +} + +func TestEnabledEntity_StateNotEqual_Fails(t *testing.T) { + state := MockState{ + EqualsReturn: false, + } + c := checkEnabledEntity(state, list(runOnError)) + assert.True(t, c.fail, "should fail") +} + +func TestEnabledEntity_NetworkError_DontRun_Fails(t *testing.T) { + state := MockState{ + EqualsError: true, + } + c := checkEnabledEntity(state, list(dontRunOnError)) + assert.True(t, c.fail, "should fail") +} + +func TestEnabledEntity_NetworkError_StillRun_Passes(t *testing.T) { + state := MockState{ + EqualsError: true, + } + c := checkEnabledEntity(state, list(runOnError)) + assert.False(t, c.fail, "should fail") +} + +func TestDisabledEntity_StateEqual_Fails(t *testing.T) { + state := MockState{ + EqualsReturn: true, + } + c := checkDisabledEntity(state, list(runOnError)) + assert.True(t, c.fail, "should pass") +} + +func TestDisabledEntity_StateNotEqual_Passes(t *testing.T) { + state := MockState{ + EqualsReturn: false, + } + c := checkDisabledEntity(state, list(runOnError)) + assert.False(t, c.fail, "should fail") +} + +func TestDisabledEntity_NetworkError_DontRun_Fails(t *testing.T) { + state := MockState{ + EqualsError: true, + } + c := checkDisabledEntity(state, list(dontRunOnError)) + assert.True(t, c.fail, "should fail") +} + +func TestDisabledEntity_NetworkError_StillRun_Passes(t *testing.T) { + state := MockState{ + EqualsError: true, + } + c := checkDisabledEntity(state, list(runOnError)) + assert.False(t, c.fail, "should fail") +} + +func TestStatesMatch(t *testing.T) { + c := checkStatesMatch("hey", "hey") + assert.False(t, c.fail, "should pass") +} + +func TestStatesDontMatch(t *testing.T) { + c := checkStatesMatch("hey", "bye") + assert.True(t, c.fail, "should fail") +} diff --git a/entitylistener.go b/entitylistener.go index e6dd49d..2dd3a3a 100644 --- a/entitylistener.go +++ b/entitylistener.go @@ -29,15 +29,11 @@ type EntityListener struct { runOnStartup bool runOnStartupCompleted bool - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } -type EntityListenerCallback func(*Service, *State, EntityData) +type EntityListenerCallback func(*Service, State, EntityData) type EntityData struct { TriggerEntityId string @@ -164,12 +160,12 @@ func (b elBuilder3) EnabledWhen(entityId, state string, runOnNetworkError bool) if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if b.entityListener.disabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting EnabledWhen on an entity listener with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - b.entityListener.enabledEntity = entityId - b.entityListener.enabledEntityState = state - b.entityListener.enabledEntityRunOnError = runOnNetworkError + b.entityListener.enabledEntities = append(b.entityListener.enabledEntities, i) return b } @@ -181,12 +177,12 @@ func (b elBuilder3) DisabledWhen(entityId, state string, runOnNetworkError bool) if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if b.entityListener.enabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting DisabledWhen on an entity listener with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - b.entityListener.disabledEntity = entityId - b.entityListener.disabledEntityState = state - b.entityListener.disabledEntityRunOnError = runOnNetworkError + b.entityListener.disabledEntities = append(b.entityListener.disabledEntities, i) return b } @@ -237,10 +233,10 @@ func callEntityListeners(app *App, msgBytes []byte) { if c := checkExceptionRanges(l.exceptionRanges); c.fail { continue } - if c := checkEnabledEntity(app.state, l.enabledEntity, l.enabledEntityState, l.enabledEntityRunOnError); c.fail { + if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { continue } - if c := checkDisabledEntity(app.state, l.disabledEntity, l.disabledEntityState, l.disabledEntityRunOnError); c.fail { + if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { continue } diff --git a/eventListener.go b/eventListener.go index 007e413..37fe98a 100644 --- a/eventListener.go +++ b/eventListener.go @@ -21,15 +21,11 @@ type EventListener struct { exceptionDates []time.Time exceptionRanges []timeRange - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } -type EventListenerCallback func(*Service, *State, EventData) +type EventListenerCallback func(*Service, State, EventData) type EventData struct { Type string @@ -106,12 +102,12 @@ func (b eventListenerBuilder3) EnabledWhen(entityId, state string, runOnNetworkE if entityId == "" { panic(fmt.Sprintf("entityId is empty in eventListener EnabledWhen entityId='%s' state='%s' runOnNetworkError='%t'", entityId, state, runOnNetworkError)) } - if b.eventListener.disabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting EnabledWhen entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - b.eventListener.enabledEntity = entityId - b.eventListener.enabledEntityState = state - b.eventListener.enabledEntityRunOnError = runOnNetworkError + b.eventListener.enabledEntities = append(b.eventListener.enabledEntities, i) return b } @@ -123,12 +119,12 @@ func (b eventListenerBuilder3) DisabledWhen(entityId, state string, runOnNetwork if entityId == "" { panic(fmt.Sprintf("entityId is empty in eventListener EnabledWhen entityId='%s' state='%s' runOnNetworkError='%t'", entityId, state, runOnNetworkError)) } - if b.eventListener.enabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting DisabledWhen entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - b.eventListener.disabledEntity = entityId - b.eventListener.disabledEntityState = state - b.eventListener.disabledEntityRunOnError = runOnNetworkError + b.eventListener.disabledEntities = append(b.eventListener.disabledEntities, i) return b } @@ -166,10 +162,10 @@ func callEventListeners(app *App, msg ws.ChanMsg) { if c := checkExceptionRanges(l.exceptionRanges); c.fail { continue } - if c := checkEnabledEntity(app.state, l.enabledEntity, l.enabledEntityState, l.enabledEntityRunOnError); c.fail { + if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { continue } - if c := checkDisabledEntity(app.state, l.disabledEntity, l.disabledEntityState, l.disabledEntityRunOnError); c.fail { + if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { continue } diff --git a/example/example.go b/example/example.go index f2599d7..8d3dc05 100644 --- a/example/example.go +++ b/example/example.go @@ -54,7 +54,7 @@ func main() { } -func pantryLights(service *ga.Service, state *ga.State, sensor ga.EntityData) { +func pantryLights(service *ga.Service, state ga.State, sensor ga.EntityData) { l := "light.pantry" if sensor.ToState == "on" { service.HomeAssistant.TurnOn(l) @@ -63,7 +63,7 @@ func pantryLights(service *ga.Service, state *ga.State, sensor ga.EntityData) { } } -func onEvent(service *ga.Service, state *ga.State, data ga.EventData) { +func onEvent(service *ga.Service, state ga.State, data ga.EventData) { // Since the structure of the event changes depending // on the event type, you can Unmarshal the raw json // into a Go type. If a type for your event doesn't @@ -74,7 +74,7 @@ func onEvent(service *ga.Service, state *ga.State, data ga.EventData) { log.Default().Println(ev) } -func lightsOut(service *ga.Service, state *ga.State) { +func lightsOut(service *ga.Service, state ga.State) { // always turn off outside lights service.Light.TurnOff("light.outside_lights") s, err := state.Get("binary_sensor.living_room_motion") @@ -89,7 +89,7 @@ func lightsOut(service *ga.Service, state *ga.State) { } } -func sunriseSched(service *ga.Service, state *ga.State) { +func sunriseSched(service *ga.Service, state ga.State) { service.Light.TurnOn("light.living_room_lamps") service.Light.TurnOff("light.christmas_lights") } diff --git a/go.mod b/go.mod index 62fe886..e4a89fe 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,14 @@ require ( ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gobuffalo/envy v1.10.2 // indirect github.com/gobuffalo/packd v1.0.2 // indirect github.com/gobuffalo/packr v1.30.1 // indirect github.com/joho/godotenv v1.4.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8f25b0c..e5e5f2f 100644 --- a/go.sum +++ b/go.sum @@ -57,12 +57,16 @@ github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DM github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/internal/internal.go b/internal/internal.go index 9742c58..f8c752d 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -10,6 +10,12 @@ import ( "github.com/golang-module/carbon" ) +type EnabledDisabledInfo struct { + Entity string + State string + RunOnError bool +} + var id int64 = 0 func GetId() int64 { diff --git a/interval.go b/interval.go index 0d207eb..3d525d9 100644 --- a/interval.go +++ b/interval.go @@ -7,7 +7,7 @@ import ( "saml.dev/gome-assistant/internal" ) -type IntervalCallback func(*Service, *State) +type IntervalCallback func(*Service, State) type Interval struct { frequency time.Duration @@ -19,12 +19,8 @@ type Interval struct { exceptionDates []time.Time exceptionRanges []timeRange - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } func (i Interval) Hash() string { @@ -118,12 +114,12 @@ func (ib intervalBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkErr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if ib.interval.disabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting EnabledWhen on an entity listener with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - ib.interval.enabledEntity = entityId - ib.interval.enabledEntityState = state - ib.interval.enabledEntityRunOnError = runOnNetworkError + ib.interval.enabledEntities = append(ib.interval.enabledEntities, i) return ib } @@ -135,12 +131,12 @@ func (ib intervalBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkEr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if ib.interval.enabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting DisabledWhen on an entity listener with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - ib.interval.disabledEntity = entityId - ib.interval.disabledEntityState = state - ib.interval.disabledEntityRunOnError = runOnNetworkError + ib.interval.disabledEntities = append(ib.interval.disabledEntities, i) return ib } @@ -184,10 +180,10 @@ func (i Interval) maybeRunCallback(a *App) { if c := checkExceptionRanges(i.exceptionRanges); c.fail { return } - if c := checkEnabledEntity(a.state, i.enabledEntity, i.enabledEntityState, i.enabledEntityRunOnError); c.fail { + if c := checkEnabledEntity(a.state, i.enabledEntities); c.fail { return } - if c := checkDisabledEntity(a.state, i.disabledEntity, i.disabledEntityState, i.disabledEntityRunOnError); c.fail { + if c := checkDisabledEntity(a.state, i.disabledEntities); c.fail { return } go i.callback(a.service, a.state) diff --git a/schedule.go b/schedule.go index c7d8e01..2325890 100644 --- a/schedule.go +++ b/schedule.go @@ -9,7 +9,7 @@ import ( "saml.dev/gome-assistant/internal" ) -type ScheduleCallback func(*Service, *State) +type ScheduleCallback func(*Service, State) type DailySchedule struct { // 0-23 @@ -27,12 +27,8 @@ type DailySchedule struct { exceptionDates []time.Time allowlistDates []time.Time - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } func (s DailySchedule) Hash() string { @@ -125,12 +121,12 @@ func (sb scheduleBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkErr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if sb.schedule.disabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting EnabledWhen on a schedule with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - sb.schedule.enabledEntity = entityId - sb.schedule.enabledEntityState = state - sb.schedule.enabledEntityRunOnError = runOnNetworkError + sb.schedule.enabledEntities = append(sb.schedule.enabledEntities, i) return sb } @@ -142,12 +138,12 @@ func (sb scheduleBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkEr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if sb.schedule.enabledEntity != "" { - panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting DisabledWhen on a schedule with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, } - sb.schedule.disabledEntity = entityId - sb.schedule.disabledEntityState = state - sb.schedule.disabledEntityRunOnError = runOnNetworkError + sb.schedule.disabledEntities = append(sb.schedule.disabledEntities, i) return sb } @@ -186,10 +182,10 @@ func (s DailySchedule) maybeRunCallback(a *App) { if c := checkAllowlistDates(s.allowlistDates); c.fail { return } - if c := checkEnabledEntity(a.state, s.enabledEntity, s.enabledEntityState, s.enabledEntityRunOnError); c.fail { + if c := checkEnabledEntity(a.state, s.enabledEntities); c.fail { return } - if c := checkDisabledEntity(a.state, s.disabledEntity, s.disabledEntityState, s.disabledEntityRunOnError); c.fail { + if c := checkDisabledEntity(a.state, s.disabledEntities); c.fail { return } go s.callback(a.service, a.state) diff --git a/state.go b/state.go index 84a7d02..edc9c91 100644 --- a/state.go +++ b/state.go @@ -10,8 +10,17 @@ import ( "saml.dev/gome-assistant/internal/http" ) +type State interface { + AfterSunrise(...DurationString) bool + BeforeSunrise(...DurationString) bool + AfterSunset(...DurationString) bool + BeforeSunset(...DurationString) bool + Get(entityId string) (EntityState, error) + Equals(entityId, state string) (bool, error) +} + // State is used to retrieve state from Home Assistant. -type State struct { +type StateImpl struct { httpClient *http.HttpClient latitude float64 longitude float64 @@ -24,8 +33,8 @@ type EntityState struct { LastChanged time.Time `json:"last_changed"` } -func newState(c *http.HttpClient, homeZoneEntityId string) (*State, error) { - state := &State{httpClient: c} +func newState(c *http.HttpClient, homeZoneEntityId string) (*StateImpl, error) { + state := &StateImpl{httpClient: c} err := state.getLatLong(c, homeZoneEntityId) if err != nil { return nil, err @@ -33,7 +42,7 @@ func newState(c *http.HttpClient, homeZoneEntityId string) (*State, error) { return state, nil } -func (s *State) getLatLong(c *http.HttpClient, homeZoneEntityId string) error { +func (s *StateImpl) getLatLong(c *http.HttpClient, homeZoneEntityId string) error { resp, err := s.Get(homeZoneEntityId) if err != nil { return fmt.Errorf("couldn't get latitude/longitude from home assistant entity '%s'. Did you type it correctly? It should be a zone like 'zone.home'", homeZoneEntityId) @@ -54,7 +63,7 @@ func (s *State) getLatLong(c *http.HttpClient, homeZoneEntityId string) error { return nil } -func (s *State) Get(entityId string) (EntityState, error) { +func (s *StateImpl) Get(entityId string) (EntityState, error) { resp, err := s.httpClient.GetState(entityId) if err != nil { return EntityState{}, err @@ -64,7 +73,7 @@ func (s *State) Get(entityId string) (EntityState, error) { return es, nil } -func (s *State) Equals(entityId string, expectedState string) (bool, error) { +func (s *StateImpl) Equals(entityId string, expectedState string) (bool, error) { currentState, err := s.Get(entityId) if err != nil { return false, err @@ -72,20 +81,20 @@ func (s *State) Equals(entityId string, expectedState string) (bool, error) { return currentState.State == expectedState, nil } -func (s *State) BeforeSunrise(offset ...DurationString) bool { +func (s *StateImpl) BeforeSunrise(offset ...DurationString) bool { sunrise := getSunriseSunset(s /* sunrise = */, true, carbon.Now(), offset...) return carbon.Now().Lt(sunrise) } -func (s *State) AfterSunrise(offset ...DurationString) bool { +func (s *StateImpl) AfterSunrise(offset ...DurationString) bool { return !s.BeforeSunrise(offset...) } -func (s *State) BeforeSunset(offset ...DurationString) bool { +func (s *StateImpl) BeforeSunset(offset ...DurationString) bool { sunset := getSunriseSunset(s /* sunrise = */, false, carbon.Now(), offset...) return carbon.Now().Lt(sunset) } -func (s *State) AfterSunset(offset ...DurationString) bool { +func (s *StateImpl) AfterSunset(offset ...DurationString) bool { return !s.BeforeSunset(offset...) }