diff --git a/checkers.go b/checkers.go index aa09c64..bb2291c 100644 --- a/checkers.go +++ b/checkers.go @@ -86,35 +86,60 @@ func checkExceptionRanges(eList []timeRange) conditionCheck { return cc } -func checkEnabledEntity(s *State, eid, expectedState string, runOnNetworkError bool) conditionCheck { +func checkEnabledEntity(s *State, infos []internal.EnabledDisabledInfo) conditionCheck { cc := conditionCheck{fail: false} - if eid == "" || expectedState == "" { + if len(infos) == 0 { return cc } - matches, err := s.Equals(eid, expectedState) - if err != nil { - cc.fail = !runOnNetworkError - return cc - } + for _, edi := range infos { + matches, err := s.Equals(edi.Entity, edi.State) - cc.fail = !matches + if err != nil { + if edi.RunOnError { + // keep checking + continue + } else { + // don't run this automation + cc.fail = true + break + } + } + + if !matches { + cc.fail = true + break + } + } return cc } -func checkDisabledEntity(s *State, eid, expectedState string, runOnNetworkError bool) conditionCheck { +func checkDisabledEntity(s *State, infos []internal.EnabledDisabledInfo) conditionCheck { cc := conditionCheck{fail: false} - if eid == "" || expectedState == "" { + if len(infos) == 0 { return cc } - matches, err := s.Equals(eid, expectedState) - if err != nil { - cc.fail = !runOnNetworkError - return cc + for _, edi := range infos { + matches, err := s.Equals(edi.Entity, edi.State) + + if err != nil { + if edi.RunOnError { + // keep checking + continue + } else { + // don't run this automation + cc.fail = true + break + } + } + + if matches { + cc.fail = true + break + } } - cc.fail = matches return cc } diff --git a/entitylistener.go b/entitylistener.go index e6dd49d..06d84a8 100644 --- a/entitylistener.go +++ b/entitylistener.go @@ -29,12 +29,8 @@ type EntityListener struct { runOnStartup bool runOnStartupCompleted bool - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } type EntityListenerCallback func(*Service, *State, EntityData) diff --git a/eventListener.go b/eventListener.go index 007e413..33e1fbd 100644 --- a/eventListener.go +++ b/eventListener.go @@ -21,12 +21,8 @@ type EventListener struct { exceptionDates []time.Time exceptionRanges []timeRange - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } type EventListenerCallback func(*Service, *State, EventData) diff --git a/internal/internal.go b/internal/internal.go index 9742c58..f8c752d 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -10,6 +10,12 @@ import ( "github.com/golang-module/carbon" ) +type EnabledDisabledInfo struct { + Entity string + State string + RunOnError bool +} + var id int64 = 0 func GetId() int64 { diff --git a/interval.go b/interval.go index 0d207eb..73e2541 100644 --- a/interval.go +++ b/interval.go @@ -19,12 +19,8 @@ type Interval struct { exceptionDates []time.Time exceptionRanges []timeRange - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } func (i Interval) Hash() string { diff --git a/schedule.go b/schedule.go index c7d8e01..9375ec9 100644 --- a/schedule.go +++ b/schedule.go @@ -27,12 +27,8 @@ type DailySchedule struct { exceptionDates []time.Time allowlistDates []time.Time - enabledEntity string - enabledEntityState string - enabledEntityRunOnError bool - disabledEntity string - disabledEntityState string - disabledEntityRunOnError bool + enabledEntities []internal.EnabledDisabledInfo + disabledEntities []internal.EnabledDisabledInfo } func (s DailySchedule) Hash() string { @@ -125,12 +121,15 @@ func (sb scheduleBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkErr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if sb.schedule.disabledEntity != "" { + if len(sb.schedule.disabledEntities) != 0 { panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting EnabledWhen on a schedule with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) } - sb.schedule.enabledEntity = entityId - sb.schedule.enabledEntityState = state - sb.schedule.enabledEntityRunOnError = runOnNetworkError + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, + } + sb.schedule.enabledEntities = append(sb.schedule.enabledEntities, i) return sb } @@ -142,12 +141,15 @@ func (sb scheduleBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkEr if entityId == "" { panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) } - if sb.schedule.enabledEntity != "" { + if len(sb.schedule.enabledEntities) != 0 { panic(fmt.Sprintf("You can't use EnabledWhen and DisabledWhen together. Error occurred while setting DisabledWhen on a schedule with params entityId=%s state=%s runOnNetworkError=%t", entityId, state, runOnNetworkError)) } - sb.schedule.disabledEntity = entityId - sb.schedule.disabledEntityState = state - sb.schedule.disabledEntityRunOnError = runOnNetworkError + i := internal.EnabledDisabledInfo{ + Entity: entityId, + State: state, + RunOnError: runOnNetworkError, + } + sb.schedule.disabledEntities = append(sb.schedule.disabledEntities, i) return sb } @@ -186,10 +188,10 @@ func (s DailySchedule) maybeRunCallback(a *App) { if c := checkAllowlistDates(s.allowlistDates); c.fail { return } - if c := checkEnabledEntity(a.state, s.enabledEntity, s.enabledEntityState, s.enabledEntityRunOnError); c.fail { + if c := checkEnabledEntity(a.state, s.enabledEntities); c.fail { return } - if c := checkDisabledEntity(a.state, s.disabledEntity, s.disabledEntityState, s.disabledEntityRunOnError); c.fail { + if c := checkDisabledEntity(a.state, s.disabledEntities); c.fail { return } go s.callback(a.service, a.state) diff --git a/state.go b/state.go index 84a7d02..2eb0ad0 100644 --- a/state.go +++ b/state.go @@ -10,6 +10,15 @@ import ( "saml.dev/gome-assistant/internal/http" ) +type StateInterface interface { + AfterSunrise() bool + BeforeSunrise() bool + AfterSunset() bool + BeforeSunset() bool + Get() (EntityState, error) + Equals() (bool, error) +} + // State is used to retrieve state from Home Assistant. type State struct { httpClient *http.HttpClient