@@ -10,6 +10,26 @@ from .util import fsspec_cache_open
1010
1111np.import_array()
1212
13+ cdef convert_detections_to_tuples(dn.detection* detections, int num_dets, str nms_type, float nms_threshold):
14+ if nms_threshold > 0 and num_dets > 0 :
15+ if nms_type == " obj" :
16+ dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
17+ elif nms_type == " sort" :
18+ dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
19+ else :
20+ raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
21+ rv = [
22+ (j,
23+ detections[i].prob[j],
24+ (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h)
25+ )
26+ for i in range (num_dets)
27+ for j in range (detections[i].classes)
28+ if detections[i].prob[j] > 0
29+ ]
30+ return sorted (rv, key = lambda x : x[1 ], reverse = True )
31+
32+
1333cdef class Metadata:
1434 classes = [] # typing: List[AnyStr]
1535
@@ -26,15 +46,16 @@ cdef class Network:
2646 cdef dn.network* _c_network
2747
2848 @staticmethod
29- def open (config_url , weights_url ):
49+ def open (config_url , weights_url , batch_size = 1 ):
3050 with fsspec_cache_open(config_url, mode = " rt" ) as config:
3151 with fsspec_cache_open(weights_url, mode = " rb" ) as weights:
32- return Network(config.name, weights.name)
33-
52+ return Network(config.name, weights.name, batch_size)
3453
35- def __cinit__ (self , config_file , weights_file ):
36- clear = 1
37- self ._c_network = dn.load_network(config_file.encode(), weights_file.encode(), clear)
54+ def __cinit__ (self , str config_file , str weights_file , int batch_size , bint clear = True ):
55+ self ._c_network = dn.load_network_custom(config_file.encode(),
56+ weights_file.encode(),
57+ clear,
58+ batch_size)
3859 if self ._c_network is NULL :
3960 raise RuntimeError (" Failed to create the DarkNet Network..." )
4061
@@ -43,10 +64,26 @@ cdef class Network:
4364 dn.free_network(self ._c_network[0 ])
4465 free(self ._c_network)
4566
67+ @property
68+ def batch_size (self ):
69+ return dn.network_batch_size(self ._c_network)
70+
4671 @property
4772 def shape (self ):
4873 return dn.network_width(self ._c_network), dn.network_height(self ._c_network)
4974
75+ @property
76+ def width (self ):
77+ return dn.network_width(self ._c_network)
78+
79+ @property
80+ def height (self ):
81+ return dn.network_height(self ._c_network)
82+
83+ @property
84+ def depth (self ):
85+ return dn.network_depth(self ._c_network)
86+
5087 def input_size (self ):
5188 return dn.network_input_size(self ._c_network)
5289
@@ -81,38 +118,79 @@ cdef class Network:
81118 output_shape[0] = self.output_size()
82119 return np.PyArray_SimpleNewFromData(1, output_shape , np.NPY_FLOAT32 , output )
83120
84- def detect(self , frame_size = None ,
85- float threshold = .5 , float hierarchical_threshold = .5 ,
86- int relative = 0 , int letterbox = 1 ,
87- str nms_type = " sort" , float nms_threshold = .45 ,
121+ def detect(self ,
122+ frame_size = None ,
123+ float threshold = .5 ,
124+ float hierarchical_threshold = .5 ,
125+ int relative = 0 ,
126+ int letterbox = 1 ,
127+ str nms_type = " sort" ,
128+ float nms_threshold = .45 ,
88129 ):
89- frame_size = self .shape if frame_size is None else frame_size
130+ pred_width, pred_height = self .shape if frame_size is None else frame_size
131+
90132 cdef int num_dets = 0
91133 cdef dn.detection* detections
92-
93134 detections = dn.get_network_boxes(self ._c_network,
94- frame_size[0 ], frame_size[1 ],
95- threshold, hierarchical_threshold,
135+ pred_width,
136+ pred_height,
137+ threshold,
138+ hierarchical_threshold,
96139 < int * > 0 ,
97140 relative,
98141 & num_dets,
99142 letterbox)
143+ rv = convert_detections_to_tuples(detections, num_dets, nms_type, nms_threshold)
144+ dn.free_detections(detections, num_dets)
100145
101- if nms_threshold > 0 and num_dets:
102- if nms_type == " obj" :
103- dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
104- elif nms_type == " sort" :
105- dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
106- else :
107- raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
146+ return rv
108147
148+ def detect_batch (self ,
149+ np.ndarray[dtype = np.float32_t, ndim = 1 , mode = " c" ] frames,
150+ frame_size = None ,
151+ float threshold = .5 ,
152+ float hierarchical_threshold = .5 ,
153+ int relative = 0 ,
154+ int letterbox = 1 ,
155+ str nms_type = " sort" ,
156+ float nms_threshold = .45
157+ ):
158+ pred_width, pred_height = self .shape if frame_size is None else frame_size
159+
160+ cdef dn.image imr
161+ # This looks awkward, but the batch predict *does not* use c, w, h.
162+ imr.c = 0
163+ imr.w = 0
164+ imr.h = 0
165+ imr.data = < float * > frames.data
166+
167+ if frames.size % self .input_size() != 0 :
168+ raise TypeError (" The frames array is not divisible by network input size. "
169+ f" ({frames.size} % {self.input_size()} != 0)" )
170+
171+ num_frames = frames.size // self .input_size()
172+ if num_frames > self .batch_size:
173+ raise TypeError (" There are more frames than the configured batch size. "
174+ f" ({num_frames} > {self.batch_size})" )
175+
176+ cdef dn.det_num_pair* batch_detections
177+ batch_detections = dn.network_predict_batch(
178+ self ._c_network,
179+ imr,
180+ num_frames,
181+ pred_width,
182+ pred_height,
183+ threshold,
184+ hierarchical_threshold,
185+ < int * > 0 ,
186+ relative,
187+ letterbox
188+ )
109189 rv = [
110- (j, detections[i].prob[j],
111- (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h))
112- for i in range (num_dets)
113- for j in range (detections[i].classes)
114- if detections[i].prob[j] > 0
190+ convert_detections_to_tuples(batch_detections[b].dets, batch_detections[b].num, nms_type, nms_threshold)
191+ for b in range (num_frames)
115192 ]
193+ dn.free_batch_detections(batch_detections, num_frames)
194+ return rv
195+
116196
117- dn.free_detections(detections, num_dets)
118- return sorted (rv, key = lambda x : x[1 ], reverse = True )
0 commit comments