1use std::ops::Add;
7
8use num_traits::Zero;
9
10use nalgebra::storage::RawStorage;
11use nalgebra::{ClosedAddAssign, DMatrix, Dim, Matrix, Scalar};
12
13use crate::coo::CooMatrix;
14use crate::cs;
15use crate::csc::CscMatrix;
16use crate::csr::CsrMatrix;
17use crate::utils::{apply_permutation, compute_sort_permutation};
18
19pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
21where
22 T: Scalar + Zero,
23 R: Dim,
24 C: Dim,
25 S: RawStorage<T, R, C>,
26{
27 let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
28
29 for (index, v) in dense.iter().enumerate() {
30 if v != &T::zero() {
31 let i = index % dense.nrows();
33 let j = index / dense.nrows();
34 coo.push(i, j, v.clone());
35 }
36 }
37
38 coo
39}
40
41pub fn convert_coo_dense<T>(coo: &CooMatrix<T>) -> DMatrix<T>
43where
44 T: Scalar + Zero + ClosedAddAssign,
45{
46 let mut output = DMatrix::repeat(coo.nrows(), coo.ncols(), T::zero());
47 for (i, j, v) in coo.triplet_iter() {
48 output[(i, j)] += v.clone();
49 }
50 output
51}
52
53pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
55where
56 T: Scalar + Zero,
57{
58 let (offsets, indices, values) = convert_coo_cs(
59 coo.nrows(),
60 coo.row_indices(),
61 coo.col_indices(),
62 coo.values(),
63 );
64
65 CsrMatrix::try_from_csr_data(coo.nrows(), coo.ncols(), offsets, indices, values)
68 .expect("Internal error: Invalid CSR data during COO->CSR conversion")
69}
70
71pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
73 let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
74 for (i, j, v) in csr.triplet_iter() {
75 result.push(i, j, v.clone());
76 }
77 result
78}
79
80pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
82where
83 T: Scalar + ClosedAddAssign + Zero,
84{
85 let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
86
87 for (i, j, v) in csr.triplet_iter() {
88 output[(i, j)] += v.clone();
89 }
90
91 output
92}
93
94pub fn convert_dense_csr<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CsrMatrix<T>
96where
97 T: Scalar + Zero,
98 R: Dim,
99 C: Dim,
100 S: RawStorage<T, R, C>,
101{
102 let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
103 let mut col_idx = Vec::new();
104 let mut values = Vec::new();
105
106 row_offsets.push(0);
110 for i in 0..dense.nrows() {
111 for j in 0..dense.ncols() {
112 let v = dense.index((i, j));
113 if v != &T::zero() {
114 col_idx.push(j);
115 values.push(v.clone());
116 }
117 }
118 row_offsets.push(col_idx.len());
119 }
120
121 CsrMatrix::try_from_csr_data(dense.nrows(), dense.ncols(), row_offsets, col_idx, values)
124 .expect("Internal error: Invalid CsrMatrix format during dense-> CSR conversion")
125}
126
127pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
129where
130 T: Scalar + Zero,
131{
132 let (offsets, indices, values) = convert_coo_cs(
133 coo.ncols(),
134 coo.col_indices(),
135 coo.row_indices(),
136 coo.values(),
137 );
138
139 CscMatrix::try_from_csc_data(coo.nrows(), coo.ncols(), offsets, indices, values)
142 .expect("Internal error: Invalid CSC data during COO->CSC conversion")
143}
144
145pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
147where
148 T: Scalar,
149{
150 let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
151 for (i, j, v) in csc.triplet_iter() {
152 coo.push(i, j, v.clone());
153 }
154 coo
155}
156
157pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
159where
160 T: Scalar + ClosedAddAssign + Zero,
161{
162 let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
163
164 for (i, j, v) in csc.triplet_iter() {
165 output[(i, j)] += v.clone();
166 }
167
168 output
169}
170
171pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
173where
174 T: Scalar + Zero,
175 R: Dim,
176 C: Dim,
177 S: RawStorage<T, R, C>,
178{
179 let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
180 let mut row_idx = Vec::new();
181 let mut values = Vec::new();
182
183 col_offsets.push(0);
184 for j in 0..dense.ncols() {
185 for i in 0..dense.nrows() {
186 let v = dense.index((i, j));
187 if v != &T::zero() {
188 row_idx.push(i);
189 values.push(v.clone());
190 }
191 }
192 col_offsets.push(row_idx.len());
193 }
194
195 CscMatrix::try_from_csc_data(dense.nrows(), dense.ncols(), col_offsets, row_idx, values)
198 .expect("Internal error: Invalid CscMatrix format during dense-> CSC conversion")
199}
200
201pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
203where
204 T: Scalar,
205{
206 let (offsets, indices, values) = cs::transpose_cs(
207 csr.nrows(),
208 csr.ncols(),
209 csr.row_offsets(),
210 csr.col_indices(),
211 csr.values(),
212 );
213
214 CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
216 .expect("Internal error: Invalid CSC data during CSR->CSC conversion")
217}
218
219pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
221where
222 T: Scalar,
223{
224 let (offsets, indices, values) = cs::transpose_cs(
225 csc.ncols(),
226 csc.nrows(),
227 csc.col_offsets(),
228 csc.row_indices(),
229 csc.values(),
230 );
231
232 CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
234 .expect("Internal error: Invalid CSR data during CSC->CSR conversion")
235}
236
237fn convert_coo_cs<T>(
238 major_dim: usize,
239 major_indices: &[usize],
240 minor_indices: &[usize],
241 values: &[T],
242) -> (Vec<usize>, Vec<usize>, Vec<T>)
243where
244 T: Scalar + Zero,
245{
246 assert_eq!(major_indices.len(), minor_indices.len());
247 assert_eq!(minor_indices.len(), values.len());
248 let nnz = major_indices.len();
249
250 let (unsorted_major_offsets, unsorted_minor_idx, unsorted_vals) = {
251 let mut offsets = vec![0usize; major_dim + 1];
252 let mut minor_idx = vec![0usize; nnz];
253 let mut vals = vec![T::zero(); nnz];
254 coo_to_unsorted_cs(
255 &mut offsets,
256 &mut minor_idx,
257 &mut vals,
258 major_dim,
259 major_indices,
260 minor_indices,
261 values,
262 );
263 (offsets, minor_idx, vals)
264 };
265
266 let mut sorted_major_offsets = Vec::new();
272 let mut sorted_minor_idx = Vec::new();
273 let mut sorted_vals = Vec::new();
274
275 sorted_major_offsets.push(0);
276
277 let mut idx_workspace = Vec::new();
281 let mut perm_workspace = Vec::new();
282 let mut values_workspace = Vec::new();
283
284 for lane in 0..major_dim {
285 let begin = unsorted_major_offsets[lane];
286 let end = unsorted_major_offsets[lane + 1];
287 let count = end - begin;
288 let range = begin..end;
289
290 perm_workspace.resize(count, 0);
292 idx_workspace.resize(count, 0);
293 values_workspace.resize(count, T::zero());
294 sort_lane(
295 &mut idx_workspace[..count],
296 &mut values_workspace[..count],
297 &unsorted_minor_idx[range.clone()],
298 &unsorted_vals[range.clone()],
299 &mut perm_workspace[..count],
300 );
301
302 let sorted_ja_current_len = sorted_minor_idx.len();
303
304 combine_duplicates(
305 |idx| sorted_minor_idx.push(idx),
306 |val| sorted_vals.push(val),
307 &idx_workspace[..count],
308 &values_workspace[..count],
309 Add::add,
310 );
311
312 let new_col_count = sorted_minor_idx.len() - sorted_ja_current_len;
313 sorted_major_offsets.push(sorted_major_offsets.last().unwrap() + new_col_count);
314 }
315
316 (sorted_major_offsets, sorted_minor_idx, sorted_vals)
317}
318
319fn coo_to_unsorted_cs<T: Clone>(
324 major_offsets: &mut [usize],
325 cs_minor_idx: &mut [usize],
326 cs_values: &mut [T],
327 major_dim: usize,
328 major_indices: &[usize],
329 minor_indices: &[usize],
330 coo_values: &[T],
331) {
332 assert_eq!(major_offsets.len(), major_dim + 1);
333 assert_eq!(cs_minor_idx.len(), cs_values.len());
334 assert_eq!(cs_values.len(), major_indices.len());
335 assert_eq!(major_indices.len(), minor_indices.len());
336 assert_eq!(minor_indices.len(), coo_values.len());
337
338 for major_idx in major_indices {
340 major_offsets[*major_idx] += 1;
341 }
342
343 cs::convert_counts_to_offsets(major_offsets);
344
345 {
346 let mut current_counts = vec![0usize; major_dim + 1];
350 let triplet_iter = major_indices.iter().zip(minor_indices).zip(coo_values);
351 for ((i, j), value) in triplet_iter {
352 let current_offset = major_offsets[*i] + current_counts[*i];
353 cs_minor_idx[current_offset] = *j;
354 cs_values[current_offset] = value.clone();
355 current_counts[*i] += 1;
356 }
357 }
358}
359
360fn sort_lane<T: Clone>(
368 minor_idx_result: &mut [usize],
369 values_result: &mut [T],
370 minor_idx: &[usize],
371 values: &[T],
372 workspace: &mut [usize],
373) {
374 assert_eq!(minor_idx_result.len(), values_result.len());
375 assert_eq!(values_result.len(), minor_idx.len());
376 assert_eq!(minor_idx.len(), values.len());
377 assert_eq!(values.len(), workspace.len());
378
379 let permutation = workspace;
380 compute_sort_permutation(permutation, minor_idx);
381
382 apply_permutation(minor_idx_result, minor_idx, permutation);
383 apply_permutation(values_result, values, permutation);
384}
385
386fn combine_duplicates<T: Clone>(
389 mut produce_idx: impl FnMut(usize),
390 mut produce_value: impl FnMut(T),
391 idx_array: &[usize],
392 values: &[T],
393 combiner: impl Fn(T, T) -> T,
394) {
395 assert_eq!(idx_array.len(), values.len());
396
397 let mut i = 0;
398 while i < idx_array.len() {
399 let idx = idx_array[i];
400 let mut combined_value = values[i].clone();
401 let mut j = i + 1;
402 while j < idx_array.len() && idx_array[j] == idx {
403 let j_val = values[j].clone();
404 combined_value = combiner(combined_value, j_val);
405 j += 1;
406 }
407 produce_idx(idx);
408 produce_value(combined_value);
409 i = j;
410 }
411}