Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func (cfg *RawConfig) ParameterKeyExcludeModelWeights() string {
return cfg.ServiceName + "/exclude-model-weights"
}

func (cfg *RawConfig) ParameterKeyExcludeFiles() string {
return cfg.ServiceName + "/exclude-files"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe return cfg.ServiceName + "/exclude-file-patterns".

}

// /var/lib/dragonfly/model-csi/volumes
func (cfg *RawConfig) GetVolumesDir() string {
return filepath.Join(cfg.RootDir, "volumes")
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type mockPuller struct {
}

func (puller *mockPuller) Pull(
ctx context.Context, reference, targetDir string, excludeModelWeights bool,
ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string,
) error {
if err := os.MkdirAll(targetDir, 0755); err != nil {
return err
Expand Down Expand Up @@ -560,7 +560,7 @@ func TestServer(t *testing.T) {
cfg.Get().PullConfig.ProxyURL = ""
service.CacheScanInterval = 1 * time.Second

service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker) service.Puller {
service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker, excludeFilePatterns []string) service.Puller {
return &mockPuller{
pullCfg: pullCfg,
duration: time.Second * 2,
Expand Down
23 changes: 21 additions & 2 deletions pkg/service/controller_local.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -68,6 +69,24 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
}
}

excludeFilePatternsParam := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyExcludeFiles()])
var excludeFilePatterns []string
if excludeFilePatternsParam != "" {
if err := json.Unmarshal([]byte(excludeFilePatternsParam), &excludeFilePatterns); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using a comma-separated format for the parameter value? For example: /foo,/bar.

return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid JSON array: %v", s.cfg.Get().ParameterKeyExcludeFiles(), err)
}

// Validate patterns for security
for _, p := range excludeFilePatterns {
if strings.HasPrefix(p, "/") && len(p) > 1 {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: absolute paths not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
if strings.Contains(p, "..") {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: parent directory reference not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
}
}
Comment on lines +79 to +88

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation logic for excludeFilePatterns is duplicated in pkg/service/dynamic_server_handler.go, pkg/service/node.go, and pkg/service/patterns.go. To improve maintainability and ensure consistency, this logic should be extracted into a single, reusable function (e.g., ValidateExcludePatterns in the service package). This new function can then be called from all the places where validation is needed.


parentSpan := trace.SpanFromContext(ctx)
parentSpan.SetAttributes(attribute.String("volume_name", volumeName))
parentSpan.SetAttributes(attribute.String("reference", modelReference))
Expand All @@ -78,7 +97,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
startedAt := time.Now()
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
span.SetAttributes(attribute.String("model_dir", modelDir))
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil {
span.SetStatus(otelCodes.Error, "failed to pull model")
span.RecordError(err)
span.End()
Expand Down Expand Up @@ -111,7 +130,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
startedAt := time.Now()
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
span.SetAttributes(attribute.String("model_dir", modelDir))
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil {
span.SetStatus(otelCodes.Error, "failed to pull model")
span.RecordError(err)
span.End()
Expand Down
27 changes: 27 additions & 0 deletions pkg/service/dynamic_server_handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -86,6 +87,31 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error {
})
}

// Validate exclude_file_patterns
for _, p := range req.ExcludeFilePatterns {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path security check is unnecessary because, after using gitignore to match the results of modelArtifact.GetPatterns, a secure list of absolute paths can be obtained.

if strings.HasPrefix(p, "/") && len(p) > 1 {
return c.JSON(http.StatusBadRequest, ErrorResponse{
Code: ERR_CODE_INVALID_ARGUMENT,
Message: fmt.Sprintf("exclude_file_patterns: absolute paths not allowed: %s", p),
})
}
if strings.Contains(p, "..") {
return c.JSON(http.StatusBadRequest, ErrorResponse{
Code: ERR_CODE_INVALID_ARGUMENT,
Message: fmt.Sprintf("exclude_file_patterns: parent directory reference not allowed: %s", p),
})
}
}

excludeFilesJSON := "[]"
if len(req.ExcludeFilePatterns) > 0 {
jsonBytes, err := json.Marshal(req.ExcludeFilePatterns)
if err != nil {
return handleError(c, fmt.Errorf("marshal exclude_file_patterns: %w", err))
}
excludeFilesJSON = string(jsonBytes)
}

_, err := h.svc.CreateVolume(c.Request().Context(), &csi.CreateVolumeRequest{
Name: volumeName,
Parameters: map[string]string{
Expand All @@ -94,6 +120,7 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error {
h.cfg.Get().ParameterKeyMountID(): req.MountID,
h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
h.cfg.Get().ParameterKeyExcludeModelWeights(): strconv.FormatBool(req.ExcludeModelWeights),
h.cfg.Get().ParameterKeyExcludeFiles(): excludeFilesJSON,
},
})
if err != nil {
Expand Down
19 changes: 18 additions & 1 deletion pkg/service/node.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"encoding/json"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -104,8 +105,24 @@ func (s *Service) nodePublishVolume(
}
}

excludeFilePatternsParam := volumeAttributes[s.cfg.Get().ParameterKeyExcludeFiles()]
var excludeFilePatterns []string
if excludeFilePatternsParam != "" {
if err := json.Unmarshal([]byte(excludeFilePatternsParam), &excludeFilePatterns); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid JSON array: %v", s.cfg.Get().ParameterKeyExcludeFiles(), err)
}
for _, p := range excludeFilePatterns {
if strings.HasPrefix(p, "/") && len(p) > 1 {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: absolute paths not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
if strings.Contains(p, "..") {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: parent directory reference not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
}
}

logger.WithContext(ctx).Infof("publishing static inline volume: %s", staticInlineModelReference)
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights)
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights, excludeFilePatterns)
return resp, isStaticVolume, err
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/service/node_static_inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import (
"google.golang.org/grpc/status"
)

func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool) (*csi.NodePublishVolumeResponse, error) {
func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) {
modelDir := s.cfg.Get().GetModelDir(volumeName)

startedAt := time.Now()
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights, excludeFilePatterns); err != nil {
return nil, status.Error(codes.Internal, errors.Wrap(err, "pull model").Error())
}
duration := time.Since(startedAt)
Expand Down
173 changes: 173 additions & 0 deletions pkg/service/patterns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package service

import (
"os"
"path/filepath"
"sort"
"strings"

gitignore "github.com/go-git/go-git/v5/plumbing/format/gitignore"
"github.com/modelpack/model-csi-driver/pkg/logger"
"github.com/pkg/errors"
)

// FilePatternMatcher wraps gitignore pattern matching functionality
type FilePatternMatcher struct {
matcher gitignore.Matcher
patterns []string
}

// NewFilePatternMatcher creates a new pattern matcher from a list of gitignore-compatible patterns
func NewFilePatternMatcher(patterns []string) (*FilePatternMatcher, error) {
// Validate patterns for security
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

for _, p := range patterns {
// Check for absolute paths (starts with / and has more characters)
if strings.HasPrefix(p, "/") && len(p) > 1 {
return nil, errors.Errorf("absolute path patterns are not allowed: %s", p)
}
if strings.Contains(p, "..") {
return nil, errors.Errorf("parent directory reference is not allowed: %s", p)
}
}

// Create gitignore matcher from patterns
// Parse each string pattern into gitignore.Pattern
var gitPatterns []gitignore.Pattern
for _, p := range patterns {
gitPatterns = append(gitPatterns, gitignore.ParsePattern(p, nil))
}
matcher := gitignore.NewMatcher(gitPatterns)

return &FilePatternMatcher{
matcher: matcher,
patterns: patterns,
}, nil
}

// Match returns true if the given path matches any of the exclusion patterns
func (m *FilePatternMatcher) Match(path string) bool {
// gitignore matcher expects paths in forward-slash format
// and uses a slice of strings for path components
path = filepath.ToSlash(path)
pathParts := strings.Split(path, "/")
isDir := strings.HasSuffix(path, "/")
return m.matcher.Match(pathParts, isDir)
}

// Excludes returns true if any exclusion patterns are defined
func (m *FilePatternMatcher) Excludes() bool {
return len(m.patterns) > 0
}

// filterFilesByPatterns walks the target directory and removes files matching the exclusion patterns
// Returns a list of excluded file paths (relative to targetDir)
func filterFilesByPatterns(targetDir string, matcher *FilePatternMatcher) ([]string, error) {
excludedFiles := []string{}

// First pass: identify and remove matched files
err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// Skip the target directory itself
if path == targetDir {
return nil
}

// Get relative path for pattern matching
relPath, err := filepath.Rel(targetDir, path)
if err != nil {
return errors.Wrap(err, "get relative path")
}

// Check if file/directory matches exclusion pattern
if matcher.Match(relPath) {
if !info.IsDir() {
logger.Logger().Infof("Excluding file: %s", relPath)
excludedFiles = append(excludedFiles, relPath)

// Remove the file
if err := os.Remove(path); err != nil {
return errors.Wrapf(err, "remove excluded file: %s", relPath)
}
}
}

return nil
})

