Skip to content

Commit 775b586

Browse files
committed
add tests, fix small glitches in code
1 parent ec7caa8 commit 775b586

File tree

3 files changed

+145
-19
lines changed

3 files changed

+145
-19
lines changed

code/numpy/compare.c

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
8080
if((size_t)minlength > length) {
8181
length = minlength;
8282
}
83+
} else {
84+
if(input->len == 0) {
85+
length = 0;
86+
}
8387
}
8488

8589
ndarray_obj_t *result = NULL;
@@ -92,6 +96,15 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
9296
mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray"));
9397
}
9498
weights = MP_OBJ_TO_PTR(args[1].u_obj);
99+
if(weights->len < input->len) {
100+
mp_raise_ValueError(MP_ERROR_TEXT("the weights and list don't have the same length"));
101+
}
102+
#if ULAB_SUPPORTS_COMPLEX
103+
if(weights->dtype == NDARRAY_COMPLEX) {
104+
mp_raise_TypeError(MP_ERROR_TEXT("cannot cast weigths to float"));
105+
}
106+
#endif /* ULAB_SUPPORTS_COMPLEX */
107+
95108
result = ndarray_new_linear_array(length, NDARRAY_FLOAT);
96109
}
97110

@@ -113,37 +126,27 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
113126
}
114127
} else {
115128
mp_float_t *rarray = (mp_float_t *)result->array;
129+
130+
mp_float_t (*get_weights)(void *) = ndarray_get_float_function(weights->dtype);
131+
uint8_t *warray = (uint8_t *)weights->array;
132+
116133
if(input->dtype == NDARRAY_UINT8) {
117134
uint8_t *iarray = (uint8_t *)input->array;
118135
for(size_t i = 0; i < input->len; i++) {
119-
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
136+
rarray[*iarray] += get_weights(warray);
120137
iarray += stride;
138+
warray += weights->strides[ULAB_MAX_DIMS - 1];
121139
}
122140
} else if(input->dtype == NDARRAY_UINT16) {
123141
uint16_t *iarray = (uint16_t *)input->array;
124142
for(size_t i = 0; i < input->len; i++) {
125-
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
143+
rarray[*iarray] += get_weights(warray);
126144
iarray += stride;
145+
warray += weights->strides[ULAB_MAX_DIMS - 1];
127146
}
128147
}
129148
}
130-
131-
if(weights != NULL) {
132-
mp_float_t (*get_weights)(void *) = ndarray_get_float_function(weights->dtype);
133-
mp_float_t *rarray = (mp_float_t *)result->array;
134-
uint8_t *warray = (uint8_t *)weights->array;
135-
136-
size_t fill_length = result->len;
137-
if(weights->len < result->len) {
138-
fill_length = weights->len;
139-
}
140-
141-
for(size_t i = 0; i < fill_length; i++) {
142-
*rarray = *rarray * get_weights(warray);
143-
rarray++;
144-
warray += weights->strides[ULAB_MAX_DIMS - 1];
145-
}
146-
}
149+
147150
return MP_OBJ_FROM_PTR(result);
148151
}
149152

