99import operator
1010import re
1111import typing
12+ from functools import partial
1213
1314import marshmallow
1415
@@ -419,6 +420,7 @@ class PredicateFilter:
419420 "!=" : operator .ne ,
420421 "<>" : operator .ne ,
421422 "is not" : operator .is_not ,
423+ "in" : operator .contains ,
422424 }
423425
424426 def __init__ (self , predicate , schema ):
@@ -481,6 +483,8 @@ def filter_field(
481483 for child_doc in obj
482484 )
483485
486+ is_sequence_value = False
487+
484488 if isinstance (schema_field , marshmallow .fields .Dict ):
485489 obj = schema_field ._deserialize (obj , None , None )
486490 assert len (path ) == 1
@@ -489,15 +493,27 @@ def filter_field(
489493 if obj is not None :
490494 obj = schema_field ._deserialize (obj , None , None )
491495 if value is not None :
492- value = schema_field ._deserialize (value , None , None )
496+ is_sequence_value = isinstance (value , tuple ) or isinstance (value , list )
497+ if is_sequence_value :
498+ deserialize = partial (
499+ schema_field ._deserialize , attr = None , data = None
500+ )
501+ value = list (map (deserialize , value ))
502+ else :
503+ value = schema_field ._deserialize (value , None , None )
493504
494- # Case insensitve comparison for strings
505+ op = self .operators [operator_value ]
506+
507+ # Case insensitive comparison for strings
495508 if isinstance (obj , str ):
496509 obj = obj .lower ()
497510 if isinstance (value , str ):
498511 value = value .lower ()
512+ elif is_sequence_value :
513+ value = [x .lower () if isinstance (x , str ) else x for x in value ]
514+ if operator_value == "in" :
515+ return op (value , obj )
499516
500- op = self .operators [operator_value ]
501517 return op (obj , value )
502518
503519 def case_insensitive_get (sef , dict , key , default = None ):
0 commit comments