if err != nil {
return nil, errors.Wrap(err, "walk directory for pattern matching")
}

// Second pass: remove empty directories
removeEmptyDirectories(targetDir, matcher)

// Sort excluded files for consistent logging
sort.Strings(excludedFiles)

logger.Logger().Infof("Excluded %d file(s) matching patterns", len(excludedFiles))

return excludedFiles, nil
}

// removeEmptyDirectories removes empty directories that were created after file removal
func removeEmptyDirectories(targetDir string, matcher *FilePatternMatcher) {
dirsToRemove := []string{}

// First, find all empty directories
err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Continue on error
}

if info.IsDir() && path != targetDir {
isEmpty, _ := isDirEmpty(path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error returned from isDirEmpty is being ignored. This function can return errors other than io.EOF (e.g., permission errors when opening the directory), which would be silently ignored. This could lead to incorrect behavior, such as not removing a directory that should be removed. The error should be checked and handled, for example by logging it and continuing the walk.

isEmpty, err := isDirEmpty(path)
if err != nil {
    logger.Logger().WithError(err).Warnf("Failed to check if directory is empty: %s", path)
    return nil
}
if isEmpty {

if isEmpty {
dirsToRemove = append(dirsToRemove, path)
}
}

return nil
})

if err != nil {
logger.Logger().WithError(err).Warn("Failed to walk directories for cleanup")
return
}

// Remove empty directories in reverse order (deepest first)
for i := len(dirsToRemove) - 1; i >= 0; i-- {
dir := dirsToRemove[i]
if err := os.Remove(dir); err != nil {
logger.Logger().WithError(err).Warnf("Failed to remove empty directory: %s", dir)
} else {
relPath, _ := filepath.Rel(targetDir, dir)
logger.Logger().Infof("Removed empty directory: %s", relPath)
}
}
}

// isDirEmpty checks if a directory is empty
func isDirEmpty(dir string) (bool, error) {
f, err := os.Open(dir)
if err != nil {
return false, err
}
defer func(f *os.File) {
err = f.Close()
if err != nil {
return
}
}(f)

_, err = f.Readdirnames(1)
if err == nil {
return false, nil // Directory is not empty
}
if err.Error() == "EOF" {
return true, nil // Directory is empty
}
Comment on lines +169 to +171

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Comparing the error string with "EOF" is brittle and not idiomatic Go. It's better to use errors.Is(err, io.EOF) or err == io.EOF for this check. This will make the code more robust against changes in the error message. Note that this will require importing the io package.

if err == io.EOF {
		return true, nil // Directory is empty
	}

return false, err // Error reading directory
}
Loading