tests/2d/numpy/bincount.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
try:
2+
from ulab import numpy as np
3+
except:
4+
import numpy as np
5+
6+
for dtype in (np.uint8, np.uint16):
7+
a = np.array([0, 1, 1, 3, 3, 3], dtype=dtype)
8+
print(np.bincount(a))
9+
10+
for dtype in (np.uint8, np.uint16):
11+
a = np.array([0, 2, 2, 4], dtype=dtype)
12+
print(np.bincount(a, minlength=3))
13+
14+
for dtype in (np.uint8, np.uint16):
15+
a = np.array([0, 2, 2, 4], dtype=dtype)
16+
print(np.bincount(a, minlength=8))
17+
18+
for dtype in (np.uint8, np.uint16):
19+
a = np.array([], dtype=dtype)
20+
print(np.bincount(a))
21+
print(np.bincount(a, minlength=8))
22+
23+
for dtype in (np.uint8, np.uint16):
24+
a = np.array([0, 1, 1, 3], dtype=dtype)
25+
w = np.array([0.5, 1.0, 2.5, 0.25])
26+
print(np.where(abs(np.bincount(a, weights=w) - np.array([0.5, 2.0, 0.0, 0.25])) < 0.001, 1, 0))
27+
28+
w = np.array([1, 2, 3, 4], dtype=np.uint8)
29+
print(np.bincount(a, weights=w))
30+
31+
for dtype in (np.uint8, np.uint16):
32+
a = np.array([1, 1], dtype=dtype)
33+
w = np.array([0.5, 1.5])
34+
print(np.bincount(a, weights=w, minlength=4))
35+
36+
for dtype in (np.uint8, np.uint16):
37+
a = np.array([2, 2, 2, 3], dtype=dtype)
38+
for wtype in (np.uint8, np.uint16, np.int8, np.int16, np.float):
39+
w = np.array([1, 2, 3, 4], dtype=wtype)
40+
print(np.bincount(a, weights=w))
41+
42+
for dtype in (np.int8, np.int16, np.float):
43+
a = np.array([2, 2, 2, 3], dtype=dtype)
44+
try:
45+
np.bincount(a)
46+
except Exception as e:
47+
print(e)
48+
49+
for dtype in (np.uint8, np.int8, np.uint16, np.int16, np.float):
50+
a = np.array(range(4), dtype=dtype).reshape((2, 2))
51+
try:
52+
np.bincount(a)
53+
except Exception as e:
54+
print(e)
55+
56+
for dtype in (np.uint8, np.uint16):
57+
a = np.array([1, 2, 3], dtype=dtype)
58+
w = np.array([1, 2])
59+
try:
60+
np.bincount(a, weights=w)
61+
except Exception as e:
62+
print(e)
63+
64+
for dtype in (np.uint8, np.uint16):
65+
a = np.array([1, 2, 3], dtype=dtype)
66+
try:
67+
np.bincount(a, minlength=-1)
68+
except Exception as e:
69+
print(e)
70+
71+
for dtype in (np.uint8, np.uint16):
72+
a = np.array([1, 2, 3], dtype=dtype)
73+
w = np.array([1j, 2j, 3j], dtype=np.complex)
74+
try:
75+
np.bincount(a, weights=w)
76+
except Exception as e:
77+
print(e)
78+
79+
80+
a = np.array([0, 1000], dtype=np.uint16)
81+
y = np.bincount(a)
82+
print(y[0], y[1000], len(y))

tests/2d/numpy/bincount.py.exp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
array([1, 2, 0, 3], dtype=uint16)
2+
array([1, 2, 0, 3], dtype=uint16)
3+
array([1, 0, 2, 0, 1], dtype=uint16)
4+
array([1, 0, 2, 0, 1], dtype=uint16)
5+
array([1, 0, 2, 0, 1, 0, 0, 0], dtype=uint16)
6+
array([1, 0, 2, 0, 1, 0, 0, 0], dtype=uint16)
7+
array([], dtype=uint16)
8+
array([0, 0, 0, 0, 0, 0, 0, 0], dtype=uint16)
9+
array([], dtype=uint16)
10+
array([0, 0, 0, 0, 0, 0, 0, 0], dtype=uint16)
11+
array([1, 0, 1, 1], dtype=uint8)
12+
array([1.0, 5.0, 0.0, 4.0], dtype=float64)
13+
array([1, 0, 1, 1], dtype=uint8)
14+
array([1.0, 5.0, 0.0, 4.0], dtype=float64)
15+
array([0.0, 2.0, 0.0, 0.0], dtype=float64)
16+
array([0.0, 2.0, 0.0, 0.0], dtype=float64)
17+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
18+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
19+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
20+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
21+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
22+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
23+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
24+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
25+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
26+
array([0.0, 0.0, 6.0, 4.0], dtype=float64)
27+
cannot cast array data from dtype
28+
cannot cast array data from dtype
29+
cannot cast array data from dtype
30+
object too deep for desired array
31+
object too deep for desired array
32+
object too deep for desired array
33+
object too deep for desired array
34+
object too deep for desired array
35+
the weights and list don't have the same length
36+
the weights and list don't have the same length
37+
minlength must not be negative
38+
minlength must not be negative
39+
cannot cast weigths to float
40+
cannot cast weigths to float
41+
1 1 1001

0 commit comments

Comments
 (0)