diff --git a/pkg/config/config.go b/pkg/config/config.go index 78cdee5..26925c0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -99,6 +99,10 @@ func (cfg *RawConfig) ParameterKeyExcludeModelWeights() string { return cfg.ServiceName + "/exclude-model-weights" } +func (cfg *RawConfig) ParameterKeyExcludeFiles() string { + return cfg.ServiceName + "/exclude-files" +} + // /var/lib/dragonfly/model-csi/volumes func (cfg *RawConfig) GetVolumesDir() string { return filepath.Join(cfg.RootDir, "volumes") diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 528fff8..e044fae 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -33,7 +33,7 @@ type mockPuller struct { } func (puller *mockPuller) Pull( - ctx context.Context, reference, targetDir string, excludeModelWeights bool, + ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string, ) error { if err := os.MkdirAll(targetDir, 0755); err != nil { return err @@ -560,7 +560,7 @@ func TestServer(t *testing.T) { cfg.Get().PullConfig.ProxyURL = "" service.CacheScanInterval = 1 * time.Second - service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker) service.Puller { + service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker, excludeFilePatterns []string) service.Puller { return &mockPuller{ pullCfg: pullCfg, duration: time.Second * 2, diff --git a/pkg/service/controller_local.go b/pkg/service/controller_local.go index 4a02997..36d69b6 100644 --- a/pkg/service/controller_local.go +++ b/pkg/service/controller_local.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "fmt" "os" "path/filepath" @@ -68,6 +69,24 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe } } + excludeFilePatternsParam := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyExcludeFiles()]) + var excludeFilePatterns []string + if excludeFilePatternsParam != "" { + if err := json.Unmarshal([]byte(excludeFilePatternsParam), &excludeFilePatterns); err != nil { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid JSON array: %v", s.cfg.Get().ParameterKeyExcludeFiles(), err) + } + + // Validate patterns for security + for _, p := range excludeFilePatterns { + if strings.HasPrefix(p, "/") && len(p) > 1 { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: absolute paths not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p) + } + if strings.Contains(p, "..") { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: parent directory reference not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p) + } + } + } + parentSpan := trace.SpanFromContext(ctx) parentSpan.SetAttributes(attribute.String("volume_name", volumeName)) parentSpan.SetAttributes(attribute.String("reference", modelReference)) @@ -78,7 +97,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe startedAt := time.Now() ctx, span := tracing.Tracer.Start(ctx, "PullModel") span.SetAttributes(attribute.String("model_dir", modelDir)) - if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil { + if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil { span.SetStatus(otelCodes.Error, "failed to pull model") span.RecordError(err) span.End() @@ -111,7 +130,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe startedAt := time.Now() ctx, span := tracing.Tracer.Start(ctx, "PullModel") span.SetAttributes(attribute.String("model_dir", modelDir)) - if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil { + if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil { span.SetStatus(otelCodes.Error, "failed to pull model") span.RecordError(err) span.End() diff --git a/pkg/service/dynamic_server_handler.go b/pkg/service/dynamic_server_handler.go index 9a8dbac..57980f9 100644 --- a/pkg/service/dynamic_server_handler.go +++ b/pkg/service/dynamic_server_handler.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "errors" "fmt" "net/http" @@ -86,6 +87,31 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error { }) } + // Validate exclude_file_patterns + for _, p := range req.ExcludeFilePatterns { + if strings.HasPrefix(p, "/") && len(p) > 1 { + return c.JSON(http.StatusBadRequest, ErrorResponse{ + Code: ERR_CODE_INVALID_ARGUMENT, + Message: fmt.Sprintf("exclude_file_patterns: absolute paths not allowed: %s", p), + }) + } + if strings.Contains(p, "..") { + return c.JSON(http.StatusBadRequest, ErrorResponse{ + Code: ERR_CODE_INVALID_ARGUMENT, + Message: fmt.Sprintf("exclude_file_patterns: parent directory reference not allowed: %s", p), + }) + } + } + + excludeFilesJSON := "[]" + if len(req.ExcludeFilePatterns) > 0 { + jsonBytes, err := json.Marshal(req.ExcludeFilePatterns) + if err != nil { + return handleError(c, fmt.Errorf("marshal exclude_file_patterns: %w", err)) + } + excludeFilesJSON = string(jsonBytes) + } + _, err := h.svc.CreateVolume(c.Request().Context(), &csi.CreateVolumeRequest{ Name: volumeName, Parameters: map[string]string{ @@ -94,6 +120,7 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error { h.cfg.Get().ParameterKeyMountID(): req.MountID, h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota), h.cfg.Get().ParameterKeyExcludeModelWeights(): strconv.FormatBool(req.ExcludeModelWeights), + h.cfg.Get().ParameterKeyExcludeFiles(): excludeFilesJSON, }, }) if err != nil { diff --git a/pkg/service/node.go b/pkg/service/node.go index 6a0ec6c..151cd91 100644 --- a/pkg/service/node.go +++ b/pkg/service/node.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "path/filepath" "strconv" "strings" @@ -104,8 +105,24 @@ func (s *Service) nodePublishVolume( } } + excludeFilePatternsParam := volumeAttributes[s.cfg.Get().ParameterKeyExcludeFiles()] + var excludeFilePatterns []string + if excludeFilePatternsParam != "" { + if err := json.Unmarshal([]byte(excludeFilePatternsParam), &excludeFilePatterns); err != nil { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid JSON array: %v", s.cfg.Get().ParameterKeyExcludeFiles(), err) + } + for _, p := range excludeFilePatterns { + if strings.HasPrefix(p, "/") && len(p) > 1 { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: absolute paths not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p) + } + if strings.Contains(p, "..") { + return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: parent directory reference not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p) + } + } + } + logger.WithContext(ctx).Infof("publishing static inline volume: %s", staticInlineModelReference) - resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights) + resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights, excludeFilePatterns) return resp, isStaticVolume, err } diff --git a/pkg/service/node_static_inline.go b/pkg/service/node_static_inline.go index 4882b72..5c064bb 100644 --- a/pkg/service/node_static_inline.go +++ b/pkg/service/node_static_inline.go @@ -15,11 +15,11 @@ import ( "google.golang.org/grpc/status" ) -func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool) (*csi.NodePublishVolumeResponse, error) { +func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) { modelDir := s.cfg.Get().GetModelDir(volumeName) startedAt := time.Now() - if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights); err != nil { + if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights, excludeFilePatterns); err != nil { return nil, status.Error(codes.Internal, errors.Wrap(err, "pull model").Error()) } duration := time.Since(startedAt) diff --git a/pkg/service/patterns.go b/pkg/service/patterns.go new file mode 100644 index 0000000..f2597dc --- /dev/null +++ b/pkg/service/patterns.go @@ -0,0 +1,173 @@ +package service + +import ( + "os" + "path/filepath" + "sort" + "strings" + + gitignore "github.com/go-git/go-git/v5/plumbing/format/gitignore" + "github.com/modelpack/model-csi-driver/pkg/logger" + "github.com/pkg/errors" +) + +// FilePatternMatcher wraps gitignore pattern matching functionality +type FilePatternMatcher struct { + matcher gitignore.Matcher + patterns []string +} + +// NewFilePatternMatcher creates a new pattern matcher from a list of gitignore-compatible patterns +func NewFilePatternMatcher(patterns []string) (*FilePatternMatcher, error) { + // Validate patterns for security + for _, p := range patterns { + // Check for absolute paths (starts with / and has more characters) + if strings.HasPrefix(p, "/") && len(p) > 1 { + return nil, errors.Errorf("absolute path patterns are not allowed: %s", p) + } + if strings.Contains(p, "..") { + return nil, errors.Errorf("parent directory reference is not allowed: %s", p) + } + } + + // Create gitignore matcher from patterns + // Parse each string pattern into gitignore.Pattern + var gitPatterns []gitignore.Pattern + for _, p := range patterns { + gitPatterns = append(gitPatterns, gitignore.ParsePattern(p, nil)) + } + matcher := gitignore.NewMatcher(gitPatterns) + + return &FilePatternMatcher{ + matcher: matcher, + patterns: patterns, + }, nil +} + +// Match returns true if the given path matches any of the exclusion patterns +func (m *FilePatternMatcher) Match(path string) bool { + // gitignore matcher expects paths in forward-slash format + // and uses a slice of strings for path components + path = filepath.ToSlash(path) + pathParts := strings.Split(path, "/") + isDir := strings.HasSuffix(path, "/") + return m.matcher.Match(pathParts, isDir) +} + +// Excludes returns true if any exclusion patterns are defined +func (m *FilePatternMatcher) Excludes() bool { + return len(m.patterns) > 0 +} + +// filterFilesByPatterns walks the target directory and removes files matching the exclusion patterns +// Returns a list of excluded file paths (relative to targetDir) +func filterFilesByPatterns(targetDir string, matcher *FilePatternMatcher) ([]string, error) { + excludedFiles := []string{} + + // First pass: identify and remove matched files + err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip the target directory itself + if path == targetDir { + return nil + } + + // Get relative path for pattern matching + relPath, err := filepath.Rel(targetDir, path) + if err != nil { + return errors.Wrap(err, "get relative path") + } + + // Check if file/directory matches exclusion pattern + if matcher.Match(relPath) { + if !info.IsDir() { + logger.Logger().Infof("Excluding file: %s", relPath) + excludedFiles = append(excludedFiles, relPath) + + // Remove the file + if err := os.Remove(path); err != nil { + return errors.Wrapf(err, "remove excluded file: %s", relPath) + } + } + } + + return nil + }) + + if err != nil { + return nil, errors.Wrap(err, "walk directory for pattern matching") + } + + // Second pass: remove empty directories + removeEmptyDirectories(targetDir, matcher) + + // Sort excluded files for consistent logging + sort.Strings(excludedFiles) + + logger.Logger().Infof("Excluded %d file(s) matching patterns", len(excludedFiles)) + + return excludedFiles, nil +} + +// removeEmptyDirectories removes empty directories that were created after file removal +func removeEmptyDirectories(targetDir string, matcher *FilePatternMatcher) { + dirsToRemove := []string{} + + // First, find all empty directories + err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil // Continue on error + } + + if info.IsDir() && path != targetDir { + isEmpty, _ := isDirEmpty(path) + if isEmpty { + dirsToRemove = append(dirsToRemove, path) + } + } + + return nil + }) + + if err != nil { + logger.Logger().WithError(err).Warn("Failed to walk directories for cleanup") + return + } + + // Remove empty directories in reverse order (deepest first) + for i := len(dirsToRemove) - 1; i >= 0; i-- { + dir := dirsToRemove[i] + if err := os.Remove(dir); err != nil { + logger.Logger().WithError(err).Warnf("Failed to remove empty directory: %s", dir) + } else { + relPath, _ := filepath.Rel(targetDir, dir) + logger.Logger().Infof("Removed empty directory: %s", relPath) + } + } +} + +// isDirEmpty checks if a directory is empty +func isDirEmpty(dir string) (bool, error) { + f, err := os.Open(dir) + if err != nil { + return false, err + } + defer func(f *os.File) { + err = f.Close() + if err != nil { + return + } + }(f) + + _, err = f.Readdirnames(1) + if err == nil { + return false, nil // Directory is not empty + } + if err.Error() == "EOF" { + return true, nil // Directory is empty + } + return false, err // Error reading directory +} diff --git a/pkg/service/patterns_integration_test.go b/pkg/service/patterns_integration_test.go new file mode 100644 index 0000000..a40cfd1 --- /dev/null +++ b/pkg/service/patterns_integration_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/status" +) + +func TestPullModel_WithExcludeFilePatterns(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // This test requires a real model registry to be available + // Skip if not in CI environment or if test registry not configured + if os.Getenv("TEST_MODEL_REGISTRY") == "" { + t.Skip("TEST_MODEL_REGISTRY not set") + } + + ctx := context.Background() + + // Create a minimal test config + tmpDir := t.TempDir() + cfgPath := filepath.Join(tmpDir, "config.yaml") + cfgContent := `service_name: "model.csi.modelpack.org" +csi_endpoint: "/tmp/csi.sock" +root_dir: "` + tmpDir + `" +pull_config: + concurrency: 5 +` + if err := os.WriteFile(cfgPath, []byte(cfgContent), 0644); err != nil { + t.Fatalf("Failed to create test config: %v", err) + } + + cfg, err := config.New(cfgPath) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + sm, err := status.NewStatusManager() + if err != nil { + t.Fatalf("Failed to create status manager: %v", err) + } + worker, err := NewWorker(cfg, sm) + if err != nil { + t.Fatalf("Failed to create worker: %v", err) + } + + modelDir := filepath.Join(tmpDir, "model") + + testReference := os.Getenv("TEST_MODEL_REFERENCE") + if testReference == "" { + testReference = "docker.io/library/test-model:latest" + } + + excludeFilePatterns := []string{"*.safetensors"} + + // Pull model with file pattern exclusion + err = worker.PullModel( + ctx, + true, // isStaticVolume + "test-volume", + "", + testReference, + modelDir, + false, // checkDiskQuota + false, // excludeModelWeights + excludeFilePatterns, + ) + + if err != nil { + t.Fatalf("PullModel failed: %v", err) + } + + // Verify .safetensors files were excluded + files, err := os.ReadDir(modelDir) + if err != nil { + t.Fatalf("Failed to read model dir: %v", err) + } + + for _, f := range files { + if filepath.Ext(f.Name()) == ".safetensors" { + t.Errorf("Found .safetensors file that should have been excluded: %s", f.Name()) + } + } + + // Verify other files still exist + foundConfig := false + err = filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && filepath.Ext(path) == ".json" { + foundConfig = true + } + return nil + }) + if err != nil { + return + } + + if !foundConfig { + t.Error("Expected to find .json files, but none were found") + } + + // Cleanup + _ = worker.DeleteModel(ctx, true, "test-volume", "") +} diff --git a/pkg/service/patterns_test.go b/pkg/service/patterns_test.go new file mode 100644 index 0000000..2c8cfa5 --- /dev/null +++ b/pkg/service/patterns_test.go @@ -0,0 +1,242 @@ +package service + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestFilePatternMatcher_Match(t *testing.T) { + tests := []struct { + name string + patterns []string + path string + want bool + }{ + { + name: "wildcard matches safetensors", + patterns: []string{"*.safetensors"}, + path: "model.safetensors", + want: true, + }, + { + name: "wildcard does not match json", + patterns: []string{"*.safetensors"}, + path: "config.json", + want: false, + }, + { + name: "negative pattern overrides positive", + patterns: []string{"*.safetensors", "!tiktoken.model"}, + path: "tiktoken.model", + want: false, + }, + { + name: "empty patterns matches nothing", + patterns: []string{}, + path: "any.file", + want: false, + }, + { + name: "directory pattern", + patterns: []string{".git/*"}, + path: ".git/config", + want: true, + }, + { + name: "nested directory pattern", + patterns: []string{"**/*.bin"}, + path: "layers/model-00001.bin", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matcher, err := NewFilePatternMatcher(tt.patterns) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + if got := matcher.Match(tt.path); got != tt.want { + t.Errorf("FilePatternMatcher.Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFilePatternMatcher_Excludes(t *testing.T) { + t.Run("empty patterns returns false", func(t *testing.T) { + matcher, err := NewFilePatternMatcher([]string{}) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + if matcher.Excludes() { + t.Error("Excludes() should return false for empty patterns") + } + }) + + t.Run("non-empty patterns returns true", func(t *testing.T) { + matcher, err := NewFilePatternMatcher([]string{"*.safetensors"}) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + if !matcher.Excludes() { + t.Error("Excludes() should return true for non-empty patterns") + } + }) +} + +func TestNewFilePatternMatcher_Validation(t *testing.T) { + tests := []struct { + name string + patterns []string + wantErr bool + }{ + { + name: "absolute path is rejected", + patterns: []string{"/absolute/path"}, + wantErr: true, + }, + { + name: "parent directory reference is rejected", + patterns: []string{"../escape"}, + wantErr: true, + }, + { + name: "valid patterns", + patterns: []string{"*.safetensors", "!tiktoken.model", "config/*"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFilePatternMatcher(tt.patterns) + if (err != nil) != tt.wantErr { + t.Errorf("NewFilePatternMatcher() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestFilterFilesByPatterns(t *testing.T) { + // Create a temporary directory structure for testing + tmpDir := t.TempDir() + + // Create test files + testFiles := []string{ + "model.safetensors", + "model.safetensors.index.json", + "config.json", + "tokenizer/tiktoken.model", + "layers/model-00001.bin", + ".git/config", + } + + for _, f := range testFiles { + fullPath := filepath.Join(tmpDir, f) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(fullPath, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + + t.Run("excludes safetensors files", func(t *testing.T) { + matcher, err := NewFilePatternMatcher([]string{"*.safetensors"}) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + + excluded, err := filterFilesByPatterns(tmpDir, matcher) + if err != nil { + t.Fatalf("filterFilesByPatterns() error = %v", err) + } + + if len(excluded) != 1 { + t.Errorf("Expected 1 excluded file, got %d", len(excluded)) + } + + // Verify file was actually deleted + if _, err := os.Stat(filepath.Join(tmpDir, "model.safetensors")); !os.IsNotExist(err) { + t.Error("model.safetensors should have been deleted") + } + + // Verify other files still exist + if _, err := os.Stat(filepath.Join(tmpDir, "config.json")); err != nil { + t.Error("config.json should still exist") + } + }) + + t.Run("negative pattern includes file", func(t *testing.T) { + // Recreate test files for this subtest + for _, f := range testFiles { + fullPath := filepath.Join(tmpDir, f) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(fullPath, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + + matcher, err := NewFilePatternMatcher([]string{"*.bin", "!tokenizer/*"}) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + + excluded, err := filterFilesByPatterns(tmpDir, matcher) + if err != nil { + t.Fatalf("filterFilesByPatterns() error = %v", err) + } + + // Only .bin file should be excluded, not tokenizer files + hasBin := false + hasTokenizer := false + for _, f := range excluded { + if strings.Contains(f, ".bin") { + hasBin = true + } + if strings.Contains(f, "tokenizer") { + hasTokenizer = true + } + } + + if !hasBin { + t.Error("Expected .bin file to be excluded") + } + if hasTokenizer { + t.Error("Tokenizer file should NOT be excluded due to negative pattern") + } + }) + + t.Run("removes empty directories", func(t *testing.T) { + // Recreate test files for this subtest + for _, f := range testFiles { + fullPath := filepath.Join(tmpDir, f) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(fullPath, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + + matcher, err := NewFilePatternMatcher([]string{"layers/*"}) + if err != nil { + t.Fatalf("NewFilePatternMatcher() error = %v", err) + } + + _, err = filterFilesByPatterns(tmpDir, matcher) + if err != nil { + t.Fatalf("filterFilesByPatterns() error = %v", err) + } + + // Verify empty layers directory was removed + if _, err := os.Stat(filepath.Join(tmpDir, "layers")); !os.IsNotExist(err) { + t.Error("Empty layers directory should have been removed") + } + }) +} diff --git a/pkg/service/puller.go b/pkg/service/puller.go index 1f28726..143108c 100644 --- a/pkg/service/puller.go +++ b/pkg/service/puller.go @@ -2,7 +2,6 @@ package service import ( "context" - "io" "os" "strings" @@ -22,24 +21,26 @@ type PullHook interface { } type Puller interface { - Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error + Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error } -var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller { +var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker, excludeFilePatterns []string) Puller { return &puller{ - pullCfg: pullCfg, - hook: hook, - diskQuotaChecker: diskQuotaChecker, + pullCfg: pullCfg, + hook: hook, + diskQuotaChecker: diskQuotaChecker, + excludeFilePatterns: excludeFilePatterns, } } type puller struct { - pullCfg *config.PullConfig - hook *status.Hook - diskQuotaChecker *DiskQuotaChecker + pullCfg *config.PullConfig + hook *status.Hook + diskQuotaChecker *DiskQuotaChecker + excludeFilePatterns []string } -func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error { +func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error { keyChain, err := auth.GetKeyChainByRef(reference) if err != nil { return errors.Wrapf(err, "get auth for model: %s", reference) @@ -63,49 +64,79 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeM return errors.Wrapf(err, "create model dir: %s", targetDir) } - if !excludeModelWeights { - pullConfig := modctlConfig.NewPull() - pullConfig.Concurrency = int(p.pullCfg.Concurrency) - pullConfig.PlainHTTP = plainHTTP - pullConfig.Proxy = p.pullCfg.ProxyURL - pullConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint - pullConfig.Insecure = true - pullConfig.ExtractDir = targetDir - pullConfig.ExtractFromRemote = true - pullConfig.Hooks = p.hook - pullConfig.ProgressWriter = io.Discard - pullConfig.DisableProgress = true - - if err := b.Pull(ctx, reference, pullConfig); err != nil { - logger.WithContext(ctx).WithError(err).Errorf("failed to pull model image: %s", reference) - return errors.Wrap(err, "pull model image") + // Determine which files to fetch/pull based on patterns + var fetchPatterns []string + if len(excludeFilePatterns) > 0 { + // Apply exclude patterns to all available files + // First, get all layers (files) from the model + patterns, err := modelArtifact.GetPatterns(ctx, excludeModelWeights) + if err != nil { + return errors.Wrap(err, "get all model layers") } - return nil - } + // Create matcher from user-provided patterns + matcher, err := NewFilePatternMatcher(excludeFilePatterns) + if err != nil { + return errors.Wrap(err, "create file pattern matcher") + } - patterns, err := modelArtifact.GetPatterns(ctx, excludeModelWeights) - if err != nil { - return errors.Wrap(err, "get model file patterns without weights") - } + // Filter files: include only those NOT matched by exclusion patterns + for _, pattern := range patterns { + // Check if this file should be included (not matched by exclusion patterns) + if !matcher.Match(pattern) { + fetchPatterns = append(fetchPatterns, pattern) + logger.WithContext(ctx).Infof("Including file from fetch: %s", pattern) + } else { + logger.WithContext(ctx).Infof("Excluding file from fetch: %s", pattern) + } + } + + if len(fetchPatterns) == 0 { + logger.WithContext(ctx).Warn("No files matched include patterns, all files would be excluded") + } + } else { + // No exclude patterns, fetch all non-weight files (original behavior) + patterns, err := modelArtifact.GetPatterns(ctx, excludeModelWeights) + if err != nil { + return errors.Wrap(err, "get model file patterns without weights") + } - logger.WithContext(ctx).Infof( - "fetching model without weights: %s, file patterns: %s", - reference, strings.Join(patterns, ", "), - ) + logger.WithContext(ctx).Infof( + "fetching model without weights: %s, file patterns: %s", + reference, strings.Join(patterns, ", "), + ) + fetchPatterns = patterns + } + + // Fetch files fetchConfig := modctlConfig.NewFetch() fetchConfig.Concurrency = int(p.pullCfg.Concurrency) fetchConfig.PlainHTTP = plainHTTP fetchConfig.Proxy = p.pullCfg.ProxyURL fetchConfig.Insecure = true fetchConfig.Output = targetDir - fetchConfig.Patterns = patterns + fetchConfig.Patterns = fetchPatterns if err := b.Fetch(ctx, reference, fetchConfig); err != nil { logger.WithContext(ctx).WithError(err).Errorf("failed to fetch model: %s", reference) return errors.Wrap(err, "fetch model") } + // Apply file pattern filtering if exclude_file_patterns are provided + if len(excludeFilePatterns) > 0 { + matcher, err := NewFilePatternMatcher(excludeFilePatterns) + if err != nil { + return errors.Wrap(err, "create file pattern matcher") + } + + logger.WithContext(ctx).Infof("Applying file exclusion patterns: %v", excludeFilePatterns) + + _, err = filterFilesByPatterns(targetDir, matcher) + if err != nil { + return errors.Wrap(err, "filter files by patterns") + } + } + return nil } diff --git a/pkg/service/request.go b/pkg/service/request.go index 7b182bf..4de6a98 100644 --- a/pkg/service/request.go +++ b/pkg/service/request.go @@ -1,8 +1,9 @@ package service type MountRequest struct { - MountID string `json:"mount_id"` - Reference string `json:"reference"` - CheckDiskQuota bool `json:"check_disk_quota"` - ExcludeModelWeights bool `json:"exclude_model_weights"` + MountID string `json:"mount_id"` + Reference string `json:"reference"` + CheckDiskQuota bool `json:"check_disk_quota"` + ExcludeModelWeights bool `json:"exclude_model_weights"` + ExcludeFilePatterns []string `json:"exclude_file_patterns"` } diff --git a/pkg/service/testdata/config.yaml b/pkg/service/testdata/config.yaml new file mode 100644 index 0000000..b993c46 --- /dev/null +++ b/pkg/service/testdata/config.yaml @@ -0,0 +1,5 @@ +service_name: "model.csi.modelpack.org" +csi_endpoint: "/tmp/csi.sock" +root_dir: "/tmp/model-csi" +pull_config: + concurrency: 5 diff --git a/pkg/service/worker.go b/pkg/service/worker.go index 8678655..756c8e7 100644 --- a/pkg/service/worker.go +++ b/pkg/service/worker.go @@ -52,7 +52,7 @@ func (cm *ContextMap) Get(key string) *context.CancelFunc { type Worker struct { cfg *config.Config - newPuller func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller + newPuller func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker, excludeFilePatterns []string) Puller sm *status.StatusManager inflight singleflight.Group contextMap *ContextMap @@ -126,11 +126,12 @@ func (worker *Worker) PullModel( modelDir string, checkDiskQuota bool, excludeModelWeights bool, + excludeFilePatterns []string, ) error { start := time.Now() statusPath := filepath.Join(filepath.Dir(modelDir), "status.json") - err := worker.pullModel(ctx, statusPath, volumeName, mountID, reference, modelDir, checkDiskQuota, excludeModelWeights) + err := worker.pullModel(ctx, statusPath, volumeName, mountID, reference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns) metrics.NodeOpObserve("pull_image", start, err) if err != nil && !errors.Is(err, ErrConflict) { @@ -142,7 +143,7 @@ func (worker *Worker) PullModel( return err } -func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mountID, reference, modelDir string, checkDiskQuota, excludeModelWeights bool) error { +func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mountID, reference, modelDir string, checkDiskQuota, excludeModelWeights bool, excludeFilePatterns []string) error { setStatus := func(state status.State) (*status.Status, error) { status, err := worker.sm.Set(statusPath, status.Status{ VolumeName: volumeName, @@ -192,12 +193,12 @@ func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mou if checkDiskQuota { diskQuotaChecker = NewDiskQuotaChecker(worker.cfg) } - puller := worker.newPuller(ctx, &worker.cfg.Get().PullConfig, hook, diskQuotaChecker) + puller := worker.newPuller(ctx, &worker.cfg.Get().PullConfig, hook, diskQuotaChecker, excludeFilePatterns) _, err := setStatus(status.StatePullRunning) if err != nil { return nil, errors.Wrapf(err, "set status before pull model") } - if err := puller.Pull(ctx, reference, modelDir, excludeModelWeights); err != nil { + if err := puller.Pull(ctx, reference, modelDir, excludeModelWeights, excludeFilePatterns); err != nil { if errors.Is(err, context.Canceled) { err = errors.Wrapf(err, "pull model canceled") if _, err2 := setStatus(status.StatePullCanceled); err2 != nil {