diff --git a/.env b/.env index ca2d149fe946..769b8fe386b4 100644 --- a/.env +++ b/.env @@ -94,3 +94,7 @@ # # Time in duration format (e.g. 1h30m) after which a backend is considered busy # LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m + +# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1 +# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1" +# LOCALAI_IP_ALLOWLIST=192.168.1.0/24 diff --git a/core/cli/run.go b/core/cli/run.go index df84ef790eb3..4df9fbb970fc 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -55,6 +55,7 @@ type RunCMD struct { ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"` Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` + IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"` CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` @@ -192,6 +193,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile) xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels) }), + config.WithIPAllowList(r.IpAllowList), } if r.DisableMetricsEndpoint { diff --git a/core/config/application_config.go b/core/config/application_config.go index 8edc22c00483..28030a22e425 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -6,6 +6,7 @@ import ( "regexp" "time" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" @@ -93,6 +94,11 @@ type ApplicationConfig struct { PathWithoutAuth []string + // ie: 192.168.1.0/24,10.0.0.1,127.0.0.1 + IpAllowList string + + IPAllowListHelper *utils.IPAllowList + // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig } @@ -205,6 +211,18 @@ func WithP2PToken(s string) AppOption { } } +func WithIPAllowList(s string) AppOption { + return func(o *ApplicationConfig) { + xlog.Info("Application IpAllowList($LOCALAI_IP_ALLOWLIST)", "value", s) + o.IpAllowList = s + ipAllowListHelper, err := utils.NewIPAllowList(s) + if err != nil { + xlog.Error("Failed to parse IpAllowList", "error", err, "value", s) + } + o.IPAllowListHelper = ipAllowListHelper + } +} + var EnableWatchDog = func(o *ApplicationConfig) { o.WatchDog = true } diff --git a/core/http/app.go b/core/http/app.go index faf343a385b2..d44add1dffa3 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -152,6 +152,21 @@ func API(application *application.Application) (*echo.Echo, error) { e.Use(middleware.Recover()) } + // IP restriction middleware + if application.ApplicationConfig().IPAllowListHelper != nil { + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + clientIP := c.RealIP() + if !application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{Message: "Forbidden: your IP is not allowed", Code: http.StatusForbidden}, + }) + } + return next(c) + } + }) + } + // Metrics middleware if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() diff --git a/core/http/utils/ipallowlist.go b/core/http/utils/ipallowlist.go new file mode 100644 index 000000000000..d3f318608ea4 --- /dev/null +++ b/core/http/utils/ipallowlist.go @@ -0,0 +1,96 @@ +package utils + +import ( + "fmt" + "net" + "net/netip" + "strings" + "sync" +) + +type IPAllowList struct { + allowList string + cidrs []*net.IPNet + ips []net.IP + mu sync.RWMutex + enabled bool +} + +func NewIPAllowList(allowList string) (*IPAllowList, error) { + + w := &IPAllowList{} + err := w.Update(allowList) + return w, err +} + +func (w *IPAllowList) GetAllowList() string { + return w.allowList +} + +func (w *IPAllowList) Update(allowListStr string) error { + var cidrs []*net.IPNet + var ips []net.IP + + allowList := make([]string, 0) + if allowListStr != "" { + allowList = strings.Split(allowListStr, ",") + } + + for _, item := range allowList { + _, cidrNet, err := net.ParseCIDR(item) + if err == nil { + cidrs = append(cidrs, cidrNet) + } else { + ip := net.ParseIP(item) + if ip != nil { + ips = append(ips, ip) + } else { + return fmt.Errorf("invalid allowList item: %s", item) + } + } + } + + w.mu.Lock() + defer w.mu.Unlock() + w.allowList = allowListStr + w.cidrs = cidrs + w.ips = ips + w.enabled = len(cidrs) > 0 || len(ips) > 0 + return nil +} + +func (w *IPAllowList) IsAllowed(ip interface{}) bool { + if !w.enabled { + return true + } + + var parsedIP net.IP + switch v := ip.(type) { + case string: + parsedIP = net.ParseIP(v) + case net.IP: + parsedIP = v + case netip.Addr: + parsedIP = net.IP(v.AsSlice()) + } + + if parsedIP == nil { + return false + } + + w.mu.RLock() + defer w.mu.RUnlock() + + for _, cidr := range w.cidrs { + if cidr.Contains(parsedIP) { + return true + } + } + + for _, allowedIP := range w.ips { + if parsedIP.Equal(allowedIP) { + return true + } + } + return false +} \ No newline at end of file diff --git a/core/http/utils/ipallowlist_test.go b/core/http/utils/ipallowlist_test.go new file mode 100644 index 000000000000..878fd0edfc59 --- /dev/null +++ b/core/http/utils/ipallowlist_test.go @@ -0,0 +1,36 @@ +package utils_test + +import ( + . "github.com/mudler/LocalAI/core/http/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("IPAllowList", func() { + It("allows all IPs when allowlist is empty", func() { + w, err := NewIPAllowList("") + Expect(err).ToNot(HaveOccurred()) + Expect(w.IsAllowed("192.168.1.100")).To(BeTrue()) + }) + + It("respects CIDRs and explicit IPs", func() { + allowList := "192.168.1.0/24,10.0.0.1,127.0.0.1" + w, err := NewIPAllowList(allowList) + Expect(err).ToNot(HaveOccurred()) + + cases := []struct { + ip string + expected bool + }{ + {"192.168.1.100", true}, + {"10.0.0.1", true}, + {"127.0.0.1", true}, + {"10.0.0.2", false}, + {"172.16.0.1", false}, + } + + for _, tc := range cases { + Expect(w.IsAllowed(tc.ip)).To(Equal(tc.expected), "IP: %s", tc.ip) + } + }) +}) \ No newline at end of file diff --git a/pkg/utils/ip_allowlist.go b/pkg/utils/ip_allowlist.go new file mode 100644 index 000000000000..cd42374281e6 --- /dev/null +++ b/pkg/utils/ip_allowlist.go @@ -0,0 +1,96 @@ +package utils + +import ( + "fmt" + "net" + "net/netip" + "strings" + "sync" +) + +type IPAllowList struct { + allowList string + cidrs []*net.IPNet + ips []net.IP + mu sync.RWMutex + enabled bool +} + +func NewIPAllowList(allowList string) (*IPAllowList, error) { + + w := &IPAllowList{} + err := w.Update(allowList) + return w, err +} + +func (w *IPAllowList) GetAllowList() string { + return w.allowList +} + +func (w *IPAllowList) Update(allowListStr string) error { + var cidrs []*net.IPNet + var ips []net.IP + + allowList := make([]string, 0) + if allowListStr != "" { + allowList = strings.Split(allowListStr, ",") + } + + for _, item := range allowList { + _, cidrNet, err := net.ParseCIDR(item) + if err == nil { + cidrs = append(cidrs, cidrNet) + } else { + ip := net.ParseIP(item) + if ip != nil { + ips = append(ips, ip) + } else { + return fmt.Errorf("invalid allowList item: %s", item) + } + } + } + + w.mu.Lock() + defer w.mu.Unlock() + w.allowList = allowListStr + w.cidrs = cidrs + w.ips = ips + w.enabled = len(cidrs) > 0 || len(ips) > 0 + return nil +} + +func (w *IPAllowList) IsAllowed(ip interface{}) bool { + if !w.enabled { + return true + } + + var parsedIP net.IP + switch v := ip.(type) { + case string: + parsedIP = net.ParseIP(v) + case net.IP: + parsedIP = v + case netip.Addr: + parsedIP = net.IP(v.AsSlice()) + } + + if parsedIP == nil { + return false + } + + w.mu.RLock() + defer w.mu.RUnlock() + + for _, cidr := range w.cidrs { + if cidr.Contains(parsedIP) { + return true + } + } + + for _, allowedIP := range w.ips { + if parsedIP.Equal(allowedIP) { + return true + } + } + return false +} diff --git a/pkg/utils/ip_allowlist_test.go b/pkg/utils/ip_allowlist_test.go new file mode 100644 index 000000000000..8cb7bb37e299 --- /dev/null +++ b/pkg/utils/ip_allowlist_test.go @@ -0,0 +1,36 @@ +package utils_test + +import ( + . "github.com/mudler/LocalAI/pkg/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("IPAllowList", func() { + It("allows all IPs when allowlist is empty", func() { + w, err := NewIPAllowList("") + Expect(err).ToNot(HaveOccurred()) + Expect(w.IsAllowed("192.168.1.100")).To(BeTrue()) + }) + + It("respects CIDRs and explicit IPs", func() { + allowList := "192.168.1.0/24,10.0.0.1,127.0.0.1" + w, err := NewIPAllowList(allowList) + Expect(err).ToNot(HaveOccurred()) + + cases := []struct { + ip string + expected bool + }{ + {"192.168.1.100", true}, + {"10.0.0.1", true}, + {"127.0.0.1", true}, + {"10.0.0.2", false}, + {"172.16.0.1", false}, + } + + for _, tc := range cases { + Expect(w.IsAllowed(tc.ip)).To(Equal(tc.expected), "IP: %s", tc.ip) + } + }) +})